From 2be3c4f188ce03f57e075d42d7641768aeedc9de Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 7 Apr 2025 12:51:42 +0200 Subject: [PATCH 001/205] cg matrix op --- src/mrpro/algorithms/optimizers/cg.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/mrpro/algorithms/optimizers/cg.py b/src/mrpro/algorithms/optimizers/cg.py index e9ca5840b..422ff85ec 100644 --- a/src/mrpro/algorithms/optimizers/cg.py +++ b/src/mrpro/algorithms/optimizers/cg.py @@ -73,8 +73,9 @@ def cg( ) -> tuple[torch.Tensor, ...] | tuple[torch.Tensor]: r"""(Preconditioned) Conjugate Gradient for solving :math:`Hx=b`. - This algorithm solves systems of the form :math:`H x = b`, where :math:`H` is a self-adjoint positive semidefinite - linear operator and :math:`b` is the right-hand side. + This algorithm solves systems of the form :math:`H x = b`, where :math:`H` is a self-adjoint linear operator + and :math:`b` is the right-hand side. The method can solve a batch of :math:`N` systems jointly, thereby taking + :math:`H` as a block-diagonal with blocks :math:`H_i` and :math:`b = [b_1, ..., b_N] ^T`. The method performs the following steps: @@ -98,13 +99,16 @@ def cg( If `preconditioner_inverse` is provided, it solves :math:`M^{-1}Hx = M^{-1}b` implicitly, where `preconditioner_inverse(r)` computes :math:`M^{-1}r`. + If `preconditioner_inverse` is provided, it solves :math:`M^{-1}Hx = M^{-1}b` + implicitly, where `preconditioner_inverse(r)` computes :math:`M^{-1}r`. + See [Hestenes1952]_, [Nocedal2006]_, and [WikipediaCG]_ for more information. Parameters ---------- operator - Self-adjoint operator :math:`H` + Self-adjoint operator :math:`H`. right_hand_side Right-hand-side :math:`b`. initial_value @@ -182,5 +186,11 @@ def cg( ) if continue_iterations is False: break - + if ( + isinstance(operator, LinearOperator) + and isinstance(right_hand_side, torch.Tensor) + and (initial_value is None or isinstance(initial_value, torch.Tensor)) + ): + # For backward compatibility if called with a single tensor and operator. + return solution[0] return solution From 1c31feda29736b3e5b4f05cb30a1441e2590efe0 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 6 Apr 2025 16:37:00 +0200 Subject: [PATCH 002/205] first draft --- src/mrpro/operators/OptimizerOp.py | 156 +++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 src/mrpro/operators/OptimizerOp.py diff --git a/src/mrpro/operators/OptimizerOp.py b/src/mrpro/operators/OptimizerOp.py new file mode 100644 index 000000000..bd65061f7 --- /dev/null +++ b/src/mrpro/operators/OptimizerOp.py @@ -0,0 +1,156 @@ +"""Differentiable Minimization.""" + +import functools +from collections.abc import Callable +from typing import TYPE_CHECKING, TypeAlias, Unpack + +import torch + +from mrpro.algorithms.optimizers.cg import cg +from mrpro.algorithms.optimizers.lbfgs import lbfgs +from mrpro.operators.Operator import Operator + +ArgumentType: TypeAlias = tuple[torch.Tensor, ...] +VariableType: TypeAlias = tuple[torch.Tensor, ...] +ObjectiveType: TypeAlias = Callable[[Unpack[VariableType]], tuple[torch.Tensor]] +FactoryType: TypeAlias = Callable[[Unpack[ArgumentType]], ObjectiveType] +OptimizeFunctionType: TypeAlias = Callable[[Unpack[VariableType]], VariableType] + +default_lbfgs = functools.partial( + lbfgs, + learning_rate=1.0, + max_iterations=20, + tolerance_change=1e-8, + tolerance_grad=1e-7, + history_size=20, + line_search_fn='strong_wolfe', +) +"""LBFGS Optimizer""" + + +class OptimizeCtx(torch.autograd.function.FunctionCtx): + """Rype hinting the CTX object.""" + + factory: FactoryType + len_x: int + needs_input_grad: tuple[bool, ...] + saved_tensors: tuple[torch.Tensor, ...] + + +class OptimizerImplicitBackward(torch.autograd.Function): + """Implicit Backward.""" + + @staticmethod + def forward( + ctx: OptimizeCtx, + factory: FactoryType, + x0: VariableType, + optimize: OptimizeFunctionType = default_lbfgs, + *parameters: Unpack[ArgumentType], + ) -> VariableType: + """Optimize.""" + ctx.factory = factory + f = factory(*parameters) + xprime = optimize(f, x0) + for xp in xprime: + xp.grad = None + ctx.save_for_backward(*xprime, *parameters) + ctx.len_x = len(x0) + return xprime + + @staticmethod + def backward(ctx: OptimizeCtx, grad: torch.Tensor) -> tuple[torch.Tensor]: + """Calculate the backward pass using implicit differentiation.""" + xprime: torch.Tensor = ctx.saved_tensors[: ctx.len_x] + parameters = ctx.saved_tensors[ctx.len_x :] + xprime = xprime.detach().clone().requires_grad_(True) + parameters = [ + p.detach().clone().requires_grad_(True) if ctx.needs_input_grad[i + 3] else p.detach() + for i, p in enumerate(parameters) + ] + dparams = [p for p in parameters if p.requires_grad] + + objective = ctx.factory(*parameters) + hessian_inverse_grad = cg(lambda v: torch.autograd.functional.vhp(objective, xprime, v=v)[1], grad) + with torch.enable_grad(): + dobjective_dxprime = torch.autograd.grad(objective(xprime), xprime, create_graph=True)[0] + # - d^2_obective / d_xprime d_params Hessian^-1_grad + grad_params = list(torch.autograd.grad(dobjective_dxprime, dparams, -hessian_inverse_grad)) + grad_inputs: list[torch.Tensor | None] = [None, None, None] # factory, x0, optimize + for need_grad in ctx.needs_input_grad[3:]: + if need_grad: + w = grad_params.pop(0) + grad_inputs.append(w) + else: + grad_inputs.append(grad_inputs) + + return tuple(grad_inputs) + + +class OptimizerOp(Operator): + def __init__(self, factory: FactoryType, optimize: OptimizeFunctionType, initializer): + r"""Initialize a differentiable Optimizer. + + Setup a differentiable argmin solver. + + This is one of the building blocks of PINQI [ZIMM2024]_ + + + Find :math:`x^*=argmin_x f_p(x) + + Example + ------- + Solving :math:`\|q(x)-y\|^2 + \lambda*\|x-x_reg\|^2` with + y, lambda and x_reg parameters. The solution x* should be differentiable with respect to these. + + Use:: + + def factory(y, lambda, x_reg): + return L2squared(y)@q+lambda*L2squared(x_reg) + def initializer(_y, _lambda, _xreg): + return (x_reg,) + + + Parameters + ---------- + factory + Function, that given the parameters of the problem returns an objective function + initialiazer + Function, that given the parameters of the problem creates initial values for the variable(s) + optimize + Function used to perform the optimization. + Use `functools.partial` to setup up all settings besides the objective function and the initial values. + + References + ---------- + .. [ZIMM2024] Zimmermann, Felix F., et al. (2024) PINQI. An End-to-End Physics-Informed Approach to Learned Quantitative + MRI Reconstruction. IEEE TCI. https://doi.org/10.1109/TCI.2024.3388869 + """ + self.factory = factory + self.optimize = optimize + self.initializer = initializer + + def forward(self, *parameters: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Solve the argmin problem. + + Parameters + ---------- + parameters + Parameters of the argmin problem. + """ + parameters = [] + for x in parameters: + if isinstance(x, torch.Tensor): + parameters.append(x.detach().clone()) + else: + raise NotImplementedError() + initial_values = self.initializer(*parameters) + initial_values = [x.copy() if x in parameters else x for x in initial_values] + initial_values = [x.detach().requires_grad_(True) for x in initial_values] + if TYPE_CHECKING: + # For mypy + result = OptimizerImplicitBackward.forward(self.factory, initial_values, self.optimize, *parameters) + else: + # For pytorch at runtime + result = OptimizerImplicitBackward.apply(self.factory, initial_values, self.optimize, *parameters) + return result From aaff916b9fb640c23c14462ae40458250a2ee8ce Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 7 Apr 2025 12:51:09 +0200 Subject: [PATCH 003/205] types --- src/mrpro/operators/OptimizerOp.py | 103 +++++++++++++++-------------- 1 file changed, 55 insertions(+), 48 deletions(-) diff --git a/src/mrpro/operators/OptimizerOp.py b/src/mrpro/operators/OptimizerOp.py index bd65061f7..446f47900 100644 --- a/src/mrpro/operators/OptimizerOp.py +++ b/src/mrpro/operators/OptimizerOp.py @@ -14,7 +14,7 @@ VariableType: TypeAlias = tuple[torch.Tensor, ...] ObjectiveType: TypeAlias = Callable[[Unpack[VariableType]], tuple[torch.Tensor]] FactoryType: TypeAlias = Callable[[Unpack[ArgumentType]], ObjectiveType] -OptimizeFunctionType: TypeAlias = Callable[[Unpack[VariableType]], VariableType] +OptimizeFunctionType: TypeAlias = Callable[[ObjectiveType, VariableType], VariableType] default_lbfgs = functools.partial( lbfgs, @@ -37,66 +37,88 @@ class OptimizeCtx(torch.autograd.function.FunctionCtx): saved_tensors: tuple[torch.Tensor, ...] -class OptimizerImplicitBackward(torch.autograd.Function): +class _OptimizerImplicitBackward(torch.autograd.Function): """Implicit Backward.""" @staticmethod def forward( ctx: OptimizeCtx, factory: FactoryType, - x0: VariableType, + initial_values: VariableType, optimize: OptimizeFunctionType = default_lbfgs, *parameters: Unpack[ArgumentType], ) -> VariableType: """Optimize.""" ctx.factory = factory f = factory(*parameters) - xprime = optimize(f, x0) + xprime = optimize(f, initial_values) for xp in xprime: xp.grad = None ctx.save_for_backward(*xprime, *parameters) - ctx.len_x = len(x0) + ctx.len_x = len(initial_values) return xprime @staticmethod - def backward(ctx: OptimizeCtx, grad: torch.Tensor) -> tuple[torch.Tensor]: + def backward(ctx: OptimizeCtx, *grad_outputs: torch.Tensor) -> tuple[torch.Tensor | None, ...]: """Calculate the backward pass using implicit differentiation.""" - xprime: torch.Tensor = ctx.saved_tensors[: ctx.len_x] + xprime = tuple(xp.detach().clone().requires_grad_(True) for xp in ctx.saved_tensors[: ctx.len_x]) parameters = ctx.saved_tensors[ctx.len_x :] - xprime = xprime.detach().clone().requires_grad_(True) - parameters = [ + parameters = tuple( p.detach().clone().requires_grad_(True) if ctx.needs_input_grad[i + 3] else p.detach() for i, p in enumerate(parameters) - ] + ) dparams = [p for p in parameters if p.requires_grad] objective = ctx.factory(*parameters) - hessian_inverse_grad = cg(lambda v: torch.autograd.functional.vhp(objective, xprime, v=v)[1], grad) + + def hvp(*v: torch.Tensor) -> tuple[torch.Tensor, ...]: + return torch.autograd.functional.vhp(objective, xprime, v=v)[1:] + + hessian_inverse_grad = cg(hvp, grad_outputs) with torch.enable_grad(): - dobjective_dxprime = torch.autograd.grad(objective(xprime), xprime, create_graph=True)[0] + dobjective_dxprime = torch.autograd.grad(objective(*xprime), xprime, create_graph=True)[0] # - d^2_obective / d_xprime d_params Hessian^-1_grad - grad_params = list(torch.autograd.grad(dobjective_dxprime, dparams, -hessian_inverse_grad)) + grad_params = list(torch.autograd.grad(dobjective_dxprime, dparams, hessian_inverse_grad)) grad_inputs: list[torch.Tensor | None] = [None, None, None] # factory, x0, optimize for need_grad in ctx.needs_input_grad[3:]: if need_grad: - w = grad_params.pop(0) - grad_inputs.append(w) + grad_inputs.append(-grad_params.pop(0)) else: - grad_inputs.append(grad_inputs) + grad_inputs.append(None) return tuple(grad_inputs) class OptimizerOp(Operator): - def __init__(self, factory: FactoryType, optimize: OptimizeFunctionType, initializer): - r"""Initialize a differentiable Optimizer. + """Differentiable Optimization Operator. - Setup a differentiable argmin solver. + One of the building blocks of PINQI [ZIMM2024]_ + Finds :math:`x^*=argmin_x f_p(x) - This is one of the building blocks of PINQI [ZIMM2024]_ + References + ---------- + .. [ZIMM2024] Zimmermann, Felix F., et al. (2024) PINQI. An End-to-End Physics-Informed Approach to Learned + Quantitative MRI Reconstruction. IEEE TCI. https://doi.org/10.1109/TCI.2024.3388869 + """ + def __init__( + self, + factory: FactoryType, + initializer: Callable[[Unpack[ArgumentType]], VariableType], + optimize: OptimizeFunctionType = default_lbfgs, + ): + r"""Initialize a differentiable argmin solver. - Find :math:`x^*=argmin_x f_p(x) + Parameters + ---------- + factory + Function, that given the parameters of the problem returns an objective function. + The objective function should be a callable that takes the variable(s) as input and returns a scalar. + initializer + Function, that given the parameters of the problem returns a tuple of initial values for the variable(s) + optimize + Function used to perform the optimization, for example `lbfgs`. + Use `functools.partial` to setup up all settings besides the objective function and the initial values. Example ------- @@ -110,47 +132,32 @@ def factory(y, lambda, x_reg): def initializer(_y, _lambda, _xreg): return (x_reg,) - - Parameters - ---------- - factory - Function, that given the parameters of the problem returns an objective function - initialiazer - Function, that given the parameters of the problem creates initial values for the variable(s) - optimize - Function used to perform the optimization. - Use `functools.partial` to setup up all settings besides the objective function and the initial values. - - References - ---------- - .. [ZIMM2024] Zimmermann, Felix F., et al. (2024) PINQI. An End-to-End Physics-Informed Approach to Learned Quantitative - MRI Reconstruction. IEEE TCI. https://doi.org/10.1109/TCI.2024.3388869 + Returns + ------- + The argmin `x^*` """ self.factory = factory self.optimize = optimize self.initializer = initializer def forward(self, *parameters: torch.Tensor) -> tuple[torch.Tensor, ...]: - """Solve the argmin problem. + """Find the argmin. Parameters ---------- parameters Parameters of the argmin problem. """ - parameters = [] - for x in parameters: - if isinstance(x, torch.Tensor): - parameters.append(x.detach().clone()) - else: - raise NotImplementedError() + parameters = tuple(p.detach().clone() for p in parameters) initial_values = self.initializer(*parameters) - initial_values = [x.copy() if x in parameters else x for x in initial_values] - initial_values = [x.detach().requires_grad_(True) for x in initial_values] + initial_values = tuple(x.clone() if x in parameters else x for x in initial_values) + initial_values = tuple(x.detach().requires_grad_(True) for x in initial_values) if TYPE_CHECKING: # For mypy - result = OptimizerImplicitBackward.forward(self.factory, initial_values, self.optimize, *parameters) + result = _OptimizerImplicitBackward.forward( + OptimizeCtx(), self.factory, initial_values, self.optimize, *parameters + ) else: - # For pytorch at runtime - result = OptimizerImplicitBackward.apply(self.factory, initial_values, self.optimize, *parameters) + # Actually used at runtime + result = _OptimizerImplicitBackward.apply(self.factory, initial_values, self.optimize, *parameters) return result From a336ded279b85e958f46d8eeb348abaa107777c3 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 16 Apr 2025 17:04:35 +0200 Subject: [PATCH 004/205] test --- src/mrpro/operators/OptimizerOp.py | 7 ++--- src/mrpro/operators/__init__.py | 2 ++ tests/operators/test_optimizer_op.py | 39 ++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 tests/operators/test_optimizer_op.py diff --git a/src/mrpro/operators/OptimizerOp.py b/src/mrpro/operators/OptimizerOp.py index 446f47900..9cb833ad0 100644 --- a/src/mrpro/operators/OptimizerOp.py +++ b/src/mrpro/operators/OptimizerOp.py @@ -122,8 +122,9 @@ def __init__( Example ------- - Solving :math:`\|q(x)-y\|^2 + \lambda*\|x-x_reg\|^2` with - y, lambda and x_reg parameters. The solution x* should be differentiable with respect to these. + Solving :math:`\|q(x)-y\|^2 + \lambda*\|x-x_\mathrm{reg}\|^2` with + :math:`y`, :math:`\lambda` and :math:`x_\mathrm{reg}` parameters. The solution :math:`x^*` should be + differentiable with respect to these. Use:: @@ -150,7 +151,7 @@ def forward(self, *parameters: torch.Tensor) -> tuple[torch.Tensor, ...]: """ parameters = tuple(p.detach().clone() for p in parameters) initial_values = self.initializer(*parameters) - initial_values = tuple(x.clone() if x in parameters else x for x in initial_values) + initial_values = tuple(x.clone() if any(x is p for p in parameters) else x for x in initial_values) initial_values = tuple(x.detach().requires_grad_(True) for x in initial_values) if TYPE_CHECKING: # For mypy diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index c6cf5c312..e90ca9139 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -20,6 +20,7 @@ from mrpro.operators.MagnitudeOp import MagnitudeOp from mrpro.operators.MultiIdentityOp import MultiIdentityOp from mrpro.operators.NonUniformFastFourierOp import NonUniformFastFourierOp +from mrpro.operators.OptimizerOp import OptimizerOp from mrpro.operators.PatchOp import PatchOp from mrpro.operators.PCACompressionOp import PCACompressionOp from mrpro.operators.PhaseOp import PhaseOp @@ -55,6 +56,7 @@ "MultiIdentityOp", "NonUniformFastFourierOp", "Operator", + "OptimizerOp", "PCACompressionOp", "PatchOp", "PhaseOp", diff --git a/tests/operators/test_optimizer_op.py b/tests/operators/test_optimizer_op.py new file mode 100644 index 000000000..422c379b5 --- /dev/null +++ b/tests/operators/test_optimizer_op.py @@ -0,0 +1,39 @@ +from collections.abc import Callable + +import torch +from mrpro.operators import OptimizerOp +from mrpro.operators.functionals import L2NormSquared +from mrpro.operators.models import InversionRecovery +from mrpro.utils import RandomGenerator + + +def test_optimizer_op(): + rng = RandomGenerator(seed=0) + + def factory( + m0_reg: torch.Tensor, + t1_reg: torch.Tensor, + lambda_m0: torch.Tensor, + lambda_t1: torch.Tensor, + signal: torch.Tensor, + ) -> Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: + data_consistency = L2NormSquared(signal) @ InversionRecovery((0.5, 1.0, 1.5, 3)) + regularization = lambda_t1 * L2NormSquared(t1_reg) | lambda_m0 * L2NormSquared(m0_reg) + return data_consistency + regularization + + factory(torch.randn(10, 10), torch.randn(10, 10), torch.randn(1), torch.randn(1), torch.randn(10, 10)) + rng = RandomGenerator(seed=0) + true_m0 = rng.complex64_tensor(size=(10, 10)) + true_t1 = rng.float32_tensor(size=(10, 10), low=0.1, high=2) + (signal,) = InversionRecovery((0.5, 1.0, 1.5, 3))(true_m0, true_t1) + signal += rng.complex64_tensor(size=(10, 10), high=0.1) + t1_reg = true_t1 + rng.rand_like(true_t1, low=-0.1, high=0.1) + m0_reg = true_m0 + rng.rand_like(true_m0, low=0, high=0.5) + op = OptimizerOp(factory, initializer=lambda m0_reg, t1_reg, *_: (m0_reg, t1_reg)) + lambda_m0 = torch.tensor(1.0, requires_grad=True) + lambda_t1 = torch.tensor(1.0, requires_grad=True) + ret = op.forward(m0_reg, t1_reg, lambda_m0, lambda_t1, signal) + (loss,) = (L2NormSquared(true_m0) | L2NormSquared(m0_reg))(*ret) + loss.backward() + assert lambda_m0.grad is not None + assert lambda_t1.grad is not None From 5d6fe988d33e5c7195f016197ed6d6b712f26757 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 23 Apr 2025 13:56:22 +0200 Subject: [PATCH 005/205] update --- src/mrpro/algorithms/optimizers/cg.py | 7 - src/mrpro/algorithms/optimizers/lbfgs.py | 4 +- src/mrpro/algorithms/optimizers/pdhg.py | 4 +- src/mrpro/operators/FourierOp.py | 3 +- src/mrpro/operators/Functional.py | 93 ++++--------- src/mrpro/operators/LinearOperatorMatrix.py | 20 ++- src/mrpro/operators/OptimizerOp.py | 88 +++++++----- .../ProximableFunctionalSeparableSum.py | 126 ++++++++++++++++-- src/mrpro/operators/ZeroOp.py | 5 + src/mrpro/operators/__init__.py | 7 +- src/mrpro/operators/functionals/L1Norm.py | 6 +- .../operators/functionals/L1NormViewAsReal.py | 4 +- .../operators/functionals/L2NormSquared.py | 6 +- src/mrpro/operators/functionals/SSIM.py | 4 +- .../operators/functionals/ZeroFunctional.py | 6 +- tests/algorithms/test_pdhg.py | 2 +- .../functionals/test_functional_arithmetic.py | 6 +- tests/operators/test_linearoperatormatrix.py | 10 ++ tests/operators/test_optimizer_op.py | 61 +++++---- 19 files changed, 292 insertions(+), 170 deletions(-) diff --git a/src/mrpro/algorithms/optimizers/cg.py b/src/mrpro/algorithms/optimizers/cg.py index 422ff85ec..a8c574210 100644 --- a/src/mrpro/algorithms/optimizers/cg.py +++ b/src/mrpro/algorithms/optimizers/cg.py @@ -186,11 +186,4 @@ def cg( ) if continue_iterations is False: break - if ( - isinstance(operator, LinearOperator) - and isinstance(right_hand_side, torch.Tensor) - and (initial_value is None or isinstance(initial_value, torch.Tensor)) - ): - # For backward compatibility if called with a single tensor and operator. - return solution[0] return solution diff --git a/src/mrpro/algorithms/optimizers/lbfgs.py b/src/mrpro/algorithms/optimizers/lbfgs.py index 444791168..d4e07d8fa 100644 --- a/src/mrpro/algorithms/optimizers/lbfgs.py +++ b/src/mrpro/algorithms/optimizers/lbfgs.py @@ -1,7 +1,7 @@ """LBFGS for solving non-linear minimization problems.""" from collections.abc import Callable, Sequence -from typing import Literal +from typing import Literal, Unpack import torch from torch.optim import LBFGS @@ -18,7 +18,7 @@ class LBFGSStatus(OptimizerStatus): def lbfgs( - f: OperatorType, + f: OperatorType | Callable[[Unpack[tuple[torch.Tensor, ...]]], tuple[torch.Tensor]], initial_parameters: Sequence[torch.Tensor], learning_rate: float = 1.0, max_iterations: int = 100, diff --git a/src/mrpro/algorithms/optimizers/pdhg.py b/src/mrpro/algorithms/optimizers/pdhg.py index 6efb89973..a31db8367 100644 --- a/src/mrpro/algorithms/optimizers/pdhg.py +++ b/src/mrpro/algorithms/optimizers/pdhg.py @@ -143,7 +143,7 @@ def pdhg( # We always use a separable sum for homogeneous handling, even if it is just a ZeroFunctional if f is None: - f_sum = ProximableFunctionalSeparableSum(*(ZeroFunctional(),) * n_rows) + f_sum: ProximableFunctionalSeparableSum = ProximableFunctionalSeparableSum(*(ZeroFunctional(),) * n_rows) elif isinstance(f, ProximableFunctional): f_sum = ProximableFunctionalSeparableSum(f) else: @@ -153,7 +153,7 @@ def pdhg( raise ValueError('Number of rows in operator does not match number of functionals in f') if g is None: - g_sum = ProximableFunctionalSeparableSum(*(ZeroFunctional(),) * n_columns) + g_sum: ProximableFunctionalSeparableSum = ProximableFunctionalSeparableSum(*(ZeroFunctional(),) * n_columns) elif isinstance(g, ProximableFunctional): g_sum = ProximableFunctionalSeparableSum(g) else: diff --git a/src/mrpro/operators/FourierOp.py b/src/mrpro/operators/FourierOp.py index 1120159ad..eb94d65e1 100644 --- a/src/mrpro/operators/FourierOp.py +++ b/src/mrpro/operators/FourierOp.py @@ -1,6 +1,7 @@ """Fourier Operator.""" from collections.abc import Sequence +from functools import cached_property import torch from typing_extensions import Self @@ -160,7 +161,7 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: (x,) = self._non_uniform_fast_fourier_op.adjoint(x) return (x,) - @property + @cached_property def gram(self) -> LinearOperator: """Return the gram operator.""" return FourierGramOp(self) diff --git a/src/mrpro/operators/Functional.py b/src/mrpro/operators/Functional.py index 10ad307f3..4724f6f79 100644 --- a/src/mrpro/operators/Functional.py +++ b/src/mrpro/operators/Functional.py @@ -5,42 +5,38 @@ import math from abc import ABC, abstractmethod from collections.abc import Sequence +from typing import TypeAlias, TypeVarTuple, Unpack import torch import mrpro.operators from mrpro.operators.Operator import Operator +T = TypeVarTuple('T') +FunctionalType: TypeAlias = Operator[Unpack[T], tuple[torch.Tensor]] +"""An Operator that returns a single tensor.""" -class Functional(Operator[torch.Tensor, tuple[torch.Tensor]]): - """Functional Base Class.""" - def __rmul__(self, scalar: torch.Tensor | complex) -> Functional: - """Multiply functional with scalar.""" - if not isinstance(scalar, int | float | torch.Tensor): - return NotImplemented - return ScaledFunctional(self, scalar) - - def _throw_if_negative_or_complex( - self, x: torch.Tensor | complex, message: str = 'sigma must be real and contain only positive values' - ) -> None: - """Throw an ValueError if any element of x is negative or complex. +def throw_if_negative_or_complex( + x: torch.Tensor | complex, message: str = 'sigma must be real and contain only positive values' +) -> None: + """Throw an ValueError if any element of x is negative or complex. - Parameters - ---------- - x - input to be checked - message - error message that is raised if x contains negative or complex values - """ - if (isinstance(x, float | int) and x >= 0) or ( - isinstance(x, torch.Tensor) and not x.dtype.is_complex and (x >= 0).all() - ): - return - raise ValueError(message) + Parameters + ---------- + x + input to be checked + message + error message that is raised if x contains negative or complex values + """ + if (isinstance(x, float | int) and x >= 0) or ( + isinstance(x, torch.Tensor) and not x.dtype.is_complex and (x >= 0).all() + ): + return + raise ValueError(message) -class ElementaryFunctional(Functional): +class ElementaryFunctional(Operator[torch.Tensor, tuple[torch.Tensor]], ABC): r"""Elementary functional base class. Here, an 'elementary' functional is a functional that can be written as @@ -122,7 +118,7 @@ def _divide_by_n(self, x: torch.Tensor, shape: None | Sequence[int]) -> torch.Te return x / math.prod(size) -class ProximableFunctional(Functional, ABC): +class ProximableFunctional(Operator[torch.Tensor, tuple[torch.Tensor]], ABC): r"""ProximableFunctional Base Class. A proximable functional is a functional :math:`f(x)` that has a prox implementation, @@ -168,7 +164,7 @@ def prox_convex_conj(self, x: torch.Tensor, sigma: torch.Tensor | float = 1.0) - """ if not isinstance(sigma, torch.Tensor): sigma = torch.as_tensor(1.0 * sigma) - self._throw_if_negative_or_complex(sigma) + throw_if_negative_or_complex(sigma) sigma = sigma.clamp(min=1e-8) return (x - sigma * self.prox(x / sigma, 1 / sigma)[0],) @@ -178,7 +174,9 @@ def __rmul__(self, scalar: torch.Tensor | complex) -> ProximableFunctional: return NotImplemented return ScaledProximableFunctional(self, scalar) - def __or__(self, other: ProximableFunctional) -> mrpro.operators.ProximableFunctionalSeparableSum: + def __or__( + self, other: ProximableFunctional + ) -> mrpro.operators.ProximableFunctionalSeparableSum[torch.Tensor, torch.Tensor]: """Create a ProximableFunctionalSeparableSum object from two proximable functionals. Parameters @@ -207,41 +205,6 @@ class ElementaryProximableFunctional(ElementaryFunctional, ProximableFunctional) """ -class ScaledFunctional(Functional): - """Functional scaled by a scalar.""" - - def __init__(self, functional: Functional, scale: torch.Tensor | float) -> None: - r"""Initialize a scaled functional. - - A scaled functional is a functional that is scaled by a scalar factor :math:`\alpha`, - i.e. :math:`f(x) = \alpha g(x)`. - - Parameters - ---------- - functional - functional to be scaled - scale - scaling factor, must be real and positive - """ - super().__init__() - self.functional = functional - self.scale = torch.as_tensor(scale) - - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]: - """Forward method. - - Parameters - ---------- - x - input tensor - - Returns - ------- - scaled output of the functional - """ - return (self.scale * self.functional(x)[0],) - - class ScaledProximableFunctional(ProximableFunctional): """Proximable Functional scaled by a scalar.""" @@ -290,7 +253,7 @@ def prox(self, x: torch.Tensor, sigma: torch.Tensor | float = 1.0) -> tuple[torc ------- Proximal mapping applied to the input tensor """ - self._throw_if_negative_or_complex( + throw_if_negative_or_complex( self.scale, 'For prox to be defined, the scaling factor must be real and non-negative' ) return (self.functional.prox(x, sigma * self.scale)[0],) @@ -309,7 +272,7 @@ def prox_convex_conj(self, x: torch.Tensor, sigma: torch.Tensor | float = 1.0) - ------- Proximal mapping of the convex conjugate applied to the input tensor """ - self._throw_if_negative_or_complex( + throw_if_negative_or_complex( self.scale, 'For prox_convex_conj to be defined, the scaling factor must be real and non-negative' ) return (self.scale * self.functional.prox_convex_conj(x / self.scale, sigma / self.scale)[0],) diff --git a/src/mrpro/operators/LinearOperatorMatrix.py b/src/mrpro/operators/LinearOperatorMatrix.py index 375753177..f9fb50954 100644 --- a/src/mrpro/operators/LinearOperatorMatrix.py +++ b/src/mrpro/operators/LinearOperatorMatrix.py @@ -4,7 +4,7 @@ import operator from collections.abc import Callable, Iterator, Sequence -from functools import reduce +from functools import cached_property, reduce from types import EllipsisType from typing import cast @@ -225,6 +225,24 @@ def H(self) -> Self: # noqa N802 """Adjoints of the operators.""" return self.__class__([[op.H for op in row] for row in zip(*self._operators, strict=True)]) + @cached_property + def gram(self) -> Self: + """Gram matrix of the operators.""" + n, m = self.shape + if n != m: + raise ValueError('Gram is only defined for square operators.') + operators: list[list[LinearOperator]] = [[ZeroOp() for _ in range(n)] for _ in range(n)] + + for i in range(n): + operators[i][i] = reduce(operator.add, (self._operators[k][i].gram for k in range(n))) + # off-diagonals: only compute upper triangular part, then mirror + for j in range(i + 1, n): + operators[i][j] = reduce( + operator.add, (self._operators[k][i].H @ self._operators[k][j] for k in range(n)) + ) + operators[j][i] = operators[i][j].H + return self.__class__(operators) + def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: """Apply the adjoint of the operator to the input. diff --git a/src/mrpro/operators/OptimizerOp.py b/src/mrpro/operators/OptimizerOp.py index 9cb833ad0..4cb9c9b9c 100644 --- a/src/mrpro/operators/OptimizerOp.py +++ b/src/mrpro/operators/OptimizerOp.py @@ -2,7 +2,7 @@ import functools from collections.abc import Callable -from typing import TYPE_CHECKING, TypeAlias, Unpack +from typing import TYPE_CHECKING, TypeVar, TypeVarTuple, Unpack, cast import torch @@ -10,16 +10,16 @@ from mrpro.algorithms.optimizers.lbfgs import lbfgs from mrpro.operators.Operator import Operator -ArgumentType: TypeAlias = tuple[torch.Tensor, ...] -VariableType: TypeAlias = tuple[torch.Tensor, ...] -ObjectiveType: TypeAlias = Callable[[Unpack[VariableType]], tuple[torch.Tensor]] -FactoryType: TypeAlias = Callable[[Unpack[ArgumentType]], ObjectiveType] -OptimizeFunctionType: TypeAlias = Callable[[ObjectiveType, VariableType], VariableType] +ArgumentType = TypeVarTuple('ArgumentType') +VariableType = TypeVar('VariableType', bound=tuple[torch.Tensor, ...]) +ObjectiveType = Callable[[VariableType], tuple[torch.Tensor]] +FactoryType = Callable[[Unpack[tuple[torch.Tensor, ...]]], Callable] +OptimizeFunctionType = Callable[[Callable, VariableType], VariableType] default_lbfgs = functools.partial( lbfgs, learning_rate=1.0, - max_iterations=20, + max_iterations=40, tolerance_change=1e-8, tolerance_grad=1e-7, history_size=20, @@ -31,31 +31,57 @@ class OptimizeCtx(torch.autograd.function.FunctionCtx): """Rype hinting the CTX object.""" - factory: FactoryType + factory: Callable[ + [Unpack[tuple[torch.Tensor, ...]]], Callable[[Unpack[tuple[torch.Tensor, ...]]], tuple[torch.Tensor]] + ] len_x: int needs_input_grad: tuple[bool, ...] saved_tensors: tuple[torch.Tensor, ...] -class _OptimizerImplicitBackward(torch.autograd.Function): +class OptimizeFunction(torch.autograd.Function): """Implicit Backward.""" + if TYPE_CHECKING: + + @classmethod + def apply( + cls, + factory: Callable[ + [Unpack[tuple[torch.Tensor, ...]]], Callable[[Unpack[tuple[torch.Tensor, ...]]], tuple[torch.Tensor]] + ], + initial_values: tuple[torch.Tensor, ...], + optimize: Callable[ + [Callable[[*tuple[Unpack[tuple[torch.Tensor, ...]]]], tuple[torch.Tensor]], tuple[torch.Tensor, ...]], + tuple[torch.Tensor, ...], + ] = default_lbfgs, + *parameters: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + """Apply the function. Only used for type hinting.""" + return super().apply(factory, initial_values, optimize, *parameters) + @staticmethod def forward( ctx: OptimizeCtx, - factory: FactoryType, - initial_values: VariableType, - optimize: OptimizeFunctionType = default_lbfgs, - *parameters: Unpack[ArgumentType], - ) -> VariableType: + factory: Callable[ + [Unpack[tuple[torch.Tensor, ...]]], Callable[[Unpack[tuple[torch.Tensor, ...]]], tuple[torch.Tensor]] + ], + initial_values: tuple[torch.Tensor, ...], + optimize: Callable[ + [Callable[[*tuple[Unpack[tuple[torch.Tensor, ...]]]], tuple[torch.Tensor]], tuple[torch.Tensor, ...]], + tuple[torch.Tensor, ...], + ] = default_lbfgs, + *parameters: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: """Optimize.""" ctx.factory = factory + + parameters_ = tuple(p.detach().clone() for p in parameters if isinstance(p, torch.Tensor)) + initial_values_ = tuple(x.detach().requires_grad_(True) for x in initial_values if isinstance(x, torch.Tensor)) f = factory(*parameters) xprime = optimize(f, initial_values) - for xp in xprime: - xp.grad = None - ctx.save_for_backward(*xprime, *parameters) - ctx.len_x = len(initial_values) + ctx.save_for_backward(*xprime, *parameters_) + ctx.len_x = len(initial_values_) return xprime @staticmethod @@ -72,11 +98,11 @@ def backward(ctx: OptimizeCtx, *grad_outputs: torch.Tensor) -> tuple[torch.Tenso objective = ctx.factory(*parameters) def hvp(*v: torch.Tensor) -> tuple[torch.Tensor, ...]: - return torch.autograd.functional.vhp(objective, xprime, v=v)[1:] + return torch.autograd.functional.vhp(lambda *x: objective(*x)[0], xprime, v=v)[1] - hessian_inverse_grad = cg(hvp, grad_outputs) + hessian_inverse_grad = cg(hvp, grad_outputs, max_iterations=200, tolerance=1e-6) with torch.enable_grad(): - dobjective_dxprime = torch.autograd.grad(objective(*xprime), xprime, create_graph=True)[0] + dobjective_dxprime = torch.autograd.grad(objective(*xprime), xprime, create_graph=True) # - d^2_obective / d_xprime d_params Hessian^-1_grad grad_params = list(torch.autograd.grad(dobjective_dxprime, dparams, hessian_inverse_grad)) grad_inputs: list[torch.Tensor | None] = [None, None, None] # factory, x0, optimize @@ -89,7 +115,7 @@ def hvp(*v: torch.Tensor) -> tuple[torch.Tensor, ...]: return tuple(grad_inputs) -class OptimizerOp(Operator): +class OptimizerOp(Operator[Unpack[ArgumentType], VariableType]): """Differentiable Optimization Operator. One of the building blocks of PINQI [ZIMM2024]_ @@ -137,11 +163,12 @@ def initializer(_y, _lambda, _xreg): ------- The argmin `x^*` """ + super().__init__() self.factory = factory self.optimize = optimize self.initializer = initializer - def forward(self, *parameters: torch.Tensor) -> tuple[torch.Tensor, ...]: + def forward(self, *parameters: Unpack[ArgumentType]) -> VariableType: """Find the argmin. Parameters @@ -149,16 +176,9 @@ def forward(self, *parameters: torch.Tensor) -> tuple[torch.Tensor, ...]: parameters Parameters of the argmin problem. """ - parameters = tuple(p.detach().clone() for p in parameters) initial_values = self.initializer(*parameters) initial_values = tuple(x.clone() if any(x is p for p in parameters) else x for x in initial_values) - initial_values = tuple(x.detach().requires_grad_(True) for x in initial_values) - if TYPE_CHECKING: - # For mypy - result = _OptimizerImplicitBackward.forward( - OptimizeCtx(), self.factory, initial_values, self.optimize, *parameters - ) - else: - # Actually used at runtime - result = _OptimizerImplicitBackward.apply(self.factory, initial_values, self.optimize, *parameters) - return result + result = OptimizeFunction.apply( + self.factory, initial_values, self.optimize, *cast(tuple[torch.Tensor, ...], parameters) + ) + return cast(VariableType, result) diff --git a/src/mrpro/operators/ProximableFunctionalSeparableSum.py b/src/mrpro/operators/ProximableFunctionalSeparableSum.py index a5f1daeb2..443ec9ebf 100644 --- a/src/mrpro/operators/ProximableFunctionalSeparableSum.py +++ b/src/mrpro/operators/ProximableFunctionalSeparableSum.py @@ -5,16 +5,19 @@ import operator from collections.abc import Iterator from functools import reduce -from typing import cast +from typing import TypeVarTuple, cast, overload import torch -from typing_extensions import Self, Unpack +from typing_extensions import Unpack from mrpro.operators.Functional import ProximableFunctional from mrpro.operators.Operator import Operator +T = TypeVarTuple('T') +T2 = TypeVarTuple('T2') -class ProximableFunctionalSeparableSum(Operator[Unpack[tuple[torch.Tensor, ...]], tuple[torch.Tensor]]): + +class ProximableFunctionalSeparableSum(Operator[Unpack[T], tuple[torch.Tensor]]): r"""Separable Sum of Proximable Functionals. This is a separable sum of the functionals. The forward method returns the sum of the functionals @@ -23,6 +26,61 @@ class ProximableFunctionalSeparableSum(Operator[Unpack[tuple[torch.Tensor, ...]] functionals: tuple[ProximableFunctional, ...] + @overload + def __init__(self: ProximableFunctionalSeparableSum[torch.Tensor], f1: ProximableFunctional, /) -> None: ... + + @overload + def __init__( + self: ProximableFunctionalSeparableSum[torch.Tensor, torch.Tensor], + f1: ProximableFunctional, + f2: ProximableFunctional, + /, + ) -> None: ... + + @overload + def __init__( + self: ProximableFunctionalSeparableSum[torch.Tensor, torch.Tensor, torch.Tensor], + f1: ProximableFunctional, + f2: ProximableFunctional, + f3: ProximableFunctional, + /, + ) -> None: ... + + @overload + def __init__( + self: ProximableFunctionalSeparableSum[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + f1: ProximableFunctional, + f2: ProximableFunctional, + f3: ProximableFunctional, + f4: ProximableFunctional, + /, + ) -> None: ... + + @overload + def __init__( + self: ProximableFunctionalSeparableSum[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + f1: ProximableFunctional, + f2: ProximableFunctional, + f3: ProximableFunctional, + f4: ProximableFunctional, + f5: ProximableFunctional, + /, + ) -> None: ... + + @overload + def __init__( + self: ProximableFunctionalSeparableSum[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Unpack[tuple[torch.Tensor, ...]] + ], + f1: ProximableFunctional, + f2: ProximableFunctional, + f3: ProximableFunctional, + f4: ProximableFunctional, + f5: ProximableFunctional, + /, + *f: ProximableFunctional, + ) -> None: ... + def __init__(self, *functionals: ProximableFunctional) -> None: """Initialize the separable sum of proximable functionals. @@ -34,7 +92,7 @@ def __init__(self, *functionals: ProximableFunctional) -> None: super().__init__() self.functionals = functionals - def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor]: + def forward(self, *x: Unpack[T]) -> tuple[torch.Tensor]: """Apply the functionals to the inputs. Parameters @@ -48,10 +106,12 @@ def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor]: """ if len(x) != len(self.functionals): raise ValueError('The number of inputs must match the number of functionals.') - result = reduce(operator.add, (f(xi)[0] for f, xi in zip(self.functionals, x, strict=True))) + result = reduce( + operator.add, (f(xi)[0] for f, xi in zip(self.functionals, cast(tuple[torch.Tensor, ...], x), strict=True)) + ) return (result,) - def prox(self, *x: torch.Tensor, sigma: float | torch.Tensor = 1) -> tuple[torch.Tensor, ...]: + def prox(self, *x: Unpack[T], sigma: float | torch.Tensor = 1) -> tuple[Unpack[T]]: """Apply the proximal operators of the functionals to the inputs. Parameters @@ -68,9 +128,9 @@ def prox(self, *x: torch.Tensor, sigma: float | torch.Tensor = 1) -> tuple[torch prox_x = tuple( f.prox(xi, sigma)[0] for f, xi in zip(self.functionals, cast(tuple[torch.Tensor, ...], x), strict=True) ) - return prox_x + return cast(tuple[Unpack[T]], prox_x) - def prox_convex_conj(self, *x: torch.Tensor, sigma: float | torch.Tensor = 1) -> tuple[torch.Tensor, ...]: + def prox_convex_conj(self, *x: Unpack[T], sigma: float | torch.Tensor = 1) -> tuple[Unpack[T]]: """Apply the proximal operators of the convex conjugate of the functionals to the inputs. Parameters @@ -88,12 +148,48 @@ def prox_convex_conj(self, *x: torch.Tensor, sigma: float | torch.Tensor = 1) -> f.prox_convex_conj(xi, sigma)[0] for f, xi in zip(self.functionals, cast(tuple[torch.Tensor, ...], x), strict=True) ) - return prox_convex_conj_x + return cast(tuple[Unpack[T]], prox_convex_conj_x) + + @overload + def __or__( + self: ProximableFunctionalSeparableSum[Unpack[T]], other: ProximableFunctional + ) -> ProximableFunctionalSeparableSum[Unpack[T], torch.Tensor]: ... + + @overload + def __or__( + self: ProximableFunctionalSeparableSum[Unpack[T]], other: ProximableFunctionalSeparableSum[torch.Tensor] + ) -> ProximableFunctionalSeparableSum[Unpack[T], torch.Tensor]: ... + + @overload + def __or__( + self: ProximableFunctionalSeparableSum[Unpack[T]], + other: ProximableFunctionalSeparableSum[torch.Tensor, torch.Tensor], + ) -> ProximableFunctionalSeparableSum[Unpack[T], torch.Tensor, torch.Tensor]: ... + + @overload + def __or__( + self: ProximableFunctionalSeparableSum[Unpack[T]], + other: ProximableFunctionalSeparableSum[torch.Tensor, torch.Tensor, torch.Tensor], + ) -> ProximableFunctionalSeparableSum[Unpack[T], torch.Tensor, torch.Tensor, torch.Tensor]: ... + + @overload + def __or__( + self: ProximableFunctionalSeparableSum[Unpack[T]], + other: ProximableFunctionalSeparableSum[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + ) -> ProximableFunctionalSeparableSum[Unpack[T], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... + + @overload + def __or__( + self: ProximableFunctionalSeparableSum[Unpack[T]], + other: ProximableFunctionalSeparableSum[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + ) -> ProximableFunctionalSeparableSum[ + Unpack[T], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor + ]: ... def __or__( - self, + self: ProximableFunctionalSeparableSum, other: ProximableFunctional | ProximableFunctionalSeparableSum, - ) -> Self: + ) -> ProximableFunctionalSeparableSum: """Separable sum functionals.""" if isinstance(other, ProximableFunctionalSeparableSum): return self.__class__(*self.functionals, *other.functionals) @@ -102,10 +198,14 @@ def __or__( else: return NotImplemented # type: ignore[unreachable] - def __ror__(self, other: ProximableFunctional) -> Self: + def __ror__( + self: ProximableFunctionalSeparableSum[Unpack[T]], other: ProximableFunctional + ) -> ProximableFunctionalSeparableSum[torch.Tensor, Unpack[T]]: """Separable sum functionals.""" if isinstance(other, ProximableFunctional): - return self.__class__(other, *self.functionals) + return cast( + ProximableFunctionalSeparableSum[torch.Tensor, Unpack[T]], self.__class__(other, *self.functionals) + ) else: return NotImplemented # type: ignore[unreachable] diff --git a/src/mrpro/operators/ZeroOp.py b/src/mrpro/operators/ZeroOp.py index fcc8adb12..d7029c5f1 100644 --- a/src/mrpro/operators/ZeroOp.py +++ b/src/mrpro/operators/ZeroOp.py @@ -62,3 +62,8 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: return (torch.zeros_like(x),) else: return (torch.tensor(0),) + + @property + def H(self) -> LinearOperator: # noqa: N802 + """Adjoint of the Zero Operator.""" + return self diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index e90ca9139..2bc42b3d5 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -2,10 +2,11 @@ from mrpro.operators.Operator import Operator from mrpro.operators.LinearOperator import LinearOperator -from mrpro.operators.Functional import Functional, ProximableFunctional, ElementaryFunctional, ElementaryProximableFunctional, ScaledFunctional, ScaledProximableFunctional +from mrpro.operators.Functional import FunctionalType, ProximableFunctional, ElementaryFunctional, ElementaryProximableFunctional, ScaledProximableFunctional from mrpro.operators import functionals, models from mrpro.operators.AveragingOp import AveragingOp from mrpro.operators.CartesianSamplingOp import CartesianSamplingOp +from mrpro.operators.ConjugateGradientOp import ConjugateGradientOp from mrpro.operators.ConstraintsOp import ConstraintsOp from mrpro.operators.DensityCompensationOp import DensityCompensationOp from mrpro.operators.DictionaryMatchOp import DictionaryMatchOp @@ -37,6 +38,7 @@ __all__ = [ "AveragingOp", "CartesianSamplingOp", + "ConjugateGradientOp", "ConstraintsOp", "DensityCompensationOp", "DictionaryMatchOp", @@ -46,7 +48,7 @@ "FastFourierOp", "FiniteDifferenceOp", "FourierOp", - "Functional", + "FunctionalType", "GridSamplingOp", "IdentityOp", "Jacobian", @@ -63,7 +65,6 @@ "ProximableFunctional", "ProximableFunctionalSeparableSum", "RearrangeOp", - "ScaledFunctional", "ScaledProximableFunctional", "SensitivityOp", "SignalModel", diff --git a/src/mrpro/operators/functionals/L1Norm.py b/src/mrpro/operators/functionals/L1Norm.py index 55221b9bf..67ed52ebf 100644 --- a/src/mrpro/operators/functionals/L1Norm.py +++ b/src/mrpro/operators/functionals/L1Norm.py @@ -2,7 +2,7 @@ import torch -from mrpro.operators.Functional import ElementaryProximableFunctional +from mrpro.operators.Functional import ElementaryProximableFunctional, throw_if_negative_or_complex class L1Norm(ElementaryProximableFunctional): @@ -58,7 +58,7 @@ def prox(self, x: torch.Tensor, sigma: torch.Tensor | float = 1.0) -> tuple[torc ------- Proximal mapping applied to the input tensor """ - self._throw_if_negative_or_complex(sigma) + throw_if_negative_or_complex(sigma) diff = x - self.target threshold = self.weight * sigma threshold = self._divide_by_n(threshold, torch.broadcast_shapes(x.shape, threshold.shape)) @@ -86,7 +86,7 @@ def prox_convex_conj( ------- Proximal of the convex conjugate applied to the input tensor """ - self._throw_if_negative_or_complex(sigma) + throw_if_negative_or_complex(sigma) diff = x - sigma * self.target threshold = self._divide_by_n(self.weight.abs(), torch.broadcast_shapes(x.shape, self.weight.shape)) x_out = torch.sgn(diff) * torch.clamp_max(diff.abs(), threshold.abs()) diff --git a/src/mrpro/operators/functionals/L1NormViewAsReal.py b/src/mrpro/operators/functionals/L1NormViewAsReal.py index 4d9e8fbf7..304fb74bb 100644 --- a/src/mrpro/operators/functionals/L1NormViewAsReal.py +++ b/src/mrpro/operators/functionals/L1NormViewAsReal.py @@ -2,7 +2,7 @@ import torch -from mrpro.operators.Functional import ElementaryProximableFunctional +from mrpro.operators.Functional import ElementaryProximableFunctional, throw_if_negative_or_complex class L1NormViewAsReal(ElementaryProximableFunctional): @@ -69,7 +69,7 @@ def prox(self, x: torch.Tensor, sigma: torch.Tensor | float = 1.0) -> tuple[torc ------- Proximal mapping applied to the input tensor """ - self._throw_if_negative_or_complex(sigma) + throw_if_negative_or_complex(sigma) diff = x - self.target threshold = self._divide_by_n(self.weight * sigma, torch.broadcast_shapes(x.shape, self.weight.shape)) out = torch.sgn(diff.real) * torch.relu(diff.real.abs() - threshold.real.abs()) diff --git a/src/mrpro/operators/functionals/L2NormSquared.py b/src/mrpro/operators/functionals/L2NormSquared.py index 25fed7cfa..5714946d8 100644 --- a/src/mrpro/operators/functionals/L2NormSquared.py +++ b/src/mrpro/operators/functionals/L2NormSquared.py @@ -2,7 +2,7 @@ import torch -from mrpro.operators.Functional import ElementaryProximableFunctional +from mrpro.operators.Functional import ElementaryProximableFunctional, throw_if_negative_or_complex class L2NormSquared(ElementaryProximableFunctional): @@ -65,7 +65,7 @@ def prox( ------- Proximal mapping applied to the input tensor """ - self._throw_if_negative_or_complex(sigma) + throw_if_negative_or_complex(sigma) weight_square_2_sigma = self._divide_by_n( self.weight.conj() * self.weight * 2 * sigma, torch.broadcast_shapes(x.shape, self.target.shape, self.weight.shape), @@ -94,7 +94,7 @@ def prox_convex_conj( ------- Proximal of convex conjugate applied to the input tensor """ - self._throw_if_negative_or_complex(sigma) + throw_if_negative_or_complex(sigma) weight_square = self._divide_by_n( self.weight.conj() * self.weight, torch.broadcast_shapes(x.shape, self.target.shape, self.weight.shape) ) diff --git a/src/mrpro/operators/functionals/SSIM.py b/src/mrpro/operators/functionals/SSIM.py index c8a9b994d..663baf089 100644 --- a/src/mrpro/operators/functionals/SSIM.py +++ b/src/mrpro/operators/functionals/SSIM.py @@ -4,7 +4,7 @@ import torch -from mrpro.operators.Functional import Functional +from mrpro.operators.Operator import Operator from mrpro.utils.sliding_window import sliding_window @@ -152,7 +152,7 @@ def ssim3d( return ssim_map -class SSIM(Functional): +class SSIM(Operator[torch.Tensor, tuple[torch.Tensor]]): """(masked) SSIM functional.""" def __init__( diff --git a/src/mrpro/operators/functionals/ZeroFunctional.py b/src/mrpro/operators/functionals/ZeroFunctional.py index ea2981eb5..2cd24e740 100644 --- a/src/mrpro/operators/functionals/ZeroFunctional.py +++ b/src/mrpro/operators/functionals/ZeroFunctional.py @@ -4,7 +4,7 @@ import torch -from mrpro.operators import ElementaryProximableFunctional +from mrpro.operators.Functional import ElementaryProximableFunctional, throw_if_negative_or_complex class ZeroFunctional(ElementaryProximableFunctional): @@ -57,7 +57,7 @@ def prox(self, x: torch.Tensor, sigma: float | torch.Tensor = 1.0) -> tuple[torc ------- Result of the proximal operator applied to x """ - self._throw_if_negative_or_complex(sigma) + throw_if_negative_or_complex(sigma) dtype = torch.promote_types(torch.promote_types(x.dtype, self.weight.dtype), self.target.dtype) return (x.to(dtype=dtype),) @@ -80,7 +80,7 @@ def prox_convex_conj(self, x: torch.Tensor, sigma: float | torch.Tensor = 1.0) - ------- Result of the proximal operator of the convex conjugate applied to x """ - self._throw_if_negative_or_complex(sigma) + throw_if_negative_or_complex(sigma) sigma = torch.as_tensor(sigma) dtype = torch.promote_types(torch.promote_types(x.dtype, self.weight.dtype), self.target.dtype) result = torch.where(sigma == 0, x, torch.zeros_like(x)).to(dtype=dtype) diff --git a/tests/algorithms/test_pdhg.py b/tests/algorithms/test_pdhg.py index 372f72faa..483f0e81b 100644 --- a/tests/algorithms/test_pdhg.py +++ b/tests/algorithms/test_pdhg.py @@ -91,7 +91,7 @@ def test_fourier_l2_l1_() -> None: l2 = 0.5 * L2NormSquared(target=data, divide_by_n=False) l1 = regularization_parameter * L1NormViewAsReal(divide_by_n=False) - f = ProximableFunctionalSeparableSum(l2, l1) + f = l2 | l1 g = ZeroFunctional() operator = LinearOperatorMatrix(((fourier_op,), (IdentityOp(),))) diff --git a/tests/operators/functionals/test_functional_arithmetic.py b/tests/operators/functionals/test_functional_arithmetic.py index 07eb225ef..63f02cbdb 100644 --- a/tests/operators/functionals/test_functional_arithmetic.py +++ b/tests/operators/functionals/test_functional_arithmetic.py @@ -3,7 +3,7 @@ import pytest import torch from mrpro.operators import ElementaryFunctional, ElementaryProximableFunctional, ProximableFunctional -from mrpro.operators.Functional import ScaledFunctional, ScaledProximableFunctional +from mrpro.operators.Functional import ScaledProximableFunctional from mrpro.utils import RandomGenerator from tests.operators.functionals.conftest import ( @@ -14,7 +14,7 @@ ) -@pytest.mark.parametrize('functional', FUNCTIONALS) +@pytest.mark.parametrize('functional', PROXIMABLE_FUNCTIONALS) @pytest.mark.parametrize('scale_type', ['negative', 'positive', 'tensor', 'int']) def test_functional_scaling_forward( functional: type[ElementaryFunctional], scale_type: Literal['negative', 'positive', 'tensor', 'int'] @@ -33,7 +33,7 @@ def test_functional_scaling_forward( case 'int': scale = 5 scaled_f = scale * f - assert isinstance(scaled_f, ScaledFunctional | ScaledProximableFunctional) + assert isinstance(scaled_f, ScaledProximableFunctional) torch.testing.assert_close(scaled_f(x)[0], scale * f(x)[0]) diff --git a/tests/operators/test_linearoperatormatrix.py b/tests/operators/test_linearoperatormatrix.py index a0bd29b52..847defae8 100644 --- a/tests/operators/test_linearoperatormatrix.py +++ b/tests/operators/test_linearoperatormatrix.py @@ -320,3 +320,13 @@ def test_linearoperatormatrix_from_diagonal(): actual = matrix(*xs) expected = tuple(op(x)[0] for op, x in zip(ops, xs, strict=False)) torch.testing.assert_close(actual, expected) + + +def test_linearoperatormatrix_gram(): + """Test gram of LinearOperatorMatrix.""" + rng = RandomGenerator(0) + matrix = random_linearoperatormatrix((2, 2), (4, 10), rng) + vector = tuple(rng.complex64_tensor((2, 10))) + result = matrix.gram(*vector) + expected = (matrix.H @ matrix)(*vector) + torch.testing.assert_close(result, expected) diff --git a/tests/operators/test_optimizer_op.py b/tests/operators/test_optimizer_op.py index 422c379b5..aa91c488b 100644 --- a/tests/operators/test_optimizer_op.py +++ b/tests/operators/test_optimizer_op.py @@ -1,14 +1,28 @@ -from collections.abc import Callable - import torch -from mrpro.operators import OptimizerOp +from mrpro.operators import ConstraintsOp, FunctionalType, OptimizerOp from mrpro.operators.functionals import L2NormSquared from mrpro.operators.models import InversionRecovery from mrpro.utils import RandomGenerator -def test_optimizer_op(): - rng = RandomGenerator(seed=0) +def test_optimizer_op_gradcheck() -> None: + """Test the optimizer op with gradcheck.""" + rng = RandomGenerator(seed=42) + constraints_op = ConstraintsOp( + bounds=( + (-1, 1), # M0 is not constrained + (0.001, 4.0), # T1 is constrained between 1 ms and 3 s + ) + ).double() # everything is double, otherwise the numerical derivative used in gradcheck gives wrong values + + rng = RandomGenerator(seed=1) + true_m0 = rng.complex128_tensor(size=(3, 2)) + true_t1 = rng.float64_tensor(size=(3, 2), low=0.1, high=2) + (signal,) = InversionRecovery(torch.tensor([0.5, 1.0, 1.5, 3], dtype=torch.float64))(true_m0, true_t1) + # signal += rng.complex128_tensor(size=(3, 2), high=0.0001) + t1_reg = true_t1 + rng.rand_like(true_t1, low=-0.01, high=0.01) + m0_reg = true_m0 + rng.rand_like(true_m0, high=0.01) + m0_reg.requires_grad = True def factory( m0_reg: torch.Tensor, @@ -16,24 +30,21 @@ def factory( lambda_m0: torch.Tensor, lambda_t1: torch.Tensor, signal: torch.Tensor, - ) -> Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - data_consistency = L2NormSquared(signal) @ InversionRecovery((0.5, 1.0, 1.5, 3)) - regularization = lambda_t1 * L2NormSquared(t1_reg) | lambda_m0 * L2NormSquared(m0_reg) - return data_consistency + regularization + ) -> FunctionalType[torch.Tensor, torch.Tensor]: + dc = L2NormSquared(signal) @ InversionRecovery((0.5, 1.0, 1.5, 3)).double() + reg = lambda_m0 * L2NormSquared(m0_reg) | lambda_t1 * L2NormSquared(t1_reg) + return (dc + reg) @ constraints_op + + op = OptimizerOp( + factory=factory, + initializer=lambda m0_reg, t1_reg, *_: constraints_op.inverse(m0_reg, t1_reg), + ) + lambda_m0 = torch.tensor(1, requires_grad=True, dtype=torch.float64) + lambda_t1 = torch.tensor(1, requires_grad=True, dtype=torch.float64) + torch.autograd.gradcheck( + op, (m0_reg, t1_reg, lambda_m0, lambda_t1, signal), fast_mode=True, atol=1e-3, rtol=1e-2, eps=1e-3 + ) - factory(torch.randn(10, 10), torch.randn(10, 10), torch.randn(1), torch.randn(1), torch.randn(10, 10)) - rng = RandomGenerator(seed=0) - true_m0 = rng.complex64_tensor(size=(10, 10)) - true_t1 = rng.float32_tensor(size=(10, 10), low=0.1, high=2) - (signal,) = InversionRecovery((0.5, 1.0, 1.5, 3))(true_m0, true_t1) - signal += rng.complex64_tensor(size=(10, 10), high=0.1) - t1_reg = true_t1 + rng.rand_like(true_t1, low=-0.1, high=0.1) - m0_reg = true_m0 + rng.rand_like(true_m0, low=0, high=0.5) - op = OptimizerOp(factory, initializer=lambda m0_reg, t1_reg, *_: (m0_reg, t1_reg)) - lambda_m0 = torch.tensor(1.0, requires_grad=True) - lambda_t1 = torch.tensor(1.0, requires_grad=True) - ret = op.forward(m0_reg, t1_reg, lambda_m0, lambda_t1, signal) - (loss,) = (L2NormSquared(true_m0) | L2NormSquared(m0_reg))(*ret) - loss.backward() - assert lambda_m0.grad is not None - assert lambda_t1.grad is not None + ret = (constraints_op @ op)(m0_reg, t1_reg, lambda_m0, lambda_t1, signal) + torch.testing.assert_close(ret[0], true_m0, atol=1e-3, rtol=1e-2) + torch.testing.assert_close(ret[1], true_t1, atol=1e-3, rtol=1e-2) From dc6429d7d3704a3382288593dfb0493af4e4c32d Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 23 Apr 2025 13:56:54 +0200 Subject: [PATCH 006/205] update --- src/mrpro/operators/ConjugateGradientOp.py | 193 ++++++++++++++++++ tests/operators/test_conjugate_gradient_op.py | 72 +++++++ 2 files changed, 265 insertions(+) create mode 100644 src/mrpro/operators/ConjugateGradientOp.py create mode 100644 tests/operators/test_conjugate_gradient_op.py diff --git a/src/mrpro/operators/ConjugateGradientOp.py b/src/mrpro/operators/ConjugateGradientOp.py new file mode 100644 index 000000000..8c185af12 --- /dev/null +++ b/src/mrpro/operators/ConjugateGradientOp.py @@ -0,0 +1,193 @@ +"""Regularized least squares operator.""" + +from collections.abc import Callable +from typing import TYPE_CHECKING, TypeVar + +import torch +from torch.autograd.function import once_differentiable + +from mrpro.algorithms.optimizers.cg import cg +from mrpro.operators.LinearOperator import LinearOperator +from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix + +LinearOperatorFactory = Callable[..., LinearOperator] +LinearOperatorMatrixFactory = Callable[..., LinearOperatorMatrix] +T = TypeVar('T', torch.Tensor, bool) + + +class ConjugateGradientCTX(torch.autograd.function.FunctionCtx): + """Only used for type hinting.""" + + saved_tensors: tuple[torch.Tensor, ...] + needs_input_grad: tuple[bool, ...] + len_solution: int + tolerance: float + max_iterations: int + rhs_factory: Callable[..., tuple[torch.Tensor, ...]] + operator_factory: Callable[..., LinearOperatorMatrix | LinearOperator] + + +class ConjugateGradientFunction(torch.autograd.Function): + """Autograd function for the regularized least squares operator.""" + + if TYPE_CHECKING: + + @classmethod + def apply( + cls, + operator_factory: Callable[..., LinearOperatorMatrix | LinearOperator], + rhs_factory: Callable[..., tuple[torch.Tensor, ...]], + *inputs: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + """Apply the function. Required for mypy.""" + return super().apply(operator_factory, rhs_factory, *inputs) + + @staticmethod + def forward( + ctx: ConjugateGradientCTX, + operator_factory: Callable[..., LinearOperatorMatrix | LinearOperator], + rhs_factory: Callable[..., tuple[torch.Tensor, ...]], + *inputs: torch.Tensor, + initial_value: tuple[torch.Tensor, ...] | None = None, + max_iterations: int = 10000, + tolerance: float = 1e-7, + ) -> tuple[torch.Tensor, ...]: + """Forward pass of the conjugate gradient operator.""" + operator = operator_factory(*inputs) + rhs = rhs_factory(*inputs) + rhs_norm = sum((r.abs().square().sum() for r in rhs), torch.tensor(0.0)).sqrt().item() + fwd_tol = tolerance * rhs_norm + if isinstance(operator, LinearOperator): + if len(rhs) != 1: + raise ValueError('LinearOperator requires a single right-hand side tensor.') + if initial_value is not None and len(initial_value) != 1: + raise ValueError('LinearOperator requires a single initial value tensor.') + solution = cg(operator, rhs, initial_value=initial_value, tolerance=fwd_tol, max_iterations=max_iterations) + else: + solution = cg(operator, rhs, initial_value=initial_value, tolerance=fwd_tol, max_iterations=max_iterations) + ctx.save_for_backward(*solution, *inputs) + ctx.len_solution = len(solution) + ctx.tolerance = tolerance + ctx.max_iterations = max_iterations + ctx.rhs_factory = rhs_factory + ctx.operator_factory = operator_factory + return solution + + @staticmethod + @once_differentiable + def backward(ctx: ConjugateGradientCTX, *grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: + """Backward pass of the conjugate gradient operator.""" + solution, inputs = ( + ctx.saved_tensors[: ctx.len_solution], + ctx.saved_tensors[ctx.len_solution :], + ) + with torch.enable_grad(): + rhs = ctx.rhs_factory(*inputs) + operator = ctx.operator_factory(*inputs) + inputs_with_grad = tuple(i for i, need_grad in zip(inputs, ctx.needs_input_grad[2:], strict=True) if need_grad) + if inputs_with_grad: + with torch.no_grad(): + if isinstance(operator, LinearOperatorMatrix): + z = cg(operator.H, grad_output) + else: + z = cg(operator.H, grad_output[0]) + with torch.enable_grad(): + residual = tuple(r - ax for r, ax in zip(rhs, operator(*(s.detach() for s in solution)), strict=True)) + grads = torch.autograd.grad(outputs=residual, inputs=inputs_with_grad, grad_outputs=z, allow_unused=True) + grad_iter = iter(grads) + else: + grad_iter = iter(()) + grads = tuple(next(grad_iter) if need else None for need in ctx.needs_input_grad[2:]) + return (None, None, *grads) # operator_factory, rhs_factory, *inputs + + +class ConjugateGradientOp(torch.nn.Module): + r"""Solves a linear positive semidefinite system with the conjugate gradient method. + + Solves :math: `A x = b` where :math:`A` is a linear operator , :math:`b` is a tensor or a tuple of tensors. + + The operator is autograd differentiable using implicit differentiation. + If this is not needed, consider using `mrpro.algorithms.optimizers.cg` directly. + """ + + def __init__( + self, + operator_factory: Callable[..., LinearOperatorMatrix | LinearOperator], + rhs_factory: Callable[..., tuple[torch.Tensor, ...]], + implicit_backward: bool = True, + tolerance: float = 1e-8, + max_iterations: int = 10000, + ): + r"""Initialize a conjugate gradient operator. + + Both the operator and the right-hand side are given as factory functions. + The arguments given to the operator when calling it are passed to the factory functions. + + **Example: Regularized Least Squares** + + Consider the regularized least squares problem: + :math:`\min_x \|A x - b\|_2^2 + \alpha \|x - x_0\|_2^2`. + + The normal equations are :math:`(A^H A + \alpha I) x = A^H b + \alpha x_0`. + This can be solved using the ConjugateGradientOp as follows: + .. code-block:: python + operator_factory = lambda alpha, x0, b: A.gram + alpha + rhs_factory = lambda alpha, x0, b: A.H(b)[0] + alpha * x0 + op = ConjugateGradientOp(operator_factory, rhs_factory) + solution = op(alpha, x0, b) + + Parameters + ---------- + operator_factory + A factory function that returns the operator :math:`A`. + Should return either a `LinearOperatorMatrix` or a `LinearOperator`. + rhs_factory + A factory function that returns the right-hand side :math:`b` + Should return a tuple of tensors. + implicit_backward + If `True`, the backward pass is done using implicit differentiation. + If `False`, the backward pass is done using unrolling the CG loop. + tolerance + The tolerance for the conjugate gradient method. The tolerance is relative + to the norm of the right-hand side. The same tolerance is used in the backward pass + if using implicit differentiation. + max_iterations + The maximum number of iterations for the conjugate gradient method. + + .. warning:: + If implicit_backward is `True`, the problem has to converge, otherwise the backward + will be wrong. `tolerance` and `max_iter` should be chosen accordingly. + """ + super().__init__() + self.operator_factory = operator_factory + self.rhs_factory = rhs_factory + self.implicit_backward = implicit_backward + self.tolerance = tolerance + self.max_iterations = max_iterations + + def forward( + self, + *parameters: torch.Tensor, + initial_value: tuple[torch.Tensor, ...] | None = None, + ) -> tuple[torch.Tensor, ...]: + r"""Solve the linear system using the conjugate gradient method. + + Parameters + ---------- + parameters + The parameters passed to the operator and right-hand side factory functions. + """ + if self.implicit_backward: + solution = ConjugateGradientFunction.apply(self.operator_factory, self.rhs_factory, *parameters) + else: + op = self.operator_factory(*parameters) + rhs = self.rhs_factory(*parameters) + if isinstance(op, LinearOperator): + if len(rhs) != 1: + raise ValueError('LinearOperator requires a single right-hand side tensor.') + if initial_value is not None and len(initial_value) != 1: + raise ValueError('LinearOperator requires a single initial value tensor.') + solution = cg(op, rhs, initial_value=initial_value) + else: + solution = cg(op, rhs, initial_value=initial_value) + return solution diff --git a/tests/operators/test_conjugate_gradient_op.py b/tests/operators/test_conjugate_gradient_op.py new file mode 100644 index 000000000..f00e2ca65 --- /dev/null +++ b/tests/operators/test_conjugate_gradient_op.py @@ -0,0 +1,72 @@ +"""Tests the conjugate gradient operator.""" + +import torch +from mrpro.operators import ConjugateGradientOp, EinsumOp, LinearOperatorMatrix +from mrpro.utils import RandomGenerator + + +def random_linearop(size: tuple[int, int], rng: RandomGenerator): + """Create a random LinearOperator.""" + return EinsumOp(rng.complex128_tensor(size), '... i j, ... j -> ... i') + + +def test_conjugate_gradient_op_least_squares_matrix(sizes: tuple[int, int] = (10, 8), noise_level: float = 1e-3): + """Test the conjugate gradient operator for |Ax-y|^2+alpha*|x-x0|^2 with a Matrix of LinearOperators.""" + rng = RandomGenerator(0) + a = LinearOperatorMatrix.from_diagonal(*(random_linearop((s, s), rng) for s in sizes)) + x = tuple(rng.complex128_tensor((s,)) for s in sizes) + y = a(*x) + y = tuple((yi + rng.rand_like(yi) * noise_level).requires_grad_(True) for yi in y) + x0 = tuple((xi + rng.rand_like(xi) * noise_level).requires_grad_(True) for xi in x) + n_y = len(sizes) + op = ConjugateGradientOp( + lambda alpha, *_: a.gram + alpha, + lambda alpha, *y_x0: tuple(ahy + alpha * x0 for ahy, x0 in zip(a.H(*y_x0[:n_y]), y_x0[n_y:], strict=True)), + ) + alpha = torch.tensor(0.1, dtype=torch.float64, requires_grad=True) + solution = op(alpha, *y, *x0) + loss = sum((((si - xi).norm()) for si, xi in zip(solution, x, strict=True)), torch.tensor(0.0)) + assert loss.item() < 0.01 + loss.backward() + assert alpha.grad is not None + for x0i in x0: + assert x0i.grad is not None + for yi in y: + assert yi.grad is not None + + +def test_conjugate_gradient_op_least_squares_gradcheck_unrolled(size: int = 10, noise_level: float = 1e-2): + """Test the implicit differentiation of the conjugate gradient operator using |Ax-y|^2+alpha*|x-x0|^2.""" + rng = RandomGenerator(0) + a = random_linearop((size, size), rng) + x = rng.complex128_tensor((size,)) + (y,) = a(x) + y = y + rng.rand_like(y) * noise_level + y.requires_grad = True + x0 = x + rng.rand_like(x) * noise_level + x0.requires_grad = True + op = ConjugateGradientOp( + lambda alpha, _x0, _y: a.gram + alpha, + lambda alpha, x0, y: (a.H(y)[0] + alpha * x0,), + implicit_backward=False, + ) + alpha = torch.tensor(0.1, dtype=torch.float64, requires_grad=True) + torch.autograd.gradcheck(op, (alpha, y, x0), fast_mode=True) + + +def test_conjugate_gradient_op_least_squares_gradcheck_implicit(size: int = 10, noise_level: float = 1e-2): + """Test the implicit differentiation of the conjugate gradient operator using |Ax-y|^2+alpha*|x-x0|^2.""" + rng = RandomGenerator(0) + a = random_linearop((size, size), rng) + x = rng.complex128_tensor((size,)) + (y,) = a(x) + y = y + rng.rand_like(y) * noise_level + y.requires_grad = True + x0 = x + rng.rand_like(x) * noise_level + x0.requires_grad = True + op = ConjugateGradientOp( + lambda alpha, _x0, _y: a.gram + alpha, + lambda alpha, x0, y: (a.H(y)[0] + alpha * x0,), + ) + alpha = torch.tensor(0.1, dtype=torch.float64, requires_grad=True) + torch.autograd.gradcheck(op, (alpha, y, x0), fast_mode=True) From c89a25a9f9b7f0bb0648944eb43cd72c50c5be7f Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 23 Apr 2025 14:03:01 +0200 Subject: [PATCH 007/205] fix merge --- src/mrpro/algorithms/optimizers/cg.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/mrpro/algorithms/optimizers/cg.py b/src/mrpro/algorithms/optimizers/cg.py index a8c574210..31ae127ae 100644 --- a/src/mrpro/algorithms/optimizers/cg.py +++ b/src/mrpro/algorithms/optimizers/cg.py @@ -73,9 +73,8 @@ def cg( ) -> tuple[torch.Tensor, ...] | tuple[torch.Tensor]: r"""(Preconditioned) Conjugate Gradient for solving :math:`Hx=b`. - This algorithm solves systems of the form :math:`H x = b`, where :math:`H` is a self-adjoint linear operator - and :math:`b` is the right-hand side. The method can solve a batch of :math:`N` systems jointly, thereby taking - :math:`H` as a block-diagonal with blocks :math:`H_i` and :math:`b = [b_1, ..., b_N] ^T`. + This algorithm solves systems of the form :math:`H x = b`, where :math:`H` is a self-adjoint positive semidefinite + linear operator and :math:`b` is the right-hand side. The method performs the following steps: @@ -99,9 +98,6 @@ def cg( If `preconditioner_inverse` is provided, it solves :math:`M^{-1}Hx = M^{-1}b` implicitly, where `preconditioner_inverse(r)` computes :math:`M^{-1}r`. - If `preconditioner_inverse` is provided, it solves :math:`M^{-1}Hx = M^{-1}b` - implicitly, where `preconditioner_inverse(r)` computes :math:`M^{-1}r`. - See [Hestenes1952]_, [Nocedal2006]_, and [WikipediaCG]_ for more information. @@ -186,4 +182,5 @@ def cg( ) if continue_iterations is False: break + return solution From 391979ee6b286bad0c6ca65e54d87b89ab4b48de Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 28 Apr 2025 00:38:31 +0200 Subject: [PATCH 008/205] cleanup --- src/mrpro/operators/ConjugateGradientOp.py | 5 ++++- tests/operators/test_optimizer_op.py | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/mrpro/operators/ConjugateGradientOp.py b/src/mrpro/operators/ConjugateGradientOp.py index 8c185af12..a5c8a3ec2 100644 --- a/src/mrpro/operators/ConjugateGradientOp.py +++ b/src/mrpro/operators/ConjugateGradientOp.py @@ -176,10 +176,13 @@ def forward( ---------- parameters The parameters passed to the operator and right-hand side factory functions. + initial_value + The initial value for the conjugate gradient method. + If `None`, the initial value is set to zero. """ if self.implicit_backward: solution = ConjugateGradientFunction.apply(self.operator_factory, self.rhs_factory, *parameters) - else: + else: # unrolled CG op = self.operator_factory(*parameters) rhs = self.rhs_factory(*parameters) if isinstance(op, LinearOperator): diff --git a/tests/operators/test_optimizer_op.py b/tests/operators/test_optimizer_op.py index aa91c488b..962e78889 100644 --- a/tests/operators/test_optimizer_op.py +++ b/tests/operators/test_optimizer_op.py @@ -19,7 +19,6 @@ def test_optimizer_op_gradcheck() -> None: true_m0 = rng.complex128_tensor(size=(3, 2)) true_t1 = rng.float64_tensor(size=(3, 2), low=0.1, high=2) (signal,) = InversionRecovery(torch.tensor([0.5, 1.0, 1.5, 3], dtype=torch.float64))(true_m0, true_t1) - # signal += rng.complex128_tensor(size=(3, 2), high=0.0001) t1_reg = true_t1 + rng.rand_like(true_t1, low=-0.01, high=0.01) m0_reg = true_m0 + rng.rand_like(true_m0, high=0.01) m0_reg.requires_grad = True From c3ffa9c211c87bf01d8103dd8259d0486133a234 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 28 Apr 2025 00:40:50 +0200 Subject: [PATCH 009/205] fix miniml --- src/mrpro/operators/Functional.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mrpro/operators/Functional.py b/src/mrpro/operators/Functional.py index 4724f6f79..44bca44c4 100644 --- a/src/mrpro/operators/Functional.py +++ b/src/mrpro/operators/Functional.py @@ -5,9 +5,10 @@ import math from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import TypeAlias, TypeVarTuple, Unpack +from typing import TypeAlias import torch +from typing_extensions import TypeVarTuple, Unpack import mrpro.operators from mrpro.operators.Operator import Operator From d9c024a8b2ebaa4524cdbddee107145e08791a06 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 28 Apr 2025 22:38:22 +0200 Subject: [PATCH 010/205] fix --- src/mrpro/operators/ConjugateGradientOp.py | 9 ++++++--- src/mrpro/operators/LinearOperator.py | 10 +++++++--- src/mrpro/operators/Operator.py | 5 +++-- src/mrpro/operators/OptimizerOp.py | 16 +++++++--------- tests/operators/test_optimizer_op.py | 13 ++++++------- 5 files changed, 29 insertions(+), 24 deletions(-) diff --git a/src/mrpro/operators/ConjugateGradientOp.py b/src/mrpro/operators/ConjugateGradientOp.py index a5c8a3ec2..4bf5cfcb4 100644 --- a/src/mrpro/operators/ConjugateGradientOp.py +++ b/src/mrpro/operators/ConjugateGradientOp.py @@ -62,7 +62,9 @@ def forward( raise ValueError('LinearOperator requires a single right-hand side tensor.') if initial_value is not None and len(initial_value) != 1: raise ValueError('LinearOperator requires a single initial value tensor.') - solution = cg(operator, rhs, initial_value=initial_value, tolerance=fwd_tol, max_iterations=max_iterations) + solution: tuple[torch.Tensor, ...] = cg( + operator, rhs, initial_value=initial_value, tolerance=fwd_tol, max_iterations=max_iterations + ) else: solution = cg(operator, rhs, initial_value=initial_value, tolerance=fwd_tol, max_iterations=max_iterations) ctx.save_for_backward(*solution, *inputs) @@ -97,8 +99,9 @@ def backward(ctx: ConjugateGradientCTX, *grad_output: torch.Tensor) -> tuple[tor grad_iter = iter(grads) else: grad_iter = iter(()) - grads = tuple(next(grad_iter) if need else None for need in ctx.needs_input_grad[2:]) - return (None, None, *grads) # operator_factory, rhs_factory, *inputs + + grad_input = tuple(next(grad_iter) if need else None for need in ctx.needs_input_grad[2:]) + return (None, None, *grad_input) # operator_factory, rhs_factory, *inputs class ConjugateGradientOp(torch.nn.Module): diff --git a/src/mrpro/operators/LinearOperator.py b/src/mrpro/operators/LinearOperator.py index 64c5a84f7..8556a6907 100644 --- a/src/mrpro/operators/LinearOperator.py +++ b/src/mrpro/operators/LinearOperator.py @@ -223,13 +223,17 @@ def __matmul__(self, other: LinearOperator) -> LinearOperator: ... @overload def __matmul__( - self, other: Operator[Unpack[Tin2], tuple[torch.Tensor,]] + self, other: Operator[Unpack[Tin2], tuple[torch.Tensor,]] | Operator[Unpack[Tin2], tuple[torch.Tensor, ...]] ) -> Operator[Unpack[Tin2], tuple[torch.Tensor,]]: ... def __matmul__( self, - other: Operator[Unpack[Tin2], tuple[torch.Tensor,]] | LinearOperator, - ) -> Operator[Unpack[Tin2], tuple[torch.Tensor,]] | LinearOperator: + other: Operator[Unpack[Tin2], tuple[torch.Tensor,]] + | LinearOperator + | Operator[Unpack[Tin2], tuple[torch.Tensor, ...]], + ) -> ( + Operator[Unpack[Tin2], tuple[torch.Tensor,]] | LinearOperator | Operator[Unpack[Tin2], tuple[torch.Tensor, ...]] + ): """Operator composition. Returns ``lambda x: self(other(x))`` diff --git a/src/mrpro/operators/Operator.py b/src/mrpro/operators/Operator.py index 91832fa92..d52b4aabc 100644 --- a/src/mrpro/operators/Operator.py +++ b/src/mrpro/operators/Operator.py @@ -43,13 +43,14 @@ def __call__(self, *args: Unpack[Tin]) -> Tout: return super().__call__(*args) def __matmul__( - self: Operator[Unpack[Tin], Tout], other: Operator[Unpack[Tin2], tuple[Unpack[Tin]]] + self: Operator[Unpack[Tin], Tout], + other: Operator[Unpack[Tin2], tuple[Unpack[Tin]]] | Operator[Unpack[Tin2], tuple[torch.Tensor, ...]], ) -> Operator[Unpack[Tin2], Tout]: """Operator composition. Returns ``lambda x: self(other(x))`` """ - return OperatorComposition(self, other) + return OperatorComposition(self, cast(Operator[Unpack[Tin2], tuple[Unpack[Tin]]], other)) def __radd__( self: Operator[Unpack[Tin], tuple[Unpack[Tin]]], other: torch.Tensor diff --git a/src/mrpro/operators/OptimizerOp.py b/src/mrpro/operators/OptimizerOp.py index 4cb9c9b9c..970d8f75f 100644 --- a/src/mrpro/operators/OptimizerOp.py +++ b/src/mrpro/operators/OptimizerOp.py @@ -2,7 +2,7 @@ import functools from collections.abc import Callable -from typing import TYPE_CHECKING, TypeVar, TypeVarTuple, Unpack, cast +from typing import TYPE_CHECKING, Any, TypeVar, TypeVarTuple, Unpack, cast import torch @@ -12,8 +12,8 @@ ArgumentType = TypeVarTuple('ArgumentType') VariableType = TypeVar('VariableType', bound=tuple[torch.Tensor, ...]) -ObjectiveType = Callable[[VariableType], tuple[torch.Tensor]] -FactoryType = Callable[[Unpack[tuple[torch.Tensor, ...]]], Callable] +ObjectiveType = Callable[..., tuple[torch.Tensor]] | Operator[Any, tuple[torch.Tensor]] +FactoryType = Callable[..., ObjectiveType] OptimizeFunctionType = Callable[[Callable, VariableType], VariableType] default_lbfgs = functools.partial( @@ -47,9 +47,7 @@ class OptimizeFunction(torch.autograd.Function): @classmethod def apply( cls, - factory: Callable[ - [Unpack[tuple[torch.Tensor, ...]]], Callable[[Unpack[tuple[torch.Tensor, ...]]], tuple[torch.Tensor]] - ], + factory: Callable[..., Callable[..., tuple[torch.Tensor]]], initial_values: tuple[torch.Tensor, ...], optimize: Callable[ [Callable[[*tuple[Unpack[tuple[torch.Tensor, ...]]]], tuple[torch.Tensor]], tuple[torch.Tensor, ...]], @@ -100,7 +98,7 @@ def backward(ctx: OptimizeCtx, *grad_outputs: torch.Tensor) -> tuple[torch.Tenso def hvp(*v: torch.Tensor) -> tuple[torch.Tensor, ...]: return torch.autograd.functional.vhp(lambda *x: objective(*x)[0], xprime, v=v)[1] - hessian_inverse_grad = cg(hvp, grad_outputs, max_iterations=200, tolerance=1e-6) + hessian_inverse_grad = cg(hvp, grad_outputs, max_iterations=100, tolerance=1e-7) with torch.enable_grad(): dobjective_dxprime = torch.autograd.grad(objective(*xprime), xprime, create_graph=True) # - d^2_obective / d_xprime d_params Hessian^-1_grad @@ -177,8 +175,8 @@ def forward(self, *parameters: Unpack[ArgumentType]) -> VariableType: Parameters of the argmin problem. """ initial_values = self.initializer(*parameters) - initial_values = tuple(x.clone() if any(x is p for p in parameters) else x for x in initial_values) + initial_values_ = tuple(x.clone() if any(x is p for p in parameters) else x for x in initial_values) result = OptimizeFunction.apply( - self.factory, initial_values, self.optimize, *cast(tuple[torch.Tensor, ...], parameters) + self.factory, initial_values_, self.optimize, *cast(tuple[torch.Tensor, ...], parameters) ) return cast(VariableType, result) diff --git a/tests/operators/test_optimizer_op.py b/tests/operators/test_optimizer_op.py index 962e78889..d7ffb5196 100644 --- a/tests/operators/test_optimizer_op.py +++ b/tests/operators/test_optimizer_op.py @@ -16,8 +16,8 @@ def test_optimizer_op_gradcheck() -> None: ).double() # everything is double, otherwise the numerical derivative used in gradcheck gives wrong values rng = RandomGenerator(seed=1) - true_m0 = rng.complex128_tensor(size=(3, 2)) - true_t1 = rng.float64_tensor(size=(3, 2), low=0.1, high=2) + true_m0 = rng.complex128_tensor(size=(2, 2), low=0.5, high=1) + true_t1 = rng.float64_tensor(size=(2, 2), low=0.1, high=2) (signal,) = InversionRecovery(torch.tensor([0.5, 1.0, 1.5, 3], dtype=torch.float64))(true_m0, true_t1) t1_reg = true_t1 + rng.rand_like(true_t1, low=-0.01, high=0.01) m0_reg = true_m0 + rng.rand_like(true_m0, high=0.01) @@ -36,14 +36,13 @@ def factory( op = OptimizerOp( factory=factory, - initializer=lambda m0_reg, t1_reg, *_: constraints_op.inverse(m0_reg, t1_reg), + initializer=lambda m0_reg, t1_reg, _lambda_m0, _lambda_t1, _signal: constraints_op.inverse(m0_reg, t1_reg), ) lambda_m0 = torch.tensor(1, requires_grad=True, dtype=torch.float64) lambda_t1 = torch.tensor(1, requires_grad=True, dtype=torch.float64) torch.autograd.gradcheck( op, (m0_reg, t1_reg, lambda_m0, lambda_t1, signal), fast_mode=True, atol=1e-3, rtol=1e-2, eps=1e-3 ) - - ret = (constraints_op @ op)(m0_reg, t1_reg, lambda_m0, lambda_t1, signal) - torch.testing.assert_close(ret[0], true_m0, atol=1e-3, rtol=1e-2) - torch.testing.assert_close(ret[1], true_t1, atol=1e-3, rtol=1e-2) + m0, t1 = (constraints_op @ op)(m0_reg, t1_reg, lambda_m0, lambda_t1, signal) + torch.testing.assert_close(m0, true_m0, atol=1e-3, rtol=1e-2) + torch.testing.assert_close(t1, true_t1, atol=1e-3, rtol=1e-2) From 97f102b2364e20f3c633b1265a160500e18ee4ad Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 28 Apr 2025 23:01:56 +0200 Subject: [PATCH 011/205] py310 --- src/mrpro/operators/OptimizerOp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mrpro/operators/OptimizerOp.py b/src/mrpro/operators/OptimizerOp.py index 970d8f75f..8a38637c0 100644 --- a/src/mrpro/operators/OptimizerOp.py +++ b/src/mrpro/operators/OptimizerOp.py @@ -2,9 +2,10 @@ import functools from collections.abc import Callable -from typing import TYPE_CHECKING, Any, TypeVar, TypeVarTuple, Unpack, cast +from typing import TYPE_CHECKING, cast import torch +from typing_extensions import Any, TypeVar, TypeVarTuple, Unpack from mrpro.algorithms.optimizers.cg import cg from mrpro.algorithms.optimizers.lbfgs import lbfgs From 19087c94fc99c9c82ce5f2446e841d2e48e1d5e6 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 29 Apr 2025 01:21:58 +0200 Subject: [PATCH 012/205] py310 --- src/mrpro/algorithms/optimizers/lbfgs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mrpro/algorithms/optimizers/lbfgs.py b/src/mrpro/algorithms/optimizers/lbfgs.py index d4e07d8fa..2105bdf50 100644 --- a/src/mrpro/algorithms/optimizers/lbfgs.py +++ b/src/mrpro/algorithms/optimizers/lbfgs.py @@ -1,10 +1,11 @@ """LBFGS for solving non-linear minimization problems.""" from collections.abc import Callable, Sequence -from typing import Literal, Unpack +from typing import Literal import torch from torch.optim import LBFGS +from typing_extensions import Unpack from mrpro.algorithms.optimizers.OptimizerStatus import OptimizerStatus from mrpro.operators.Operator import OperatorType From 41b5aec6c32cc756ca51b969e232384b09687972 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 29 Apr 2025 09:46:45 +0200 Subject: [PATCH 013/205] py310 --- src/mrpro/operators/OptimizerOp.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/mrpro/operators/OptimizerOp.py b/src/mrpro/operators/OptimizerOp.py index 8a38637c0..55d52d424 100644 --- a/src/mrpro/operators/OptimizerOp.py +++ b/src/mrpro/operators/OptimizerOp.py @@ -51,7 +51,10 @@ def apply( factory: Callable[..., Callable[..., tuple[torch.Tensor]]], initial_values: tuple[torch.Tensor, ...], optimize: Callable[ - [Callable[[*tuple[Unpack[tuple[torch.Tensor, ...]]]], tuple[torch.Tensor]], tuple[torch.Tensor, ...]], + [ + Callable[[Unpack[tuple[Unpack[tuple[torch.Tensor, ...]]]]], tuple[torch.Tensor]], + tuple[torch.Tensor, ...], + ], tuple[torch.Tensor, ...], ] = default_lbfgs, *parameters: torch.Tensor, @@ -67,7 +70,10 @@ def forward( ], initial_values: tuple[torch.Tensor, ...], optimize: Callable[ - [Callable[[*tuple[Unpack[tuple[torch.Tensor, ...]]]], tuple[torch.Tensor]], tuple[torch.Tensor, ...]], + [ + Callable[[Unpack[tuple[Unpack[tuple[torch.Tensor, ...]]]]], tuple[torch.Tensor]], + tuple[torch.Tensor, ...], + ], tuple[torch.Tensor, ...], ] = default_lbfgs, *parameters: torch.Tensor, From 3118792823a045408fdf9e426be9965b2c20b951 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 29 Apr 2025 11:09:41 +0200 Subject: [PATCH 014/205] py310 --- src/mrpro/operators/ProximableFunctionalSeparableSum.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrpro/operators/ProximableFunctionalSeparableSum.py b/src/mrpro/operators/ProximableFunctionalSeparableSum.py index 443ec9ebf..e205db307 100644 --- a/src/mrpro/operators/ProximableFunctionalSeparableSum.py +++ b/src/mrpro/operators/ProximableFunctionalSeparableSum.py @@ -5,10 +5,10 @@ import operator from collections.abc import Iterator from functools import reduce -from typing import TypeVarTuple, cast, overload +from typing import cast import torch -from typing_extensions import Unpack +from typing_extensions import TypeVarTuple, Unpack, overload from mrpro.operators.Functional import ProximableFunctional from mrpro.operators.Operator import Operator From 6454662dabde169b863e69be4a5a9a9b2ec0c419 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 29 Apr 2025 21:16:21 +0200 Subject: [PATCH 015/205] pyr310 --- src/mrpro/operators/Functional.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/mrpro/operators/Functional.py b/src/mrpro/operators/Functional.py index 44bca44c4..bf362c48c 100644 --- a/src/mrpro/operators/Functional.py +++ b/src/mrpro/operators/Functional.py @@ -5,7 +5,7 @@ import math from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import TypeAlias +from typing import TYPE_CHECKING, TypeAlias import torch from typing_extensions import TypeVarTuple, Unpack @@ -13,9 +13,11 @@ import mrpro.operators from mrpro.operators.Operator import Operator -T = TypeVarTuple('T') -FunctionalType: TypeAlias = Operator[Unpack[T], tuple[torch.Tensor]] -"""An Operator that returns a single tensor.""" +if TYPE_CHECKING: + T = TypeVarTuple('T') + FunctionalType: TypeAlias = Operator[Unpack[T], tuple[torch.Tensor]] +else: # python 3.10 runtime compatibility. typing_extension + FunctionalType: TypeAlias = Operator def throw_if_negative_or_complex( From c8f51eccd7324cebb2c6c49a61fc11f585e93923 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 29 Apr 2025 23:43:45 +0200 Subject: [PATCH 016/205] review --- src/mrpro/operators/ConjugateGradientOp.py | 2 +- src/mrpro/operators/LinearOperatorMatrix.py | 18 +++++++----------- src/mrpro/operators/OptimizerOp.py | 14 +++++++------- .../ProximableFunctionalSeparableSum.py | 1 - tests/operators/test_linearoperatormatrix.py | 2 +- tests/operators/test_optimizer_op.py | 1 + 6 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/mrpro/operators/ConjugateGradientOp.py b/src/mrpro/operators/ConjugateGradientOp.py index 4bf5cfcb4..6310d4f81 100644 --- a/src/mrpro/operators/ConjugateGradientOp.py +++ b/src/mrpro/operators/ConjugateGradientOp.py @@ -1,4 +1,4 @@ -"""Regularized least squares operator.""" +"""Conjugate gradient operator.""" from collections.abc import Callable from typing import TYPE_CHECKING, TypeVar diff --git a/src/mrpro/operators/LinearOperatorMatrix.py b/src/mrpro/operators/LinearOperatorMatrix.py index f9fb50954..8ace21e9b 100644 --- a/src/mrpro/operators/LinearOperatorMatrix.py +++ b/src/mrpro/operators/LinearOperatorMatrix.py @@ -229,18 +229,14 @@ def H(self) -> Self: # noqa N802 def gram(self) -> Self: """Gram matrix of the operators.""" n, m = self.shape - if n != m: - raise ValueError('Gram is only defined for square operators.') - operators: list[list[LinearOperator]] = [[ZeroOp() for _ in range(n)] for _ in range(n)] - - for i in range(n): - operators[i][i] = reduce(operator.add, (self._operators[k][i].gram for k in range(n))) - # off-diagonals: only compute upper triangular part, then mirror - for j in range(i + 1, n): - operators[i][j] = reduce( - operator.add, (self._operators[k][i].H @ self._operators[k][j] for k in range(n)) + operators: list[list[LinearOperator]] = [[ZeroOp() for _ in range(m)] for _ in range(m)] + for j in range(m): + operators[j][j] = reduce(operator.add, (self._operators[i][j].gram for i in range(n)), ZeroOp()) + for k in range(j + 1, m): + operators[j][k] = reduce( + operator.add, (self._operators[i][j].H @ self._operators[i][k] for i in range(n)), ZeroOp() ) - operators[j][i] = operators[i][j].H + operators[k][j] = operators[j][k].H return self.__class__(operators) def adjoint(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]: diff --git a/src/mrpro/operators/OptimizerOp.py b/src/mrpro/operators/OptimizerOp.py index 55d52d424..19849c08f 100644 --- a/src/mrpro/operators/OptimizerOp.py +++ b/src/mrpro/operators/OptimizerOp.py @@ -83,8 +83,8 @@ def forward( parameters_ = tuple(p.detach().clone() for p in parameters if isinstance(p, torch.Tensor)) initial_values_ = tuple(x.detach().requires_grad_(True) for x in initial_values if isinstance(x, torch.Tensor)) - f = factory(*parameters) - xprime = optimize(f, initial_values) + objective = factory(*parameters) + xprime = optimize(objective, initial_values) ctx.save_for_backward(*xprime, *parameters_) ctx.len_x = len(initial_values_) return xprime @@ -153,15 +153,15 @@ def __init__( Example ------- - Solving :math:`\|q(x)-y\|^2 + \lambda*\|x-x_\mathrm{reg}\|^2` with - :math:`y`, :math:`\lambda` and :math:`x_\mathrm{reg}` parameters. The solution :math:`x^*` should be + Solving :math:`\|q(x)-y\|^2 + \alpha*\|x-x_\mathrm{reg}\|^2` with + :math:`y`, :math:`\alpha` and :math:`x_\mathrm{reg}` parameters. The solution :math:`x^*` should be differentiable with respect to these. Use:: - def factory(y, lambda, x_reg): - return L2squared(y)@q+lambda*L2squared(x_reg) - def initializer(_y, _lambda, _xreg): + def factory(y, alpha, x_reg): + return L2squared(y)@q+alpha*L2squared(x_reg) + def initializer(_y, _alpha, _xreg): return (x_reg,) Returns diff --git a/src/mrpro/operators/ProximableFunctionalSeparableSum.py b/src/mrpro/operators/ProximableFunctionalSeparableSum.py index e205db307..ee5465b49 100644 --- a/src/mrpro/operators/ProximableFunctionalSeparableSum.py +++ b/src/mrpro/operators/ProximableFunctionalSeparableSum.py @@ -14,7 +14,6 @@ from mrpro.operators.Operator import Operator T = TypeVarTuple('T') -T2 = TypeVarTuple('T2') class ProximableFunctionalSeparableSum(Operator[Unpack[T], tuple[torch.Tensor]]): diff --git a/tests/operators/test_linearoperatormatrix.py b/tests/operators/test_linearoperatormatrix.py index 847defae8..a0f07f905 100644 --- a/tests/operators/test_linearoperatormatrix.py +++ b/tests/operators/test_linearoperatormatrix.py @@ -325,7 +325,7 @@ def test_linearoperatormatrix_from_diagonal(): def test_linearoperatormatrix_gram(): """Test gram of LinearOperatorMatrix.""" rng = RandomGenerator(0) - matrix = random_linearoperatormatrix((2, 2), (4, 10), rng) + matrix = random_linearoperatormatrix((3, 2), (4, 10), rng) vector = tuple(rng.complex64_tensor((2, 10))) result = matrix.gram(*vector) expected = (matrix.H @ matrix)(*vector) diff --git a/tests/operators/test_optimizer_op.py b/tests/operators/test_optimizer_op.py index d7ffb5196..19c408c13 100644 --- a/tests/operators/test_optimizer_op.py +++ b/tests/operators/test_optimizer_op.py @@ -21,6 +21,7 @@ def test_optimizer_op_gradcheck() -> None: (signal,) = InversionRecovery(torch.tensor([0.5, 1.0, 1.5, 3], dtype=torch.float64))(true_m0, true_t1) t1_reg = true_t1 + rng.rand_like(true_t1, low=-0.01, high=0.01) m0_reg = true_m0 + rng.rand_like(true_m0, high=0.01) + t1_reg.requires_grad = True m0_reg.requires_grad = True def factory( From 227646af9f0e7209d4a92226805a612569b68eba Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 29 Apr 2025 23:58:40 +0200 Subject: [PATCH 017/205] norm --- src/mrpro/operators/ConjugateGradientOp.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/mrpro/operators/ConjugateGradientOp.py b/src/mrpro/operators/ConjugateGradientOp.py index 6310d4f81..01b60de77 100644 --- a/src/mrpro/operators/ConjugateGradientOp.py +++ b/src/mrpro/operators/ConjugateGradientOp.py @@ -88,11 +88,13 @@ def backward(ctx: ConjugateGradientCTX, *grad_output: torch.Tensor) -> tuple[tor operator = ctx.operator_factory(*inputs) inputs_with_grad = tuple(i for i, need_grad in zip(inputs, ctx.needs_input_grad[2:], strict=True) if need_grad) if inputs_with_grad: + rhs_norm = sum((r.abs().square().sum() for r in grad_output), torch.tensor(0.0)).sqrt().item() + bwd_tol = ctx.tolerance * rhs_norm with torch.no_grad(): if isinstance(operator, LinearOperatorMatrix): - z = cg(operator.H, grad_output) + z = cg(operator.H, grad_output, tolerance=bwd_tol, max_iterations=ctx.max_iterations) else: - z = cg(operator.H, grad_output[0]) + z = cg(operator.H, grad_output[0], tolerance=bwd_tol, max_iterations=ctx.max_iterations) with torch.enable_grad(): residual = tuple(r - ax for r, ax in zip(rhs, operator(*(s.detach() for s in solution)), strict=True)) grads = torch.autograd.grad(outputs=residual, inputs=inputs_with_grad, grad_outputs=z, allow_unused=True) @@ -152,10 +154,12 @@ def __init__( If `False`, the backward pass is done using unrolling the CG loop. tolerance The tolerance for the conjugate gradient method. The tolerance is relative - to the norm of the right-hand side. The same tolerance is used in the backward pass - if using implicit differentiation. + to the norm of the right-hand side. The same relative tolerance is used in the + backward pass if using implicit differentiation. max_iterations The maximum number of iterations for the conjugate gradient method. + The same maximum number of iterations is used in the backward pass if using + implicit differentiation. .. warning:: If implicit_backward is `True`, the problem has to converge, otherwise the backward @@ -188,12 +192,18 @@ def forward( else: # unrolled CG op = self.operator_factory(*parameters) rhs = self.rhs_factory(*parameters) + rhs_norm = sum((r.abs().square().sum() for r in rhs), torch.tensor(0.0)).sqrt().item() + fwd_tol = self.tolerance * rhs_norm if isinstance(op, LinearOperator): if len(rhs) != 1: raise ValueError('LinearOperator requires a single right-hand side tensor.') if initial_value is not None and len(initial_value) != 1: raise ValueError('LinearOperator requires a single initial value tensor.') - solution = cg(op, rhs, initial_value=initial_value) + solution = cg( + op, rhs, initial_value=initial_value, tolerance=fwd_tol, max_iterations=self.max_iterations + ) else: - solution = cg(op, rhs, initial_value=initial_value) + solution = cg( + op, rhs, initial_value=initial_value, tolerance=fwd_tol, max_iterations=self.max_iterations + ) return solution From 78b570d7f72f28bae89eef3ecfb61212d6ef24c1 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 30 Apr 2025 01:32:40 +0200 Subject: [PATCH 018/205] fix rhs norm zero --- src/mrpro/operators/ConjugateGradientOp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/mrpro/operators/ConjugateGradientOp.py b/src/mrpro/operators/ConjugateGradientOp.py index 01b60de77..27d6daee9 100644 --- a/src/mrpro/operators/ConjugateGradientOp.py +++ b/src/mrpro/operators/ConjugateGradientOp.py @@ -56,7 +56,7 @@ def forward( operator = operator_factory(*inputs) rhs = rhs_factory(*inputs) rhs_norm = sum((r.abs().square().sum() for r in rhs), torch.tensor(0.0)).sqrt().item() - fwd_tol = tolerance * rhs_norm + fwd_tol = tolerance * max(rhs_norm, 1e-6) # clip in case rhs is 0 if isinstance(operator, LinearOperator): if len(rhs) != 1: raise ValueError('LinearOperator requires a single right-hand side tensor.') @@ -89,12 +89,14 @@ def backward(ctx: ConjugateGradientCTX, *grad_output: torch.Tensor) -> tuple[tor inputs_with_grad = tuple(i for i, need_grad in zip(inputs, ctx.needs_input_grad[2:], strict=True) if need_grad) if inputs_with_grad: rhs_norm = sum((r.abs().square().sum() for r in grad_output), torch.tensor(0.0)).sqrt().item() - bwd_tol = ctx.tolerance * rhs_norm + bwd_tol = ctx.tolerance * max(rhs_norm, 1e-6) # clip in case rhs is 0 with torch.no_grad(): if isinstance(operator, LinearOperatorMatrix): z = cg(operator.H, grad_output, tolerance=bwd_tol, max_iterations=ctx.max_iterations) else: z = cg(operator.H, grad_output[0], tolerance=bwd_tol, max_iterations=ctx.max_iterations) + if any(zi.isnan().any() for zi in z): + raise RuntimeError('NaN in ConjugateGradientFunction.backward') with torch.enable_grad(): residual = tuple(r - ax for r, ax in zip(rhs, operator(*(s.detach() for s in solution)), strict=True)) grads = torch.autograd.grad(outputs=residual, inputs=inputs_with_grad, grad_outputs=z, allow_unused=True) From 904f3c941e308fad00dbb0ec857a7c4b23b9060c Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sat, 10 May 2025 21:05:34 +0200 Subject: [PATCH 019/205] fix doc --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index e97bdd9cc..f7da23874 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ docs = [ "sphinx-autodoc-typehints>=3, <3.1", "sphinx-copybutton>=0.5, <0.6", "sphinx-last-updated-by-git>=0.3, <0.4", + "snowballstemmer>=2.2, <3.0", ] notebooks = [ "zenodo_get", From a458855117baf3b21bdc0738230ecf60f7256bb2 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sat, 10 May 2025 21:09:00 +0200 Subject: [PATCH 020/205] start --- src/mrpro/nn/NDModules.py | 173 ++++++++++++++++++++++++++++++ src/mrpro/nn/UNet.py | 49 +++++++++ src/mrpro/nn/Uformer.py | 38 +++++++ src/mrpro/nn/__init__,py | 19 ++++ src/mrpro/nn/layers.py | 218 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 497 insertions(+) create mode 100644 src/mrpro/nn/NDModules.py create mode 100644 src/mrpro/nn/UNet.py create mode 100644 src/mrpro/nn/Uformer.py create mode 100644 src/mrpro/nn/__init__,py create mode 100644 src/mrpro/nn/layers.py diff --git a/src/mrpro/nn/NDModules.py b/src/mrpro/nn/NDModules.py new file mode 100644 index 000000000..d3466ba29 --- /dev/null +++ b/src/mrpro/nn/NDModules.py @@ -0,0 +1,173 @@ +from abc import ABC +from collections.abc import Sequence +from functools import partial + +import torch +from einops import rearrange +from torch.nn import Identity, Linear, Module, Parameter, ReLU, Sequential, Sigmoid, SiLU + +from mrpro.utils.reshape import unsqueeze_tensors_right + + +class NDModule(Module, ABC): + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the module to the input tensor.""" + return super().__call__(x) + + def __forward__(self, x: torch.Tensor) -> torch.Tensor: + return self.module(x) + + +class ConvND(NDModule): + """N-dimensional convolution. + + Parameters + ---------- + dim + The dimension of the convolution. + """ + + def __init__( + self, + dim: int, + in_channels: int, + out_channels: int, + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int = 1, + padding: str | Sequence[int] | int = 'same', + dilation: Sequence[int] | int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + ) -> None: + if not isinstance(kernel_size, int) and len(kernel_size) != dim: + raise ValueError(f'kernel_size must be an int or a sequence of length {dim}') + if stride is not None and not isinstance(stride, int) and len(stride) != dim: + raise ValueError(f'stride must be None, an int, or a sequence of length {dim}') + if padding != 'same' and not isinstance(padding, int) and len(padding) != dim: + raise ValueError(f'padding must be an int or a sequence of length {dim}') + try: + self.module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}[dim]( + in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode + ) + except KeyError: + raise NotImplementedError(f'ConvND for dim {dim} not implemented.') from None + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self._inner(x) + + +class MaxPoolND(NDModule): + def __init__(self, dim: int) -> None: + super().__init__() + try: + self.module = {1: torch.nn.MaxPool1d, 2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d}[dim] + except KeyError: + raise NotImplementedError(f'MaxPoolNd for dim {dim} not implemented.') + + +class AvgPoolND(NDModule): + """N-dimensional average pooling.""" + + def __init__( + self, + dim: int, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + ceil_mode: bool = False, + count_include_pad: bool = False, + divisor_override: int | None = None, + ) -> None: + """Parameters for AvgPoolNd. + + Parameters + ---------- + dim + The dimension of the input tensor. + kernel_size + The size of the kernel. + stride + The stride of the kernel. + padding + The padding of the kernel. + ceil_mode + Whether to use ceil instead of floor to compute the output shape. + count_include_pad + Whether to include the padding in the divisor. + divisor_override + Overwrite the default divisor of the number of elements in the pooling region. + """ + super().__init__() + if not isinstance(kernel_size, int) and len(kernel_size) != dim: + raise ValueError(f'kernel_size must be an int or a sequence of length {dim}') + if stride is not None and not isinstance(stride, int) and len(stride) != dim: + raise ValueError(f'stride must be None, an int, or a sequence of length {dim}') + if padding != 'same' and not isinstance(padding, int) and len(padding) != dim: + raise ValueError(f'padding must be an int or a sequence of length {dim}') + try: + module = {1: torch.nn.AvgPool1d, 2: torch.nn.AvgPool2d, 3: torch.nn.AvgPool3d()}[dim] + except KeyError: + raise NotImplementedError(f'AvgPoolNd for dim {dim} not implemented.') from None + self.module = module(kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + + +class AdaptiveAvgPoolND(NDModule): + """N-dimensional adaptive average pooling.""" + + def __init__(self, dim: int, output_size: int | None | Sequence[int] = None): + super().__init__() + if not isinstance(output_size, int) and len(output_size) != dim: + raise ValueError(f'output_size must be an int or a sequence of length {dim}') + try: + self.module = (torch.nn.AdaptiveAvgPool1d, torch.nn.AdaptiveAvgPool2d, torch.nn.AdaptiveAvgPool3d)[dim - 1]( + output_size + ) + except KeyError: + raise NotImplementedError(f'AdaptiveAvgPoolnD for dim {dim} not implemented.') from None + + +class MaxPoolND(NDModule): + """N-dimensional max pooling.""" + + def __init__( + self, + dim: int, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + ceil_mode: bool = False, + ) -> None: + """Initialize MaxPoolNd. + + Parameters + ---------- + dim + The dimension of the input tensor. + kernel_size + The size of the kernel. + stride + The stride of the kernel. + padding + The padding of the kernel. + dilation + The dilation of the kernel. + ceil_mode + Whether to use ceil instead of floor to compute the output shape. + """ + if not isinstance(kernel_size, int) and len(kernel_size) != dim: + raise ValueError(f'kernel_size must be an int or a sequence of length {dim}') + if stride is not None and not isinstance(stride, int) and len(stride) != dim: + raise ValueError(f'stride must be None, an int, or a sequence of length {dim}') + if not isinstance(padding, int) and len(padding) != dim: + raise ValueError(f'padding must be an int or a sequence of length {dim}') + if not isinstance(dilation, int) and len(dilation) != dim: + raise ValueError(f'dilation must be an int or a sequence of length {dim}') + super().__init__() + self.module = {1: torch.nn.MaxPool1d, 2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d}[dim]( + kernel_size, stride, padding, dilation, ceil_mode + ) diff --git a/src/mrpro/nn/UNet.py b/src/mrpro/nn/UNet.py new file mode 100644 index 000000000..24ee7a6a9 --- /dev/null +++ b/src/mrpro/nn/UNet.py @@ -0,0 +1,49 @@ +from functools import partial + +import torch +from torch.nn import Module + +from mrpro.nn.layers import call_with_emb + + +class UNetBase(Module): + def __init__( + self, + in_channels: int, + out_channels: int, + channels_emb: int, + dim: int, + num_blocks: int, + ) -> None: ... + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + """Apply to Network.""" + call = partial(call_with_emb, emb=emb) + x = call(self.first, x) + xs = [] + for block, down, skip in zip(self.input_blocks, self.down_blocks, self.skip_blocks, strict=False): + x = call(block, x) + xs.append(call(skip, x)) + x = call(down, x) + x = call(self.middel_block, x) + for block, up in (self.output_blocks, self.up_blocks): + x = call(up, x) + x = torch.cat([x, xs.pop()], dim=1) + x = call(block, x) + return call(self.last, x) + + def __call__(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: + """Apply to Network. + + Parameters + ---------- + x + The input tensor. + emb + The embedding tensor. + + Returns + ------- + The output tensor. + """ + return self(x, emb) diff --git a/src/mrpro/nn/Uformer.py b/src/mrpro/nn/Uformer.py new file mode 100644 index 000000000..d4dccc7d9 --- /dev/null +++ b/src/mrpro/nn/Uformer.py @@ -0,0 +1,38 @@ +class LeFF(nn.Module): + """Fast Locally-enhanced Feed-Forward Network.""" + + def __init__( + self, + dim: int = 32, + hidden_dim: int = 128, + act_layer: Callable[[], nn.Module] = nn.GELU, + ) -> None: + """ + Parameters + ---------- + dim : int + Input and output feature dimension. + hidden_dim : int + Hidden feature dimension. + act_layer : Callable + Activation function. + """ + super().__init__() + from torch_dwconv import DepthwiseConv2d # Local import for optional dependency + + self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim), act_layer()) + self.dwconv = nn.Sequential( + DepthwiseConv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1), + act_layer(), + ) + self.linear2 = nn.Linear(hidden_dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bs, hw, c = x.size() + hh = int(math.sqrt(hw)) + x = self.linear1(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=hh, w=hh) + x = self.dwconv(x) + x = rearrange(x, 'b c h w -> b (h w) c', h=hh, w=hh) + x = self.linear2(x) + return x diff --git a/src/mrpro/nn/__init__,py b/src/mrpro/nn/__init__,py new file mode 100644 index 000000000..02bb157d4 --- /dev/null +++ b/src/mrpro/nn/__init__,py @@ -0,0 +1,19 @@ +from mrpro.nn.layers import EmbMixin, EmbSequential, FiLM, GroupNorm32, ResBlock, SqueezeExcitation, TransposedAttention +from mrpro.nn.NDModules import AdaptiveAvgPoolND, AvgPoolND, ConvND, MaxPoolND, NDModule +from mrpro.nn.UNet import UNetBase + +__all__ = [ + 'AdaptiveAvgPoolND', + 'AvgPoolND', + 'ConvND', + 'EmbMixin', + 'EmbSequential', + 'FiLM', + 'GroupNorm32', + 'MaxPoolND', + 'NDModule', + 'ResBlock', + 'SqueezeExcitation', + 'TransposedAttention', + 'UNetBase', +] diff --git a/src/mrpro/nn/layers.py b/src/mrpro/nn/layers.py new file mode 100644 index 000000000..392f7f076 --- /dev/null +++ b/src/mrpro/nn/layers.py @@ -0,0 +1,218 @@ +import torch +from einops import rearrange +from torch.nn import Identity, Linear, Module, Parameter, ReLU, Sequential, Sigmoid, SiLU + +from mrpro.nn.NDModules import AdaptiveAvgPoolND, ConvND +from mrpro.utils.reshape import unsqueeze_tensors_right + + +class EmbMixin: ... + + +class SqueezeExcitation(Module): + """Squeeze-and-Excitation block. + + Sequeeze-and-Excitation block from [SE]_. + + References + ---------- + ..[SE] Hu, Jie, Li Shen, and Gang Sun. "Squeeze-and-excitation networks." CVPR 2018, https://arxiv.org/abs/1709.01507 + """ + + def __init__(self, dim: int, input_channels: int, squeeze_channels: int) -> None: + """Initialize SqueezeExcitation. + + Parameters + ---------- + dim + The dimension of the input tensor. + input_channels + The number of channels in the input tensor. + squeeze_channels + The number of channels in the squeeze tensor. + """ + super().__init__() + self.scale = Sequential( + AdaptiveAvgPoolND(dim, 1), + ConvND(dim, input_channels, squeeze_channels, 1), + ReLU(), + ConvND(dim, squeeze_channels, input_channels, 1), + Sigmoid(), + ) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply SqueezeExcitation. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply SqueezeExcitation.""" + return x * self.scale(x) + + +class TransposedAttention(Module): + def __init__(self, dim: int, channels: int, num_heads: int): + """Transposed Self Attention from Restormer. + + Implements the transposed self-attention, i.e. channel-wise multihead self-attention, + layer from Restormer [ZAM22]_. + + References + ---------- + ..[ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for high-resolution image restoration." + CVPR 2022, https://arxiv.org/pdf/2111.09881.pdf + + Parameters + ---------- + dim + input dimension + channels + input channels + num_heads + number of attention heads + """ + super().__init__() + self.num_heads = num_heads + self.temperature = Parameter(torch.ones(num_heads, 1, 1)) + self.qkv = ConvND(dim, channels, channels * 3, kernel_size=1, bias=True) + self.qkv_dwconv = ConvND( + dim, + channels * 3, + channels * 3, + kernel_size=3, + groups=channels * 3, + bias=False, + ) + self.project_out = ConvND(dim, channels, channels, kernel_size=1, bias=True) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply transposed attention. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply transposed Attention.""" + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = rearrange(qkv, 'b (qkv head c) ... -> qkv b head (...) c', head=self.num_heads, qkv=3) + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=self.temperature) + out = rearrange(out, '... head points c -> ... (head c) points').reshape(x.shape) + out = self.project_out(out) + return out + + +class GroupNorm32(torch.nn.GroupNorm): + """A 32-bit GroupNorm. + + Casts to float32 before calling the parent class to avoid instabilities in mixed precision training. + """ + + def __init__(self, channels: int, groups: int | None = None): + """Initialize GroupNorm32. + + Parameters + ---------- + channels + The number of channels in the input tensor. + groups + The number of groups to use. If None, the number of groups is determined automatically as + a power of 2 that is less than or equal to 32 and leaves at least 4 channels per group. + """ + if groups is None: + groups_ = channels & -channels + while (groups_ >= channels // 4) or groups_ > 32: + groups_ //= 2 + else: + groups_ = groups + super().__init__(groups_, channels) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return super().__call__(x.float()).type(x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super(x.float).type(x.dtype) + + +class EmbSequential(Sequential): + def __call__(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: + return super().__call__(x, emb) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: + for module in self: + if isinstance(module, EmbMixin): + x = module(x, emb) + else: + x = module(x) + return x + + +def call_with_emb(module: Module, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: + if isinstance(module, EmbMixin): + return module(x, emb) + return module(x) + + +class FiLM(Module, EmbMixin): + def __init__(self, channels: int, channels_emb: int) -> None: + super().__init__() + self.project = Sequential( + SiLU(), + Linear(channels_emb, 2 * channels), + ) + + def __call__(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + return super().__call__(x, emb) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + emb = self.project(emb) + scale, shift = emb.chunk(2, dim=1) + scale, shift = unsqueeze_tensors_right(scale, shift, ndim=x.ndim) + return x * (1 + scale) + shift + + +class ResBlock(Module, EmbMixin): + def __init__(self, channels_in: int, channels_out: int, channels_emb: int, dim: int, dropout: float = 0.1) -> None: + super().__init__() + self.rezero = torch.nn.Parameter(torch.tensor(1e-6)) + self.modules = EmbSequential( + GroupNorm32(channels_in), + SiLU(), + ConvND(dim, channels_in, channels_out, 3), + GroupNorm32(channels_out), + SiLU(), + ConvND(dim, channels_out, channels_out, 3), + ) + if channels_emb > 0: + self.modules.insert(-3, FiLM(channels_out, channels_emb)) + + if channels_out == channels_in: + self.skip_connection = Identity() + else: + self.skip_connection = ConvND(dim, channels_in, channels_out, 1) + + def __call__(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: + return super().__call__(x, emb) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: + h = self.modules(x, emb) + x = self.skip_connection(x) + h + return x From 7e83be7af248fcc288d69b435de48db1e8ad703c Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 12 May 2025 23:03:54 +0200 Subject: [PATCH 021/205] update --- pyproject.toml | 1 + src/mrpro/nn/AttentionGate.py | 64 +++++ src/mrpro/nn/EmbMixin.py | 21 ++ src/mrpro/nn/FiLM.py | 54 ++++ src/mrpro/nn/GroupNorm32.py | 47 ++++ src/mrpro/nn/NDModules.py | 319 +++++++++++----------- src/mrpro/nn/NeighborhoodSelfAttention.py | 211 ++++++++++++++ src/mrpro/nn/ResBlock.py | 69 +++++ src/mrpro/nn/Sequential.py | 36 +++ src/mrpro/nn/ShiftedWindowAttention.py | 84 ++++++ src/mrpro/nn/SqueezeExcitation.py | 57 ++++ src/mrpro/nn/TransposedAttention.py | 69 +++++ src/mrpro/nn/Uformer.py | 38 --- src/mrpro/nn/__init__,py | 34 ++- src/mrpro/nn/layers.py | 218 --------------- src/mrpro/nn/{ => nets}/UNet.py | 22 +- src/mrpro/nn/nets/Uformer.py | 141 ++++++++++ 17 files changed, 1060 insertions(+), 425 deletions(-) create mode 100644 src/mrpro/nn/AttentionGate.py create mode 100644 src/mrpro/nn/EmbMixin.py create mode 100644 src/mrpro/nn/FiLM.py create mode 100644 src/mrpro/nn/GroupNorm32.py create mode 100644 src/mrpro/nn/NeighborhoodSelfAttention.py create mode 100644 src/mrpro/nn/ResBlock.py create mode 100644 src/mrpro/nn/Sequential.py create mode 100644 src/mrpro/nn/ShiftedWindowAttention.py create mode 100644 src/mrpro/nn/SqueezeExcitation.py create mode 100644 src/mrpro/nn/TransposedAttention.py delete mode 100644 src/mrpro/nn/Uformer.py delete mode 100644 src/mrpro/nn/layers.py rename src/mrpro/nn/{ => nets}/UNet.py (65%) create mode 100644 src/mrpro/nn/nets/Uformer.py diff --git a/pyproject.toml b/pyproject.toml index f7da23874..51f20acc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -225,6 +225,7 @@ iy = "iy" arange = "arange" # torch.arange Ba = "Ba" wht = "wht" # Brainweb tissue class +ND = "ND" # Short for N-dimensional [tool.typos.files] extend-exclude = [ diff --git a/src/mrpro/nn/AttentionGate.py b/src/mrpro/nn/AttentionGate.py new file mode 100644 index 000000000..53a1daec8 --- /dev/null +++ b/src/mrpro/nn/AttentionGate.py @@ -0,0 +1,64 @@ +"""Attention gate from Attention UNet.""" + +import torch +from torch.nn import Module, ReLU, Sequential, Sigmoid + +from mrpro.nn.NDModules import ConvND + + +class AttenionGate(Module): + """Attention gate from Attention UNet. + + The attention mechanism from the attention UNet [OKT18]_. + + References + ---------- + ..[OKT18] Oktay, Ozan, et al. "Attention u-net: Learning where to look for the pancreas." MIDL (2018). + https://arxiv.org/abs/1804.03999 + """ + + def __init__(self, dim: int, channels_gate: int, channels_in: int, channels_hidden: int): + """Initialize the attention gate. + + Parameters + ---------- + dim + The dimension, i.e. 1, 2 or 3. + channels_gate + The number of channels in the gate tensor. + channels_in + The number of channels in the input tensor. + channels_hidden + The number of internal, hidden channels. + """ + super().__init__() + self.project_gate = ConvND(dim)(channels_gate, channels_hidden, kernel_size=1) + self.project_x = ConvND(dim)(channels_in, channels_hidden, kernel_size=1) + self.psi = Sequential( + ReLU(), + ConvND(dim)(channels_hidden, 1, kernel_size=1), + Sigmoid(), + ) + + def __call__(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + """Apply the attention gate. + + Parameters + ---------- + x + The input tensor. + gate + The gate tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, gate) + + def forward(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + """Apply the attention gate.""" + gate = self.project_gate(gate) + x = self.project_x(x) + alpha = self.psi(gate + x) + return x * alpha diff --git a/src/mrpro/nn/EmbMixin.py b/src/mrpro/nn/EmbMixin.py new file mode 100644 index 000000000..44f9b2a0f --- /dev/null +++ b/src/mrpro/nn/EmbMixin.py @@ -0,0 +1,21 @@ +"""Base class for modules using an embedding.""" + +import torch +from torch.nn import Module + + +def call_with_emb(module: Module, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: + if isinstance(module, EmbMixin): + return module(x, emb) + return module(x) + + +class EmbMixin(Module): + """Mixin for modules using an embedding. + + Used to determine if a module uses an embedding within a Sequential container. + """ + + def __call__(self, x: torch.Tensor, emb: torch.Tensor | None = None, **kwargs) -> torch.Tensor: + """Apply the module to the input.""" + return super().__call__(x, emb, **kwargs) diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py new file mode 100644 index 000000000..7a5ef634f --- /dev/null +++ b/src/mrpro/nn/FiLM.py @@ -0,0 +1,54 @@ +"""Feature-wise Linear Modulation.""" + +import torch +from torch.nn import Linear, Module, Sequential, SiLU + +from mrpro.nn.EmbMixin import EmbMixin +from mrpro.utils.reshape import unsqueeze_tensors_right + + +class FiLM(Module, EmbMixin): + """Feature-wise Linear Modulation. + + Feature-wise Linear Modulation from [FiLM]_ + + References + ---------- + ..[FiLM] Perez, L., Strub, F., de Vries, H., Dumoulin, V., & Courville, A. "Film: Visual reasoning with a general conditioning layer." AAAI (2018). + https://arxiv.org/abs/1709.07871 + """ + + def __init__(self, channels: int, channels_emb: int) -> None: + """Initialize FiLM. + + Parameters + ---------- + channels + The number of channels in the input tensor. + channels_emb + The number of channels in the embedding tensor. + """ + super().__init__() + self.project = Sequential( + SiLU(), + Linear(channels_emb, 2 * channels), + ) + + def __call__(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + """Apply FiLM. + + Parameters + ---------- + x + The input tensor. + emb + The embedding tensor. + """ + return super().__call__(x, emb) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + """Apply FiLM.""" + emb = self.project(emb) + scale, shift = emb.chunk(2, dim=1) + scale, shift = unsqueeze_tensors_right(scale, shift, ndim=x.ndim) + return x * (1 + scale) + shift diff --git a/src/mrpro/nn/GroupNorm32.py b/src/mrpro/nn/GroupNorm32.py new file mode 100644 index 000000000..0baf66abe --- /dev/null +++ b/src/mrpro/nn/GroupNorm32.py @@ -0,0 +1,47 @@ +"""GroupNorm with 32-bit precision.""" + +import torch + + +class GroupNorm32(torch.nn.GroupNorm): + """A 32-bit GroupNorm. + + Casts to float32 before calling the parent class to avoid instabilities in mixed precision training. + """ + + def __init__(self, channels: int, groups: int | None = None): + """Initialize GroupNorm32. + + Parameters + ---------- + channels + The number of channels in the input tensor. + groups + The number of groups to use. If None, the number of groups is determined automatically as + a power of 2 that is less than or equal to 32 and leaves at least 4 channels per group. + """ + if groups is None: + groups_ = channels & -channels + while (groups_ >= channels // 4) or groups_ > 32: + groups_ //= 2 + else: + groups_ = groups + super().__init__(groups_, channels) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply GroupNorm32. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x.float()).type(x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply GroupNorm32.""" + return super().__call__(x.float()).type(x.dtype) diff --git a/src/mrpro/nn/NDModules.py b/src/mrpro/nn/NDModules.py index d3466ba29..b4bf089ea 100644 --- a/src/mrpro/nn/NDModules.py +++ b/src/mrpro/nn/NDModules.py @@ -1,173 +1,176 @@ -from abc import ABC -from collections.abc import Sequence -from functools import partial +"""Helper functions to get the correct N-dimensional module.""" import torch -from einops import rearrange -from torch.nn import Identity, Linear, Module, Parameter, ReLU, Sequential, Sigmoid, SiLU -from mrpro.utils.reshape import unsqueeze_tensors_right +def ConvND(dim: int) -> type[torch.nn.Conv1d] | type[torch.nn.Conv2d] | type[torch.nn.Conv3d]: # noqa: N802 + """Get the `dim`-dimensional convolution class. -class NDModule(Module, ABC): - def __call__(self, x: torch.Tensor) -> torch.Tensor: - """Apply the module to the input tensor.""" - return super().__call__(x) + Parameters + ---------- + dim + The dimension of the convolution. - def __forward__(self, x: torch.Tensor) -> torch.Tensor: - return self.module(x) + Returns + ------- + The convolution class. + """ + match dim: + case 1: + return torch.nn.Conv1d + case 2: + return torch.nn.Conv2d + case 3: + return torch.nn.Conv3d + case _: + raise NotImplementedError(f'ConvND for dim {dim} not implemented. Raise an issue if you need this.') + + +def ConvTransposeND( # noqa: N802 + dim: int, +) -> type[torch.nn.ConvTranspose1d] | type[torch.nn.ConvTranspose2d] | type[torch.nn.ConvTranspose3d]: + """Get the `dim`-dimensional transposed convolution class. + Parameters + ---------- + dim + The dimension of the transposed convolution. + + Returns + ------- + The transposed convolution class. + """ + match dim: + case 1: + return torch.nn.ConvTranspose1d + case 2: + return torch.nn.ConvTranspose2d + case 3: + return torch.nn.ConvTranspose3d + case _: + raise NotImplementedError( + f'ConvTransposeND for dim {dim} not implemented. Raise an issue if you need this.' + ) -class ConvND(NDModule): - """N-dimensional convolution. + +def MaxPoolND(dim: int) -> type[torch.nn.MaxPool1d] | type[torch.nn.MaxPool2d] | type[torch.nn.MaxPool3d]: # noqa: N802 + """Get the `dim`-dimensional max pooling class. Parameters ---------- dim - The dimension of the convolution. + The dimension of the max pooling. + + Returns + ------- + The max pooling class. """ + match dim: + case 1: + return torch.nn.MaxPool1d + case 2: + return torch.nn.MaxPool2d + case 3: + return torch.nn.MaxPool3d + case _: + raise NotImplementedError(f'MaxPoolNd for dim {dim} not implemented. Raise an issue if you need this.') - def __init__( - self, - dim: int, - in_channels: int, - out_channels: int, - kernel_size: Sequence[int] | int, - stride: Sequence[int] | int = 1, - padding: str | Sequence[int] | int = 'same', - dilation: Sequence[int] | int = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = 'zeros', - ) -> None: - if not isinstance(kernel_size, int) and len(kernel_size) != dim: - raise ValueError(f'kernel_size must be an int or a sequence of length {dim}') - if stride is not None and not isinstance(stride, int) and len(stride) != dim: - raise ValueError(f'stride must be None, an int, or a sequence of length {dim}') - if padding != 'same' and not isinstance(padding, int) and len(padding) != dim: - raise ValueError(f'padding must be an int or a sequence of length {dim}') - try: - self.module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}[dim]( - in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode - ) - except KeyError: - raise NotImplementedError(f'ConvND for dim {dim} not implemented.') from None - - def __call__(self, x: torch.Tensor) -> torch.Tensor: - return super().__call__(x) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self._inner(x) - - -class MaxPoolND(NDModule): - def __init__(self, dim: int) -> None: - super().__init__() - try: - self.module = {1: torch.nn.MaxPool1d, 2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d}[dim] - except KeyError: - raise NotImplementedError(f'MaxPoolNd for dim {dim} not implemented.') - - -class AvgPoolND(NDModule): - """N-dimensional average pooling.""" - - def __init__( - self, - dim: int, - kernel_size: int | Sequence[int], - stride: int | Sequence[int] | None = None, - padding: int | Sequence[int] = 0, - ceil_mode: bool = False, - count_include_pad: bool = False, - divisor_override: int | None = None, - ) -> None: - """Parameters for AvgPoolNd. - - Parameters - ---------- - dim - The dimension of the input tensor. - kernel_size - The size of the kernel. - stride - The stride of the kernel. - padding - The padding of the kernel. - ceil_mode - Whether to use ceil instead of floor to compute the output shape. - count_include_pad - Whether to include the padding in the divisor. - divisor_override - Overwrite the default divisor of the number of elements in the pooling region. - """ - super().__init__() - if not isinstance(kernel_size, int) and len(kernel_size) != dim: - raise ValueError(f'kernel_size must be an int or a sequence of length {dim}') - if stride is not None and not isinstance(stride, int) and len(stride) != dim: - raise ValueError(f'stride must be None, an int, or a sequence of length {dim}') - if padding != 'same' and not isinstance(padding, int) and len(padding) != dim: - raise ValueError(f'padding must be an int or a sequence of length {dim}') - try: - module = {1: torch.nn.AvgPool1d, 2: torch.nn.AvgPool2d, 3: torch.nn.AvgPool3d()}[dim] - except KeyError: - raise NotImplementedError(f'AvgPoolNd for dim {dim} not implemented.') from None - self.module = module(kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) - - -class AdaptiveAvgPoolND(NDModule): - """N-dimensional adaptive average pooling.""" - - def __init__(self, dim: int, output_size: int | None | Sequence[int] = None): - super().__init__() - if not isinstance(output_size, int) and len(output_size) != dim: - raise ValueError(f'output_size must be an int or a sequence of length {dim}') - try: - self.module = (torch.nn.AdaptiveAvgPool1d, torch.nn.AdaptiveAvgPool2d, torch.nn.AdaptiveAvgPool3d)[dim - 1]( - output_size + +def AvgPoolND(dim: int) -> type[torch.nn.AvgPool1d] | type[torch.nn.AvgPool2d] | type[torch.nn.AvgPool3d]: # noqa: N802 + """Get the `dim`-dimensional average pooling class. + + Parameters + ---------- + dim + The dimension of the average pooling. + + Returns + ------- + The average pooling class. + """ + match dim: + case 1: + return torch.nn.AvgPool1d + case 2: + return torch.nn.AvgPool2d + case 3: + return torch.nn.AvgPool3d + case _: + raise NotImplementedError(f'AvgPoolNd for dim {dim} not implemented. Raise an issue if you need this.') + + +def AdaptiveAvgPoolND( # noqa: N802 + dim: int, +) -> type[torch.nn.AdaptiveAvgPool1d] | type[torch.nn.AdaptiveAvgPool2d] | type[torch.nn.AdaptiveAvgPool3d]: + """Get the `dim`-dimensional adaptive average pooling class. + + Parameters + ---------- + dim + The dimension of the adaptive average pooling. + + Returns + ------- + The adaptive average pooling class. + """ + match dim: + case 1: + return torch.nn.AdaptiveAvgPool1d + case 2: + return torch.nn.AdaptiveAvgPool2d + case 3: + return torch.nn.AdaptiveAvgPool3d + case _: + raise NotImplementedError( + f'AdaptiveAvgPoolNd for dim {dim} not implemented. Raise an issue if you need this.' ) - except KeyError: - raise NotImplementedError(f'AdaptiveAvgPoolnD for dim {dim} not implemented.') from None - - -class MaxPoolND(NDModule): - """N-dimensional max pooling.""" - - def __init__( - self, - dim: int, - kernel_size: int | Sequence[int], - stride: int | Sequence[int] | None = None, - padding: int | Sequence[int] = 0, - dilation: int | Sequence[int] = 1, - ceil_mode: bool = False, - ) -> None: - """Initialize MaxPoolNd. - - Parameters - ---------- - dim - The dimension of the input tensor. - kernel_size - The size of the kernel. - stride - The stride of the kernel. - padding - The padding of the kernel. - dilation - The dilation of the kernel. - ceil_mode - Whether to use ceil instead of floor to compute the output shape. - """ - if not isinstance(kernel_size, int) and len(kernel_size) != dim: - raise ValueError(f'kernel_size must be an int or a sequence of length {dim}') - if stride is not None and not isinstance(stride, int) and len(stride) != dim: - raise ValueError(f'stride must be None, an int, or a sequence of length {dim}') - if not isinstance(padding, int) and len(padding) != dim: - raise ValueError(f'padding must be an int or a sequence of length {dim}') - if not isinstance(dilation, int) and len(dilation) != dim: - raise ValueError(f'dilation must be an int or a sequence of length {dim}') - super().__init__() - self.module = {1: torch.nn.MaxPool1d, 2: torch.nn.MaxPool2d, 3: torch.nn.MaxPool3d}[dim]( - kernel_size, stride, padding, dilation, ceil_mode - ) + + +def InstanceNormND( # noqa: N802 + dim: int, +) -> type[torch.nn.InstanceNorm1d] | type[torch.nn.InstanceNorm2d] | type[torch.nn.InstanceNorm3d]: + """Get the `dim`-dimensional instance normalization class. + + Parameters + ---------- + dim + The dimension of the instance normalization. + + Returns + ------- + The instance normalization class. + """ + match dim: + case 1: + return torch.nn.InstanceNorm1d + case 2: + return torch.nn.InstanceNorm2d + case 3: + return torch.nn.InstanceNorm3d + case _: + raise NotImplementedError(f'InstanceNormNd for dim {dim} not implemented. Raise an issue if you need this.') + + +def BatchNormND( # noqa: N802 + dim: int, +) -> type[torch.nn.BatchNorm1d] | type[torch.nn.BatchNorm2d] | type[torch.nn.BatchNorm3d]: + """Get the `dim`-dimensional batch normalization class. + + Parameters + ---------- + dim + The dimension of the batch normalization. + + Returns + ------- + The batch normalization class. + """ + match dim: + case 1: + return torch.nn.BatchNorm1d + case 2: + return torch.nn.BatchNorm2d + case 3: + return torch.nn.BatchNorm3d + case _: + raise NotImplementedError(f'BatchNormNd for dim {dim} not implemented. Raise an issue if you need this.') diff --git a/src/mrpro/nn/NeighborhoodSelfAttention.py b/src/mrpro/nn/NeighborhoodSelfAttention.py new file mode 100644 index 000000000..ccafc03ed --- /dev/null +++ b/src/mrpro/nn/NeighborhoodSelfAttention.py @@ -0,0 +1,211 @@ +import math +from collections.abc import Sequence +from functools import cache, reduce +from typing import TypeVar + +import torch +from einops import rearrange +from torch.nn import Linear, Module +from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention + +T = TypeVar('T') + + +def check_arg(length: int, arg: Sequence[T] | T) -> tuple[T, ...]: + """Standardize an argument to a fixed-length tuple. + + If the argument is a sequence, it checks if its length matches the + specified dimension. If it's a single value, it replicates it `dim` times. + + Parameters + ---------- + length + The expected length of the sequence. + arg + The argument to check. Can be a single value of type T or a + sequence of T. + + Returns + ------- + A tuple of length `dim` containing elements of type T. + + Raises + ------ + ValueError + If `arg` is a sequence and its length does not match `length`. + """ + if isinstance(arg, Sequence): + if not len(arg) == length: + raise ValueError(f'The arguments must be either single values or have length {length}. Got {arg}.') + return tuple(arg) + return (arg,) * length + + +@cache +def neighborhood_mask( + input_size: torch.Size, + kernel_size: int | Sequence[int], + dilation: int | Sequence[int] = 1, + circular: bool | Sequence[bool] = False, +) -> BlockMask: + """Create a flex attention block mask for neighborhood attention. + + This function defines which key/value pairs a query can attend to based + on a local neighborhood. The neighborhood is defined by `kernel_size` + and `dilation` and can be circular (wrapping around edges). + + Parameters + ---------- + input_size + The dimensions of the input tensor (e.g., (H, W) for 2D). + kernel_size + The size of the attention neighborhood window. Can be a single + integer for a symmetric window or a sequence of integers for + each dimension. + dilation + The dilation factor for the neighborhood + Can be a single integer for a symmetric window or a sequence + of integers for each dimension. + circular + Whether the neighborhood wraps around the edges (circular padding). + Can be a single boolean or a sequence of booleans. + + Returns + ------- + A mask object suitable for `flex_attention` that defines the + allowed attention connections. + """ + kernel_size_tuple, dilation_tuple, circular_tuple = ( + check_arg(len(input_size), x) for x in (kernel_size, dilation, circular) + ) + + def unravel_index(idx: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Convert a flat 1D index into multi-dimensional coordinates.""" + idx = idx.clone() + coords = [] + for dim in reversed(input_size): + coords.append(idx % dim) + idx = (idx / dim).floor().long() + coords.reverse() + return tuple(coords) + + def mask( + _batch: torch.Tensor, + _head: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + ) -> torch.Tensor: + """Determine if a query can attend to a key/value pair.""" + q_coord = unravel_index(q_idx) + kv_coord = unravel_index(kv_idx) + + masks = [] + for input_, kernel_, dilation_, circular_, q_, kv_ in zip( + input_size, + kernel_size_tuple, + dilation_tuple, + circular_tuple, + q_coord, + kv_coord, + strict=False, + ): + masks.append((q_ % dilation_) == (kv_ % dilation_)) + kernel_dilation = kernel_ * dilation_ + window_left = kernel_dilation // 2 + window_right = (kernel_dilation // 2) + ((kernel_dilation % 2) - 1) + if circular_: + left = (q_ - kv_ + input_) % input_ + right = (kv_ - q_ + input_) % input_ + masks.append((left <= window_left) | (right <= window_right)) + else: + center = q_.clamp(window_left, input_ - 1 - window_right) + left = center - kv_ + right = kv_ - center + masks.append(((left >= 0) & (left <= window_left)) | ((right >= 0) & (right <= window_right))) + return reduce(lambda x, y: x & y, masks) + + qkv_len = input_size.numel() + return create_block_mask(mask, B=None, H=None, Q_LEN=qkv_len, KV_LEN=qkv_len, _compile=True) + + +class NeighborhoodSelfAttention(Module): + """Attention where each query attends to a neighborhood of the key and value. + + Neighborhood attention is a type of attention where each query attends to a neighborhood of the key and value. + It is a more efficient alternative to regular attention, especially for large input sizes [NAT]_. + + This implementation uses `~torch.nn.attention.flex_attention`. For a more efficient implementation, see also [NATTEN]_. + + + References + ---------- + .. [NAT] Hassani, A. et al. "Neighborhood Attention Transformer" CVPR, 2023, https://arxiv.org/abs/2204.07143 + .. [NATTEN] https://github.com/SHI-Labs/NATTEN/ + """ + + def __init__( + self, + channels: int, + n_head: int, + kernel_size: int | Sequence[int], + dilation: int | Sequence[int] = 1, + circular: bool | Sequence[bool] = False, + channel_last: bool = False, + ) -> None: + """Initialize a neighborhood attention module. + + The parameters `kernel_size`, `dilation`, and `circular` can either be sequences, interpreted as per-dimension + values, or scalars, interpreted as the same value for all dimensions. + + Parameters + ---------- + channels + The number of channels in the input tensor. + n_head + The number of attention heads. + kernel_size + The size of the attention neighborhood window. + dilation + The dilation factor for the neighborhood. + circular + Whether the neighborhood wraps around the edges (circular padding) + channel_last + Whether the channels are in the last dimension of the tensor, as common in transformers. + """ + super().__init__() + self.n_head = n_head + self.kernel_size = kernel_size + self.dilation = dilation + self.circular = circular + self.channel_last = channel_last + self.to_qkv = Linear(channels, 3 * channels * n_head) + self.to_out = Linear(channels * n_head, channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply neighborhood attention to the input tensor. + + Parameters + ---------- + x + The input tensor, with shape `batch, channels, *spatial_dims`. + + Returns + ------- + The output tensor after attention, with the same shape as the input tensor. + """ + if not self.channel_last: + x = x.moveaxis(1, -1) + spatial_shape = x.shape[2:-1] + qkv = self.to_qkv(x) + query, key, value = rearrange(qkv, 'batch ... (qkv head channels) -> qkv batch head (...) channel') + # the mask depends on the input size. To be more flexible if used within CNNs, we compute it here. + # The computation is cached.. + mask = neighborhood_mask( + input_size=spatial_shape, kernel_size=self.kernel_size, dilation=self.dilation, circular=self.circular + ) + out: torch.Tensor = flex_attention(query.contiguous(), key.contiguous(), value.contiguous(), block_mask=mask) # type: ignore[assignment] # wrong type hints + out = self.to_out(out) + out = out.unflatten(-2, spatial_shape) + if not self.channel_last: + out = out.moveaxis(-1, 1) + return out diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py new file mode 100644 index 000000000..042ba6f87 --- /dev/null +++ b/src/mrpro/nn/ResBlock.py @@ -0,0 +1,69 @@ +"""Residual convolution block with two convolutions.""" + +import torch +from torch.nn import Identity, Module, Sequential, SiLU + +from mrpro.nn.NDModules import ConvND +from mrpro.nn.EmbMixin import EmbMixin +from mrpro.nn.GroupNorm32 import GroupNorm32 +from mrpro.nn.FiLM import FiLM + + +class ResBlock(Module, EmbMixin): + """Residual convolution block with two convolutions.""" + + def __init__(self, dim: int, channels_in: int, channels_out: int, channels_emb: int) -> None: + """Initialize the ResBlock. + + Parameters + ---------- + dim + The dimension, i.e. 1, 2 or 3. + channels_in + The number of channels in the input tensor. + channels_out + The number of channels in the output tensor. + channels_emb + The number of channels in the embedding tensor used in a FiLM embedding. + If set to 0 no FiLM embedding is used. + + """ + super().__init__() + self.rezero = torch.nn.Parameter(torch.tensor(1e-6)) + self.block = Sequential( + GroupNorm32(channels_in), + SiLU(), + ConvND(dim)(channels_in, channels_out, kernel_size=3), + GroupNorm32(channels_out), + SiLU(), + ConvND(dim)(channels_out, channels_out, kernel_size=3), + ) + if channels_emb > 0: + self.block.insert(-3, FiLM(channels_out, channels_emb)) + + if channels_out == channels_in: + self.skip_connection: Module = Identity() + else: + self.skip_connection = ConvND(dim)(channels_in, channels_out, kernel_size=1) + + def __call__(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: + """Apply the ResBlock. + + Parameters + ---------- + x + The input tensor. + emb + An embedding tensor to be used for FiLM. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, emb) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: + """Apply the ResBlock.""" + h = self.block(x, emb) + x = self.skip_connection(x) + h + return x diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py new file mode 100644 index 000000000..0d96355eb --- /dev/null +++ b/src/mrpro/nn/Sequential.py @@ -0,0 +1,36 @@ +import torch + +from mrpro.operators import Operator +from mrpro.nn.EmbMixin import EmbMixin +from torch.nn import Module + + +class Sequential(torch.nn.Sequential): + """Sequential container with support for embedding and Operators.""" + + def __call__(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + """Apply all modules in series to the input. + + Parameters + ---------- + x + The input tensor. + emb + The (optional) embedding tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, emb) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + """Apply all modules in series to the input.""" + for module in self: + if isinstance(module, EmbMixin): + x = module(x, emb) + elif isinstance(module, Operator): + (x,) = module(x) + else: + x = module(x) + return x diff --git a/src/mrpro/nn/ShiftedWindowAttention.py b/src/mrpro/nn/ShiftedWindowAttention.py new file mode 100644 index 000000000..a8734940b --- /dev/null +++ b/src/mrpro/nn/ShiftedWindowAttention.py @@ -0,0 +1,84 @@ +"""Shifted Window Attention.""" + +import torch +from einops import rearrange +from torch.nn import Module + +from mrpro.nn.NDModules import ConvND +from mrpro.utils.sliding_window import sliding_window + + +class ShiftedWindowAttention(Module): + """Shifted Window Attention. + + (Shifted) Window Attention calculates attention over windows of the input. + It was introduced in Swin Transformer [Swin] and is used in Uformer. + + References + ---------- + .. [SWIN] Liu, Ze, et al. "Swin transformer: Hierarchical vision transformer using shifted windows." ICCV 2021. + """ + + def __init__(self, dim: int, channels: int, n_heads: int, window_size: int = 7, shifted: bool = True): + """Initialize the ShiftedWindowAttention module. + + Parameters + ---------- + dim : int + The dimension of the input. + channels : int + The number of channels in the input. + n_heads : int + The number of attention heads. The number if channels per head is ``channels // n_heads``. + window_size : int + The size of the window. + shifted : bool + Whether to shift the window. + """ + super().__init__() + self.channels = channels + self.n_heads = n_heads + self.window_size = window_size + self.shifted = shifted + self.to_qkv = ConvND(dim)(channels, 3 * channels, 1) + self.dim = dim + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the ShiftedWindowAttention. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the ShiftedWindowAttention.""" + if self.shifted: + x = torch.roll(x, (-(self.window_size // 2),) * self.dim, dims=tuple(range(-self.dim, 0))) + qkv = self.to_qkv(x) + windowed = sliding_window(qkv, window_shape=self.window_size, stride=self.window_size, dim=range(-self.dim, 0)) + flat = windowed.flatten(0, self.dim - 1).flatten(-self.dim) + q, k, v = rearrange( + flat, + 'spatial batch (qkv heads channels) window->qkv spatial batch heads window channels', + heads=self.n_heads, + qkv=3, + ) + result = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None) + result = rearrange(result, 'spatial batch head window channels->batch (head channels) spatial window') + result = result.unflatten(-2, windowed.shape[: self.dim]).unflatten(-1, (self.window_size,) * self.dim) + # permute (in 3d) batch channels z y x wz wy wx -> batch channels wz z wy y wx x + result = result.moveaxis(list(range(-self.dim, 0)), list(range(3, 3 + 2 * self.dim, 2))) + result = result.reshape(x.shape) + if self.shifted: + result = torch.roll(result, (self.window_size // 2,) * self.dim, dims=tuple(range(-self.dim, 0))) + return result + + +'' diff --git a/src/mrpro/nn/SqueezeExcitation.py b/src/mrpro/nn/SqueezeExcitation.py new file mode 100644 index 000000000..bccd9e73d --- /dev/null +++ b/src/mrpro/nn/SqueezeExcitation.py @@ -0,0 +1,57 @@ +"""Squeeze-and-Excitation block.""" + +from torch.nn import Module, ReLU, Sigmoid + +from mrpro.nn.NDModules import AdaptiveAvgPoolND, ConvND +from mrpro.nn.Sequential import Sequential +import torch + + +class SqueezeExcitation(Module): + """Squeeze-and-Excitation block. + + Sequeeze-and-Excitation block from [SE]_. + + References + ---------- + ..[SE] Hu, Jie, Li Shen, and Gang Sun. "Squeeze-and-excitation networks." CVPR 2018, https://arxiv.org/abs/1709.01507 + """ + + def __init__(self, dim: int, input_channels: int, squeeze_channels: int) -> None: + """Initialize SqueezeExcitation. + + Parameters + ---------- + dim + The dimension of the input tensor. + input_channels + The number of channels in the input tensor. + squeeze_channels + The number of channels in the squeeze tensor. + """ + super().__init__() + self.scale = Sequential( + AdaptiveAvgPoolND(dim)(1), + ConvND(dim)(input_channels, squeeze_channels, kernel_size=1), + ReLU(), + ConvND(dim)(squeeze_channels, input_channels, kernel_size=1), + Sigmoid(), + ) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply SqueezeExcitation. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply SqueezeExcitation.""" + return x * self.scale(x) diff --git a/src/mrpro/nn/TransposedAttention.py b/src/mrpro/nn/TransposedAttention.py new file mode 100644 index 000000000..621508f75 --- /dev/null +++ b/src/mrpro/nn/TransposedAttention.py @@ -0,0 +1,69 @@ +"""Transposed Attention from Restormer.""" + +import torch +from einops import rearrange +from torch.nn import Identity, Linear, Module, Parameter, ReLU, Sequential, Sigmoid, SiLU + +from mrpro.nn.NDModules import AdaptiveAvgPoolND, ConvND, InstanceNormND +from mrpro.utils.reshape import unsqueeze_tensors_right +from mrpro.operators import Operator + + +class TransposedAttention(Module): + def __init__(self, dim: int, channels: int, num_heads: int): + """Transposed Self Attention from Restormer. + + Implements the transposed self-attention, i.e. channel-wise multihead self-attention, + layer from Restormer [ZAM22]_. + + References + ---------- + ..[ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for high-resolution image restoration." + CVPR 2022, https://arxiv.org/pdf/2111.09881.pdf + + Parameters + ---------- + dim + input dimension + channels + input channels + num_heads + number of attention heads + """ + super().__init__() + self.num_heads = num_heads + self.temperature = Parameter(torch.ones(num_heads, 1, 1)) + self.qkv = ConvND(dim)(channels, channels * 3, kernel_size=1, bias=True) + self.qkv_dwconv = ConvND(dim)( + channels * 3, + channels * 3, + kernel_size=3, + groups=channels * 3, + bias=False, + ) + self.project_out = ConvND(dim)(channels, channels, kernel_size=1, bias=True) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply transposed attention. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply transposed Attention.""" + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = rearrange(qkv, 'b (qkv head c) ... -> qkv b head (...) c', head=self.num_heads, qkv=3) + q = torch.nn.functional.normalize(q, dim=-1) * self.temperature + k = torch.nn.functional.normalize(k, dim=-1) + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=1.0) + out = rearrange(out, '... head points c -> ... (head c) points').reshape(x.shape) + out = self.project_out(out) + return out diff --git a/src/mrpro/nn/Uformer.py b/src/mrpro/nn/Uformer.py deleted file mode 100644 index d4dccc7d9..000000000 --- a/src/mrpro/nn/Uformer.py +++ /dev/null @@ -1,38 +0,0 @@ -class LeFF(nn.Module): - """Fast Locally-enhanced Feed-Forward Network.""" - - def __init__( - self, - dim: int = 32, - hidden_dim: int = 128, - act_layer: Callable[[], nn.Module] = nn.GELU, - ) -> None: - """ - Parameters - ---------- - dim : int - Input and output feature dimension. - hidden_dim : int - Hidden feature dimension. - act_layer : Callable - Activation function. - """ - super().__init__() - from torch_dwconv import DepthwiseConv2d # Local import for optional dependency - - self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim), act_layer()) - self.dwconv = nn.Sequential( - DepthwiseConv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1), - act_layer(), - ) - self.linear2 = nn.Linear(hidden_dim, dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - bs, hw, c = x.size() - hh = int(math.sqrt(hw)) - x = self.linear1(x) - x = rearrange(x, 'b (h w) c -> b c h w', h=hh, w=hh) - x = self.dwconv(x) - x = rearrange(x, 'b c h w -> b (h w) c', h=hh, w=hh) - x = self.linear2(x) - return x diff --git a/src/mrpro/nn/__init__,py b/src/mrpro/nn/__init__,py index 02bb157d4..cbe2a9d10 100644 --- a/src/mrpro/nn/__init__,py +++ b/src/mrpro/nn/__init__,py @@ -1,19 +1,43 @@ -from mrpro.nn.layers import EmbMixin, EmbSequential, FiLM, GroupNorm32, ResBlock, SqueezeExcitation, TransposedAttention -from mrpro.nn.NDModules import AdaptiveAvgPoolND, AvgPoolND, ConvND, MaxPoolND, NDModule -from mrpro.nn.UNet import UNetBase +"""Neural network modules and utilities.""" + +from mrpro.nn.AttentionGate import AttentionGate +from mrpro.nn.EmbMixin import EmbMixin +from mrpro.nn.EmbSequential import EmbSequential +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm32 import GroupNorm32 +from mrpro.nn.NDModules import ( + AdaptiveAvgPoolND, + AvgPoolND, + BatchNormND, + ConvND, + ConvTransposeND, + InstanceNormND, + MaxPoolND, +) +from mrpro.nn.NeighborhoodSelfAttention import NeighborhoodSelfAttention +from mrpro.nn.ResBlock import ResBlock +from mrpro.nn.Sequential import Sequential +from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention +from mrpro.nn.SqueezeExcitation import SqueezeExcitation +from mrpro.nn.TransposedAttention import TransposedAttention __all__ = [ 'AdaptiveAvgPoolND', + 'AttentionGate', 'AvgPoolND', + 'BatchNormND', 'ConvND', + 'ConvTransposeND', 'EmbMixin', 'EmbSequential', 'FiLM', 'GroupNorm32', + 'InstanceNormND', 'MaxPoolND', - 'NDModule', + 'NeighborhoodSelfAttention', 'ResBlock', + 'Sequential', + 'ShiftedWindowAttention', 'SqueezeExcitation', 'TransposedAttention', - 'UNetBase', ] diff --git a/src/mrpro/nn/layers.py b/src/mrpro/nn/layers.py deleted file mode 100644 index 392f7f076..000000000 --- a/src/mrpro/nn/layers.py +++ /dev/null @@ -1,218 +0,0 @@ -import torch -from einops import rearrange -from torch.nn import Identity, Linear, Module, Parameter, ReLU, Sequential, Sigmoid, SiLU - -from mrpro.nn.NDModules import AdaptiveAvgPoolND, ConvND -from mrpro.utils.reshape import unsqueeze_tensors_right - - -class EmbMixin: ... - - -class SqueezeExcitation(Module): - """Squeeze-and-Excitation block. - - Sequeeze-and-Excitation block from [SE]_. - - References - ---------- - ..[SE] Hu, Jie, Li Shen, and Gang Sun. "Squeeze-and-excitation networks." CVPR 2018, https://arxiv.org/abs/1709.01507 - """ - - def __init__(self, dim: int, input_channels: int, squeeze_channels: int) -> None: - """Initialize SqueezeExcitation. - - Parameters - ---------- - dim - The dimension of the input tensor. - input_channels - The number of channels in the input tensor. - squeeze_channels - The number of channels in the squeeze tensor. - """ - super().__init__() - self.scale = Sequential( - AdaptiveAvgPoolND(dim, 1), - ConvND(dim, input_channels, squeeze_channels, 1), - ReLU(), - ConvND(dim, squeeze_channels, input_channels, 1), - Sigmoid(), - ) - - def __call__(self, x: torch.Tensor) -> torch.Tensor: - """Apply SqueezeExcitation. - - Parameters - ---------- - x - The input tensor. - - Returns - ------- - The output tensor. - """ - return super().__call__(x) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Apply SqueezeExcitation.""" - return x * self.scale(x) - - -class TransposedAttention(Module): - def __init__(self, dim: int, channels: int, num_heads: int): - """Transposed Self Attention from Restormer. - - Implements the transposed self-attention, i.e. channel-wise multihead self-attention, - layer from Restormer [ZAM22]_. - - References - ---------- - ..[ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for high-resolution image restoration." - CVPR 2022, https://arxiv.org/pdf/2111.09881.pdf - - Parameters - ---------- - dim - input dimension - channels - input channels - num_heads - number of attention heads - """ - super().__init__() - self.num_heads = num_heads - self.temperature = Parameter(torch.ones(num_heads, 1, 1)) - self.qkv = ConvND(dim, channels, channels * 3, kernel_size=1, bias=True) - self.qkv_dwconv = ConvND( - dim, - channels * 3, - channels * 3, - kernel_size=3, - groups=channels * 3, - bias=False, - ) - self.project_out = ConvND(dim, channels, channels, kernel_size=1, bias=True) - - def __call__(self, x: torch.Tensor) -> torch.Tensor: - """Apply transposed attention. - - Parameters - ---------- - x - The input tensor. - - Returns - ------- - The output tensor. - """ - return super().__call__(x) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Apply transposed Attention.""" - qkv = self.qkv_dwconv(self.qkv(x)) - q, k, v = rearrange(qkv, 'b (qkv head c) ... -> qkv b head (...) c', head=self.num_heads, qkv=3) - q = torch.nn.functional.normalize(q, dim=-1) - k = torch.nn.functional.normalize(k, dim=-1) - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=self.temperature) - out = rearrange(out, '... head points c -> ... (head c) points').reshape(x.shape) - out = self.project_out(out) - return out - - -class GroupNorm32(torch.nn.GroupNorm): - """A 32-bit GroupNorm. - - Casts to float32 before calling the parent class to avoid instabilities in mixed precision training. - """ - - def __init__(self, channels: int, groups: int | None = None): - """Initialize GroupNorm32. - - Parameters - ---------- - channels - The number of channels in the input tensor. - groups - The number of groups to use. If None, the number of groups is determined automatically as - a power of 2 that is less than or equal to 32 and leaves at least 4 channels per group. - """ - if groups is None: - groups_ = channels & -channels - while (groups_ >= channels // 4) or groups_ > 32: - groups_ //= 2 - else: - groups_ = groups - super().__init__(groups_, channels) - - def __call__(self, x: torch.Tensor) -> torch.Tensor: - return super().__call__(x.float()).type(x.dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return super(x.float).type(x.dtype) - - -class EmbSequential(Sequential): - def __call__(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: - return super().__call__(x, emb) - - def forward(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: - for module in self: - if isinstance(module, EmbMixin): - x = module(x, emb) - else: - x = module(x) - return x - - -def call_with_emb(module: Module, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: - if isinstance(module, EmbMixin): - return module(x, emb) - return module(x) - - -class FiLM(Module, EmbMixin): - def __init__(self, channels: int, channels_emb: int) -> None: - super().__init__() - self.project = Sequential( - SiLU(), - Linear(channels_emb, 2 * channels), - ) - - def __call__(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - return super().__call__(x, emb) - - def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - emb = self.project(emb) - scale, shift = emb.chunk(2, dim=1) - scale, shift = unsqueeze_tensors_right(scale, shift, ndim=x.ndim) - return x * (1 + scale) + shift - - -class ResBlock(Module, EmbMixin): - def __init__(self, channels_in: int, channels_out: int, channels_emb: int, dim: int, dropout: float = 0.1) -> None: - super().__init__() - self.rezero = torch.nn.Parameter(torch.tensor(1e-6)) - self.modules = EmbSequential( - GroupNorm32(channels_in), - SiLU(), - ConvND(dim, channels_in, channels_out, 3), - GroupNorm32(channels_out), - SiLU(), - ConvND(dim, channels_out, channels_out, 3), - ) - if channels_emb > 0: - self.modules.insert(-3, FiLM(channels_out, channels_emb)) - - if channels_out == channels_in: - self.skip_connection = Identity() - else: - self.skip_connection = ConvND(dim, channels_in, channels_out, 1) - - def __call__(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: - return super().__call__(x, emb) - - def forward(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: - h = self.modules(x, emb) - x = self.skip_connection(x) + h - return x diff --git a/src/mrpro/nn/UNet.py b/src/mrpro/nn/nets/UNet.py similarity index 65% rename from src/mrpro/nn/UNet.py rename to src/mrpro/nn/nets/UNet.py index 24ee7a6a9..6a5175bcb 100644 --- a/src/mrpro/nn/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -1,9 +1,9 @@ from functools import partial import torch -from torch.nn import Module +from torch.nn import Module, ModuleList -from mrpro.nn.layers import call_with_emb +from mrpro.nn.EmbMixin import call_with_emb, EmbMixin class UNetBase(Module): @@ -16,19 +16,29 @@ def __init__( num_blocks: int, ) -> None: ... + input_blocks: ModuleList + down_blocks: ModuleList + skip_blocks: ModuleList + middle_block: Module + output_blocks: ModuleList + up_blocks: ModuleList + concat_blocks: ModuleList + last: Module + first: Module + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: """Apply to Network.""" call = partial(call_with_emb, emb=emb) x = call(self.first, x) xs = [] - for block, down, skip in zip(self.input_blocks, self.down_blocks, self.skip_blocks, strict=False): + for block, down, skip in zip(self.input_blocks, self.down_blocks, self.skip_blocks, strict=True): x = call(block, x) xs.append(call(skip, x)) x = call(down, x) - x = call(self.middel_block, x) - for block, up in (self.output_blocks, self.up_blocks): + x = call(self.middle_block, x) + for block, up, concat in zip(self.output_blocks, self.up_blocks, self.concat_blocks, strict=True): x = call(up, x) - x = torch.cat([x, xs.pop()], dim=1) + x = concat(x, xs.pop()) x = call(block, x) return call(self.last, x) diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py new file mode 100644 index 000000000..cce5e0c39 --- /dev/null +++ b/src/mrpro/nn/nets/Uformer.py @@ -0,0 +1,141 @@ +import torch +from torch.nn import Module, GELU, Linear, Sequential, Conv2d, ConvTranspose2d +from mrpro.nn.NDModules import ConvND +from mrpro.utils.sliding_window import sliding_window + +import torch +from mrpro.utils.sliding_window import sliding_window +from torch.nn import Module +from einops import rearrange +from mrpro.nn.NDModules import ConvND + + +class LeFF(Module): + """Locally-enhanced Feed-Forward Network. + + Part of the Uformer architecture. + """ + + def __init__( + self, + dim: int, + channels_in: int = 32, + channels_out: int = 32, + expand_ratio: float = 4, + ) -> None: + """Initialize the LeFF module. + + Parameters + ---------- + dim : int + 2 or 3, for 2D or 3D input + channels_in : int + Input feature dimension + channels_out : int + Output feature dimension + expand_ratio : float + Expansion ratio of the hidden dimension + """ + super().__init__() + hidden_dim = int(dim * expand_ratio) + self.block = Sequential( + ConvND(dim)(channels_in, hidden_dim, 1), + GELU(), + ConvND(dim)(hidden_dim, hidden_dim, kernel_size=3, groups=hidden_dim, stride=1, padding=1), + GELU(), + ConvND(dim)(hidden_dim, channels_out, 1), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block(x) + + +class LeWinTransformerBlock(Module): + def __init__( + self, + dim, + channels, + input_resolution, + num_heads, + win_size=8, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + norm_layer=nn.LayerNorm, + token_projection='linear', + ): + super().__init__() + self.channels = channels + self.input_resolution = input_resolution + self.num_heads = num_heads + self.win_size = win_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.token_mlp = token_mlp + self.modulator = Embedding(win_size * win_size, channels) # modulator + self.norm1 = norm_layer(channels) + self.attn = WindowAttention( + channels, + win_size=to_2tuple(self.win_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + token_projection=token_projection, + ) + + self.norm2 = norm_layer(channels) + mlp_hidden_dim = int(channels * mlp_ratio) + self.mlp = LeFF(channels, mlp_hidden_dim) + + def extra_repr(self) -> str: + return ( + f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' + f'win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio},modulator={self.modulator}' + ) + + def forward(self, x, mask=None): + B, L, C = x.shape + H = int(math.sqrt(L)) + W = int(math.sqrt(L)) + + ## input mask + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + shifted_x = x + x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C N*C->C + x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C + wmsa_in = self.with_pos_embed(x_windows, self.modulator.weight) + attn_windows = self.attn(wmsa_in, mask=attn_mask) # nW*B, win_size*win_size, C + attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C) + shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + return x + + +class SAM(Module): + """Spatial Attention Module. + + Part of the Uformer architecture. + """ + + def __init__(self, dim, channels): + super().__init__() + self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias) + self.conv2 = conv(n_feat, 3, kernel_size, bias=bias) + self.conv3 = conv(3, n_feat, kernel_size, bias=bias) + + def forward(self, x, x_img): + x1 = self.conv1(x) + img = self.conv2(x) + x_img + x2 = torch.sigmoid(self.conv3(img)) + x1 = x1 * x2 + x1 = x1 + x + return x1, img From 26467bf5714f6ff57883c714fed9e26583698aeb Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 13 May 2025 21:36:23 +0200 Subject: [PATCH 022/205] update --- src/mrpro/nn/ComplexAsChannel.py | 3 +++ src/mrpro/nn/ShiftedWindowAttention.py | 17 +++++++++++++-- tests/nn/test_shiftedwindowattention.py | 29 +++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 src/mrpro/nn/ComplexAsChannel.py create mode 100644 tests/nn/test_shiftedwindowattention.py diff --git a/src/mrpro/nn/ComplexAsChannel.py b/src/mrpro/nn/ComplexAsChannel.py new file mode 100644 index 000000000..640501496 --- /dev/null +++ b/src/mrpro/nn/ComplexAsChannel.py @@ -0,0 +1,3 @@ +from torch.nn import Module + +class ComplexAsChannel(Module): diff --git a/src/mrpro/nn/ShiftedWindowAttention.py b/src/mrpro/nn/ShiftedWindowAttention.py index a8734940b..f35dd1a38 100644 --- a/src/mrpro/nn/ShiftedWindowAttention.py +++ b/src/mrpro/nn/ShiftedWindowAttention.py @@ -5,6 +5,7 @@ from torch.nn import Module from mrpro.nn.NDModules import ConvND +from mrpro.utils.reshape import ravel_multi_index from mrpro.utils.sliding_window import sliding_window @@ -12,7 +13,7 @@ class ShiftedWindowAttention(Module): """Shifted Window Attention. (Shifted) Window Attention calculates attention over windows of the input. - It was introduced in Swin Transformer [Swin] and is used in Uformer. + It was introduced in Swin Transformer [SWIN]_ and is used in Uformer. References ---------- @@ -36,12 +37,23 @@ def __init__(self, dim: int, channels: int, n_heads: int, window_size: int = 7, Whether to shift the window. """ super().__init__() + if channels % n_heads: + raise ValueError('channels must be divisible by n_heads.') self.channels = channels self.n_heads = n_heads self.window_size = window_size self.shifted = shifted self.to_qkv = ConvND(dim)(channels, 3 * channels, 1) self.dim = dim + coords_1d = torch.arange(window_size) + coords_nd = torch.stack(torch.meshgrid(*([coords_1d] * dim), indexing='ij'), 0).flatten(1) + rel_coords = coords_nd[:, :, None] - coords_nd[:, None, :] # (dim, window_size**dim, window_size**dim) + rel_coords += window_size - 1 # shift to >=0 + rel_position_index = ravel_multi_index(rel_coords, (2 * window_size - 1,) * dim) + self.register_buffer('rel_position_index', rel_position_index) + + self.relative_position_bias_table = torch.nn.Parameter(torch.empty((2 * window_size - 1) ** dim, n_heads)) + torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02, a=-0.04, b=0.04) def __call__(self, x: torch.Tensor) -> torch.Tensor: """Apply the ShiftedWindowAttention. @@ -70,7 +82,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: heads=self.n_heads, qkv=3, ) - result = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None) + bias = rearrange(self.relative_position_bias_table[self.rel_position_index], 'wd1 wd2 heads -> 1 heads wd1 wd2') + result = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=bias) result = rearrange(result, 'spatial batch head window channels->batch (head channels) spatial window') result = result.unflatten(-2, windowed.shape[: self.dim]).unflatten(-1, (self.window_size,) * self.dim) # permute (in 3d) batch channels z y x wz wy wx -> batch channels wz z wy y wx x diff --git a/tests/nn/test_shiftedwindowattention.py b/tests/nn/test_shiftedwindowattention.py new file mode 100644 index 000000000..c14266d7f --- /dev/null +++ b/tests/nn/test_shiftedwindowattention.py @@ -0,0 +1,29 @@ +import pytest +import torch +from mrpro.nn import ShiftedWindowAttention + + +@pytest.mark.parametrize( + 'dim,window_size,shifted', + [ + (2, 4, False), + (2, 4, True), + (3, 2, False), + (3, 2, True), + ], +) +def test_shifted_window_attention_forward_and_grad(dim: int, window_size: int , shifted)->: + batch = 2 + channels = 8 + n_heads = 2 + spatial_shape = (window_size * 2,) * dim + x = torch.randn((batch, channels) + spatial_shape, requires_grad=True) + + attn = ShiftedWindowAttention(dim=dim, channels=channels, n_heads=n_heads, window_size=window_size, shifted=shifted) + + out = attn(x) + assert out.shape == x.shape, f'Output shape {out.shape} != input shape {x.shape}' + + # Check backward + out.sum().backward() + assert x.grad is not None, 'No gradient computed for input' From c39a9afcb1849005e635c16c74344296acc511d5 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 13 May 2025 22:27:48 +0200 Subject: [PATCH 023/205] update --- src/mrpro/__init__.py | 6 ++- src/mrpro/nn/AttentionGate.py | 2 +- src/mrpro/nn/ComplexAsChannel.py | 51 ++++++++++++++++++++++- src/mrpro/nn/EmbMixin.py | 2 +- src/mrpro/nn/FiLM.py | 2 +- src/mrpro/nn/NeighborhoodSelfAttention.py | 1 - src/mrpro/nn/ResBlock.py | 6 +-- src/mrpro/nn/Sequential.py | 5 +-- src/mrpro/nn/SqueezeExcitation.py | 2 +- src/mrpro/nn/TransposedAttention.py | 6 +-- src/mrpro/nn/__init__,py | 43 ------------------- src/mrpro/nn/nets/UNet.py | 2 +- src/mrpro/nn/nets/Uformer.py | 8 +--- tests/nn/test_shiftedwindowattention.py | 25 ++++++----- 14 files changed, 79 insertions(+), 82 deletions(-) delete mode 100644 src/mrpro/nn/__init__,py diff --git a/src/mrpro/__init__.py b/src/mrpro/__init__.py index 729ae188c..bbd401f1f 100644 --- a/src/mrpro/__init__.py +++ b/src/mrpro/__init__.py @@ -1,10 +1,12 @@ from mrpro._version import __version__ -from mrpro import algorithms, operators, data, phantoms, utils +from mrpro import algorithms, operators, data, phantoms, utils, nn + __all__ = [ "__version__", "algorithms", "data", + "nn", "operators", "phantoms", "utils" -] +] \ No newline at end of file diff --git a/src/mrpro/nn/AttentionGate.py b/src/mrpro/nn/AttentionGate.py index 53a1daec8..bb99edd40 100644 --- a/src/mrpro/nn/AttentionGate.py +++ b/src/mrpro/nn/AttentionGate.py @@ -6,7 +6,7 @@ from mrpro.nn.NDModules import ConvND -class AttenionGate(Module): +class AttentionGate(Module): """Attention gate from Attention UNet. The attention mechanism from the attention UNet [OKT18]_. diff --git a/src/mrpro/nn/ComplexAsChannel.py b/src/mrpro/nn/ComplexAsChannel.py index 640501496..e16b90bda 100644 --- a/src/mrpro/nn/ComplexAsChannel.py +++ b/src/mrpro/nn/ComplexAsChannel.py @@ -1,3 +1,52 @@ +import torch +from einops import rearrange from torch.nn import Module -class ComplexAsChannel(Module): +from mrpro.nn.EmbMixin import EmbMixin + + +class ComplexAsChannel(EmbMixin, Module): + """Wrap module to treat complex numbers as a channel dimension.""" + + def __init__(self, module: Module): + """Initialize the ComplexAsChannel module. + + Wraps a module to treat complex numbers as a channel dimension. + If called with a complex tensor, real and imaginary parts are concatenated along the channel dimension. + as ``(batch, (channel real/imaginary), ...)``. + + + Parameters + ---------- + module : Module + The module to wrap. + """ + super().__init__() + self.module = module + + def __call__(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + emb : torch.Tensor | None + The embedding tensor. + """ + return super().__call__(x, emb) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module.""" + if x.is_complex(): + x_real = torch.view_as_real(x) + x_real = rearrange(x_real, 'batch channel ... complex -> batch (channel complex) ...') + else: + x_real = x + + y = self.module(x_real) + + if x.is_complex(): + y = rearrange(y, 'b (c x y) 2 -> batch channel ... complex', complex=2).contiguous() + y = torch.view_as_complex(y) + return y diff --git a/src/mrpro/nn/EmbMixin.py b/src/mrpro/nn/EmbMixin.py index 44f9b2a0f..931f00683 100644 --- a/src/mrpro/nn/EmbMixin.py +++ b/src/mrpro/nn/EmbMixin.py @@ -5,7 +5,7 @@ def call_with_emb(module: Module, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: - if isinstance(module, EmbMixin): + if isinstance(EmbMixin, Module): return module(x, emb) return module(x) diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py index 7a5ef634f..b4fd567ab 100644 --- a/src/mrpro/nn/FiLM.py +++ b/src/mrpro/nn/FiLM.py @@ -7,7 +7,7 @@ from mrpro.utils.reshape import unsqueeze_tensors_right -class FiLM(Module, EmbMixin): +class FiLM(EmbMixin, Module): """Feature-wise Linear Modulation. Feature-wise Linear Modulation from [FiLM]_ diff --git a/src/mrpro/nn/NeighborhoodSelfAttention.py b/src/mrpro/nn/NeighborhoodSelfAttention.py index ccafc03ed..762c68ab3 100644 --- a/src/mrpro/nn/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/NeighborhoodSelfAttention.py @@ -1,4 +1,3 @@ -import math from collections.abc import Sequence from functools import cache, reduce from typing import TypeVar diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py index 042ba6f87..239bf2ab7 100644 --- a/src/mrpro/nn/ResBlock.py +++ b/src/mrpro/nn/ResBlock.py @@ -3,13 +3,13 @@ import torch from torch.nn import Identity, Module, Sequential, SiLU -from mrpro.nn.NDModules import ConvND from mrpro.nn.EmbMixin import EmbMixin -from mrpro.nn.GroupNorm32 import GroupNorm32 from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm32 import GroupNorm32 +from mrpro.nn.NDModules import ConvND -class ResBlock(Module, EmbMixin): +class ResBlock(EmbMixin, Module): """Residual convolution block with two convolutions.""" def __init__(self, dim: int, channels_in: int, channels_out: int, channels_emb: int) -> None: diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py index 0d96355eb..0273e6b51 100644 --- a/src/mrpro/nn/Sequential.py +++ b/src/mrpro/nn/Sequential.py @@ -1,8 +1,7 @@ import torch -from mrpro.operators import Operator from mrpro.nn.EmbMixin import EmbMixin -from torch.nn import Module +from mrpro.operators import Operator class Sequential(torch.nn.Sequential): @@ -27,7 +26,7 @@ def __call__(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Te def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: """Apply all modules in series to the input.""" for module in self: - if isinstance(module, EmbMixin): + if isinstance(EmbMixin, Module): x = module(x, emb) elif isinstance(module, Operator): (x,) = module(x) diff --git a/src/mrpro/nn/SqueezeExcitation.py b/src/mrpro/nn/SqueezeExcitation.py index bccd9e73d..8dcb87b65 100644 --- a/src/mrpro/nn/SqueezeExcitation.py +++ b/src/mrpro/nn/SqueezeExcitation.py @@ -1,10 +1,10 @@ """Squeeze-and-Excitation block.""" +import torch from torch.nn import Module, ReLU, Sigmoid from mrpro.nn.NDModules import AdaptiveAvgPoolND, ConvND from mrpro.nn.Sequential import Sequential -import torch class SqueezeExcitation(Module): diff --git a/src/mrpro/nn/TransposedAttention.py b/src/mrpro/nn/TransposedAttention.py index 621508f75..7b42f794a 100644 --- a/src/mrpro/nn/TransposedAttention.py +++ b/src/mrpro/nn/TransposedAttention.py @@ -2,11 +2,9 @@ import torch from einops import rearrange -from torch.nn import Identity, Linear, Module, Parameter, ReLU, Sequential, Sigmoid, SiLU +from torch.nn import Module, Parameter -from mrpro.nn.NDModules import AdaptiveAvgPoolND, ConvND, InstanceNormND -from mrpro.utils.reshape import unsqueeze_tensors_right -from mrpro.operators import Operator +from mrpro.nn.NDModules import ConvND class TransposedAttention(Module): diff --git a/src/mrpro/nn/__init__,py b/src/mrpro/nn/__init__,py deleted file mode 100644 index cbe2a9d10..000000000 --- a/src/mrpro/nn/__init__,py +++ /dev/null @@ -1,43 +0,0 @@ -"""Neural network modules and utilities.""" - -from mrpro.nn.AttentionGate import AttentionGate -from mrpro.nn.EmbMixin import EmbMixin -from mrpro.nn.EmbSequential import EmbSequential -from mrpro.nn.FiLM import FiLM -from mrpro.nn.GroupNorm32 import GroupNorm32 -from mrpro.nn.NDModules import ( - AdaptiveAvgPoolND, - AvgPoolND, - BatchNormND, - ConvND, - ConvTransposeND, - InstanceNormND, - MaxPoolND, -) -from mrpro.nn.NeighborhoodSelfAttention import NeighborhoodSelfAttention -from mrpro.nn.ResBlock import ResBlock -from mrpro.nn.Sequential import Sequential -from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention -from mrpro.nn.SqueezeExcitation import SqueezeExcitation -from mrpro.nn.TransposedAttention import TransposedAttention - -__all__ = [ - 'AdaptiveAvgPoolND', - 'AttentionGate', - 'AvgPoolND', - 'BatchNormND', - 'ConvND', - 'ConvTransposeND', - 'EmbMixin', - 'EmbSequential', - 'FiLM', - 'GroupNorm32', - 'InstanceNormND', - 'MaxPoolND', - 'NeighborhoodSelfAttention', - 'ResBlock', - 'Sequential', - 'ShiftedWindowAttention', - 'SqueezeExcitation', - 'TransposedAttention', -] diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 6a5175bcb..cceb68dce 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -3,7 +3,7 @@ import torch from torch.nn import Module, ModuleList -from mrpro.nn.EmbMixin import call_with_emb, EmbMixin +from mrpro.nn.EmbMixin import call_with_emb class UNetBase(Module): diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index cce5e0c39..552a7c6d9 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -1,12 +1,6 @@ import torch -from torch.nn import Module, GELU, Linear, Sequential, Conv2d, ConvTranspose2d -from mrpro.nn.NDModules import ConvND -from mrpro.utils.sliding_window import sliding_window +from torch.nn import GELU, Module, Sequential -import torch -from mrpro.utils.sliding_window import sliding_window -from torch.nn import Module -from einops import rearrange from mrpro.nn.NDModules import ConvND diff --git a/tests/nn/test_shiftedwindowattention.py b/tests/nn/test_shiftedwindowattention.py index c14266d7f..49e10f8ad 100644 --- a/tests/nn/test_shiftedwindowattention.py +++ b/tests/nn/test_shiftedwindowattention.py @@ -1,29 +1,28 @@ import pytest -import torch from mrpro.nn import ShiftedWindowAttention +from mrpro.utils.RandomGenerator import RandomGenerator @pytest.mark.parametrize( - 'dim,window_size,shifted', + ('dim', 'window_size', 'shifted'), [ - (2, 4, False), - (2, 4, True), - (3, 2, False), - (3, 2, True), + (2, 8, False), + (4, 4, True), ], ) -def test_shifted_window_attention_forward_and_grad(dim: int, window_size: int , shifted)->: +def test_shifted_window_attentio(dim: int, window_size: int, shifted) -> None: batch = 2 channels = 8 n_heads = 2 - spatial_shape = (window_size * 2,) * dim - x = torch.randn((batch, channels) + spatial_shape, requires_grad=True) + spatial_shape = (window_size * 4,) * dim + rng = RandomGenerator(13) + x = rng.float32_tensor((batch, channels, *spatial_shape)).requires_grad_(True) - attn = ShiftedWindowAttention(dim=dim, channels=channels, n_heads=n_heads, window_size=window_size, shifted=shifted) + swin = ShiftedWindowAttention(dim=dim, channels=channels, n_heads=n_heads, window_size=window_size, shifted=shifted) - out = attn(x) + out = swin(x) assert out.shape == x.shape, f'Output shape {out.shape} != input shape {x.shape}' - - # Check backward out.sum().backward() assert x.grad is not None, 'No gradient computed for input' + assert swin.to_qkv.weight.grad is not None, 'No gradient computed for to_qkv.weight' + assert swin.relative_position_bias_table.grad is not None, 'No gradient computed for relative_position_bias_table' From 633682b1959c07dec272277ade9ea50c5b8898f5 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 14 May 2025 00:53:02 +0200 Subject: [PATCH 024/205] update --- src/mrpro/nn/AttentionGate.py | 8 ++++--- src/mrpro/nn/EmbMixin.py | 4 ++-- src/mrpro/nn/FiLM.py | 3 +++ src/mrpro/nn/ResBlock.py | 4 ++-- src/mrpro/nn/Sequential.py | 1 + src/mrpro/nn/ShiftedWindowAttention.py | 4 +++- tests/nn/test_attentiongate.py | 32 +++++++++++++++++++++++++ tests/nn/test_film.py | 30 +++++++++++++++++++++++ tests/nn/test_shiftedwindowattention.py | 4 ++-- tests/nn/test_sqeezeexcitation.py | 26 ++++++++++++++++++++ 10 files changed, 106 insertions(+), 10 deletions(-) create mode 100644 tests/nn/test_attentiongate.py create mode 100644 tests/nn/test_film.py create mode 100644 tests/nn/test_sqeezeexcitation.py diff --git a/src/mrpro/nn/AttentionGate.py b/src/mrpro/nn/AttentionGate.py index bb99edd40..a20f04396 100644 --- a/src/mrpro/nn/AttentionGate.py +++ b/src/mrpro/nn/AttentionGate.py @@ -58,7 +58,9 @@ def __call__(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: """Apply the attention gate.""" - gate = self.project_gate(gate) - x = self.project_x(x) - alpha = self.psi(gate + x) + projected_gate = self.project_gate(gate) + projected_x = self.project_x(x) + if gate.shape[2:] != x.shape[2:]: + projected_gate = torch.nn.functional.interpolate(projected_gate, size=x.shape[2:], mode='nearest') + alpha = self.psi(projected_gate + projected_x) return x * alpha diff --git a/src/mrpro/nn/EmbMixin.py b/src/mrpro/nn/EmbMixin.py index 931f00683..5188ae964 100644 --- a/src/mrpro/nn/EmbMixin.py +++ b/src/mrpro/nn/EmbMixin.py @@ -16,6 +16,6 @@ class EmbMixin(Module): Used to determine if a module uses an embedding within a Sequential container. """ - def __call__(self, x: torch.Tensor, emb: torch.Tensor | None = None, **kwargs) -> torch.Tensor: + def __call__(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: """Apply the module to the input.""" - return super().__call__(x, emb, **kwargs) + return super().__call__(x, emb) diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py index b4fd567ab..c74825e57 100644 --- a/src/mrpro/nn/FiLM.py +++ b/src/mrpro/nn/FiLM.py @@ -48,6 +48,9 @@ def __call__(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: """Apply FiLM.""" + if emb is None: + return x + emb = self.project(emb) scale, shift = emb.chunk(2, dim=1) scale, shift = unsqueeze_tensors_right(scale, shift, ndim=x.ndim) diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py index 239bf2ab7..35a13baf2 100644 --- a/src/mrpro/nn/ResBlock.py +++ b/src/mrpro/nn/ResBlock.py @@ -46,7 +46,7 @@ def __init__(self, dim: int, channels_in: int, channels_out: int, channels_emb: else: self.skip_connection = ConvND(dim)(channels_in, channels_out, kernel_size=1) - def __call__(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: """Apply the ResBlock. Parameters @@ -62,7 +62,7 @@ def __call__(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: """ return super().__call__(x, emb) - def forward(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: """Apply the ResBlock.""" h = self.block(x, emb) x = self.skip_connection(x) + h diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py index 0273e6b51..33b4eba33 100644 --- a/src/mrpro/nn/Sequential.py +++ b/src/mrpro/nn/Sequential.py @@ -1,4 +1,5 @@ import torch +from torch.nn import Module from mrpro.nn.EmbMixin import EmbMixin from mrpro.operators import Operator diff --git a/src/mrpro/nn/ShiftedWindowAttention.py b/src/mrpro/nn/ShiftedWindowAttention.py index f35dd1a38..d29ba8238 100644 --- a/src/mrpro/nn/ShiftedWindowAttention.py +++ b/src/mrpro/nn/ShiftedWindowAttention.py @@ -20,6 +20,8 @@ class ShiftedWindowAttention(Module): .. [SWIN] Liu, Ze, et al. "Swin transformer: Hierarchical vision transformer using shifted windows." ICCV 2021. """ + rel_position_index: torch.Tensor + def __init__(self, dim: int, channels: int, n_heads: int, window_size: int = 7, shifted: bool = True): """Initialize the ShiftedWindowAttention module. @@ -49,7 +51,7 @@ def __init__(self, dim: int, channels: int, n_heads: int, window_size: int = 7, coords_nd = torch.stack(torch.meshgrid(*([coords_1d] * dim), indexing='ij'), 0).flatten(1) rel_coords = coords_nd[:, :, None] - coords_nd[:, None, :] # (dim, window_size**dim, window_size**dim) rel_coords += window_size - 1 # shift to >=0 - rel_position_index = ravel_multi_index(rel_coords, (2 * window_size - 1,) * dim) + rel_position_index = ravel_multi_index(tuple(rel_coords), (2 * window_size - 1,) * dim) self.register_buffer('rel_position_index', rel_position_index) self.relative_position_bias_table = torch.nn.Parameter(torch.empty((2 * window_size - 1) ** dim, n_heads)) diff --git a/tests/nn/test_attentiongate.py b/tests/nn/test_attentiongate.py new file mode 100644 index 000000000..d209d6483 --- /dev/null +++ b/tests/nn/test_attentiongate.py @@ -0,0 +1,32 @@ +"""Tests for AttentionGate module.""" + +import pytest +from mrpro.nn.AttentionGate import AttentionGate +from mrpro.utils.RandomGenerator import RandomGenerator + + +@pytest.mark.parametrize( + ('dim', 'channels_gate', 'channels_in', 'channels_hidden', 'input_shape', 'gate_shape'), + [ + (2, 32, 32, 16, (1, 32, 32, 32), (1, 32, 16, 16)), + (3, 32, 4, 8, (2, 4, 16, 16, 16), (2, 32, 16, 16, 16)), + ], +) +def test_attention_gate(dim, channels_gate, channels_in, channels_hidden, input_shape, gate_shape): + """Test AttentionGate output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).requires_grad_(True) + gate = rng.float32_tensor(gate_shape).requires_grad_(True) + attn = AttentionGate(dim=dim, channels_gate=channels_gate, channels_in=channels_in, channels_hidden=channels_hidden) + output = attn(x, gate) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert gate.grad is not None, 'No gradient computed for gate' + assert not x.isnan().any(), 'NaN values in input' + assert not gate.isnan().any(), 'NaN values in gate' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert not gate.grad.isnan().any(), 'NaN values in gate gradients' + assert attn.project_gate.weight.grad is not None, 'No gradient computed for project_gate' + assert attn.project_x.weight.grad is not None, 'No gradient computed for project_x' + assert attn.psi[1].weight.grad is not None, 'No gradient computed for psi' diff --git a/tests/nn/test_film.py b/tests/nn/test_film.py new file mode 100644 index 000000000..b329ab313 --- /dev/null +++ b/tests/nn/test_film.py @@ -0,0 +1,30 @@ +"""Tests for FiLM module.""" + +import pytest +from mrpro.nn.FiLM import FiLM +from mrpro.utils.RandomGenerator import RandomGenerator + + +@pytest.mark.parametrize( + ('channels', 'channels_emb', 'input_shape', 'emb_shape'), + [ + (64, 32, (1, 64, 32, 32), (1, 32)), + (32, 16, (2, 32, 16, 16), (2, 16)), + ], +) +def test_film(channels, channels_emb, input_shape, emb_shape): + """Test FiLM output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).requires_grad_(True) + emb = rng.float32_tensor(emb_shape).requires_grad_(True) + film = FiLM(channels=channels, channels_emb=channels_emb) + output = film(x, emb) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert emb.grad is not None, 'No gradient computed for embedding' + assert not x.isnan().any(), 'NaN values in input' + assert not emb.isnan().any(), 'NaN values in embedding' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert not emb.grad.isnan().any(), 'NaN values in embedding gradients' + assert film.project[1].weight.grad is not None, 'No gradient computed for Linear layer' diff --git a/tests/nn/test_shiftedwindowattention.py b/tests/nn/test_shiftedwindowattention.py index 49e10f8ad..411c04d63 100644 --- a/tests/nn/test_shiftedwindowattention.py +++ b/tests/nn/test_shiftedwindowattention.py @@ -17,12 +17,12 @@ def test_shifted_window_attentio(dim: int, window_size: int, shifted) -> None: spatial_shape = (window_size * 4,) * dim rng = RandomGenerator(13) x = rng.float32_tensor((batch, channels, *spatial_shape)).requires_grad_(True) - swin = ShiftedWindowAttention(dim=dim, channels=channels, n_heads=n_heads, window_size=window_size, shifted=shifted) - out = swin(x) assert out.shape == x.shape, f'Output shape {out.shape} != input shape {x.shape}' + assert not out.isnan().any(), 'NaN values in output' out.sum().backward() assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' assert swin.to_qkv.weight.grad is not None, 'No gradient computed for to_qkv.weight' assert swin.relative_position_bias_table.grad is not None, 'No gradient computed for relative_position_bias_table' diff --git a/tests/nn/test_sqeezeexcitation.py b/tests/nn/test_sqeezeexcitation.py new file mode 100644 index 000000000..b241aa56b --- /dev/null +++ b/tests/nn/test_sqeezeexcitation.py @@ -0,0 +1,26 @@ +"""Tests for SqueezeExcitation module.""" + +import pytest +from mrpro.nn.SqueezeExcitation import SqueezeExcitation +from mrpro.utils.RandomGenerator import RandomGenerator + + +@pytest.mark.parametrize( + ('dim', 'input_shape', 'squeeze_channels'), + [ + (2, (1, 64, 32, 32), 16), + (3, (1, 64, 16, 16, 16), 16), + ], +) +def test_squeeze_excitation(dim, input_shape, squeeze_channels): + """Test SqueezeExcitation output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).requires_grad_(True) + se = SqueezeExcitation(dim=dim, input_channels=input_shape[1], squeeze_channels=squeeze_channels) + output = se(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.isnan().any(), 'NaN values in input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert se.scale[1].weight.grad is not None, 'No gradient computed for Conv' From 420cdc1b29a6604a91eec2b3578d6b5b6447ad83 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 14 May 2025 01:18:30 +0200 Subject: [PATCH 025/205] update --- src/mrpro/nn/ComplexAsChannel.py | 2 +- src/mrpro/nn/__init__.py | 42 +++++++ src/mrpro/nn/test.ipynb | 149 ++++++++++++++++++++++++ tests/nn/test_attentiongate.py | 19 ++- tests/nn/test_complexaschannel.py | 30 +++++ tests/nn/test_film.py | 17 ++- tests/nn/test_groupnorm32.py | 35 ++++++ tests/nn/test_resblock.py | 42 +++++++ tests/nn/test_sequential.py | 41 +++++++ tests/nn/test_shiftedwindowattention.py | 2 +- tests/nn/test_sqeezeexcitation.py | 4 +- tests/nn/test_transposedattention.py | 37 ++++++ 12 files changed, 406 insertions(+), 14 deletions(-) create mode 100644 src/mrpro/nn/__init__.py create mode 100644 src/mrpro/nn/test.ipynb create mode 100644 tests/nn/test_complexaschannel.py create mode 100644 tests/nn/test_groupnorm32.py create mode 100644 tests/nn/test_resblock.py create mode 100644 tests/nn/test_sequential.py create mode 100644 tests/nn/test_transposedattention.py diff --git a/src/mrpro/nn/ComplexAsChannel.py b/src/mrpro/nn/ComplexAsChannel.py index e16b90bda..b64c00151 100644 --- a/src/mrpro/nn/ComplexAsChannel.py +++ b/src/mrpro/nn/ComplexAsChannel.py @@ -47,6 +47,6 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Ten y = self.module(x_real) if x.is_complex(): - y = rearrange(y, 'b (c x y) 2 -> batch channel ... complex', complex=2).contiguous() + y = rearrange(y, 'b (c x y) ... complex -> batch channel ... complex', complex=2).contiguous() y = torch.view_as_complex(y) return y diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py new file mode 100644 index 000000000..b6e3f9b8d --- /dev/null +++ b/src/mrpro/nn/__init__.py @@ -0,0 +1,42 @@ +"""Neural network modules and utilities.""" + +from mrpro.nn.AttentionGate import AttentionGate +from mrpro.nn.EmbMixin import EmbMixin +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm32 import GroupNorm32 +from mrpro.nn.NDModules import ( + AdaptiveAvgPoolND, + AvgPoolND, + BatchNormND, + ConvND, + ConvTransposeND, + InstanceNormND, + MaxPoolND, +) +from mrpro.nn.NeighborhoodSelfAttention import NeighborhoodSelfAttention +from mrpro.nn.ResBlock import ResBlock +from mrpro.nn.Sequential import Sequential +from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention +from mrpro.nn.SqueezeExcitation import SqueezeExcitation +from mrpro.nn.TransposedAttention import TransposedAttention + +__all__ = [ + 'AdaptiveAvgPoolND', + 'AttentionGate', + 'AvgPoolND', + 'BatchNormND', + 'ConvND', + 'ConvTransposeND', + 'EmbMixin', + 'EmbSequential', + 'FiLM', + 'GroupNorm32', + 'InstanceNormND', + 'MaxPoolND', + 'NeighborhoodSelfAttention', + 'ResBlock', + 'Sequential', + 'ShiftedWindowAttention', + 'SqueezeExcitation', + 'TransposedAttention', +] \ No newline at end of file diff --git a/src/mrpro/nn/test.ipynb b/src/mrpro/nn/test.ipynb new file mode 100644 index 000000000..ddfd6b8b3 --- /dev/null +++ b/src/mrpro/nn/test.ipynb @@ -0,0 +1,149 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from mrpro.utils.sliding_window import sliding_window\n", + "from torch.nn import Module\n", + "from einops import rearrange\n", + "from mrpro.nn.NDModules import ConvND\n", + "class ShiftedWindowMSA(Module):\n", + " def __init__(self, dim, channels, n_heads, window_size=7, shifted=True):\n", + " super().__init__()\n", + " self.channels = channels\n", + " self.n_heads = n_heads\n", + " self.window_size = window_size\n", + " self.shifted = shifted\n", + " self.to_qkv = ConvND(dim)(channels, 3*channels, 1)\n", + " self.dim=dim\n", + " def forward(self, x):\n", + " if self.shifted:\n", + " x = torch.roll(x, (-(self.window_size//2),)*self.dim,dims=tuple(range(-self.dim, 0)))\n", + " qkv = self.to_qkv(x) \n", + " windowed = sliding_window(qkv, window_shape=self.window_size, stride=self.window_size, dim=range(-self.dim, 0))\n", + " flat = windowed.flatten(0,self.dim-1).flatten(-self.dim) \n", + " q,k,v = rearrange(flat, 'spatial batch (qkv heads channels) window->qkv spatial batch heads window channels', heads = self.n_heads, qkv=3)\n", + " result = torch.nn.functional.scaled_dot_product_attention(q,k,v, attn_mask=None)\n", + " result = rearrange(result, 'spatial batch head window channels->batch (head channels) spatial window')\n", + " result=result.unflatten(-2, windowed.shape[:self.dim]).unflatten(-1, (self.window_size,)*self.dim)\n", + " result=result.moveaxis(list(range(-self.dim, 0)), list(range(3, 3+2*self.dim, 2)))\n", + " result = result.reshape(x.shape)\n", + " if self.shifted:\n", + " result = torch.roll(result, (self.window_size//2,)*self.dim,dims=tuple(range(-self.dim, 0)))\n", + " return result" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [], + "source": [ + "m=ShiftedWindowMSA(dim=2, channels=16, n_heads=4, window_size=5, shifted=True).cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [], + "source": [ + "x=torch.arange(2*16*20*30).reshape(2,16,20,30).float().cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 108, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 16, 20, 30])" + ] + }, + "execution_count": 108, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "m(x).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[100]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmatplotlib\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mpyplot\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mplt\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[43mplt\u001b[49m\u001b[43m.\u001b[49m\u001b[43mimshow\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3\u001b[39m plt.show()\n\u001b[32m 4\u001b[39m plt.imshow(m(x)[\u001b[32m0\u001b[39m,\u001b[32m0\u001b[39m])\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/mrpro/.venv/lib/python3.11/site-packages/matplotlib/pyplot.py:3590\u001b[39m, in \u001b[36mimshow\u001b[39m\u001b[34m(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, colorizer, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, data, **kwargs)\u001b[39m\n\u001b[32m 3568\u001b[39m \u001b[38;5;129m@_copy_docstring_and_deprecators\u001b[39m(Axes.imshow)\n\u001b[32m 3569\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mimshow\u001b[39m(\n\u001b[32m 3570\u001b[39m X: ArrayLike | PIL.Image.Image,\n\u001b[32m (...)\u001b[39m\u001b[32m 3588\u001b[39m **kwargs,\n\u001b[32m 3589\u001b[39m ) -> AxesImage:\n\u001b[32m-> \u001b[39m\u001b[32m3590\u001b[39m __ret = \u001b[43mgca\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mimshow\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 3591\u001b[39m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3592\u001b[39m \u001b[43m \u001b[49m\u001b[43mcmap\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcmap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3593\u001b[39m \u001b[43m \u001b[49m\u001b[43mnorm\u001b[49m\u001b[43m=\u001b[49m\u001b[43mnorm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3594\u001b[39m \u001b[43m \u001b[49m\u001b[43maspect\u001b[49m\u001b[43m=\u001b[49m\u001b[43maspect\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3595\u001b[39m \u001b[43m \u001b[49m\u001b[43minterpolation\u001b[49m\u001b[43m=\u001b[49m\u001b[43minterpolation\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3596\u001b[39m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m=\u001b[49m\u001b[43malpha\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3597\u001b[39m \u001b[43m \u001b[49m\u001b[43mvmin\u001b[49m\u001b[43m=\u001b[49m\u001b[43mvmin\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3598\u001b[39m \u001b[43m \u001b[49m\u001b[43mvmax\u001b[49m\u001b[43m=\u001b[49m\u001b[43mvmax\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3599\u001b[39m \u001b[43m \u001b[49m\u001b[43mcolorizer\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcolorizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3600\u001b[39m \u001b[43m \u001b[49m\u001b[43morigin\u001b[49m\u001b[43m=\u001b[49m\u001b[43morigin\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3601\u001b[39m \u001b[43m \u001b[49m\u001b[43mextent\u001b[49m\u001b[43m=\u001b[49m\u001b[43mextent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3602\u001b[39m \u001b[43m \u001b[49m\u001b[43minterpolation_stage\u001b[49m\u001b[43m=\u001b[49m\u001b[43minterpolation_stage\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3603\u001b[39m \u001b[43m \u001b[49m\u001b[43mfilternorm\u001b[49m\u001b[43m=\u001b[49m\u001b[43mfilternorm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3604\u001b[39m \u001b[43m \u001b[49m\u001b[43mfilterrad\u001b[49m\u001b[43m=\u001b[49m\u001b[43mfilterrad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3605\u001b[39m \u001b[43m \u001b[49m\u001b[43mresample\u001b[49m\u001b[43m=\u001b[49m\u001b[43mresample\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3606\u001b[39m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m=\u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3607\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mdata\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m}\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m{\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3608\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3609\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3610\u001b[39m sci(__ret)\n\u001b[32m 3611\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m __ret\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/mrpro/.venv/lib/python3.11/site-packages/matplotlib/__init__.py:1521\u001b[39m, in \u001b[36m_preprocess_data..inner\u001b[39m\u001b[34m(ax, data, *args, **kwargs)\u001b[39m\n\u001b[32m 1518\u001b[39m \u001b[38;5;129m@functools\u001b[39m.wraps(func)\n\u001b[32m 1519\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34minner\u001b[39m(ax, *args, data=\u001b[38;5;28;01mNone\u001b[39;00m, **kwargs):\n\u001b[32m 1520\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m data \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1521\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1522\u001b[39m \u001b[43m \u001b[49m\u001b[43max\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1523\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[38;5;28;43mmap\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcbook\u001b[49m\u001b[43m.\u001b[49m\u001b[43msanitize_sequence\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1524\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43m{\u001b[49m\u001b[43mk\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mcbook\u001b[49m\u001b[43m.\u001b[49m\u001b[43msanitize_sequence\u001b[49m\u001b[43m(\u001b[49m\u001b[43mv\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m.\u001b[49m\u001b[43mitems\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1526\u001b[39m bound = new_sig.bind(ax, *args, **kwargs)\n\u001b[32m 1527\u001b[39m auto_label = (bound.arguments.get(label_namer)\n\u001b[32m 1528\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m bound.kwargs.get(label_namer))\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/mrpro/.venv/lib/python3.11/site-packages/matplotlib/axes/_axes.py:5976\u001b[39m, in \u001b[36mAxes.imshow\u001b[39m\u001b[34m(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, colorizer, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, **kwargs)\u001b[39m\n\u001b[32m 5973\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m aspect \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 5974\u001b[39m \u001b[38;5;28mself\u001b[39m.set_aspect(aspect)\n\u001b[32m-> \u001b[39m\u001b[32m5976\u001b[39m \u001b[43mim\u001b[49m\u001b[43m.\u001b[49m\u001b[43mset_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 5977\u001b[39m im.set_alpha(alpha)\n\u001b[32m 5978\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m im.get_clip_path() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 5979\u001b[39m \u001b[38;5;66;03m# image does not already have clipping set, clip to Axes patch\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/mrpro/.venv/lib/python3.11/site-packages/matplotlib/image.py:685\u001b[39m, in \u001b[36m_ImageBase.set_data\u001b[39m\u001b[34m(self, A)\u001b[39m\n\u001b[32m 683\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(A, PIL.Image.Image):\n\u001b[32m 684\u001b[39m A = pil_to_array(A) \u001b[38;5;66;03m# Needed e.g. to apply png palette.\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m685\u001b[39m \u001b[38;5;28mself\u001b[39m._A = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_normalize_image_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43mA\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 686\u001b[39m \u001b[38;5;28mself\u001b[39m._imcache = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 687\u001b[39m \u001b[38;5;28mself\u001b[39m.stale = \u001b[38;5;28;01mTrue\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/mrpro/.venv/lib/python3.11/site-packages/matplotlib/image.py:646\u001b[39m, in \u001b[36m_ImageBase._normalize_image_array\u001b[39m\u001b[34m(A)\u001b[39m\n\u001b[32m 640\u001b[39m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[32m 641\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_normalize_image_array\u001b[39m(A):\n\u001b[32m 642\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 643\u001b[39m \u001b[33;03m Check validity of image-like input *A* and normalize it to a format suitable for\u001b[39;00m\n\u001b[32m 644\u001b[39m \u001b[33;03m Image subclasses.\u001b[39;00m\n\u001b[32m 645\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m646\u001b[39m A = \u001b[43mcbook\u001b[49m\u001b[43m.\u001b[49m\u001b[43msafe_masked_invalid\u001b[49m\u001b[43m(\u001b[49m\u001b[43mA\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 647\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m A.dtype != np.uint8 \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m np.can_cast(A.dtype, \u001b[38;5;28mfloat\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33msame_kind\u001b[39m\u001b[33m\"\u001b[39m):\n\u001b[32m 648\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mImage data of dtype \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mA.dtype\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m cannot be \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 649\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mconverted to float\u001b[39m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/mrpro/.venv/lib/python3.11/site-packages/matplotlib/cbook.py:684\u001b[39m, in \u001b[36msafe_masked_invalid\u001b[39m\u001b[34m(x, copy)\u001b[39m\n\u001b[32m 683\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34msafe_masked_invalid\u001b[39m(x, copy=\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[32m--> \u001b[39m\u001b[32m684\u001b[39m x = \u001b[43mnp\u001b[49m\u001b[43m.\u001b[49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msubok\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 685\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m x.dtype.isnative:\n\u001b[32m 686\u001b[39m \u001b[38;5;66;03m# If we have already made a copy, do the byteswap in place, else make a\u001b[39;00m\n\u001b[32m 687\u001b[39m \u001b[38;5;66;03m# copy with the byte order swapped.\u001b[39;00m\n\u001b[32m 688\u001b[39m \u001b[38;5;66;03m# Swap to native order.\u001b[39;00m\n\u001b[32m 689\u001b[39m x = x.byteswap(inplace=copy).view(x.dtype.newbyteorder(\u001b[33m'\u001b[39m\u001b[33mN\u001b[39m\u001b[33m'\u001b[39m))\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/mrpro/.venv/lib/python3.11/site-packages/torch/_tensor.py:1194\u001b[39m, in \u001b[36mTensor.__array__\u001b[39m\u001b[34m(self, dtype)\u001b[39m\n\u001b[32m 1192\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(Tensor.__array__, (\u001b[38;5;28mself\u001b[39m,), \u001b[38;5;28mself\u001b[39m, dtype=dtype)\n\u001b[32m 1193\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m dtype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1194\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mnumpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1195\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 1196\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.numpy().astype(dtype, copy=\u001b[38;5;28;01mFalse\u001b[39;00m)\n", + "\u001b[31mTypeError\u001b[39m: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first." + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAAGiCAYAAACGUJO6AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQAAGwdJREFUeJzt3X9M3dX9x/EX0HKpsdA6xoWyq6x1/ralgmVYG+dyJ4kG1z8WmTWFEX9MZUZ7s9liW1Crpau2I7NoY9XpHzqqRo2xBKdMYlSWRloSnW1NpRVmvLclrtyOKrTc8/1j316HBcsH+dG3PB/J5w/OPud+zj1h9+m9vfeS4JxzAgDAmMSJXgAAACNBwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmeQ7Y22+/reLiYs2aNUsJCQl65ZVXTjqnublZl1xyiXw+n84++2w9/fTTI1gqAABf8xywnp4ezZs3T3V1dcM6f9++fbrmmmt05ZVXqq2tTXfddZduuukmvf76654XCwDAcQnf5ct8ExIS9PLLL2vx4sVDnrN8+XJt27ZNH374YXzs17/+tQ4dOqTGxsaRXhoAMMlNGesLtLS0KBgMDhgrKirSXXfdNeSc3t5e9fb2xn+OxWL64osv9IMf/EAJCQljtVQAwBhwzunw4cOaNWuWEhNH760XYx6wcDgsv98/YMzv9ysajerLL7/UtGnTTphTU1Oj++67b6yXBgAYR52dnfrRj340arc35gEbicrKSoVCofjP3d3dOvPMM9XZ2anU1NQJXBkAwKtoNKpAIKDp06eP6u2OecAyMzMViUQGjEUiEaWmpg767EuSfD6ffD7fCeOpqakEDACMGu1/Ahrzz4EVFhaqqalpwNgbb7yhwsLCsb40AOB7zHPA/vOf/6itrU1tbW2S/vs2+ba2NnV0dEj678t/paWl8fNvvfVWtbe36+6779bu3bv16KOP6vnnn9eyZctG5x4AACYlzwF7//33NX/+fM2fP1+SFAqFNH/+fFVVVUmSPv/883jMJOnHP/6xtm3bpjfeeEPz5s3Thg0b9MQTT6ioqGiU7gIAYDL6Tp8DGy/RaFRpaWnq7u7m38AAwJixegznuxABACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGDSiAJWV1ennJwcpaSkqKCgQNu3b//W82tra3Xuuedq2rRpCgQCWrZsmb766qsRLRgAAGkEAdu6datCoZCqq6u1Y8cOzZs3T0VFRTpw4MCg5z/33HNasWKFqqurtWvXLj355JPaunWr7rnnnu+8eADA5OU5YBs3btTNN9+s8vJyXXDBBdq8ebNOO+00PfXUU4Oe/95772nhwoVasmSJcnJydNVVV+n6668/6bM2AAC+jaeA9fX1qbW1VcFg8OsbSExUMBhUS0vLoHMuu+wytba2xoPV3t6uhoYGXX311UNep7e3V9FodMABAMD/muLl5K6uLvX398vv9w8Y9/v92r1796BzlixZoq6uLl1++eVyzunYsWO69dZbv/UlxJqaGt13331elgYAmGTG/F2Izc3NWrt2rR599FHt2LFDL730krZt26Y1a9YMOaeyslLd3d3xo7Ozc6yXCQAwxtMzsPT0dCUlJSkSiQwYj0QiyszMHHTO6tWrtXTpUt10002SpIsvvlg9PT265ZZbtHLlSiUmnthQn88nn8/nZWkAgEnG0zOw5ORk5eXlqampKT4Wi8XU1NSkwsLCQeccOXLkhEglJSVJkpxzXtcLAIAkj8/AJCkUCqmsrEz5+flasGCBamtr1dPTo/LycklSaWmpsrOzVVNTI0kqLi7Wxo0bNX/+fBUUFGjv3r1avXq1iouL4yEDAMArzwErKSnRwYMHVVVVpXA4rNzcXDU2Nsbf2NHR0THgGdeqVauUkJCgVatW6bPPPtMPf/hDFRcX68EHHxy9ewEAmHQSnIHX8aLRqNLS0tTd3a3U1NSJXg4AwIOxegznuxABACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGDSiAJWV1ennJwcpaSkqKCgQNu3b//W8w8dOqSKigplZWXJ5/PpnHPOUUNDw4gWDACAJE3xOmHr1q0KhULavHmzCgoKVFtbq6KiIu3Zs0cZGRknnN/X16df/OIXysjI0Isvvqjs7Gx9+umnmjFjxmisHwAwSSU455yXCQUFBbr00ku1adMmSVIsFlMgENAdd9yhFStWnHD+5s2b9dBDD2n37t2aOnXqiBYZjUaVlpam7u5upaamjug2AAATY6wewz29hNjX16fW1lYFg8GvbyAxUcFgUC0tLYPOefXVV1VYWKiKigr5/X5ddNFFWrt2rfr7+4e8Tm9vr6LR6IADAID/5SlgXV1d6u/vl9/vHzDu9/sVDocHndPe3q4XX3xR/f39amho0OrVq7VhwwY98MADQ16npqZGaWlp8SMQCHhZJgBgEhjzdyHGYjFlZGTo8ccfV15enkpKSrRy5Upt3rx5yDmVlZXq7u6OH52dnWO9TACAMZ7exJGenq6kpCRFIpEB45FIRJmZmYPOycrK0tSpU5WUlBQfO//88xUOh9XX16fk5OQT5vh8Pvl8Pi9LAwBMMp6egSUnJysvL09NTU3xsVgspqamJhUWFg46Z+HChdq7d69isVh87OOPP1ZWVtag8QIAYDg8v4QYCoW0ZcsWPfPMM9q1a5duu+029fT0qLy8XJJUWlqqysrK+Pm33XabvvjiC9155536+OOPtW3bNq1du1YVFRWjdy8AAJOO58+BlZSU6ODBg6qqqlI4HFZubq4aGxvjb+zo6OhQYuLXXQwEAnr99de1bNkyzZ07V9nZ2brzzju1fPny0bsXAIBJx/PnwCYCnwMDALtOic+BAQBwqiBgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwKQRBayurk45OTlKSUlRQUGBtm/fPqx59fX1SkhI0OLFi0dyWQAA4jwHbOvWrQqFQqqurtaOHTs0b948FRUV6cCBA986b//+/fr973+vRYsWjXixAAAc5zlgGzdu1M0336zy8nJdcMEF2rx5s0477TQ99dRTQ87p7+/XDTfcoPvuu0+zZ88+6TV6e3sVjUYHHAAA/C9PAevr61Nra6uCweDXN5CYqGAwqJaWliHn3X///crIyNCNN944rOvU1NQoLS0tfgQCAS/LBABMAp4C1tXVpf7+fvn9/gHjfr9f4XB40DnvvPOOnnzySW3ZsmXY16msrFR3d3f86Ozs9LJMAMAkMGUsb/zw4cNaunSptmzZovT09GHP8/l88vl8Y7gyAIB1ngKWnp6upKQkRSKRAeORSESZmZknnP/JJ59o//79Ki4ujo/FYrH/XnjKFO3Zs0dz5swZyboBAJOcp5cQk5OTlZeXp6ampvhYLBZTU1OTCgsLTzj/vPPO0wcffKC2trb4ce211+rKK69UW1sb/7YFABgxzy8hhkIhlZWVKT8/XwsWLFBtba16enpUXl4uSSotLVV2drZqamqUkpKiiy66aMD8GTNmSNIJ4wAAeOE5YCUlJTp48KCqqqoUDoeVm5urxsbG+Bs7Ojo6lJjIF3wAAMZWgnPOTfQiTiYajSotLU3d3d1KTU2d6OUAADwYq8dwnioBAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMCkEQWsrq5OOTk5SklJUUFBgbZv3z7kuVu2bNGiRYs0c+ZMzZw5U8Fg8FvPBwBgODwHbOvWrQqFQqqurtaOHTs0b948FRUV6cCBA4Oe39zcrOuvv15vvfWWWlpaFAgEdNVVV+mzzz77zosHAExeCc4552VCQUGBLr30Um3atEmSFIvFFAgEdMcdd2jFihUnnd/f36+ZM2dq06ZNKi0tHfSc3t5e9fb2xn+ORqMKBALq7u5Wamqql+UCACZYNBpVWlraqD+Ge3oG1tfXp9bWVgWDwa9vIDFRwWBQLS0tw7qNI0eO6OjRozrjjDOGPKempkZpaWnxIxAIeFkmAGAS8BSwrq4u9ff3y+/3Dxj3+/0Kh8PDuo3ly5dr1qxZAyL4TZWVleru7o4fnZ2dXpYJAJgEpoznxdatW6f6+no1NzcrJSVlyPN8Pp98Pt84rgwAYI2ngKWnpyspKUmRSGTAeCQSUWZm5rfOffjhh7Vu3Tq9+eabmjt3rveVAgDwPzy9hJicnKy8vDw1NTXFx2KxmJqamlRYWDjkvPXr12vNmjVqbGxUfn7+yFcLAMD/8/wSYigUUllZmfLz87VgwQLV1taqp6dH5eXlkqTS0lJlZ2erpqZGkvTHP/5RVVVVeu6555STkxP/t7LTTz9dp59++ijeFQDAZOI5YCUlJTp48KCqqqoUDoeVm5urxsbG+Bs7Ojo6lJj49RO7xx57TH19ffrVr3414Haqq6t17733frfVAwAmLc+fA5sIY/UZAgDA2DslPgcGAMCpgoABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAk0YUsLq6OuXk5CglJUUFBQXavn37t57/wgsv6LzzzlNKSoouvvhiNTQ0jGixAAAc5zlgW7duVSgUUnV1tXbs2KF58+apqKhIBw4cGPT89957T9dff71uvPFG7dy5U4sXL9bixYv14YcffufFAwAmrwTnnPMyoaCgQJdeeqk2bdokSYrFYgoEArrjjju0YsWKE84vKSlRT0+PXnvttfjYT3/6U+Xm5mrz5s2DXqO3t1e9vb3xn7u7u3XmmWeqs7NTqampXpYLAJhg0WhUgUBAhw4dUlpa2ujdsPOgt7fXJSUluZdffnnAeGlpqbv22msHnRMIBNyf/vSnAWNVVVVu7ty5Q16nurraSeLg4ODg+B4dn3zyiZfknNQUedDV1aX+/n75/f4B436/X7t37x50TjgcHvT8cDg85HUqKysVCoXiPx86dEhnnXWWOjo6Rrfe3zPH/yuHZ6rfjn06OfZoeNin4Tn+KtoZZ5wxqrfrKWDjxefzyefznTCelpbGL8kwpKamsk/DwD6dHHs0POzT8CQmju4b3z3dWnp6upKSkhSJRAaMRyIRZWZmDjonMzPT0/kAAAyHp4AlJycrLy9PTU1N8bFYLKampiYVFhYOOqewsHDA+ZL0xhtvDHk+AADD4fklxFAopLKyMuXn52vBggWqra1VT0+PysvLJUmlpaXKzs5WTU2NJOnOO+/UFVdcoQ0bNuiaa65RfX293n//fT3++OPDvqbP51N1dfWgLyvia+zT8LBPJ8ceDQ/7NDxjtU+e30YvSZs2bdJDDz2kcDis3Nxc/fnPf1ZBQYEk6Wc/+5lycnL09NNPx89/4YUXtGrVKu3fv18/+clPtH79el199dWjdicAAJPPiAIGAMBE47sQAQAmETAAgEkEDABgEgEDAJh0ygSMP9EyPF72acuWLVq0aJFmzpypmTNnKhgMnnRfvw+8/i4dV19fr4SEBC1evHhsF3iK8LpPhw4dUkVFhbKysuTz+XTOOedMiv/fed2n2tpanXvuuZo2bZoCgYCWLVumr776apxWOzHefvttFRcXa9asWUpISNArr7xy0jnNzc265JJL5PP5dPbZZw945/qwjeo3K45QfX29S05Odk899ZT75z//6W6++WY3Y8YMF4lEBj3/3XffdUlJSW79+vXuo48+cqtWrXJTp051H3zwwTivfHx53aclS5a4uro6t3PnTrdr1y73m9/8xqWlpbl//etf47zy8eN1j47bt2+fy87OdosWLXK//OUvx2exE8jrPvX29rr8/Hx39dVXu3feecft27fPNTc3u7a2tnFe+fjyuk/PPvus8/l87tlnn3X79u1zr7/+usvKynLLli0b55WPr4aGBrdy5Ur30ksvOUknfOH7N7W3t7vTTjvNhUIh99FHH7lHHnnEJSUlucbGRk/XPSUCtmDBAldRURH/ub+/382aNcvV1NQMev51113nrrnmmgFjBQUF7re//e2YrnOied2nbzp27JibPn26e+aZZ8ZqiRNuJHt07Ngxd9lll7knnnjClZWVTYqAed2nxx57zM2ePdv19fWN1xJPCV73qaKiwv385z8fMBYKhdzChQvHdJ2nkuEE7O6773YXXnjhgLGSkhJXVFTk6VoT/hJiX1+fWltbFQwG42OJiYkKBoNqaWkZdE5LS8uA8yWpqKhoyPO/D0ayT9905MgRHT16dNS/EfpUMdI9uv/++5WRkaEbb7xxPJY54UayT6+++qoKCwtVUVEhv9+viy66SGvXrlV/f/94LXvcjWSfLrvsMrW2tsZfZmxvb1dDQwNf3PANo/UYPuHfRj9ef6LFupHs0zctX75cs2bNOuEX5/tiJHv0zjvv6Mknn1RbW9s4rPDUMJJ9am9v19///nfdcMMNamho0N69e3X77bfr6NGjqq6uHo9lj7uR7NOSJUvU1dWlyy+/XM45HTt2TLfeeqvuueee8ViyGUM9hkejUX355ZeaNm3asG5nwp+BYXysW7dO9fX1evnll5WSkjLRyzklHD58WEuXLtWWLVuUnp4+0cs5pcViMWVkZOjxxx9XXl6eSkpKtHLlyiH/qvpk1dzcrLVr1+rRRx/Vjh079NJLL2nbtm1as2bNRC/te2nCn4HxJ1qGZyT7dNzDDz+sdevW6c0339TcuXPHcpkTyuseffLJJ9q/f7+Ki4vjY7FYTJI0ZcoU7dmzR3PmzBnbRU+AkfwuZWVlaerUqUpKSoqPnX/++QqHw+rr61NycvKYrnkijGSfVq9eraVLl+qmm26SJF188cXq6enRLbfcopUrV47638OyaqjH8NTU1GE/+5JOgWdg/ImW4RnJPknS+vXrtWbNGjU2Nio/P388ljphvO7Reeedpw8++EBtbW3x49prr9WVV16ptrY2BQKB8Vz+uBnJ79LChQu1d+/eeOAl6eOPP1ZWVtb3Ml7SyPbpyJEjJ0TqePQdXzsbN2qP4d7eXzI26uvrnc/nc08//bT76KOP3C233OJmzJjhwuGwc865pUuXuhUrVsTPf/fdd92UKVPcww8/7Hbt2uWqq6snzdvovezTunXrXHJysnvxxRfd559/Hj8OHz48UXdhzHndo2+aLO9C9LpPHR0dbvr06e53v/ud27Nnj3vttddcRkaGe+CBBybqLowLr/tUXV3tpk+f7v7617+69vZ297e//c3NmTPHXXfddRN1F8bF4cOH3c6dO93OnTudJLdx40a3c+dO9+mnnzrnnFuxYoVbunRp/Pzjb6P/wx/+4Hbt2uXq6ursvo3eOeceeeQRd+aZZ7rk5GS3YMEC949//CP+v11xxRWurKxswPnPP/+8O+ecc1xycrK78MIL3bZt28Z5xRPDyz6dddZZTtIJR3V19fgvfBx5/V36X5MlYM5536f33nvPFRQUOJ/P52bPnu0efPBBd+zYsXFe9fjzsk9Hjx519957r5szZ45LSUlxgUDA3X777e7f//73+C98HL311luDPtYc35uysjJ3xRVXnDAnNzfXJScnu9mzZ7u//OUvnq/Ln1MBAJg04f8GBgDASBAwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABg0v8Bc0z++5j1+JwAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.imshow(x[0,0])\n", + "plt.show()\n", + "plt.imshow(m(x)[0,0])\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/nn/test_attentiongate.py b/tests/nn/test_attentiongate.py index d209d6483..4b470be1c 100644 --- a/tests/nn/test_attentiongate.py +++ b/tests/nn/test_attentiongate.py @@ -2,9 +2,16 @@ import pytest from mrpro.nn.AttentionGate import AttentionGate -from mrpro.utils.RandomGenerator import RandomGenerator +from mrpro.utils import RandomGenerator +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) @pytest.mark.parametrize( ('dim', 'channels_gate', 'channels_in', 'channels_hidden', 'input_shape', 'gate_shape'), [ @@ -12,12 +19,14 @@ (3, 32, 4, 8, (2, 4, 16, 16, 16), (2, 32, 16, 16, 16)), ], ) -def test_attention_gate(dim, channels_gate, channels_in, channels_hidden, input_shape, gate_shape): +def test_attention_gate(dim, channels_gate, channels_in, channels_hidden, input_shape, gate_shape, device): """Test AttentionGate output shape and backpropagation.""" rng = RandomGenerator(seed=42) - x = rng.float32_tensor(input_shape).requires_grad_(True) - gate = rng.float32_tensor(gate_shape).requires_grad_(True) - attn = AttentionGate(dim=dim, channels_gate=channels_gate, channels_in=channels_in, channels_hidden=channels_hidden) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + gate = rng.float32_tensor(gate_shape).to(device).requires_grad_(True) + attn = AttentionGate( + dim=dim, channels_gate=channels_gate, channels_in=channels_in, channels_hidden=channels_hidden + ).to(device) output = attn(x, gate) assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() diff --git a/tests/nn/test_complexaschannel.py b/tests/nn/test_complexaschannel.py new file mode 100644 index 000000000..e731da5dd --- /dev/null +++ b/tests/nn/test_complexaschannel.py @@ -0,0 +1,30 @@ +"""Tests for ComplexAsChannel module.""" + +from mrpro.nn.ComplexAsChannel import ComplexAsChannel +from mrpro.utils import RandomGenerator +from torch.nn import Linear +import pytest + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_complexaschannel(device): + """Test ComplexAsChannel output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + input_shape = (1, 32) + x = rng.complex64_tensor(input_shape).to(device).requires_grad_(True) + module = ComplexAsChannel(Linear(input_shape[1] * 2, input_shape[1] * 2)).to(device) + output = module(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + assert output.is_complex(), 'Output is not complex' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.isnan().any(), 'NaN values in input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert module.module.weight.grad is not None, 'No gradient computed for weight' + assert module.module.bias.grad is not None, 'No gradient computed for bias' diff --git a/tests/nn/test_film.py b/tests/nn/test_film.py index b329ab313..e069d4529 100644 --- a/tests/nn/test_film.py +++ b/tests/nn/test_film.py @@ -2,9 +2,16 @@ import pytest from mrpro.nn.FiLM import FiLM -from mrpro.utils.RandomGenerator import RandomGenerator +from mrpro.utils import RandomGenerator +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) @pytest.mark.parametrize( ('channels', 'channels_emb', 'input_shape', 'emb_shape'), [ @@ -12,12 +19,12 @@ (32, 16, (2, 32, 16, 16), (2, 16)), ], ) -def test_film(channels, channels_emb, input_shape, emb_shape): +def test_film(channels, channels_emb, input_shape, emb_shape, device): """Test FiLM output shape and backpropagation.""" rng = RandomGenerator(seed=42) - x = rng.float32_tensor(input_shape).requires_grad_(True) - emb = rng.float32_tensor(emb_shape).requires_grad_(True) - film = FiLM(channels=channels, channels_emb=channels_emb) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + emb = rng.float32_tensor(emb_shape).to(device).requires_grad_(True) + film = FiLM(channels=channels, channels_emb=channels_emb).to(device) output = film(x, emb) assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() diff --git a/tests/nn/test_groupnorm32.py b/tests/nn/test_groupnorm32.py new file mode 100644 index 000000000..549dbf2b0 --- /dev/null +++ b/tests/nn/test_groupnorm32.py @@ -0,0 +1,35 @@ +"""Tests for GroupNorm32 module.""" + +import pytest + +from mrpro.nn import GroupNorm32 +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('channels', 'groups', 'input_shape'), + [ + (32, None, (1, 32, 32, 32)), + (64, 8, (2, 64, 16, 16, 16)), + ], +) +def test_groupnorm32(channels, groups, input_shape, device): + """Test GroupNorm32 output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + norm = GroupNorm32(channels=channels, groups=groups).to(device) + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.isnan().any(), 'NaN values in input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert norm.weight.grad is not None, 'No gradient computed for weight' + assert norm.bias.grad is not None, 'No gradient computed for bias' diff --git a/tests/nn/test_resblock.py b/tests/nn/test_resblock.py new file mode 100644 index 000000000..e16acfbf9 --- /dev/null +++ b/tests/nn/test_resblock.py @@ -0,0 +1,42 @@ +"""Tests for ResBlock module.""" + +import pytest + +from mrpro.nn import ResBlock +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'channels_in', 'channels_out', 'channels_emb', 'input_shape', 'emb_shape'), + [ + (2, 32, 32, 16, (1, 32, 32, 32), (1, 16)), + (3, 64, 32, 0, (2, 64, 16, 16, 16), None), + ], +) +def test_resblock(dim, channels_in, channels_out, channels_emb, input_shape, emb_shape, device): + """Test ResBlock output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + emb = rng.float32_tensor(emb_shape).to(device).requires_grad_(True) if emb_shape else None + res = ResBlock(dim=dim, channels_in=channels_in, channels_out=channels_out, channels_emb=channels_emb).to(device) + output = res(x, emb) + assert output.shape == (input_shape[0], channels_out, *input_shape[2:]), ( + f'Output shape {output.shape} != expected {(input_shape[0], channels_out, *input_shape[2:])}' + ) + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.isnan().any(), 'NaN values in input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert res.block[2].weight.grad is not None, 'No gradient computed for first Conv' + assert res.block[5].weight.grad is not None, 'No gradient computed for second Conv' + if emb is not None: + assert emb.grad is not None, 'No gradient computed for embedding' + assert not emb.isnan().any(), 'NaN values in embedding' + assert not emb.grad.isnan().any(), 'NaN values in embedding gradients' diff --git a/tests/nn/test_sequential.py b/tests/nn/test_sequential.py new file mode 100644 index 000000000..da3297cdb --- /dev/null +++ b/tests/nn/test_sequential.py @@ -0,0 +1,41 @@ +"""Tests for Sequential module.""" + +import pytest +from mrpro.nn import FiLM, Sequential +from mrpro.operators import FastFourierOp +from mrpro.utils import RandomGenerator +from torch.nn import Linear + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('input_shape', 'emb_shape'), + [ + ((1, 32), (1, 16)), + ((2, 64), None), + ], +) +def test_sequential(input_shape, emb_shape, device): + """Test Sequential output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + emb = rng.float32_tensor(emb_shape).to(device).requires_grad_(True) if emb_shape else None + seq = Sequential( + Linear(input_shape[1], 64), + FastFourierOp(), + FiLM(channels=64, channels_emb=16), + ).to(device) + output = seq(x, emb) + assert output.shape == (input_shape[0], 32), f'Output shape {output.shape} != expected {(input_shape[0], 32)}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.isnan().any(), 'NaN values in input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert seq[0].weight.grad is not None, 'No gradient computed for first Linear' + assert seq[2].weight.grad is not None, 'No gradient computed for second Linear' diff --git a/tests/nn/test_shiftedwindowattention.py b/tests/nn/test_shiftedwindowattention.py index 411c04d63..545f90360 100644 --- a/tests/nn/test_shiftedwindowattention.py +++ b/tests/nn/test_shiftedwindowattention.py @@ -1,6 +1,6 @@ import pytest from mrpro.nn import ShiftedWindowAttention -from mrpro.utils.RandomGenerator import RandomGenerator +from mrpro.utils import RandomGenerator @pytest.mark.parametrize( diff --git a/tests/nn/test_sqeezeexcitation.py b/tests/nn/test_sqeezeexcitation.py index b241aa56b..8929b9868 100644 --- a/tests/nn/test_sqeezeexcitation.py +++ b/tests/nn/test_sqeezeexcitation.py @@ -1,8 +1,8 @@ """Tests for SqueezeExcitation module.""" import pytest -from mrpro.nn.SqueezeExcitation import SqueezeExcitation -from mrpro.utils.RandomGenerator import RandomGenerator +from mrpro.nn import SqueezeExcitation +from mrpro.utils import RandomGenerator @pytest.mark.parametrize( diff --git a/tests/nn/test_transposedattention.py b/tests/nn/test_transposedattention.py new file mode 100644 index 000000000..e865fbc7a --- /dev/null +++ b/tests/nn/test_transposedattention.py @@ -0,0 +1,37 @@ +"""Tests for TransposedAttention module.""" + +import pytest + +from mrpro.nn import TransposedAttention +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'channels', 'num_heads', 'input_shape'), + [ + (2, 32, 4, (1, 32, 32, 32)), + (3, 64, 8, (2, 64, 16, 16, 16)), + ], +) +def test_transposed_attention(dim, channels, num_heads, input_shape, device): + """Test TransposedAttention output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + attn = TransposedAttention(dim=dim, channels=channels, num_heads=num_heads).to(device) + output = attn(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.isnan().any(), 'NaN values in input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert attn.qkv.weight.grad is not None, 'No gradient computed for qkv' + assert attn.qkv_dwconv.weight.grad is not None, 'No gradient computed for qkv_dwconv' + assert attn.project_out.weight.grad is not None, 'No gradient computed for project_out' + assert attn.temperature.grad is not None, 'No gradient computed for temperature' From 9cfae55b33c34de10f606019ca3e0d03fc560ee0 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 14 May 2025 02:22:29 +0200 Subject: [PATCH 026/205] update --- .vscode/settings.json | 2 +- src/mrpro/nn/ComplexAsChannel.py | 2 +- src/mrpro/nn/FiLM.py | 4 +-- src/mrpro/nn/GroupNorm32.py | 2 +- src/mrpro/nn/ResBlock.py | 7 ++--- src/mrpro/nn/Sequential.py | 3 +-- src/mrpro/nn/__init__.py | 35 ++++++++++++------------- tests/nn/test_complexaschannel.py | 4 +-- tests/nn/test_groupnorm32.py | 1 - tests/nn/test_resblock.py | 2 -- tests/nn/test_shiftedwindowattention.py | 15 ++++++++--- tests/nn/test_transposedattention.py | 1 - 12 files changed, 41 insertions(+), 37 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index e1cbd013e..a63f276cc 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -14,7 +14,7 @@ }, "python.testing.pytestArgs": [ "tests", - "-m not cuda" + // "-m not cuda" ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, diff --git a/src/mrpro/nn/ComplexAsChannel.py b/src/mrpro/nn/ComplexAsChannel.py index b64c00151..5acce1245 100644 --- a/src/mrpro/nn/ComplexAsChannel.py +++ b/src/mrpro/nn/ComplexAsChannel.py @@ -47,6 +47,6 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Ten y = self.module(x_real) if x.is_complex(): - y = rearrange(y, 'b (c x y) ... complex -> batch channel ... complex', complex=2).contiguous() + y = rearrange(y, 'b (channel complex) ... -> b channel ... complex', complex=2).contiguous() y = torch.view_as_complex(y) return y diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py index c74825e57..2977adca9 100644 --- a/src/mrpro/nn/FiLM.py +++ b/src/mrpro/nn/FiLM.py @@ -34,7 +34,7 @@ def __init__(self, channels: int, channels_emb: int) -> None: Linear(channels_emb, 2 * channels), ) - def __call__(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + def __call__(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: """Apply FiLM. Parameters @@ -46,7 +46,7 @@ def __call__(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: """ return super().__call__(x, emb) - def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: """Apply FiLM.""" if emb is None: return x diff --git a/src/mrpro/nn/GroupNorm32.py b/src/mrpro/nn/GroupNorm32.py index 0baf66abe..55c11d1f6 100644 --- a/src/mrpro/nn/GroupNorm32.py +++ b/src/mrpro/nn/GroupNorm32.py @@ -44,4 +44,4 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply GroupNorm32.""" - return super().__call__(x.float()).type(x.dtype) + return super().forward(x.float()).type(x.dtype) diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py index 35a13baf2..da6ee3fff 100644 --- a/src/mrpro/nn/ResBlock.py +++ b/src/mrpro/nn/ResBlock.py @@ -1,12 +1,13 @@ """Residual convolution block with two convolutions.""" import torch -from torch.nn import Identity, Module, Sequential, SiLU +from torch.nn import Identity, Module, SiLU from mrpro.nn.EmbMixin import EmbMixin from mrpro.nn.FiLM import FiLM from mrpro.nn.GroupNorm32 import GroupNorm32 from mrpro.nn.NDModules import ConvND +from mrpro.nn.Sequential import Sequential class ResBlock(EmbMixin, Module): @@ -33,10 +34,10 @@ def __init__(self, dim: int, channels_in: int, channels_out: int, channels_emb: self.block = Sequential( GroupNorm32(channels_in), SiLU(), - ConvND(dim)(channels_in, channels_out, kernel_size=3), + ConvND(dim)(channels_in, channels_out, kernel_size=3, padding=1), GroupNorm32(channels_out), SiLU(), - ConvND(dim)(channels_out, channels_out, kernel_size=3), + ConvND(dim)(channels_out, channels_out, kernel_size=3, padding=1), ) if channels_emb > 0: self.block.insert(-3, FiLM(channels_out, channels_emb)) diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py index 33b4eba33..b08e43f62 100644 --- a/src/mrpro/nn/Sequential.py +++ b/src/mrpro/nn/Sequential.py @@ -1,5 +1,4 @@ import torch -from torch.nn import Module from mrpro.nn.EmbMixin import EmbMixin from mrpro.operators import Operator @@ -27,7 +26,7 @@ def __call__(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Te def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: """Apply all modules in series to the input.""" for module in self: - if isinstance(EmbMixin, Module): + if isinstance(module, EmbMixin): x = module(x, emb) elif isinstance(module, Operator): (x,) = module(x) diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index b6e3f9b8d..68b911e02 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -21,22 +21,21 @@ from mrpro.nn.TransposedAttention import TransposedAttention __all__ = [ - 'AdaptiveAvgPoolND', - 'AttentionGate', - 'AvgPoolND', - 'BatchNormND', - 'ConvND', - 'ConvTransposeND', - 'EmbMixin', - 'EmbSequential', - 'FiLM', - 'GroupNorm32', - 'InstanceNormND', - 'MaxPoolND', - 'NeighborhoodSelfAttention', - 'ResBlock', - 'Sequential', - 'ShiftedWindowAttention', - 'SqueezeExcitation', - 'TransposedAttention', + "AdaptiveAvgPoolND", + "AttentionGate", + "AvgPoolND", + "BatchNormND", + "ConvND", + "ConvTransposeND", + "EmbMixin", + "FiLM", + "GroupNorm32", + "InstanceNormND", + "MaxPoolND", + "NeighborhoodSelfAttention", + "ResBlock", + "Sequential", + "ShiftedWindowAttention", + "SqueezeExcitation", + "TransposedAttention" ] \ No newline at end of file diff --git a/tests/nn/test_complexaschannel.py b/tests/nn/test_complexaschannel.py index e731da5dd..5eb1e87f8 100644 --- a/tests/nn/test_complexaschannel.py +++ b/tests/nn/test_complexaschannel.py @@ -1,9 +1,9 @@ """Tests for ComplexAsChannel module.""" +import pytest from mrpro.nn.ComplexAsChannel import ComplexAsChannel from mrpro.utils import RandomGenerator from torch.nn import Linear -import pytest @pytest.mark.parametrize( @@ -22,7 +22,7 @@ def test_complexaschannel(device): output = module(x) assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' assert output.is_complex(), 'Output is not complex' - output.sum().backward() + output.sum().abs().backward() assert x.grad is not None, 'No gradient computed for input' assert not x.isnan().any(), 'NaN values in input' assert not x.grad.isnan().any(), 'NaN values in input gradients' diff --git a/tests/nn/test_groupnorm32.py b/tests/nn/test_groupnorm32.py index 549dbf2b0..389b8ca85 100644 --- a/tests/nn/test_groupnorm32.py +++ b/tests/nn/test_groupnorm32.py @@ -1,7 +1,6 @@ """Tests for GroupNorm32 module.""" import pytest - from mrpro.nn import GroupNorm32 from mrpro.utils import RandomGenerator diff --git a/tests/nn/test_resblock.py b/tests/nn/test_resblock.py index e16acfbf9..957eb70c6 100644 --- a/tests/nn/test_resblock.py +++ b/tests/nn/test_resblock.py @@ -1,7 +1,6 @@ """Tests for ResBlock module.""" import pytest - from mrpro.nn import ResBlock from mrpro.utils import RandomGenerator @@ -35,7 +34,6 @@ def test_resblock(dim, channels_in, channels_out, channels_emb, input_shape, emb assert not x.isnan().any(), 'NaN values in input' assert not x.grad.isnan().any(), 'NaN values in input gradients' assert res.block[2].weight.grad is not None, 'No gradient computed for first Conv' - assert res.block[5].weight.grad is not None, 'No gradient computed for second Conv' if emb is not None: assert emb.grad is not None, 'No gradient computed for embedding' assert not emb.isnan().any(), 'NaN values in embedding' diff --git a/tests/nn/test_shiftedwindowattention.py b/tests/nn/test_shiftedwindowattention.py index 545f90360..773e0daff 100644 --- a/tests/nn/test_shiftedwindowattention.py +++ b/tests/nn/test_shiftedwindowattention.py @@ -3,6 +3,13 @@ from mrpro.utils import RandomGenerator +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) @pytest.mark.parametrize( ('dim', 'window_size', 'shifted'), [ @@ -10,14 +17,16 @@ (4, 4, True), ], ) -def test_shifted_window_attentio(dim: int, window_size: int, shifted) -> None: +def test_shifted_window_attentio(dim: int, window_size: int, shifted: bool, device: str) -> None: batch = 2 channels = 8 n_heads = 2 spatial_shape = (window_size * 4,) * dim rng = RandomGenerator(13) - x = rng.float32_tensor((batch, channels, *spatial_shape)).requires_grad_(True) - swin = ShiftedWindowAttention(dim=dim, channels=channels, n_heads=n_heads, window_size=window_size, shifted=shifted) + x = rng.float32_tensor((batch, channels, *spatial_shape)).to(device).requires_grad_(True) + swin = ShiftedWindowAttention( + dim=dim, channels=channels, n_heads=n_heads, window_size=window_size, shifted=shifted + ).to(device) out = swin(x) assert out.shape == x.shape, f'Output shape {out.shape} != input shape {x.shape}' assert not out.isnan().any(), 'NaN values in output' diff --git a/tests/nn/test_transposedattention.py b/tests/nn/test_transposedattention.py index e865fbc7a..417743135 100644 --- a/tests/nn/test_transposedattention.py +++ b/tests/nn/test_transposedattention.py @@ -1,7 +1,6 @@ """Tests for TransposedAttention module.""" import pytest - from mrpro.nn import TransposedAttention from mrpro.utils import RandomGenerator From cf4be7f27b769f1a6070bf0169037ad996a452e9 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 14 May 2025 17:14:33 +0200 Subject: [PATCH 027/205] uformer --- src/mrpro/nn/ShiftedWindowAttention.py | 11 +- src/mrpro/nn/nets/Uformer.py | 169 +++++++++++++------------ 2 files changed, 92 insertions(+), 88 deletions(-) diff --git a/src/mrpro/nn/ShiftedWindowAttention.py b/src/mrpro/nn/ShiftedWindowAttention.py index d29ba8238..d0978f8f1 100644 --- a/src/mrpro/nn/ShiftedWindowAttention.py +++ b/src/mrpro/nn/ShiftedWindowAttention.py @@ -22,15 +22,15 @@ class ShiftedWindowAttention(Module): rel_position_index: torch.Tensor - def __init__(self, dim: int, channels: int, n_heads: int, window_size: int = 7, shifted: bool = True): + def __init__(self, dim: int, n_channels_per_head: int, n_heads: int, window_size: int = 7, shifted: bool = True): """Initialize the ShiftedWindowAttention module. Parameters ---------- dim : int The dimension of the input. - channels : int - The number of channels in the input. + n_channels_per_head : int + The number of channels per head. n_heads : int The number of attention heads. The number if channels per head is ``channels // n_heads``. window_size : int @@ -39,13 +39,10 @@ def __init__(self, dim: int, channels: int, n_heads: int, window_size: int = 7, Whether to shift the window. """ super().__init__() - if channels % n_heads: - raise ValueError('channels must be divisible by n_heads.') - self.channels = channels self.n_heads = n_heads self.window_size = window_size self.shifted = shifted - self.to_qkv = ConvND(dim)(channels, 3 * channels, 1) + self.to_qkv = ConvND(dim)(n_channels_per_head * n_heads, 3 * n_channels_per_head * n_heads, 1) self.dim = dim coords_1d = torch.arange(window_size) coords_nd = torch.stack(torch.meshgrid(*([coords_1d] * dim), indexing='ij'), 0).flatten(1) diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index 552a7c6d9..f2bd80fc4 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -1,7 +1,13 @@ +from collections.abc import Sequence +from itertools import pairwise + import torch -from torch.nn import GELU, Module, Sequential +from sympy import Identity +from torch.nn import GELU, LeakyReLU, Module, Sequential -from mrpro.nn.NDModules import ConvND +from mrpro.nn.NDModules import ConvND, ConvTransposeND, InstanceNormND +from mrpro.nn.nets import UNet +from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention class LeFF(Module): @@ -47,89 +53,90 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LeWinTransformerBlock(Module): def __init__( self, - dim, - channels, - input_resolution, - num_heads, - win_size=8, - shift_size=0, - mlp_ratio=4.0, - qkv_bias=True, - qk_scale=None, - norm_layer=nn.LayerNorm, - token_projection='linear', - ): + dim: int, + n_channels_per_head: int, + n_heads: int, + window_size: int = 8, + shifted: bool = False, + mlp_ratio: float = 4.0, + ) -> None: super().__init__() - self.channels = channels - self.input_resolution = input_resolution - self.num_heads = num_heads - self.win_size = win_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - self.token_mlp = token_mlp - self.modulator = Embedding(win_size * win_size, channels) # modulator - self.norm1 = norm_layer(channels) - self.attn = WindowAttention( - channels, - win_size=to_2tuple(self.win_size), - num_heads=num_heads, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, - token_projection=token_projection, + channels = n_channels_per_head * n_heads + self.norm1 = InstanceNormND(dim)(channels) + self.attn = ShiftedWindowAttention( + dim=dim, + n_channels_per_head=n_channels_per_head, + n_heads=n_heads, + window_size=window_size, + shifted=shifted, ) - self.norm2 = norm_layer(channels) - mlp_hidden_dim = int(channels * mlp_ratio) - self.mlp = LeFF(channels, mlp_hidden_dim) - - def extra_repr(self) -> str: - return ( - f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' - f'win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio},modulator={self.modulator}' - ) - - def forward(self, x, mask=None): - B, L, C = x.shape - H = int(math.sqrt(L)) - W = int(math.sqrt(L)) - - ## input mask - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - shifted_x = x - x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C N*C->C - x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C - wmsa_in = self.with_pos_embed(x_windows, self.modulator.weight) - attn_windows = self.attn(wmsa_in, mask=attn_mask) # nW*B, win_size*win_size, C - attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C) - shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C - - x = shortcut + x - x = x + self.mlp(self.norm2(x)) - return x + self.norm2 = InstanceNormND(dim)(channels) + self.ff = LeFF(dim=dim, channels_in=channels, channels_out=channels, expand_ratio=mlp_ratio) + self.modulator = torch.nn.Parameter(torch.empty(channels, *((window_size,) * dim))) + torch.nn.init.trunc_normal_(self.modulator) + def forward(self, x): + modulator = self.modulator.tile([t // s for t, s in zip(x.shape[1:], self.modulator.shape, strict=False)]) + x_mod = self.norm1(x) + modulator + x_attn = self.attn(x_mod) + x_ff = self.ff(self.norm2(x_attn)) + return x + x_ff -class SAM(Module): - """Spatial Attention Module. - Part of the Uformer architecture. - """ - - def __init__(self, dim, channels): +class Uformer(UNet): + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + n_features_per_head: int = 32, + n_heads: Sequence[int] = (1, 2, 4, 8), + n_blocks: int = 2, + window_size: int = 8, + mlp_ratio: float = 4.0, + drop_path_rate: float = 0.1, + ): super().__init__() - self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias) - self.conv2 = conv(n_feat, 3, kernel_size, bias=bias) - self.conv3 = conv(3, n_feat, kernel_size, bias=bias) - - def forward(self, x, x_img): - x1 = self.conv1(x) - img = self.conv2(x) + x_img - x2 = torch.sigmoid(self.conv3(img)) - x1 = x1 * x2 - x1 = x1 + x - return x1, img + + def blocks(n_heads: int): + return [ + LeWinTransformerBlock( + dim=dim, + n_heads=n_heads, + n_features_per_head=n_features_per_head, + window_size=window_size, + mlp_ratio=mlp_ratio, + shifted=bool(i % 2), + ) + for i in range(n_blocks) + ] + + for n_head in n_heads: + self.input_blocks.extend(blocks(n_heads=n_head)) + self.output_blocks.extend(blocks(n_heads=n_head)) + self.skip_blocks.append(Identity()) + self.middle_block = torch.nn.Sequential(*blocks(n_heads=n_heads[-1])) + + for n_head_current, n_head_next in pairwise(n_heads): + self.down_blocks.append( + ConvND(dim)( + n_features_per_head * n_head_current, + n_features_per_head * n_head_next, + kernel_size=4, + stride=2, + padding=1, + ) + ) + self.up_blocks.append( + ConvTransposeND(dim)( + n_features_per_head * n_head_next, n_features_per_head * n_head_current, kernel_size=2, stride=2 + ) + ) + self.first = torch.nn.Sequential( + ConvND(dim)(channels_in, n_features_per_head * n_heads[0], kernel_size=3, stride=1, padding='same'), + LeakyReLU(), + ) + self.last = ConvND(dim)( + n_features_per_head * n_heads[-1], channels_out, kernel_size=3, stride=1, padding='same' + ) From 54a66b61321c58dd95c6056030ac7e32986608f6 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 14 May 2025 17:17:38 +0200 Subject: [PATCH 028/205] fix --- src/mrpro/nn/nets/Uformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index f2bd80fc4..f313baf4c 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -6,7 +6,7 @@ from torch.nn import GELU, LeakyReLU, Module, Sequential from mrpro.nn.NDModules import ConvND, ConvTransposeND, InstanceNormND -from mrpro.nn.nets import UNet +from mrpro.nn.nets.UNet import UNetBase from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention @@ -84,7 +84,7 @@ def forward(self, x): return x + x_ff -class Uformer(UNet): +class Uformer(UNetBase): def __init__( self, dim: int, From 3ae37d1ceaf5b8f1ff22e98517b0873efb04e07f Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 15 May 2025 02:33:40 +0200 Subject: [PATCH 029/205] update --- src/mrpro/nn/DropPath.py | 55 +++++++ src/mrpro/nn/PixelShuffle.py | 49 ++++++ src/mrpro/nn/__init__.py | 3 +- src/mrpro/nn/nets/Restormer.py | 274 +++++++++++++++++++++++++++++++++ src/mrpro/nn/nets/Uformer.py | 122 ++++++++++++--- 5 files changed, 481 insertions(+), 22 deletions(-) create mode 100644 src/mrpro/nn/DropPath.py create mode 100644 src/mrpro/nn/PixelShuffle.py create mode 100644 src/mrpro/nn/nets/Restormer.py diff --git a/src/mrpro/nn/DropPath.py b/src/mrpro/nn/DropPath.py new file mode 100644 index 000000000..2d0abba1e --- /dev/null +++ b/src/mrpro/nn/DropPath.py @@ -0,0 +1,55 @@ +"""DropPath (stochastic depth).""" + +import torch +from torch.nn import Module + + +class DropPath(Module): + """Drop path or stochastic depth. + + Drops full samples from batch with probability `droprate`. + Should be used in the main path of a Resblock. + + References + ---------- + .. [HUANG16] Huang, G., Sun, Y., Liu, Z., Sedra, D., & Weinberger, K. Q. Deep networks with stochastic depth. + ECCV 2016. https://link.springer.com/chapter/10.1007/978-3-319-46493-0_39 + """ + + def __init__(self, droprate: float = 0.0, scale_by_keep: bool = False): + """Initialize the DropPath module. + + Parameters + ---------- + droprate : float, optional + Drop probability + scale_by_keep : bool, optional + If True, the kept samples are scaled by `1/(1-droprate)` + """ + super().__init__() + self.droprate = droprate + self.scale_by_keep = scale_by_keep + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply DropPath. + + Parameters + ---------- + x : torch.Tensor + Input tensor + + Returns + ------- + Tensor with + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply DropPath.""" + if self.droprate == 0 or not self.training: + return x + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + mask = ( + ((1 - self.droprate) + torch.rand(shape, dtype=x.dtype, device=x.device)).floor_().div_(1 - self.droprate) + ) + return x * mask diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py new file mode 100644 index 000000000..fddb903de --- /dev/null +++ b/src/mrpro/nn/PixelShuffle.py @@ -0,0 +1,49 @@ +"""ND-version of PixelShuffle and PixelUnshuffle.""" + +import torch +from torch.nn import Module + + +class PixelUnshuffle(Module): + """ND-version of PixelUnshuffle downscaling.""" + + def __init__(self, downscale_factor: int): + super().__init__() + self.downscale_factor = downscale_factor + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dim = x.ndim - 2 + if dim == 2: # fast path for 2D + return torch.nn.functional.pixel_unshuffle(x, self.downscale_factor) + + new_shape = list(x.shape[:2]) + source_positions = [] + for i, old in enumerate(x.shape[2:]): + new_shape.append(old // self.downscale_factor) + new_shape.append(self.downscale_factor) + source_positions.append(2 + 2 * i) + + x = x.view(new_shape) + x = x.moveaxis(source_positions, tuple(range(-dim, 0))) + x = x.flatten(1, -dim - 1) + return x + + +class PixelShuffle(Module): + """ND-version of PixelShuffle upscaling.""" + + def __init__(self, upscale_factor: int): + super().__init__() + self.upscale_factor = upscale_factor + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dim = x.ndim - 2 + if dim == 2: # fast path for 2D + return torch.nn.functional.pixel_shuffle(x, self.upscale_factor) + + new_shape = (x.shape[0], -1, *(old * self.upscale_factor for old in x.shape[-dim:])) + + x = x.unflatten(1, (-1, *(self.upscale_factor,) * dim)) + x = x.moveaxis(tuple(range(2, 2 + dim)), tuple(range(-2 * dim + 1, 0, 2))) + x = x.reshape(new_shape) + return x diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index 68b911e02..6bf744e31 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -19,7 +19,7 @@ from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention from mrpro.nn.SqueezeExcitation import SqueezeExcitation from mrpro.nn.TransposedAttention import TransposedAttention - +from mrpro.nn.DropPath import DropPath __all__ = [ "AdaptiveAvgPoolND", "AttentionGate", @@ -27,6 +27,7 @@ "BatchNormND", "ConvND", "ConvTransposeND", + "DropPath", "EmbMixin", "FiLM", "GroupNorm32", diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py new file mode 100644 index 000000000..b37284676 --- /dev/null +++ b/src/mrpro/nn/nets/Restormer.py @@ -0,0 +1,274 @@ +"""Restormer implementation.""" + +from collections.abc import Sequence +import torch +from torch.nn import Module, PixelUnshuffle, PixelShuffle +from mrpro.nn.TransposedAttention import TransposedAttention +from mrpro.nn.NDModules import ConvNd, InstanceNormNd +from mrpro.nn.FiLM import FiLM + + +class GDFN(Module): + """Gated depthwise feed forward network. + + As used in the Restormer architecture. + """ + + def __init__(self, dim: int, channels: int, mlp_ratio: float): + super().__init__() + + hidden_features = int(channels * mlp_ratio) + self.project_in = ConvNd(dim)(channels, hidden_features * 2, kernel_size=1) + self.depthwise_conv = ConvNd(dim)( + hidden_features * 2, + hidden_features * 2, + kernel_size=3, + stride=1, + padding=1, + groups=hidden_features * 2, + ) + self.project_out = ConvNd(dim)(hidden_features, channels, kernel_size=1) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.depthwise_conv(x).chunk(2, dim=1) + return self.project_out(torch.nn.functional.gelu(x1) * x2) + + +class RestormerBlock(Module): + """Transformer block with transposed attention and gated depthwise feed forward network.""" + + def __init__(self, dim: int, channels: int, num_heads: int, mlp_ratio: float): + super().__init__() + self.norm1 = InstanceNormNd(dim)(channels) + self.attn = TransposedAttention(dim, channels, num_heads) + self.norm2 = InstanceNormNd(dim)(channels) + self.ffn = GDFN(dim, channels, mlp_ratio) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + return x + + +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = + + def forward(self, x): + return self.body(x) + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = + + def forward(self, x): + return self.body(x) + + +class Restormer(UNetBase): + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + n_blocks: Sequence[int] = (4, 6, 6, 8), + n_refinement_blocks: int = 4, + n_heads: Sequence[int] = (1, 2, 4, 8), + n_channels_per_head: int = 48, + mlp_ratio: float = 2.66, + emb_dim: int = 0, + ): + super().__init__() + + self.first = ConvNd(dim)(channels_in, n_channels_per_head, kernel_size=3, stride=1, padding=1, bias=False) + + def blocks(n_heads: int, n_blocks: int): + layers = Sequential( + *(RestormerBlock(dim, n_channels_per_head, n_heads, mlp_ratio) for i in range(n_blocks)) + ) + + if emb_dim > 0 and n_blocks > 1: + layers.insert(1, FiLM(channels=n_features_per_head * n_heads, channels_emb=emb_dim)) + return layers + + + + for n_block, n_heads in zip(n_blocks, n_heads): + self.input_blocks.append(blocks(n_heads, n_block)) + self.output_blocks.append(blocks(n_heads, n_block)) + self.skip_blocks.append(Identity()) + + + for n_head_current, n_head_next in pairwise(n_heads): + self.down_blocks.append( + Sequential( + ConvND(dim)(n_channels_per_head * n_head_current, n_channels_per_head * n_head_next, kernel_size=4, stride=2, padding=1, + PixelUnshuffle(2) + ) + self.up_blocks.append( + nn.Sequential( + nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), PixelShuffle(2) + ) + ) + self.output_blocks = self.output_blocks[::-1] + self.middle_block = blocks(n_heads, n_blocks) + + num_heads=heads[0], + ffn_expansion_factor=mlp_ratio, + ) + for i in range(num_blocks[0]) + ] + ) + + self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 + self.encoder_level2 = nn.Sequential( + *[ + TransformerBlock( + dim=int(dim * 2**1), + num_heads=heads[1], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + LayerNorm_type=LayerNorm_type, + ) + for i in range(num_blocks[1]) + ] + ) + + self.down2_3 = Downsample(int(dim * 2**1)) ## From Level 2 to Level 3 + self.encoder_level3 = nn.Sequential( + *[ + TransformerBlock( + dim=int(dim * 2**2), + num_heads=heads[2], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + LayerNorm_type=LayerNorm_type, + ) + for i in range(num_blocks[2]) + ] + ) + + self.down3_4 = Downsample(int(dim * 2**2)) ## From Level 3 to Level 4 + self.latent = nn.Sequential( + *[ + TransformerBlock( + dim=int(dim * 2**3), + num_heads=heads[3], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + LayerNorm_type=LayerNorm_type, + ) + for i in range(num_blocks[3]) + ] + ) + + self.up4_3 = Upsample(int(dim * 2**3)) ## From Level 4 to Level 3 + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2**3), int(dim * 2**2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential( + *[ + TransformerBlock( + dim=int(dim * 2**2), + num_heads=heads[2], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + LayerNorm_type=LayerNorm_type, + ) + for i in range(num_blocks[2]) + ] + ) + + self.up3_2 = Upsample(int(dim * 2**2)) ## From Level 3 to Level 2 + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2**2), int(dim * 2**1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential( + *[ + TransformerBlock( + dim=int(dim * 2**1), + num_heads=heads[1], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + LayerNorm_type=LayerNorm_type, + ) + for i in range(num_blocks[1]) + ] + ) + + self.up2_1 = Upsample(int(dim * 2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) + + self.decoder_level1 = nn.Sequential( + *[ + TransformerBlock( + dim=int(dim * 2**1), + num_heads=heads[0], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + LayerNorm_type=LayerNorm_type, + ) + for i in range(num_blocks[0]) + ] + ) + + self.refinement = nn.Sequential( + *[ + TransformerBlock( + dim=int(dim * 2**1), + num_heads=heads[0], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + LayerNorm_type=LayerNorm_type, + ) + for i in range(num_refinement_blocks) + ] + ) + + #### For Dual-Pixel Defocus Deblurring Task #### + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim * 2**1), kernel_size=1, bias=bias) + ########################### + + self.output = nn.Conv2d(int(dim * 2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, inp_img): + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + + inp_enc_level4 = self.down3_4(out_enc_level3) + latent = self.latent(inp_enc_level4) + + inp_dec_level3 = self.up4_3(latent) + inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + out_dec_level1 = self.refinement(out_dec_level1) + + #### For Dual-Pixel Defocus Deblurring Task #### + if self.dual_pixel_task: + out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) + out_dec_level1 = self.output(out_dec_level1) + ########################### + else: + out_dec_level1 = self.output(out_dec_level1) + inp_img + + return out_dec_level1 diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index f313baf4c..ee2c86d81 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -8,6 +8,9 @@ from mrpro.nn.NDModules import ConvND, ConvTransposeND, InstanceNormND from mrpro.nn.nets.UNet import UNetBase from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention +from mrpro.nn.FiLM import FiLM +from mrpro.nn.Sequential import Sequential +from mrpro.nn.DropPath import DropPath class LeFF(Module): @@ -51,21 +54,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LeWinTransformerBlock(Module): + """Locally-enhanced windowed attention transformer block. + + Part of the Uformer architecture. + """ + def __init__( self, dim: int, - n_channels_per_head: int, + n_features_per_head: int, n_heads: int, window_size: int = 8, shifted: bool = False, mlp_ratio: float = 4.0, + p_droppath: float = 0.0, ) -> None: + """Initialize the LeWinTransformerBlock module. + + Parameters + ---------- + dim : int + Dimension of the input, e.g. 2 or 3 + n_features_per_head : int + Number of features per head + n_heads : int + Number of attention heads + window_size : int, optional + Size of the attention window + shifted : bool, optional + Whether to use shifted variant of the attention + mlp_ratio : float, optional + Ratio of the hidden dimension to the input dimension + p_droppath : float, optional + Dropout probability for the drop path. + """ super().__init__() - channels = n_channels_per_head * n_heads + channels = n_features_per_head * n_heads self.norm1 = InstanceNormND(dim)(channels) self.attn = ShiftedWindowAttention( dim=dim, - n_channels_per_head=n_channels_per_head, + n_channels_per_head=n_features_per_head, n_heads=n_heads, window_size=window_size, shifted=shifted, @@ -75,16 +103,30 @@ def __init__( self.ff = LeFF(dim=dim, channels_in=channels, channels_out=channels, expand_ratio=mlp_ratio) self.modulator = torch.nn.Parameter(torch.empty(channels, *((window_size,) * dim))) torch.nn.init.trunc_normal_(self.modulator) + self.drop_path = DropPath(droprate=p_droppath) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the transformer block.""" modulator = self.modulator.tile([t // s for t, s in zip(x.shape[1:], self.modulator.shape, strict=False)]) x_mod = self.norm1(x) + modulator x_attn = self.attn(x_mod) x_ff = self.ff(self.norm2(x_attn)) - return x + x_ff + return x + self.drop_path(x_ff) class Uformer(UNetBase): + """Uformer: U-Net with window attention. + + Implements the Uformer network proposed in [WANG21]_ + It is SWIN/U-Net hybrid consisting of (shifted) windows attention transformer layers at different + resolution levels, extended by FiLM layers for conditioning. + + References + ---------- + .. [WANG21] Wang, Z., Cun, X., Bao, J., Zhou, W., Liu, J., & Li, H. Uformer: A general u-shaped transformer for + image restoration. CVPR 2022. https://doi.org/10.48550/arXiv.2106.03106 + """ + def __init__( self, dim: int, @@ -93,30 +135,68 @@ def __init__( n_features_per_head: int = 32, n_heads: Sequence[int] = (1, 2, 4, 8), n_blocks: int = 2, + emb_dim: int = 0, window_size: int = 8, mlp_ratio: float = 4.0, - drop_path_rate: float = 0.1, + max_droppath_rate: float = 0.1, ): + """Initialize the Uformer module. + + Parameters + ---------- + dim : int + Dimension of the input, e.g. 2 or 3 + channels_in : int + Number of input channels + channels_out : int + Number of output channels + n_features_per_head : int, optional + Number of features per head. The number of features at a resolution level is given by + `n_features_per_head * n_heads`. + n_heads : Sequence[int], optional + Number of attention heads at each resolution level. + n_blocks : int, optional + Number of transformer blocks at each resolution level in the input and output path + emb_dim : int, optional + Dimension of the embedding. If `0`, no FiLM layers are added. + window_size : int, optional + Size of the attention windows in the (shifted) window attention layers. + mlp_ratio : float, optional + Ratio of the hidden dimension to the input dimension in the feed-forward blocks + max_droppath_rate : float, optional + Maximum drop path rate. As in the original implementation, the drop path rate in the input path + is linearly increased from `0` to `max_droppath_rate` with decreasing resolution. The rate in output + blocks is fixed to `max_droppath_rate`. + """ super().__init__() - def blocks(n_heads: int): - return [ - LeWinTransformerBlock( - dim=dim, - n_heads=n_heads, - n_features_per_head=n_features_per_head, - window_size=window_size, - mlp_ratio=mlp_ratio, - shifted=bool(i % 2), + def blocks(n_heads: int, p_droppath: float = 0.0): + layers = Sequential( + *( + LeWinTransformerBlock( + dim=dim, + n_heads=n_heads, + n_features_per_head=n_features_per_head, + window_size=window_size, + mlp_ratio=mlp_ratio, + shifted=bool(i % 2), + p_droppath=p_droppath, + ) + for i in range(n_blocks) ) - for i in range(n_blocks) - ] + ) + + if emb_dim > 0 and n_blocks > 1: + layers.insert(1, FiLM(channels=n_features_per_head * n_heads, channels_emb=emb_dim)) + return layers - for n_head in n_heads: - self.input_blocks.extend(blocks(n_heads=n_head)) - self.output_blocks.extend(blocks(n_heads=n_head)) + drop_path_rates = torch.linspace(0, max_droppath_rate, len(n_heads)).tolist() + for n_head, p_droppath_input in zip(n_heads, drop_path_rates, strict=True): + self.input_blocks.append(blocks(n_heads=n_head, p_droppath=p_droppath_input)) + self.output_blocks.append(blocks(n_heads=n_head, p_droppath=max_droppath_rate)) self.skip_blocks.append(Identity()) - self.middle_block = torch.nn.Sequential(*blocks(n_heads=n_heads[-1])) + self.output_blocks = self.output_blocks[::-1] + self.middle_block = blocks(n_heads=n_heads[-1], p_droppath=max_droppath_rate) for n_head_current, n_head_next in pairwise(n_heads): self.down_blocks.append( From 912d7c8ecfa0b1a54d9e1c53f819b0dc550067c2 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 16 May 2025 00:21:58 +0200 Subject: [PATCH 030/205] update --- src/mrpro/nn/FiLM.py | 13 +++--- src/mrpro/nn/__init__.py | 1 + src/mrpro/nn/nets/Restormer.py | 31 ++++--------- src/mrpro/nn/nets/UNet.py | 79 ++++++++++++++++++++++++++-------- 4 files changed, 78 insertions(+), 46 deletions(-) diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py index 2977adca9..9944dc0a4 100644 --- a/src/mrpro/nn/FiLM.py +++ b/src/mrpro/nn/FiLM.py @@ -1,7 +1,7 @@ """Feature-wise Linear Modulation.""" import torch -from torch.nn import Linear, Module, Sequential, SiLU +from torch.nn import Identity, Linear, Module, Sequential, SiLU from mrpro.nn.EmbMixin import EmbMixin from mrpro.utils.reshape import unsqueeze_tensors_right @@ -29,10 +29,13 @@ def __init__(self, channels: int, channels_emb: int) -> None: The number of channels in the embedding tensor. """ super().__init__() - self.project = Sequential( - SiLU(), - Linear(channels_emb, 2 * channels), - ) + if channels_emb > 0: + self.project = Sequential( + SiLU(), + Linear(channels_emb, 2 * channels), + ) + else: + self.project = Identity() def __call__(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: """Apply FiLM. diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index 6bf744e31..bc6a93aea 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -20,6 +20,7 @@ from mrpro.nn.SqueezeExcitation import SqueezeExcitation from mrpro.nn.TransposedAttention import TransposedAttention from mrpro.nn.DropPath import DropPath +import mrpro.nn.nets __all__ = [ "AdaptiveAvgPoolND", "AttentionGate", diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index b37284676..126d88c08 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -6,7 +6,8 @@ from mrpro.nn.TransposedAttention import TransposedAttention from mrpro.nn.NDModules import ConvNd, InstanceNormNd from mrpro.nn.FiLM import FiLM - +from mrpro.nn.nets.UNetBase import UNetBase +from mrpro.nn.Sequential import Sequential class GDFN(Module): """Gated depthwise feed forward network. @@ -38,12 +39,15 @@ def forward(self, x): class RestormerBlock(Module): """Transformer block with transposed attention and gated depthwise feed forward network.""" - def __init__(self, dim: int, channels: int, num_heads: int, mlp_ratio: float): + def __init__(self, dim: int, channels: int, num_heads: int, mlp_ratio: float, emb_dim: int = 0): super().__init__() - self.norm1 = InstanceNormNd(dim)(channels) + self.norm1 = Sequential(InstanceNormNd(dim)(channels)) self.attn = TransposedAttention(dim, channels, num_heads) - self.norm2 = InstanceNormNd(dim)(channels) + self.norm2 = Sequential(InstanceNormNd(dim)(channels)) self.ffn = GDFN(dim, channels, mlp_ratio) + if emb_dim > 0: + self.norm1.append(FiLM(channels=channels, channels_emb=emb_dim)) + self.norm2.append(FiLM(channels=channels, channels_emb=emb_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.norm1(x)) @@ -51,25 +55,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class Downsample(nn.Module): - def __init__(self, n_feat): - super(Downsample, self).__init__() - - self.body = - - def forward(self, x): - return self.body(x) - - -class Upsample(nn.Module): - def __init__(self, n_feat): - super(Upsample, self).__init__() - - self.body = - - def forward(self, x): - return self.body(x) - class Restormer(UNetBase): def __init__( diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index cceb68dce..dd28e4715 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -1,30 +1,42 @@ from functools import partial import torch -from torch.nn import Module, ModuleList +from torch.nn import Identity, Module, ModuleList from mrpro.nn.EmbMixin import call_with_emb class UNetBase(Module): - def __init__( - self, - in_channels: int, - out_channels: int, - channels_emb: int, - dim: int, - num_blocks: int, - ) -> None: ... + """Base class for U-shaped networks.""" + + def __init__(self) -> None: + super().__init__() + self.input_blocks = ModuleList() + """The encoder blocks. Order is highest resolution to lowest resolution.""" + + self.down_blocks = ModuleList() + """The downsampling blocks""" + + self.skip_blocks = ModuleList() + """Modifications to the skip connections""" + + self.middle_block = Module() + """Also called bottleneck block""" + + self.output_blocks = ModuleList() + """Also called decoder blocks. Order is lowest resolution to highest resolution.""" + + self.up_blocks = ModuleList() + """The upsampling blocks""" + + self.concat_blocks = ModuleList() + """Joins the skip connections with the upsampled features from a lower resolution level""" - input_blocks: ModuleList - down_blocks: ModuleList - skip_blocks: ModuleList - middle_block: Module - output_blocks: ModuleList - up_blocks: ModuleList - concat_blocks: ModuleList - last: Module - first: Module + self.last = Identity() + """The last block""" + + self.first = Identity() + """The first block""" def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: """Apply to Network.""" @@ -57,3 +69,34 @@ def __call__(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: The output tensor. """ return self(x, emb) + + +class UNet(UNetBase): + """UNet. + + U-shaped convolutional network [UNET]_ with optional patch attention. + Inspired by the OpenAi DDPM UNet/Latent Diffusion UNet [LDM]_. + + References + ---------- + .. [UNET] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image + segmentation MICCAI 2015. https://arxiv.org/abs/1505.04597 + .. [LDM] https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py + """ + + def __init__( + self, + dim:int, + + in_channels: int, + out_channels: int, + n_features: Sequence[int], + n_heads:Sequence[int] + n_blocks:int|Sequence[int] + channels_emb: int, + dim: int, + num_blocks: int, + attention_gate: + padding_modes:str|Sequence[str] + + ) -> None: ... From b6a1db3c809c8f2338e7891ab897406da9a0a08f Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 16 May 2025 14:02:00 +0200 Subject: [PATCH 031/205] doc --- src/mrpro/nn/PixelShuffle.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py index fddb903de..8894bdff7 100644 --- a/src/mrpro/nn/PixelShuffle.py +++ b/src/mrpro/nn/PixelShuffle.py @@ -8,6 +8,19 @@ class PixelUnshuffle(Module): """ND-version of PixelUnshuffle downscaling.""" def __init__(self, downscale_factor: int): + """Initialize PixelUnshuffle. + + Reduces spatial dimensions and increases the channel number by reshaping. + The first dimension is considered a batch dimension, the second dimension + the channel dimension, and the remaining dimensions the spatial dimensions that are downscaled. + + See `mrpro.nn.PixelShuffle` for the inverse operation. + + Parameters + ---------- + downscale_factor : int + The factor by which to downscale the input tensor. + """ super().__init__() self.downscale_factor = downscale_factor @@ -33,10 +46,24 @@ class PixelShuffle(Module): """ND-version of PixelShuffle upscaling.""" def __init__(self, upscale_factor: int): + """Initialize PixelShuffle. + + Upscales spatial dimensions and decreases the channel number by reshaping. + The first dimension is considered a batch dimension, the second dimension + the channel dimension, and the remaining dimensions the spatial dimensions that are upscaled. + + See `mrpro.nn.PixelUnshuffle` for the inverse operation. + + Parameters + ---------- + upscale_factor : int + The factor by which to upscale the spatial dimensions. + """ super().__init__() self.upscale_factor = upscale_factor def forward(self, x: torch.Tensor) -> torch.Tensor: + """Upscale the input.""" dim = x.ndim - 2 if dim == 2: # fast path for 2D return torch.nn.functional.pixel_shuffle(x, self.upscale_factor) From 3d0122093a1acf71bbeb167005c90398d2bd7eca Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 18 May 2025 16:30:43 +0200 Subject: [PATCH 032/205] update --- src/mrpro/nn/GluMBConvResBlock.py | 102 +++++++++++ src/mrpro/nn/LinearSelfAttention.py | 89 ++++++++++ src/mrpro/nn/MultiHeadAttention.py | 74 ++++++++ src/mrpro/nn/NeighborhoodSelfAttention.py | 55 ++---- src/mrpro/nn/PixelShuffle.py | 117 +++++++++++++ src/mrpro/nn/RMSNorm.py | 47 ++++++ src/mrpro/nn/RoPE.py | 77 +++++++++ src/mrpro/nn/ShiftedWindowAttention.py | 29 ++-- src/mrpro/nn/TransposedAttention.py | 59 ++++--- src/mrpro/nn/activations.py | 29 ++++ src/mrpro/nn/encoding.py | 46 +++++ src/mrpro/nn/nets/DCAE.py | 162 ++++++++++++++++++ src/mrpro/nn/nets/Restormer.py | 196 +++------------------- src/mrpro/nn/nets/UNet.py | 26 ++- src/mrpro/nn/nets/Uformer.py | 35 ++-- src/mrpro/utils/__init__.py | 5 +- src/mrpro/utils/to_tuple.py | 36 ++++ tests/nn/test_transposedattention.py | 2 +- 18 files changed, 915 insertions(+), 271 deletions(-) create mode 100644 src/mrpro/nn/GluMBConvResBlock.py create mode 100644 src/mrpro/nn/LinearSelfAttention.py create mode 100644 src/mrpro/nn/MultiHeadAttention.py create mode 100644 src/mrpro/nn/RMSNorm.py create mode 100644 src/mrpro/nn/RoPE.py create mode 100644 src/mrpro/nn/activations.py create mode 100644 src/mrpro/nn/encoding.py create mode 100644 src/mrpro/nn/nets/DCAE.py create mode 100644 src/mrpro/utils/to_tuple.py diff --git a/src/mrpro/nn/GluMBConvResBlock.py b/src/mrpro/nn/GluMBConvResBlock.py new file mode 100644 index 000000000..1b9059a04 --- /dev/null +++ b/src/mrpro/nn/GluMBConvResBlock.py @@ -0,0 +1,102 @@ +"""Gateded MBConv Residual Block.""" + +import torch +from torch.nn import Identity, Module, Sequential, SiLU + +from mrpro.nn.EmbMixin import EmbMixin +from mrpro.nn.FiLM import FiLM +from mrpro.nn.NDModules import ConvND +from mrpro.nn.RMSNorm import RMSNorm + + +class GluMBConvResBlock(EmbMixin, Module): + """Gated MBConv residual block. + + Gated variant [DCAE]_ of the MBConv block [EffNet]_ with a residual connection. + + References + ---------- + .. [DCAE] Chen et al. Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models. ICLR 2025 + https://arxiv.org/abs/2410.10733 + .. [EffNet] Tan et al. EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. ICML 2019 + https://arxiv.org/abs/1905.11946 + """ + + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + expand_ratio: int = 6, + stride: int = 1, + kernel_size: int = 3, + emb_dim: int = 0, + ): + """Initialize MBConv block. + + Parameters + ---------- + dim + Number of spatial dimensions. + channels_in + Number of input channels. + channels_out + Number of output channels. + expand_ratio + Expansion ratio inside the block. + stride + Stride of the depthwise convolution. + kernel_size + Kernel size of the depthwise convolution. + emb_dim + Size of the FiLM embedding. If 0, no embedding is used. + """ + super().__init__() + channels_mid = channels_in * expand_ratio + if stride == 1 and channels_in == channels_out: + self.skip: Module = Identity() + else: + self.skip = ConvND(dim)(channels_in, channels_out, kernel_size=1, stride=stride) + self.inverted_conv = Sequential( + ConvND(dim)( + channels_in, + channels_mid * 2, + kernel_size=1, + ), + SiLU(), + ) + self.depth_conv = Sequential( + ConvND(dim)( + channels_mid * 2, + channels_mid * 2, + kernel_size=kernel_size, + stride=stride, + padding='same', + groups=channels_mid * 2, + ), + SiLU(), + ) + self.point_conv = Sequential( + ConvND(dim)( + channels_mid, + channels_out, + kernel_size=1, + ), + RMSNorm(channels_out), + SiLU(), + ) + if emb_dim > 0: + self.film: FiLM | None = FiLM(channels_mid, emb_dim) + else: + self.film = None + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + """Apply MBConv block.""" + h = self.inverted_conv(x) + h = self.depth_conv(h) + h, gate = torch.chunk(h, 2, dim=1) + h = h * torch.nn.functional.silu(gate) + if self.film is not None: + h = self.film(h, emb) + h = self.point_conv(h) + return self.skip(x) + h diff --git a/src/mrpro/nn/LinearSelfAttention.py b/src/mrpro/nn/LinearSelfAttention.py new file mode 100644 index 000000000..362c78269 --- /dev/null +++ b/src/mrpro/nn/LinearSelfAttention.py @@ -0,0 +1,89 @@ +import torch +from einops import rearrange +from torch import Tensor +from torch.nn import Linear, Module, ReLU + + +class LinearSelfAttention(Module): + """Linear multi-head self-attention via kernel trick. + + Uses a ReLU kernel to compute attention in O(N) [KAT20]_ time and space. + + + Refereces + .. [KAT20] Katharopoulos, Angelos, et al. Transformers are rnns: Fast autoregressive transformers with linear + attention. ICML 2020. https://arxiv.org/abs/2006.16236 + + Parameters + ---------- + channels + Input and output channel dimension. + num_heads + Number of attention heads. + bias + Whether to use bias in the QKV projection. + eps + Small epsilon for numerical stability in normalization. + """ + + def __init__( + self, + channels_in: int, + channels_out: int, + n_heads: int, + eps: float = 1e-6, + channel_last: bool = False, + ): + super().__init__() + self.channel_last = channel_last + self.eps = eps + self.n_heads = n_heads + channels_per_head = channels_in // n_heads + self.to_qkv = Linear(channels_in, 3 * channels_per_head * n_heads) + self.kernel_function = ReLU() + self.to_out = Linear(channels_per_head * n_heads, channels_out) + + def __call__(self, x: Tensor) -> Tensor: + """Apply linear self-attention. + + Parameters + ---------- + x + Tensor of shape `batch, channels, *spatial_dims` or (`batch, *spatial_dims, channels` if `channel_last`) + + Returns + ------- + Tensor after attention, same shape as input. + """ + return super().__call__(x) + + def forward(self, x: Tensor) -> Tensor: + """Apply linear self-attention.""" + orig_dtype = x.dtype + if x.dtype == torch.float16: + x = x.float() + if not self.channel_last: + x = x.moveaxis(1, -1) + spatial_shape = x.shape[2:-1] + + qkv = self.to_qkv(x) + query, key, value = rearrange( + qkv, 'batch ... (qkv head channels) -> qkv batch head (...) channel', qkv=3, head=self.n_heads + ) + + query = self.kernel_function(query) + key = self.kernel_function(key) + + # trick to avoid second attention calculation: add normalization slot + value = torch.nn.functional.pad(value, (0, 0, 0, 1), mode='constant', value=1.0) + + value_key = value @ key.transpose(-1, -2) + value_key_query = value_key @ query + normalisation = value_key_query[..., -1:, :] + self.eps + attn = value_key_query[..., :-1, :] / normalisation + out = self.to_out(attn) + out = out.to(orig_dtype) + out.unflatten(-2, spatial_shape) + if not self.channel_last: + out = out.moveaxis(-1, 1) + return out diff --git a/src/mrpro/nn/MultiHeadAttention.py b/src/mrpro/nn/MultiHeadAttention.py new file mode 100644 index 000000000..332ada94c --- /dev/null +++ b/src/mrpro/nn/MultiHeadAttention.py @@ -0,0 +1,74 @@ +"""Multi-head Attention.""" + +import torch +from torch.nn import Linear, Module + + +class MultiHeadAttention(Module): + """Multi-head Attention. + + Implements multihead scaled dot-product attention and supports "image-like" inputs, + i.e. `batch, channels, *spatial_dims` as well as "transformer-like" inputs, `batch, sequence, features`. + """ + + def __init__( + self, + channels_in: int, + channels_out: int, + num_heads: int, + features_last: bool = False, + p_dropout: float = 0.0, + ): + """Initialize the Multi-head Attention. + + Parameters + ---------- + dim + Number of spatial dimensions. + channels_in + Number of input channels. + channels_out + Number of output channels. + num_heads + number of attention heads + """ + super().__init__() + self.mha = torch.nn.MultiheadAttention( + embed_dim=channels_in, num_heads=num_heads, batch_first=True, dropout=p_dropout + ) + self.features_last = features_last + self.to_out = Linear(channels_in, channels_out) + + def __call__(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: + """Apply multi-head attention. + + Parameters + ---------- + x + The input tensor. + cross_attention + The key and value tensors for cross-attention. If `None`, self-attention is applied. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cross_attention) + + def _reshape(self, x: torch.Tensor) -> torch.Tensor: + if not self.features_last: + x = x.moveaxis(1, -1) + return x.flatten(2, -2) + + def forward(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: + """Apply multi-head attention.""" + reshaped_x = self._reshape(x) + reshaped_cross_attention = self._reshape(cross_attention) if cross_attention is not None else reshaped_x + + y = self.mha(reshaped_cross_attention, reshaped_cross_attention, reshaped_x) + out = self.to_out(y) + + if not self.features_last: + out = out.moveaxes(-1, 1) + + return out.reshape(x.shape) diff --git a/src/mrpro/nn/NeighborhoodSelfAttention.py b/src/mrpro/nn/NeighborhoodSelfAttention.py index 762c68ab3..625c47174 100644 --- a/src/mrpro/nn/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/NeighborhoodSelfAttention.py @@ -1,3 +1,5 @@ +"""Neighborhood Self Attention.""" + from collections.abc import Sequence from functools import cache, reduce from typing import TypeVar @@ -7,37 +9,9 @@ from torch.nn import Linear, Module from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention -T = TypeVar('T') - - -def check_arg(length: int, arg: Sequence[T] | T) -> tuple[T, ...]: - """Standardize an argument to a fixed-length tuple. - - If the argument is a sequence, it checks if its length matches the - specified dimension. If it's a single value, it replicates it `dim` times. - - Parameters - ---------- - length - The expected length of the sequence. - arg - The argument to check. Can be a single value of type T or a - sequence of T. - - Returns - ------- - A tuple of length `dim` containing elements of type T. +from mrpro.utils.to_tuple import to_tuple - Raises - ------ - ValueError - If `arg` is a sequence and its length does not match `length`. - """ - if isinstance(arg, Sequence): - if not len(arg) == length: - raise ValueError(f'The arguments must be either single values or have length {length}. Got {arg}.') - return tuple(arg) - return (arg,) * length +T = TypeVar('T') @cache @@ -75,7 +49,7 @@ def neighborhood_mask( allowed attention connections. """ kernel_size_tuple, dilation_tuple, circular_tuple = ( - check_arg(len(input_size), x) for x in (kernel_size, dilation, circular) + to_tuple(len(input_size), x) for x in (kernel_size, dilation, circular) ) def unravel_index(idx: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -144,8 +118,9 @@ class NeighborhoodSelfAttention(Module): def __init__( self, - channels: int, - n_head: int, + channels_in: int, + channels_out: int, + n_heads: int, kernel_size: int | Sequence[int], dilation: int | Sequence[int] = 1, circular: bool | Sequence[bool] = False, @@ -169,16 +144,18 @@ def __init__( circular Whether the neighborhood wraps around the edges (circular padding) channel_last - Whether the channels are in the last dimension of the tensor, as common in transformers. + Whether the channels are in the last dimension of the tensor, as common in visíon transformers. + Otherwise, assume the channels are in the second dimension, as common in CNN models. """ super().__init__() - self.n_head = n_head + self.n_head = n_heads self.kernel_size = kernel_size self.dilation = dilation self.circular = circular self.channel_last = channel_last - self.to_qkv = Linear(channels, 3 * channels * n_head) - self.to_out = Linear(channels * n_head, channels) + channels_per_head = channels_in // n_heads + self.to_qkv = Linear(channels_in, 3 * channels_per_head * n_heads) + self.to_out = Linear(channels_per_head * n_heads, channels_out) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply neighborhood attention to the input tensor. @@ -196,7 +173,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.moveaxis(1, -1) spatial_shape = x.shape[2:-1] qkv = self.to_qkv(x) - query, key, value = rearrange(qkv, 'batch ... (qkv head channels) -> qkv batch head (...) channel') + query, key, value = rearrange( + qkv, 'batch ... (qkv head channels) -> qkv batch head (...) channel', qkv=3, head=self.n_head + ) # the mask depends on the input size. To be more flexible if used within CNNs, we compute it here. # The computation is cached.. mask = neighborhood_mask( diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py index 8894bdff7..4994ff0a7 100644 --- a/src/mrpro/nn/PixelShuffle.py +++ b/src/mrpro/nn/PixelShuffle.py @@ -3,6 +3,8 @@ import torch from torch.nn import Module +from mrpro.nn.NDModules import ConvND + class PixelUnshuffle(Module): """ND-version of PixelUnshuffle downscaling.""" @@ -42,6 +44,121 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class PixelUnshuffleDownsample(Module): + """PixelUnshuffle Downsampling. + + PixelUnshuffle followed by a convolution. Optionally uses a residual connection [DCAE]_ + + References + ---------- + .. [DCAE] Chen et al. Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models. ICLR 2025 + https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, dim: int, channels_in: int, channels_out: int, downscale_factor: int = 2, residual: bool = False + ): + """Initialize a PixelUnshuffleDownsample layer. + + Parameters + ---------- + dim : int + Dimension of the input tensor. + channels_in : int + Number of channels in the input tensor. + channels_out : int + Number of channels in the output tensor. + downscale_factor : int, optional + Factor by which to downscale the input tensor. + residual : bool, optional + Whether to use a residual connection as proposed in [DCAE]_. + """ + super().__init__() + self.pixel_unshuffle = PixelUnshuffle(downscale_factor) + out_ratio = downscale_factor**dim + if channels_out % out_ratio != 0: + raise ValueError(f'channels_out must be divisible by downscale_factor**{dim}.') + self.conv = ConvND(dim)(channels_in, channels_out // out_ratio, kernel_size=3, padding='same') + self.residual = residual + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply downsampling. + + Parameters + ---------- + x + Tensor of shape `batch, channels_in, *spatial_dims` + + Returns + ------- + Tensor of shape `batch, channels_out, *spatial_dims/downscale_factor` + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply downsampling.""" + x = self.pixel_unshuffle(x) + h = self.conv(x) + if self.residual: + h = h + x.unflatten(1, (h.shape[1], -1)).mean(2) + return h + + +class PixelShuffleUpsample(Module): + """PixelShuffle Upsampling. + + Convolution followed by PixelShuffle. Optionally uses a residual connection [DCAE]_ + + References + ---------- + .. [DCAE] Chen et al. Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models. ICLR 2025 + https://arxiv.org/abs/2410.10733 + """ + + def __init__(self, dim: int, channels_in: int, channels_out: int, upscale_factor: int = 2, residual: bool = False): + """Initialize a PixelShuffleUpsample layer. + + Parameters + ---------- + dim : int + Dimension of the input tensor. + channels_in : int + Number of channels in the input tensor. + channels_out : int + Number of channels in the output tensor. + upscale_factor : int, optional + Factor by which to upscale the input tensor. + residual : bool, optional + Whether to use a residual connection as proposed in [DCAE]_. + """ + super().__init__() + self.conv = ConvND(dim)(channels_in, channels_out * upscale_factor**dim, kernel_size=3, padding='same') + self.pixel_shuffle = PixelShuffle(upscale_factor) + self.residual = residual + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply upsampling. + + Parameters + ---------- + x + Tensor of shape `batch, channels_in, *spatial_dims` + + Returns + ------- + Tensor of shape `batch, channels_out, *spatial_dims * upscale_factor` + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply upsampling.""" + h = self.conv(x) + if self.residual: + h = h + x.repeat_interleave(h.shape[1] // x.shape[1], dim=1) + out = self.pixel_shuffle(h) + return out + + class PixelShuffle(Module): """ND-version of PixelShuffle upscaling.""" diff --git a/src/mrpro/nn/RMSNorm.py b/src/mrpro/nn/RMSNorm.py new file mode 100644 index 000000000..89d32ee2b --- /dev/null +++ b/src/mrpro/nn/RMSNorm.py @@ -0,0 +1,47 @@ +import torch +from torch.nn import Module, Parameter + + +class RMSNorm(Module): + """RMSNorm over the channel dimension.""" + + def __init__(self, channels: int, eps: float = 1e-8): + """Initialize RMSNorm. + + Includes a learnable weight and bias. + + Parameters + ---------- + channels + Number of channels. + eps + Epsilon value to avoid division by zero. + """ + super().__init__() + self.weight = Parameter(torch.zeros(channels)) + self.bias = Parameter(torch.zeros(channels)) + self.eps = eps + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply RMSNorm over the channel dimension. + + Parameters + ---------- + x + Input tensor. + + Returns + ------- + Normalized tensor. + """ + return self.forward(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply RMSNorm over the channel dimension.""" + mean_square = x.pow(2).mean(dim=1, keepdim=True) + scale = (mean_square + self.eps).rsqrt() + x = x * scale + shape = (1, -1, *([1] * (x.ndim - 2))) + weight = (1 + self.weight).view(shape) + bias = self.bias.view(shape) + return x * weight + bias diff --git a/src/mrpro/nn/RoPE.py b/src/mrpro/nn/RoPE.py new file mode 100644 index 000000000..4f71d1165 --- /dev/null +++ b/src/mrpro/nn/RoPE.py @@ -0,0 +1,77 @@ +from math import log + +import torch + + +# Rotary position embeddings +@torch.compile +def apply_rotary_emb_(x: torch.Tensor, theta: torch.Tensor, conjugated: bool): + """Adds the rotary embedding to the input tensor (inplace). + + This is a helper function for the `AxialRoPE` class. + """ + n_emb = theta.shape[-1] * 2 + if n_emb > x.shape[-1]: + raise ValueError('More theta values then channels//2 in the input tensor.') + x1, x2 = x[..., :n_emb].chunk(2, dim=-1) + dtype = torch.promote_type(torch.result_type(x, theta), torch.float32) + x1_, x2_, theta = x1.to(dtype), x2.to(dtype), theta.to(dtype) + cos, sin = torch.cos(theta), torch.sin(theta) + sin = -sin if conjugated else sin + y1 = x1_ * cos - x2_ * sin + y2 = x2_ * cos + x1_ * sin + x1.copy_(y1) + x2.copy_(y2) + + +class RotaryEmbedding_(torch.autograd.Function): + """Adds the rotary embedding to the input tensor (inplace). + + This is a autograd helper class for the `AxialRoPE` class. + """ + + @staticmethod + def forward(x: torch.Tensor, theta: torch.Tensor, conjugated: bool) -> torch.Tensor: + apply_rotary_emb_(x, theta, conj=conj) + return x + + @staticmethod + def setup_context(ctx, inputs: tuple[torch.Tensor, torch.Tensor, bool], output: torch.Tensor): + _, theta, conjugated = inputs + ctx.save_for_backward(theta) + ctx.conjugated = conjugated + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]: + (theta,) = ctx.saved_tensors + apply_rotary_emb_(grad_output, theta, conjugated=not ctx.conjugated) + return grad_output, None, None + + +class AxialRoPE(Module): + def __init__(self, dim: int, d_head: int, n_heads: int, headpos: int = -2, non_embed_fraction: float = 0.5): + super().__init__() + log_min = log(torch.pi) + log_max = log(100 ** (1 / dim) * torch.pi) + d_per_head = int(d_head / dim * (1 - non_embed_fraction)) + freqs = torch.linspace(log_min, log_max, n_heads * d_per_head).exp() + freqs = freqs.view(-1, n_heads).T + freqs = freqs.unsqueeze(-2).repeat(1, dim, 1).contiguous() + self.freqs = torch.nn.Parameter(freqs) + self.headpos = headpos + + def get_theta(self, pos): + return (self.freqs * pos[..., None, :, None]).flatten(start_dim=-2).movedim(-2, self.headpos) + + def forward(self, pos, *tensors): + theta = self.get_theta(pos) + tuple(RotaryEmbedding_.apply(x, theta, False) for x in tensors) + + @staticmethod + def make_axial_positions(*shape): + shape = torch.as_tensor(shape) + m = shape.max() + pos = torch.stack( + torch.meshgrid([torch.linspace(-1 + 1 / s, 1 - 1 / s, s) * (s / m) for s in shape], indexing='ij'), -1 + ) + return pos diff --git a/src/mrpro/nn/ShiftedWindowAttention.py b/src/mrpro/nn/ShiftedWindowAttention.py index d0978f8f1..61b40351b 100644 --- a/src/mrpro/nn/ShiftedWindowAttention.py +++ b/src/mrpro/nn/ShiftedWindowAttention.py @@ -22,15 +22,19 @@ class ShiftedWindowAttention(Module): rel_position_index: torch.Tensor - def __init__(self, dim: int, n_channels_per_head: int, n_heads: int, window_size: int = 7, shifted: bool = True): + def __init__( + self, dim: int, channels_in: int, channels_out: int, n_heads: int, window_size: int = 7, shifted: bool = True + ): """Initialize the ShiftedWindowAttention module. Parameters ---------- dim : int The dimension of the input. - n_channels_per_head : int - The number of channels per head. + channels_in : int + The number of channels in the input tensor. + channels_out : int + The number of channels in the output tensor. n_heads : int The number of attention heads. The number if channels per head is ``channels // n_heads``. window_size : int @@ -42,7 +46,9 @@ def __init__(self, dim: int, n_channels_per_head: int, n_heads: int, window_size self.n_heads = n_heads self.window_size = window_size self.shifted = shifted - self.to_qkv = ConvND(dim)(n_channels_per_head * n_heads, 3 * n_channels_per_head * n_heads, 1) + channels_per_head = channels_in // n_heads + self.to_qkv = ConvND(dim)(channels_per_head * n_heads, 3 * channels_per_head * n_heads, 1) + self.to_out = ConvND(dim)(channels_per_head * n_heads, channels_out, 1) self.dim = dim coords_1d = torch.arange(window_size) coords_nd = torch.stack(torch.meshgrid(*([coords_1d] * dim), indexing='ij'), 0).flatten(1) @@ -82,15 +88,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: qkv=3, ) bias = rearrange(self.relative_position_bias_table[self.rel_position_index], 'wd1 wd2 heads -> 1 heads wd1 wd2') - result = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=bias) - result = rearrange(result, 'spatial batch head window channels->batch (head channels) spatial window') - result = result.unflatten(-2, windowed.shape[: self.dim]).unflatten(-1, (self.window_size,) * self.dim) + attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=bias) + attention = rearrange(attention, 'spatial batch head window channels->batch (head channels) spatial window') + attention = attention.unflatten(-2, windowed.shape[: self.dim]).unflatten(-1, (self.window_size,) * self.dim) # permute (in 3d) batch channels z y x wz wy wx -> batch channels wz z wy y wx x - result = result.moveaxis(list(range(-self.dim, 0)), list(range(3, 3 + 2 * self.dim, 2))) - result = result.reshape(x.shape) + attention = attention.moveaxis(list(range(-self.dim, 0)), list(range(3, 3 + 2 * self.dim, 2))) + attention = attention.reshape(x.shape) if self.shifted: - result = torch.roll(result, (self.window_size // 2,) * self.dim, dims=tuple(range(-self.dim, 0))) - return result + attention = torch.roll(attention, (self.window_size // 2,) * self.dim, dims=tuple(range(-self.dim, 0))) + out = self.to_out(attention) + return out '' diff --git a/src/mrpro/nn/TransposedAttention.py b/src/mrpro/nn/TransposedAttention.py index 7b42f794a..b9105285a 100644 --- a/src/mrpro/nn/TransposedAttention.py +++ b/src/mrpro/nn/TransposedAttention.py @@ -8,38 +8,45 @@ class TransposedAttention(Module): - def __init__(self, dim: int, channels: int, num_heads: int): - """Transposed Self Attention from Restormer. + """Transposed Self Attention from Restormer. - Implements the transposed self-attention, i.e. channel-wise multihead self-attention, - layer from Restormer [ZAM22]_. + Implements the transposed self-attention, i.e. channel-wise multihead self-attention, + layer from Restormer [ZAM22]_. - References - ---------- - ..[ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for high-resolution image restoration." - CVPR 2022, https://arxiv.org/pdf/2111.09881.pdf + References + ---------- + .. [ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for high-resolution image restoration." + CVPR 2022, https://arxiv.org/pdf/2111.09881.pdf + """ + + def __init__(self, dim: int, channels_in: int, channels_out: int, n_heads: int): + """Initialize a TransposedAttention layer. Parameters ---------- dim input dimension - channels - input channels - num_heads - number of attention heads + channels_in + Number of channels in the input tensor. + channels_out + Number of channels in the output tensor. + n_heads + Number of attention heads. """ super().__init__() - self.num_heads = num_heads - self.temperature = Parameter(torch.ones(num_heads, 1, 1)) - self.qkv = ConvND(dim)(channels, channels * 3, kernel_size=1, bias=True) + self.n_heads = n_heads + self.temperature = Parameter(torch.ones(n_heads, 1, 1)) + channels_per_head = channels_in // n_heads + self.to_qkv = ConvND(dim)(channels_in, channels_per_head * n_heads * 3, kernel_size=1) self.qkv_dwconv = ConvND(dim)( - channels * 3, - channels * 3, + channels_per_head * n_heads * 3, + channels_per_head * n_heads * 3, kernel_size=3, - groups=channels * 3, + groups=channels_in * 3, + padding=1, bias=False, ) - self.project_out = ConvND(dim)(channels, channels, kernel_size=1, bias=True) + self.to_out = ConvND(dim)(channels_per_head * n_heads * 3, channels_out, kernel_size=1) def __call__(self, x: torch.Tensor) -> torch.Tensor: """Apply transposed attention. @@ -56,12 +63,14 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: return super().__call__(x) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Apply transposed Attention.""" - qkv = self.qkv_dwconv(self.qkv(x)) - q, k, v = rearrange(qkv, 'b (qkv head c) ... -> qkv b head (...) c', head=self.num_heads, qkv=3) + """Apply transposed attention.""" + qkv = self.qkv_dwconv(self.to_qkv(x)) + q, k, v = rearrange(qkv, 'b (qkv heads channels) ... -> qkv b heads (...) channels', heads=self.n_heads, qkv=3) q = torch.nn.functional.normalize(q, dim=-1) * self.temperature k = torch.nn.functional.normalize(k, dim=-1) - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=1.0) - out = rearrange(out, '... head points c -> ... (head c) points').reshape(x.shape) - out = self.project_out(out) + attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=1.0) + out = rearrange(attention, '... heads points channels -> ... (heads channels) points').unflatten( + -1, x.shape[2:] + ) + out = self.to_out(out) return out diff --git a/src/mrpro/nn/activations.py b/src/mrpro/nn/activations.py new file mode 100644 index 000000000..4f757476a --- /dev/null +++ b/src/mrpro/nn/activations.py @@ -0,0 +1,29 @@ +import torch +from torch.nn import Linear, Module + + +class GEGLU(Module): + r"""Gated linear unit activation function. + + References + ---------- + ..[GLU] Shazeer, N. (2020). GLU variants improve transformer. https://arxiv.org/abs/2002.05202 + """ + + def __init__(self, in_features: int, out_features: int | None = None): + """Initialize the GEGLU activation function. + + Parameters + ---------- + in_features : int + The number of input features. + out_features : int + The number of output features. If None, the number of output features is the same as the number of input features. + """ + super().__init__() + self.proj = Linear(in_features, out_features * 2) + + def forward(self, x): + h, gate = self.proj(x).chunk(2, dim=-1) + gate = torch.nn.functional.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + return h * gate diff --git a/src/mrpro/nn/encoding.py b/src/mrpro/nn/encoding.py new file mode 100644 index 000000000..32e239adc --- /dev/null +++ b/src/mrpro/nn/encoding.py @@ -0,0 +1,46 @@ +from itertools import combinations +from math import ceil + +import torch +from torch.nn import Module + +from mrpro.utils.reshape import unsqueeze_right + + +class FourierFeatures(Module): + def __init__(self, in_features: int, out_features: int, std: float = 1.0): + super().__init__() + assert out_features % 2 == 0 + self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + f = 2 * torch.pi * x @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) + + +class AbsolutePositionEncoding(Module): + def __init__(self, dim: int, features: int, include_radii: bool = True, base_resolution: int = 128): + super().__init__() + + coords = [unsqueeze_right(torch.linspace(-1, 1, base_resolution), i) for i in range(dim)] + if include_radii: + for n in range(2, dim + 1): + for combination in combinations(coords, n): + coords.append(2**0.5 * torch.sqrt(sum([c**2 for c in combination])) - 1) + n_freqs = ceil(features / len(coords) / 2) + freqs = unsqueeze_right((base_resolution) ** torch.linspace(0, 1, n_freqs), dim) + encoding = [] + for coord in coords: + encoding.append(torch.sin(coord * freqs).broadcast_to(1, -1, *((base_resolution,) * dim))) + encoding.append(torch.cos(coord * freqs).broadcast_to(1, -1, *((base_resolution,) * dim))) + self.register_buffer('encoding', torch.cat(encoding, dim=1)[:, :features]) + self.interpolation_mode = ['linear', 'bilinear', 'trilinear'][dim - 1] + + def forward(self, x): + features = self.encoding.shape[1] + if features > x.shape[1]: + raise ValueError(f'x has {x.shape[1]} features, but {features} are required') + + x_enc, x_unenc = x.split([features, x.shape[1] - features], dim=1) + encoding = torch.nn.functional.interpolate(self.encoding, size=x_unenc.shape[2:], mode=self.interpolation_mode) + return torch.cat((x_enc + encoding, x_unenc), dim=1) diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py new file mode 100644 index 000000000..a00702613 --- /dev/null +++ b/src/mrpro/nn/nets/DCAE.py @@ -0,0 +1,162 @@ +from collections.abc import Sequence +from torch.nn import Module +import torch +from mrpro.nn import Sequential, SiLU +from mrpro.nn.LinearSelfAttention import LinearSelfAttention +from mrpro.nn.NDModules import ConvND +from mrpro.nn.GluMBConvResBlock import GluMBConvResBlock +from mrpro.nn.MultiHeadAttention import MultiHeadAttention +from mrpro.nn.PixelShuffle import PixelUnshuffleDownsample, PixelShuffleUpsampe +from mrpro.nn.RMSNorm import RMSNorm + + +class ResBlock(Module): + def __init__( + self, + dim: int, + channels int, + + ): + super().__init__() + self.inner=Sequential( + ConvND(dim)( + channels,channels,kernel_size=3,padding=1 + ), + SiLU(), + ConvND(dim)( + channels,channels,kernel_size=3,padding=1, bias=False + ), + RMSNorm(channels) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.inner(x) + x + + + +class EfficientViTBlock(Module): + def __init__( + self, + dim: int, + channels: int, + n_heads: int, + expand_ratio: float = 4, + linear_attn: bool = False, + ): + super().__init__() + if linear_attn: + attention = LinearSelfAttention(channels, channels, n_heads) #TODO: check heads and head dim + else: + attention = MultiHeadAttention(channels, channels, n_heads, features_last=False) + self.context_module=Sequential(attention,RMSNorm(channels)) + self.local_module = GluMBConvResBlock( + dim=dim, + channels_in=channels, + channels_out=channels, + expand_ratio=expand_ratio, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.context_module(x) + x + x = self.local_module(x) # is already residual + return x + + + +class Encoder(Sequential): + def __init__(self, dim:int=2, channels_in:int=3, channels_out:int=32,block_types:Sequence[str]=("ResBlock","ResBlock","LinearViT","LinearViT","ViT"),widths:Sequence[int]=(256,512,512,1024,1024),depths:Sequence[int]=(4,6,2,2,2)): + super().__init__() + self.append(PixelUnshuffleDownsample(dim,channels_in,widths[0], downscale_factor=2, residual=False)) + if len(block_types) != len(widths) or len(block_types) != len(depths): + raise ValueError("block_types, widths, and depths must have the same length") + for block_type,width, depth in zip(block_types,widths,depths): + match block_type: + case "ResBlock": + stage = [ResBlock(dim,width) for _ in range(depth)] + case "LinearViT": + stage = [EfficientViTBlock(dim,width,n_heads=1, linear_attn=True) for _ in range(depth)] # TODO: heads + case "ViT": + stage = [EfficientViTBlock(dim,width,n_heads=1, linear_attn=False) for _ in range(depth)] + case _: + raise ValueError(f"Block type {block_type} not supported") + self.append(Sequential(stage)) + if len(self) < len(widths): + self.append(PixelUnshuffleDownsample(dim,width,width, downscale_factor=2, residual=True)) + self.append(PixelUnshuffleDownsample(dim,widths[-1],channels_out, downscale_factor=1, residual=True)) + + +class Decoder(Module): +def __init__( + self, + dim: int = 2, + channels_in:int=32, + channels_out: int = 3, + block_types: Sequence[str] = ( + "ViT", "LinearViT", + "LinearViT", "ResBlock", "ResBlock" + ), + widths: Sequence[int] = (1024, 1024, 512, 512, 256), + depths: Sequence[int] = (2, 2, 2, 6, 4), + ): + super().__init__() + if not (len(block_types) == len(widths) == len(depths)): + raise ValueError( + "block_types, widths, and depths must have the same length" + ) + # "decoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU] " + # "decoder.width_list=[128,256,512,512,1024,1024] decoder.depth_list=[0,5,10,2,2,2] " + # "decoder.norm=[bn2d,bn2d,bn2d,trms2d,trms2d,trms2d] decoder.act=[relu,relu,relu,silu,silu,silu]" + self.append( + PixelShuffleUpsampe(dim, channels_in, widths[0], upscale_factor=1, residual=True) + ) + + + + self.stages: list[OpSequential] = [] + for block_type,width, depth in zip(block_types,widths,depths): + match block_type: + case "ResBlock": + stage = [ResBlock(dim,width) for _ in range(depth)] + case "LinearViT": + stage = [EfficientViTBlock(dim,width,n_heads=1, linear_attn=True) for _ in range(depth)] # TODO: heads + case "ViT": + stage = [EfficientViTBlock(dim,width,n_heads=1, linear_attn=False) for _ in range(depth)] + case _: + raise ValueError(f"Block type {block_type} not supported") + self.append(Sequential(stage)) + if len(self) < len(widths): + self.append(PixelShuffleUpsampe(dim,width,width, upscale_factor=2, residual=True)) + + + stage.extend( + build_stage_main( + width=width, + depth=depth, + block_type=block_type, + norm=norm, + act=act, + input_width=( + width if cfg.upsample_match_channel else cfg.width_list[min(stage_id + 1, num_stages - 1)] + ), + ) + ) + self.stages.insert(0, OpSequential(stage)) + self.stages = nn.ModuleList(self.stages) + + self.project_out = build_decoder_project_out_block( + in_channels=cfg.width_list[0] if cfg.depth_list[0] > 0 else cfg.width_list[1], + out_channels=cfg.in_channels, + factor=1 if cfg.depth_list[0] > 0 else 2, + upsample_block_type=cfg.upsample_block_type, + norm=cfg.out_norm, + act=cfg.out_act, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.project_in(x) + for stage in reversed(self.stages): + if len(stage.op_list) == 0: + continue + x = stage(x) + x = self.project_out(x) + return x diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index 126d88c08..505c88bcf 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -2,12 +2,13 @@ from collections.abc import Sequence import torch -from torch.nn import Module, PixelUnshuffle, PixelShuffle +from torch.nn import Module from mrpro.nn.TransposedAttention import TransposedAttention -from mrpro.nn.NDModules import ConvNd, InstanceNormNd +from mrpro.nn.NDModules import ConvND, ConvNd, InstanceNormNd from mrpro.nn.FiLM import FiLM -from mrpro.nn.nets.UNetBase import UNetBase +from mrpro.nn.nets.UNet import UNetBase from mrpro.nn.Sequential import Sequential +from mrpro.nn.PixelShuffle import PixelShuffle, PixelUnshuffle class GDFN(Module): """Gated depthwise feed forward network. @@ -55,8 +56,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x - class Restormer(UNetBase): + """Restormer architecture. + + Implements the Restormer [ZAM22]_ network, which is a U-shaped transformer + with channel wise attention and depthwise convolutions in the feed forward network. + + References + ---------- + .. [ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for high-resolution image restoration." + CVPR 2022, https://arxiv.org/pdf/2111.09881.pdf + """ + def __init__( self, dim: int, @@ -75,185 +86,22 @@ def __init__( def blocks(n_heads: int, n_blocks: int): layers = Sequential( - *(RestormerBlock(dim, n_channels_per_head, n_heads, mlp_ratio) for i in range(n_blocks)) + *(RestormerBlock(dim, n_channels_per_head, n_heads, mlp_ratio) for _ in range(n_blocks)) ) if emb_dim > 0 and n_blocks > 1: - layers.insert(1, FiLM(channels=n_features_per_head * n_heads, channels_emb=emb_dim)) + layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, channels_emb=emb_dim)) return layers - - for n_block, n_heads in zip(n_blocks, n_heads): self.input_blocks.append(blocks(n_heads, n_block)) self.output_blocks.append(blocks(n_heads, n_block)) self.skip_blocks.append(Identity()) - - for n_head_current, n_head_next in pairwise(n_heads): - self.down_blocks.append( - Sequential( - ConvND(dim)(n_channels_per_head * n_head_current, n_channels_per_head * n_head_next, kernel_size=4, stride=2, padding=1, - PixelUnshuffle(2) - ) - self.up_blocks.append( - nn.Sequential( - nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), PixelShuffle(2) - ) - ) self.output_blocks = self.output_blocks[::-1] - self.middle_block = blocks(n_heads, n_blocks) - - num_heads=heads[0], - ffn_expansion_factor=mlp_ratio, - ) - for i in range(num_blocks[0]) - ] - ) - - self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 - self.encoder_level2 = nn.Sequential( - *[ - TransformerBlock( - dim=int(dim * 2**1), - num_heads=heads[1], - ffn_expansion_factor=ffn_expansion_factor, - bias=bias, - LayerNorm_type=LayerNorm_type, - ) - for i in range(num_blocks[1]) - ] - ) - - self.down2_3 = Downsample(int(dim * 2**1)) ## From Level 2 to Level 3 - self.encoder_level3 = nn.Sequential( - *[ - TransformerBlock( - dim=int(dim * 2**2), - num_heads=heads[2], - ffn_expansion_factor=ffn_expansion_factor, - bias=bias, - LayerNorm_type=LayerNorm_type, - ) - for i in range(num_blocks[2]) - ] - ) - - self.down3_4 = Downsample(int(dim * 2**2)) ## From Level 3 to Level 4 - self.latent = nn.Sequential( - *[ - TransformerBlock( - dim=int(dim * 2**3), - num_heads=heads[3], - ffn_expansion_factor=ffn_expansion_factor, - bias=bias, - LayerNorm_type=LayerNorm_type, - ) - for i in range(num_blocks[3]) - ] - ) - - self.up4_3 = Upsample(int(dim * 2**3)) ## From Level 4 to Level 3 - self.reduce_chan_level3 = nn.Conv2d(int(dim * 2**3), int(dim * 2**2), kernel_size=1, bias=bias) - self.decoder_level3 = nn.Sequential( - *[ - TransformerBlock( - dim=int(dim * 2**2), - num_heads=heads[2], - ffn_expansion_factor=ffn_expansion_factor, - bias=bias, - LayerNorm_type=LayerNorm_type, - ) - for i in range(num_blocks[2]) - ] - ) - - self.up3_2 = Upsample(int(dim * 2**2)) ## From Level 3 to Level 2 - self.reduce_chan_level2 = nn.Conv2d(int(dim * 2**2), int(dim * 2**1), kernel_size=1, bias=bias) - self.decoder_level2 = nn.Sequential( - *[ - TransformerBlock( - dim=int(dim * 2**1), - num_heads=heads[1], - ffn_expansion_factor=ffn_expansion_factor, - bias=bias, - LayerNorm_type=LayerNorm_type, - ) - for i in range(num_blocks[1]) - ] - ) - - self.up2_1 = Upsample(int(dim * 2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) - - self.decoder_level1 = nn.Sequential( - *[ - TransformerBlock( - dim=int(dim * 2**1), - num_heads=heads[0], - ffn_expansion_factor=ffn_expansion_factor, - bias=bias, - LayerNorm_type=LayerNorm_type, - ) - for i in range(num_blocks[0]) - ] - ) - - self.refinement = nn.Sequential( - *[ - TransformerBlock( - dim=int(dim * 2**1), - num_heads=heads[0], - ffn_expansion_factor=ffn_expansion_factor, - bias=bias, - LayerNorm_type=LayerNorm_type, - ) - for i in range(num_refinement_blocks) - ] - ) - - #### For Dual-Pixel Defocus Deblurring Task #### - self.dual_pixel_task = dual_pixel_task - if self.dual_pixel_task: - self.skip_conv = nn.Conv2d(dim, int(dim * 2**1), kernel_size=1, bias=bias) - ########################### - - self.output = nn.Conv2d(int(dim * 2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) - - def forward(self, inp_img): - inp_enc_level1 = self.patch_embed(inp_img) - out_enc_level1 = self.encoder_level1(inp_enc_level1) - - inp_enc_level2 = self.down1_2(out_enc_level1) - out_enc_level2 = self.encoder_level2(inp_enc_level2) - - inp_enc_level3 = self.down2_3(out_enc_level2) - out_enc_level3 = self.encoder_level3(inp_enc_level3) - - inp_enc_level4 = self.down3_4(out_enc_level3) - latent = self.latent(inp_enc_level4) - - inp_dec_level3 = self.up4_3(latent) - inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) - inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) - out_dec_level3 = self.decoder_level3(inp_dec_level3) - - inp_dec_level2 = self.up3_2(out_dec_level3) - inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) - inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) - out_dec_level2 = self.decoder_level2(inp_dec_level2) - - inp_dec_level1 = self.up2_1(out_dec_level2) - inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) - out_dec_level1 = self.decoder_level1(inp_dec_level1) - - out_dec_level1 = self.refinement(out_dec_level1) - - #### For Dual-Pixel Defocus Deblurring Task #### - if self.dual_pixel_task: - out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) - out_dec_level1 = self.output(out_dec_level1) - ########################### - else: - out_dec_level1 = self.output(out_dec_level1) + inp_img + for n_head_current, n_head_next in pairwise(n_heads): + self.down_blocks.append(Sequential(ConvND(dim)(n_head_current*n_channels_per_head, n_head_next*n_channels_per_head // 2**dim, kernel_size=3, padding=1, bias=False), PixelUnshuffle(2))) + self.up_blocks.append(Sequential(ConvND(dim)(n_head_next*n_channels_per_head, n_head_current*n_channels_per_head*2**dim, kernel_size=3, padding=1, bias=False), PixelShuffle(2))) - return out_dec_level1 + self.middle_block = blocks(n_heads, n_blocks) + self.last = Sequential(*blocks(n_heads[0],n_refinement_blocks), ConvND(dim)(n_channels_per_head*n_heads[0], channels_out, kernel_size=3, stride=1, padding=1)) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index dd28e4715..6b8eebe1d 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -71,6 +71,27 @@ def __call__(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: return self(x, emb) +class AttentionUNet(UNet): + """UNet with attention gates. + + References + ---------- + .. [OKT18] Oktay, Ozan, et al. "Attention U-net: Learning where to look for the pancreas." MIDL (2018). + https://arxiv.org/abs/1804.03999 + """ + +class SeparableUNet(UNetBase): + """UNet where blocks apply separable convolutions in different dimensions + + Based on the pseudo-3D residual network of [QUI]_ and the residual blocks of [ZIM]_. + + References + ---------- + .. [QUI] Qiu, Z., Yao, T., & Mei, T. Learning spatio-temporal representation with pseudo-3d residual networks. + ICCV 2017. https://arxiv.org/abs/1711.10305 + .. [ZIM] Zimmermann, F. F., & Kofler, A. (2023, October). NoSENSE: Learned unrolled cardiac MRI reconstruction without explicit sensitivity maps. + STACOM MICCAI 2023. https://arxiv.org/abs/2309.15608 + """ class UNet(UNetBase): """UNet. @@ -87,7 +108,7 @@ class UNet(UNetBase): def __init__( self, dim:int, - + in_channels: int, out_channels: int, n_features: Sequence[int], @@ -96,7 +117,6 @@ def __init__( channels_emb: int, dim: int, num_blocks: int, - attention_gate: padding_modes:str|Sequence[str] - + ) -> None: ... diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index ee2c86d81..e351e32a6 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -5,12 +5,12 @@ from sympy import Identity from torch.nn import GELU, LeakyReLU, Module, Sequential +from mrpro.nn.DropPath import DropPath +from mrpro.nn.FiLM import FiLM from mrpro.nn.NDModules import ConvND, ConvTransposeND, InstanceNormND from mrpro.nn.nets.UNet import UNetBase -from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention -from mrpro.nn.FiLM import FiLM from mrpro.nn.Sequential import Sequential -from mrpro.nn.DropPath import DropPath +from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention class LeFF(Module): @@ -62,7 +62,7 @@ class LeWinTransformerBlock(Module): def __init__( self, dim: int, - n_features_per_head: int, + n_channels_per_head: int, n_heads: int, window_size: int = 8, shifted: bool = False, @@ -75,7 +75,7 @@ def __init__( ---------- dim : int Dimension of the input, e.g. 2 or 3 - n_features_per_head : int + n_channels_per_head : int Number of features per head n_heads : int Number of attention heads @@ -89,11 +89,12 @@ def __init__( Dropout probability for the drop path. """ super().__init__() - channels = n_features_per_head * n_heads + channels = n_channels_per_head * n_heads self.norm1 = InstanceNormND(dim)(channels) self.attn = ShiftedWindowAttention( dim=dim, - n_channels_per_head=n_features_per_head, + channels_in=channels, + channels_out=channels, n_heads=n_heads, window_size=window_size, shifted=shifted, @@ -132,7 +133,7 @@ def __init__( dim: int, channels_in: int, channels_out: int, - n_features_per_head: int = 32, + n_channels_per_head: int = 32, n_heads: Sequence[int] = (1, 2, 4, 8), n_blocks: int = 2, emb_dim: int = 0, @@ -150,9 +151,9 @@ def __init__( Number of input channels channels_out : int Number of output channels - n_features_per_head : int, optional + n_channels_per_head : int, optional Number of features per head. The number of features at a resolution level is given by - `n_features_per_head * n_heads`. + `n_channels_per_head * n_heads`. n_heads : Sequence[int], optional Number of attention heads at each resolution level. n_blocks : int, optional @@ -176,7 +177,7 @@ def blocks(n_heads: int, p_droppath: float = 0.0): LeWinTransformerBlock( dim=dim, n_heads=n_heads, - n_features_per_head=n_features_per_head, + n_channels_per_head=n_channels_per_head, window_size=window_size, mlp_ratio=mlp_ratio, shifted=bool(i % 2), @@ -187,7 +188,7 @@ def blocks(n_heads: int, p_droppath: float = 0.0): ) if emb_dim > 0 and n_blocks > 1: - layers.insert(1, FiLM(channels=n_features_per_head * n_heads, channels_emb=emb_dim)) + layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, channels_emb=emb_dim)) return layers drop_path_rates = torch.linspace(0, max_droppath_rate, len(n_heads)).tolist() @@ -201,8 +202,8 @@ def blocks(n_heads: int, p_droppath: float = 0.0): for n_head_current, n_head_next in pairwise(n_heads): self.down_blocks.append( ConvND(dim)( - n_features_per_head * n_head_current, - n_features_per_head * n_head_next, + n_channels_per_head * n_head_current, + n_channels_per_head * n_head_next, kernel_size=4, stride=2, padding=1, @@ -210,13 +211,13 @@ def blocks(n_heads: int, p_droppath: float = 0.0): ) self.up_blocks.append( ConvTransposeND(dim)( - n_features_per_head * n_head_next, n_features_per_head * n_head_current, kernel_size=2, stride=2 + n_channels_per_head * n_head_next, n_channels_per_head * n_head_current, kernel_size=2, stride=2 ) ) self.first = torch.nn.Sequential( - ConvND(dim)(channels_in, n_features_per_head * n_heads[0], kernel_size=3, stride=1, padding='same'), + ConvND(dim)(channels_in, n_channels_per_head * n_heads[0], kernel_size=3, stride=1, padding='same'), LeakyReLU(), ) self.last = ConvND(dim)( - n_features_per_head * n_heads[-1], channels_out, kernel_size=3, stride=1, padding='same' + n_channels_per_head * n_heads[-1], channels_out, kernel_size=3, stride=1, padding='same' ) diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index 1f36ff0c3..92465470d 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -15,7 +15,7 @@ from mrpro.utils.TensorAttributeMixin import TensorAttributeMixin from mrpro.utils.interpolate import interpolate, apply_lowres from mrpro.utils.RandomGenerator import RandomGenerator - +from mrpro.utils.to_tuple import to_tuple __all__ = [ "Indexer", "RandomGenerator", @@ -35,6 +35,7 @@ "split_idx", "summarize_object", "summarize_values", + "to_tuple", "typing", "unit_conversion", "unsqueeze_at", @@ -42,5 +43,5 @@ "unsqueeze_right", "unsqueeze_tensors_at", "unsqueeze_tensors_left", - "unsqueeze_tensors_right", + "unsqueeze_tensors_right" ] \ No newline at end of file diff --git a/src/mrpro/utils/to_tuple.py b/src/mrpro/utils/to_tuple.py new file mode 100644 index 000000000..657d7bf56 --- /dev/null +++ b/src/mrpro/utils/to_tuple.py @@ -0,0 +1,36 @@ +"""Standardize an argument to a fixed-length tuple.""" + +from collections.abc import Sequence +from typing import TypeVar + +T = TypeVar('T') + + +def to_tuple(length: int, arg: Sequence[T] | T) -> tuple[T, ...]: + """Standardize an argument to a fixed-length tuple. + + If the argument is a sequence, it checks if its length matches the + specified dimension. If it's a single value, it replicates it `dim` times. + + Parameters + ---------- + length + The expected length of the sequence. + arg + The argument to check. Can be a single value of type T or a + sequence of T. + + Returns + ------- + A tuple of length `dim` containing elements of type T. + + Raises + ------ + ValueError + If `arg` is a sequence and its length does not match `length`. + """ + if isinstance(arg, Sequence): + if not len(arg) == length: + raise ValueError(f'The arguments must be either single values or have length {length}. Got {arg}.') + return tuple(arg) + return (arg,) * length diff --git a/tests/nn/test_transposedattention.py b/tests/nn/test_transposedattention.py index 417743135..8768301df 100644 --- a/tests/nn/test_transposedattention.py +++ b/tests/nn/test_transposedattention.py @@ -30,7 +30,7 @@ def test_transposed_attention(dim, channels, num_heads, input_shape, device): assert x.grad is not None, 'No gradient computed for input' assert not x.isnan().any(), 'NaN values in input' assert not x.grad.isnan().any(), 'NaN values in input gradients' - assert attn.qkv.weight.grad is not None, 'No gradient computed for qkv' + assert attn.to_qkv.weight.grad is not None, 'No gradient computed for qkv' assert attn.qkv_dwconv.weight.grad is not None, 'No gradient computed for qkv_dwconv' assert attn.project_out.weight.grad is not None, 'No gradient computed for project_out' assert attn.temperature.grad is not None, 'No gradient computed for temperature' From 7f37fa99ba4392cd89a52cca08b5e18ff0fe50a3 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 19 May 2025 01:52:18 +0200 Subject: [PATCH 033/205] update --- src/mrpro/nn/AttentionGate.py | 2 +- src/mrpro/nn/ComplexAsChannel.py | 16 +- src/mrpro/nn/CondMixin.py | 21 ++ src/mrpro/nn/CoordConv.py | 0 src/mrpro/nn/EmbMixin.py | 21 -- src/mrpro/nn/FiLM.py | 37 ++-- src/mrpro/nn/GluMBConvResBlock.py | 18 +- src/mrpro/nn/{GroupNorm32.py => GroupNorm.py} | 2 +- src/mrpro/nn/LayerNorm.py | 64 ++++++ src/mrpro/nn/LinearSelfAttention.py | 25 ++- src/mrpro/nn/MultiHeadAttention.py | 7 +- src/mrpro/nn/NeighborhoodSelfAttention.py | 3 +- src/mrpro/nn/PixelShuffle.py | 29 +++ src/mrpro/nn/ResBlock.py | 34 ++-- src/mrpro/nn/Residual.py | 43 ++++ src/mrpro/nn/Sequential.py | 18 +- src/mrpro/nn/__init__.py | 11 +- src/mrpro/nn/nets/CNN.py | 58 ++++++ src/mrpro/nn/nets/DCAE.py | 109 +++++----- src/mrpro/nn/nets/Restormer.py | 58 ++++-- src/mrpro/nn/nets/SwinIR.py | 189 ++++++++++++++++++ src/mrpro/nn/nets/UNet.py | 35 ++-- src/mrpro/nn/nets/Uformer.py | 10 +- src/mrpro/nn/nets/VAE.py | 55 +++++ tests/nn/test_film.py | 2 +- tests/nn/test_groupnorm32.py | 4 +- tests/nn/test_sequential.py | 2 +- 27 files changed, 682 insertions(+), 191 deletions(-) create mode 100644 src/mrpro/nn/CondMixin.py create mode 100644 src/mrpro/nn/CoordConv.py delete mode 100644 src/mrpro/nn/EmbMixin.py rename src/mrpro/nn/{GroupNorm32.py => GroupNorm.py} (97%) create mode 100644 src/mrpro/nn/LayerNorm.py create mode 100644 src/mrpro/nn/Residual.py create mode 100644 src/mrpro/nn/nets/CNN.py create mode 100644 src/mrpro/nn/nets/SwinIR.py create mode 100644 src/mrpro/nn/nets/VAE.py diff --git a/src/mrpro/nn/AttentionGate.py b/src/mrpro/nn/AttentionGate.py index a20f04396..db1c4aac7 100644 --- a/src/mrpro/nn/AttentionGate.py +++ b/src/mrpro/nn/AttentionGate.py @@ -13,7 +13,7 @@ class AttentionGate(Module): References ---------- - ..[OKT18] Oktay, Ozan, et al. "Attention u-net: Learning where to look for the pancreas." MIDL (2018). + ..[OKT18] Oktay, Ozan, et al. "Attention U-net: Learning where to look for the pancreas." MIDL (2018). https://arxiv.org/abs/1804.03999 """ diff --git a/src/mrpro/nn/ComplexAsChannel.py b/src/mrpro/nn/ComplexAsChannel.py index 5acce1245..bab6cf6cb 100644 --- a/src/mrpro/nn/ComplexAsChannel.py +++ b/src/mrpro/nn/ComplexAsChannel.py @@ -2,10 +2,10 @@ from einops import rearrange from torch.nn import Module -from mrpro.nn.EmbMixin import EmbMixin +from mrpro.nn.CondMixin import CondMixin, call_with_cond -class ComplexAsChannel(EmbMixin, Module): +class ComplexAsChannel(CondMixin, Module): """Wrap module to treat complex numbers as a channel dimension.""" def __init__(self, module: Module): @@ -24,19 +24,19 @@ def __init__(self, module: Module): super().__init__() self.module = module - def __call__(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the module. Parameters ---------- x : torch.Tensor The input tensor. - emb : torch.Tensor | None - The embedding tensor. + cond : torch.Tensor | None + The conditioning tensor (if used by the wrapped module) """ - return super().__call__(x, emb) + return super().__call__(x, cond) - def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the module.""" if x.is_complex(): x_real = torch.view_as_real(x) @@ -44,7 +44,7 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Ten else: x_real = x - y = self.module(x_real) + y = call_with_cond(self.module, x_real, cond) if x.is_complex(): y = rearrange(y, 'b (channel complex) ... -> b channel ... complex', complex=2).contiguous() diff --git a/src/mrpro/nn/CondMixin.py b/src/mrpro/nn/CondMixin.py new file mode 100644 index 000000000..2dccbd305 --- /dev/null +++ b/src/mrpro/nn/CondMixin.py @@ -0,0 +1,21 @@ +"""Base class for modules using a conditioning.""" + +import torch +from torch.nn import Module + + +def call_with_cond(module: Module, x: torch.Tensor, cond: torch.Tensor | None) -> torch.Tensor: + if isinstance(CondMixin, Module): + return module(x, cond) + return module(x) + + +class CondMixin(Module): + """Mixin for modules using a conditioning. + + Used to determine if a module uses a conditioning within a Sequential container. + """ + + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module to the input.""" + return super().__call__(x, cond) diff --git a/src/mrpro/nn/CoordConv.py b/src/mrpro/nn/CoordConv.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/mrpro/nn/EmbMixin.py b/src/mrpro/nn/EmbMixin.py deleted file mode 100644 index 5188ae964..000000000 --- a/src/mrpro/nn/EmbMixin.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Base class for modules using an embedding.""" - -import torch -from torch.nn import Module - - -def call_with_emb(module: Module, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: - if isinstance(EmbMixin, Module): - return module(x, emb) - return module(x) - - -class EmbMixin(Module): - """Mixin for modules using an embedding. - - Used to determine if a module uses an embedding within a Sequential container. - """ - - def __call__(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: - """Apply the module to the input.""" - return super().__call__(x, emb) diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py index 9944dc0a4..9cac169e9 100644 --- a/src/mrpro/nn/FiLM.py +++ b/src/mrpro/nn/FiLM.py @@ -3,58 +3,59 @@ import torch from torch.nn import Identity, Linear, Module, Sequential, SiLU -from mrpro.nn.EmbMixin import EmbMixin +from mrpro.nn.CondMixin import CondMixin from mrpro.utils.reshape import unsqueeze_tensors_right -class FiLM(EmbMixin, Module): +class FiLM(CondMixin, Module): """Feature-wise Linear Modulation. - Feature-wise Linear Modulation from [FiLM]_ + Feature-wise Linear Modulation from [FiLM]_ to condition a network on a conditioning tensor. + References ---------- - ..[FiLM] Perez, L., Strub, F., de Vries, H., Dumoulin, V., & Courville, A. "Film: Visual reasoning with a general conditioning layer." AAAI (2018). - https://arxiv.org/abs/1709.07871 + ..[FiLM] Perez, L., Strub, F., de Vries, H., Dumoulin, V., & Courville, A. "FiLM: Visual reasoning with a general + conditioning layer." AAAI (2018). https://arxiv.org/abs/1709.07871 """ - def __init__(self, channels: int, channels_emb: int) -> None: + def __init__(self, channels: int, cond_dim: int) -> None: """Initialize FiLM. Parameters ---------- channels The number of channels in the input tensor. - channels_emb - The number of channels in the embedding tensor. + cond_dim + The dimension of the conditioning tensor. """ super().__init__() - if channels_emb > 0: + if cond_dim > 0: self.project = Sequential( SiLU(), - Linear(channels_emb, 2 * channels), + Linear(cond_dim, 2 * channels), ) else: self.project = Identity() - def __call__(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply FiLM. Parameters ---------- x The input tensor. - emb - The embedding tensor. + cond + The conditioning tensor. """ - return super().__call__(x, emb) + return super().__call__(x, cond) - def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply FiLM.""" - if emb is None: + if cond is None: return x - emb = self.project(emb) - scale, shift = emb.chunk(2, dim=1) + cond = self.project(cond) + scale, shift = cond.chunk(2, dim=1) scale, shift = unsqueeze_tensors_right(scale, shift, ndim=x.ndim) return x * (1 + scale) + shift diff --git a/src/mrpro/nn/GluMBConvResBlock.py b/src/mrpro/nn/GluMBConvResBlock.py index 1b9059a04..f48883125 100644 --- a/src/mrpro/nn/GluMBConvResBlock.py +++ b/src/mrpro/nn/GluMBConvResBlock.py @@ -3,13 +3,13 @@ import torch from torch.nn import Identity, Module, Sequential, SiLU -from mrpro.nn.EmbMixin import EmbMixin +from mrpro.nn.CondMixin import CondMixin from mrpro.nn.FiLM import FiLM from mrpro.nn.NDModules import ConvND from mrpro.nn.RMSNorm import RMSNorm -class GluMBConvResBlock(EmbMixin, Module): +class GluMBConvResBlock(CondMixin, Module): """Gated MBConv residual block. Gated variant [DCAE]_ of the MBConv block [EffNet]_ with a residual connection. @@ -30,7 +30,7 @@ def __init__( expand_ratio: int = 6, stride: int = 1, kernel_size: int = 3, - emb_dim: int = 0, + cond_dim: int = 0, ): """Initialize MBConv block. @@ -48,8 +48,8 @@ def __init__( Stride of the depthwise convolution. kernel_size Kernel size of the depthwise convolution. - emb_dim - Size of the FiLM embedding. If 0, no embedding is used. + cond_dim + Dimension of the conditioning tensor used in a FiLM. If 0, no FiLM is used. """ super().__init__() channels_mid = channels_in * expand_ratio @@ -85,18 +85,18 @@ def __init__( RMSNorm(channels_out), SiLU(), ) - if emb_dim > 0: - self.film: FiLM | None = FiLM(channels_mid, emb_dim) + if cond_dim > 0: + self.film: FiLM | None = FiLM(channels_mid, cond_dim) else: self.film = None - def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply MBConv block.""" h = self.inverted_conv(x) h = self.depth_conv(h) h, gate = torch.chunk(h, 2, dim=1) h = h * torch.nn.functional.silu(gate) if self.film is not None: - h = self.film(h, emb) + h = self.film(h, cond) h = self.point_conv(h) return self.skip(x) + h diff --git a/src/mrpro/nn/GroupNorm32.py b/src/mrpro/nn/GroupNorm.py similarity index 97% rename from src/mrpro/nn/GroupNorm32.py rename to src/mrpro/nn/GroupNorm.py index 55c11d1f6..9a50a6319 100644 --- a/src/mrpro/nn/GroupNorm32.py +++ b/src/mrpro/nn/GroupNorm.py @@ -3,7 +3,7 @@ import torch -class GroupNorm32(torch.nn.GroupNorm): +class GroupNorm(torch.nn.GroupNorm): """A 32-bit GroupNorm. Casts to float32 before calling the parent class to avoid instabilities in mixed precision training. diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py new file mode 100644 index 000000000..1bea902d1 --- /dev/null +++ b/src/mrpro/nn/LayerNorm.py @@ -0,0 +1,64 @@ +from torch.nn import Module, Parameter +import torch +from mrpro.nn.utils import unsqueeze_right + + +class LayerNorm(Module): + """Layer normalization.""" + + def __init__(self, channels: int | None, channel_last: bool = False, bias: bool = True) -> None: + """Initialize the layer normalization. + + Parameters + ---------- + channels + Number of channels in the input tensor. If `None`, the layer normalization does not do an elementwise + affine transformation. + channel_last + If `True`, the channel dimension is the last dimension. + bias + If `False`, only a scaling is applied without an offset if an affine transformation is used. + """ + super().__init__() + if channels is not None: + self.weight = Parameter(torch.ones(channels)) + self.bias = Parameter(torch.zeros(channels)) if bias else None + else: + self.weight = None + self.bias = None + self.channel_last = channel_last + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply layer normalization to the input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor + + Returns + ------- + Normalized output tensor + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply layer normalization to the input tensor.""" + dims = tuple(range(1, x.ndim)) + mean = x.mean(dim=dims, keepdim=True) + std = x.std(dim=dims, keepdim=True, unbiased=False) + x = (x - mean) / (std + 1e-5) + + if self.weight is not None: + if self.channel_last: + x = x * self.weight + else: + x = x * unsqueeze_right(self.weight, x.ndim - 2) + + if self.bias is not None: + if self.channel_last: + x = x + self.bias + else: + x = x + unsqueeze_right(self.bias, x.ndim - 2) + + return x diff --git a/src/mrpro/nn/LinearSelfAttention.py b/src/mrpro/nn/LinearSelfAttention.py index 362c78269..4adb77e88 100644 --- a/src/mrpro/nn/LinearSelfAttention.py +++ b/src/mrpro/nn/LinearSelfAttention.py @@ -1,3 +1,5 @@ +"""Linear self-attention""" + import torch from einops import rearrange from torch import Tensor @@ -10,7 +12,8 @@ class LinearSelfAttention(Module): Uses a ReLU kernel to compute attention in O(N) [KAT20]_ time and space. - Refereces + References + ---------- .. [KAT20] Katharopoulos, Angelos, et al. Transformers are rnns: Fast autoregressive transformers with linear attention. ICML 2020. https://arxiv.org/abs/2006.16236 @@ -34,6 +37,22 @@ def __init__( eps: float = 1e-6, channel_last: bool = False, ): + """Initialize linear self-attention layer. + + Parameters + ---------- + channels_in + Input channel dimension. + channels_out + Output channel dimension. + n_heads + Number of attention heads. + eps + Small epsilon for numerical stability in normalization. + channel_last + Whether the channel dimension is the last dimension, as common in transformer models, + or the second dimension, as common in image models. + """ super().__init__() self.channel_last = channel_last self.eps = eps @@ -79,8 +98,8 @@ def forward(self, x: Tensor) -> Tensor: value_key = value @ key.transpose(-1, -2) value_key_query = value_key @ query - normalisation = value_key_query[..., -1:, :] + self.eps - attn = value_key_query[..., :-1, :] / normalisation + normalization = value_key_query[..., -1:, :] + self.eps + attn = value_key_query[..., :-1, :] / normalization out = self.to_out(attn) out = out.to(orig_dtype) out.unflatten(-2, spatial_shape) diff --git a/src/mrpro/nn/MultiHeadAttention.py b/src/mrpro/nn/MultiHeadAttention.py index 332ada94c..953f0a1e7 100644 --- a/src/mrpro/nn/MultiHeadAttention.py +++ b/src/mrpro/nn/MultiHeadAttention.py @@ -31,10 +31,15 @@ def __init__( Number of output channels. num_heads number of attention heads + features_last + Whether the features dimension is the last dimension, as common in transformer models, + or the second dimension, as common in image models. + p_dropout + Dropout probability. """ super().__init__() self.mha = torch.nn.MultiheadAttention( - embed_dim=channels_in, num_heads=num_heads, batch_first=True, dropout=p_dropout + conded_dim=channels_in, num_heads=num_heads, batch_first=True, dropout=p_dropout ) self.features_last = features_last self.to_out = Linear(channels_in, channels_out) diff --git a/src/mrpro/nn/NeighborhoodSelfAttention.py b/src/mrpro/nn/NeighborhoodSelfAttention.py index 625c47174..fe7aeec74 100644 --- a/src/mrpro/nn/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/NeighborhoodSelfAttention.py @@ -107,7 +107,8 @@ class NeighborhoodSelfAttention(Module): Neighborhood attention is a type of attention where each query attends to a neighborhood of the key and value. It is a more efficient alternative to regular attention, especially for large input sizes [NAT]_. - This implementation uses `~torch.nn.attention.flex_attention`. For a more efficient implementation, see also [NATTEN]_. + This implementation uses `~torch.nn.attention.flex_attention`. For a more efficient implementation, + see also [NATTEN]_. References diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py index 4994ff0a7..fedacb9dd 100644 --- a/src/mrpro/nn/PixelShuffle.py +++ b/src/mrpro/nn/PixelShuffle.py @@ -26,7 +26,22 @@ def __init__(self, downscale_factor: int): super().__init__() self.downscale_factor = downscale_factor + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Downscale the input. + + Parameters + ---------- + x + Tensor of shape `batch, channels, *spatial_dims` + + Returns + ------- + Tensor of shape `batch, channels * downscale_factor**dim, *spatial_dims/downscale_factor` + """ + return super().__call__(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Downscale the input.""" dim = x.ndim - 2 if dim == 2: # fast path for 2D return torch.nn.functional.pixel_unshuffle(x, self.downscale_factor) @@ -179,6 +194,20 @@ def __init__(self, upscale_factor: int): super().__init__() self.upscale_factor = upscale_factor + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Upscale the input. + + Parameters + ---------- + x + Tensor of shape `batch, channels, *spatial_dims` + + Returns + ------- + Tensor of shape `batch, channels / upscale_factor**dim, *spatial_dims * upscale_factor` + """ + return super().__call__(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Upscale the input.""" dim = x.ndim - 2 diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py index da6ee3fff..bc0f4f9cf 100644 --- a/src/mrpro/nn/ResBlock.py +++ b/src/mrpro/nn/ResBlock.py @@ -3,17 +3,17 @@ import torch from torch.nn import Identity, Module, SiLU -from mrpro.nn.EmbMixin import EmbMixin +from mrpro.nn.CondMixin import CondMixin from mrpro.nn.FiLM import FiLM -from mrpro.nn.GroupNorm32 import GroupNorm32 +from mrpro.nn.GroupNorm import GroupNorm from mrpro.nn.NDModules import ConvND from mrpro.nn.Sequential import Sequential -class ResBlock(EmbMixin, Module): +class ResBlock(CondMixin, Module): """Residual convolution block with two convolutions.""" - def __init__(self, dim: int, channels_in: int, channels_out: int, channels_emb: int) -> None: + def __init__(self, dim: int, channels_in: int, channels_out: int, cond_dim: int) -> None: """Initialize the ResBlock. Parameters @@ -24,47 +24,47 @@ def __init__(self, dim: int, channels_in: int, channels_out: int, channels_emb: The number of channels in the input tensor. channels_out The number of channels in the output tensor. - channels_emb - The number of channels in the embedding tensor used in a FiLM embedding. - If set to 0 no FiLM embedding is used. + cond_dim + The number of features in the conditioning tensor used in a FiLM. + If set to 0 no FiLM is used. """ super().__init__() self.rezero = torch.nn.Parameter(torch.tensor(1e-6)) self.block = Sequential( - GroupNorm32(channels_in), + GroupNorm(channels_in), SiLU(), ConvND(dim)(channels_in, channels_out, kernel_size=3, padding=1), - GroupNorm32(channels_out), + GroupNorm(channels_out), SiLU(), ConvND(dim)(channels_out, channels_out, kernel_size=3, padding=1), ) - if channels_emb > 0: - self.block.insert(-3, FiLM(channels_out, channels_emb)) + if cond_dim > 0: + self.block.insert(-3, FiLM(channels_out, cond_dim)) if channels_out == channels_in: self.skip_connection: Module = Identity() else: self.skip_connection = ConvND(dim)(channels_in, channels_out, kernel_size=1) - def __call__(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the ResBlock. Parameters ---------- x The input tensor. - emb - An embedding tensor to be used for FiLM. + cond + A conditioning tensor to be used for FiLM. Returns ------- The output tensor. """ - return super().__call__(x, emb) + return super().__call__(x, cond) - def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the ResBlock.""" - h = self.block(x, emb) + h = self.block(x, cond) x = self.skip_connection(x) + h return x diff --git a/src/mrpro/nn/Residual.py b/src/mrpro/nn/Residual.py new file mode 100644 index 000000000..86698ce6b --- /dev/null +++ b/src/mrpro/nn/Residual.py @@ -0,0 +1,43 @@ +from mrpro.nn.CondMixin import CondMixin, call_with_cond + +import torch +from torch.nn import Module, Identity + + +class Residual(CondMixin, Module): + """Residual connection.""" + + def __init__(self, module: Module, skip: Module | None = None): + """Initialize the residual connection. + + Parameters + ---------- + module + The main path of the residual connection. + skip + The skip path of the residual connection. If None, the identity function is used. + """ + super().__init__() + self.module = module + self.skip = Identity() if skip is None else skip + + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module. + + Parameters + ---------- + x + The input tensor. + cond + The optional conditioning tensor. If the modules are an instance of `CondMixin`, + the conditioning is passed to the modules. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond) + + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module.""" + return call_with_cond(self.module, x, cond) + call_with_cond(self.skip, x, cond) diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py index b08e43f62..84c19dd46 100644 --- a/src/mrpro/nn/Sequential.py +++ b/src/mrpro/nn/Sequential.py @@ -1,33 +1,33 @@ import torch -from mrpro.nn.EmbMixin import EmbMixin +from mrpro.nn.CondMixin import CondMixin from mrpro.operators import Operator class Sequential(torch.nn.Sequential): - """Sequential container with support for embedding and Operators.""" + """Sequential container with support for conditioning and Operators.""" - def __call__(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply all modules in series to the input. Parameters ---------- x The input tensor. - emb - The (optional) embedding tensor. + cond + The (optional) conditioning tensor. Returns ------- The output tensor. """ - return super().__call__(x, emb) + return super().__call__(x, cond) - def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply all modules in series to the input.""" for module in self: - if isinstance(module, EmbMixin): - x = module(x, emb) + if isinstance(module, CondMixin): + x = module(x, cond) elif isinstance(module, Operator): (x,) = module(x) else: diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index bc6a93aea..93f243237 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -1,9 +1,9 @@ """Neural network modules and utilities.""" from mrpro.nn.AttentionGate import AttentionGate -from mrpro.nn.EmbMixin import EmbMixin +from mrpro.nn.CondMixin import CondMixin from mrpro.nn.FiLM import FiLM -from mrpro.nn.GroupNorm32 import GroupNorm32 +from mrpro.nn.GroupNorm import GroupNorm from mrpro.nn.NDModules import ( AdaptiveAvgPoolND, AvgPoolND, @@ -26,12 +26,12 @@ "AttentionGate", "AvgPoolND", "BatchNormND", + "CondMixin", "ConvND", "ConvTransposeND", "DropPath", - "EmbMixin", "FiLM", - "GroupNorm32", + "GroupNorm", "InstanceNormND", "MaxPoolND", "NeighborhoodSelfAttention", @@ -39,5 +39,6 @@ "Sequential", "ShiftedWindowAttention", "SqueezeExcitation", - "TransposedAttention" + "TransposedAttention", + "nets" ] \ No newline at end of file diff --git a/src/mrpro/nn/nets/CNN.py b/src/mrpro/nn/nets/CNN.py new file mode 100644 index 000000000..39e5d48b0 --- /dev/null +++ b/src/mrpro/nn/nets/CNN.py @@ -0,0 +1,58 @@ +from collections.abc import Sequence +from itertools import pairwise + +from torch.nn import SiLU + +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.NDModules import ConvND +from mrpro.nn.Residual import Residual +from mrpro.nn.Sequential import Sequential +from mrpro.nn.FiLM import FiLM + + +class CNN(Sequential): + """A simple CNN network.""" + + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + features: Sequence[int], + norm: bool = True, + residual: bool = True, + cond_dim: int = 0, + ): + """Initialize the CNN. + + Parameters + ---------- + dim + The number of spatial dimensions. + channels_in + The number of input channels. + channels_out + The number of output channels. + features + The number of features in each layer. The length of the list is the number of hidden layers. + norm + Whether to use layer normalization. + residual + Whether to use residual connections. + cond_dim + The dimension of the conditioning tensor. If 0, no FiLM is used. + """ + super().__init__() + channels = [channels_in, *features] + for i, (channels_current, channels_next) in enumerate(pairwise(channels)): + block = Sequential(ConvND(dim)(channels_current, channels_next, 3, padding=1), SiLU(True)) + if norm: + block.append(GroupNorm(1)) + if cond_dim > 0 and i % 2 == 0: + block.append(FiLM(channels_next, cond_dim)) + if residual: + self.append(Residual(block)) + else: + self.append(block) + + self.append(ConvND(dim)(channels_next, channels_out, 3, padding=1)) diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py index a00702613..31266ede3 100644 --- a/src/mrpro/nn/nets/DCAE.py +++ b/src/mrpro/nn/nets/DCAE.py @@ -1,12 +1,14 @@ from collections.abc import Sequence -from torch.nn import Module + import torch +from torch.nn import Module + from mrpro.nn import Sequential, SiLU -from mrpro.nn.LinearSelfAttention import LinearSelfAttention -from mrpro.nn.NDModules import ConvND from mrpro.nn.GluMBConvResBlock import GluMBConvResBlock +from mrpro.nn.LinearSelfAttention import LinearSelfAttention from mrpro.nn.MultiHeadAttention import MultiHeadAttention -from mrpro.nn.PixelShuffle import PixelUnshuffleDownsample, PixelShuffleUpsampe +from mrpro.nn.NDModules import ConvND +from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample from mrpro.nn.RMSNorm import RMSNorm @@ -14,26 +16,20 @@ class ResBlock(Module): def __init__( self, dim: int, - channels int, - + channels: int, ): super().__init__() - self.inner=Sequential( - ConvND(dim)( - channels,channels,kernel_size=3,padding=1 - ), + self.inner = Sequential( + ConvND(dim)(channels, channels, kernel_size=3, padding=1), SiLU(), - ConvND(dim)( - channels,channels,kernel_size=3,padding=1, bias=False - ), - RMSNorm(channels) + ConvND(dim)(channels, channels, kernel_size=3, padding=1, bias=False), + RMSNorm(channels), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.inner(x) + x - class EfficientViTBlock(Module): def __init__( self, @@ -45,10 +41,10 @@ def __init__( ): super().__init__() if linear_attn: - attention = LinearSelfAttention(channels, channels, n_heads) #TODO: check heads and head dim + attention = LinearSelfAttention(channels, channels, n_heads) # TODO: check heads and head dim else: attention = MultiHeadAttention(channels, channels, n_heads, features_last=False) - self.context_module=Sequential(attention,RMSNorm(channels)) + self.context_module = Sequential(attention, RMSNorm(channels)) self.local_module = GluMBConvResBlock( dim=dim, channels_in=channels, @@ -58,75 +54,76 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.context_module(x) + x - x = self.local_module(x) # is already residual + x = self.local_module(x) # is already residual return x - class Encoder(Sequential): - def __init__(self, dim:int=2, channels_in:int=3, channels_out:int=32,block_types:Sequence[str]=("ResBlock","ResBlock","LinearViT","LinearViT","ViT"),widths:Sequence[int]=(256,512,512,1024,1024),depths:Sequence[int]=(4,6,2,2,2)): + def __init__( + self, + dim: int = 2, + channels_in: int = 3, + channels_out: int = 32, + block_types: Sequence[str] = ('ResBlock', 'ResBlock', 'LinearViT', 'LinearViT', 'ViT'), + widths: Sequence[int] = (256, 512, 512, 1024, 1024), + depths: Sequence[int] = (4, 6, 2, 2, 2), + ): super().__init__() - self.append(PixelUnshuffleDownsample(dim,channels_in,widths[0], downscale_factor=2, residual=False)) + self.append(PixelUnshuffleDownsample(dim, channels_in, widths[0], downscale_factor=2, residual=False)) if len(block_types) != len(widths) or len(block_types) != len(depths): - raise ValueError("block_types, widths, and depths must have the same length") - for block_type,width, depth in zip(block_types,widths,depths): + raise ValueError('block_types, widths, and depths must have the same length') + for block_type, width, depth in zip(block_types, widths, depths, strict=False): match block_type: - case "ResBlock": - stage = [ResBlock(dim,width) for _ in range(depth)] - case "LinearViT": - stage = [EfficientViTBlock(dim,width,n_heads=1, linear_attn=True) for _ in range(depth)] # TODO: heads - case "ViT": - stage = [EfficientViTBlock(dim,width,n_heads=1, linear_attn=False) for _ in range(depth)] + case 'ResBlock': + stage = [ResBlock(dim, width) for _ in range(depth)] + case 'LinearViT': + stage = [ + EfficientViTBlock(dim, width, n_heads=1, linear_attn=True) for _ in range(depth) + ] # TODO: heads + case 'ViT': + stage = [EfficientViTBlock(dim, width, n_heads=1, linear_attn=False) for _ in range(depth)] case _: - raise ValueError(f"Block type {block_type} not supported") + raise ValueError(f'Block type {block_type} not supported') self.append(Sequential(stage)) if len(self) < len(widths): - self.append(PixelUnshuffleDownsample(dim,width,width, downscale_factor=2, residual=True)) - self.append(PixelUnshuffleDownsample(dim,widths[-1],channels_out, downscale_factor=1, residual=True)) + self.append(PixelUnshuffleDownsample(dim, width, width, downscale_factor=2, residual=True)) + self.append(PixelUnshuffleDownsample(dim, widths[-1], channels_out, downscale_factor=1, residual=True)) class Decoder(Module): -def __init__( + def __init__( self, dim: int = 2, - channels_in:int=32, + channels_in: int = 32, channels_out: int = 3, - block_types: Sequence[str] = ( - "ViT", "LinearViT", - "LinearViT", "ResBlock", "ResBlock" - ), + block_types: Sequence[str] = ('ViT', 'LinearViT', 'LinearViT', 'ResBlock', 'ResBlock'), widths: Sequence[int] = (1024, 1024, 512, 512, 256), depths: Sequence[int] = (2, 2, 2, 6, 4), ): super().__init__() if not (len(block_types) == len(widths) == len(depths)): - raise ValueError( - "block_types, widths, and depths must have the same length" - ) + raise ValueError('block_types, widths, and depths must have the same length') # "decoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU] " # "decoder.width_list=[128,256,512,512,1024,1024] decoder.depth_list=[0,5,10,2,2,2] " # "decoder.norm=[bn2d,bn2d,bn2d,trms2d,trms2d,trms2d] decoder.act=[relu,relu,relu,silu,silu,silu]" - self.append( - PixelShuffleUpsampe(dim, channels_in, widths[0], upscale_factor=1, residual=True) - ) - - + self.append(PixelShuffleUpsampe(dim, channels_in, widths[0], upscale_factor=1, residual=True)) self.stages: list[OpSequential] = [] - for block_type,width, depth in zip(block_types,widths,depths): + for block_type, width, depth in zip(block_types, widths, depths, strict=False): match block_type: - case "ResBlock": - stage = [ResBlock(dim,width) for _ in range(depth)] - case "LinearViT": - stage = [EfficientViTBlock(dim,width,n_heads=1, linear_attn=True) for _ in range(depth)] # TODO: heads - case "ViT": - stage = [EfficientViTBlock(dim,width,n_heads=1, linear_attn=False) for _ in range(depth)] + case 'ResBlock': + stage = [ResBlock(dim, width) for _ in range(depth)] + case 'LinearViT': + stage = [ + EfficientViTBlock(dim, width, n_heads=1, linear_attn=True) for _ in range(depth) + ] # TODO: heads + case 'ViT': + stage = [EfficientViTBlock(dim, width, n_heads=1, linear_attn=False) for _ in range(depth)] case _: - raise ValueError(f"Block type {block_type} not supported") + raise ValueError(f'Block type {block_type} not supported') self.append(Sequential(stage)) if len(self) < len(widths): - self.append(PixelShuffleUpsampe(dim,width,width, upscale_factor=2, residual=True)) - + self.append(PixelShuffleUpsample(dim, width, width, upscale_factor=2, residual=True)) stage.extend( build_stage_main( diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index 505c88bcf..49f437d15 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -1,14 +1,17 @@ """Restormer implementation.""" from collections.abc import Sequence + import torch from torch.nn import Module -from mrpro.nn.TransposedAttention import TransposedAttention -from mrpro.nn.NDModules import ConvND, ConvNd, InstanceNormNd + from mrpro.nn.FiLM import FiLM +from mrpro.nn.NDModules import ConvND, ConvNd, InstanceNormNd from mrpro.nn.nets.UNet import UNetBase -from mrpro.nn.Sequential import Sequential from mrpro.nn.PixelShuffle import PixelShuffle, PixelUnshuffle +from mrpro.nn.Sequential import Sequential +from mrpro.nn.TransposedAttention import TransposedAttention + class GDFN(Module): """Gated depthwise feed forward network. @@ -40,15 +43,15 @@ def forward(self, x): class RestormerBlock(Module): """Transformer block with transposed attention and gated depthwise feed forward network.""" - def __init__(self, dim: int, channels: int, num_heads: int, mlp_ratio: float, emb_dim: int = 0): + def __init__(self, dim: int, channels: int, num_heads: int, mlp_ratio: float, cond_dim: int = 0): super().__init__() self.norm1 = Sequential(InstanceNormNd(dim)(channels)) self.attn = TransposedAttention(dim, channels, num_heads) self.norm2 = Sequential(InstanceNormNd(dim)(channels)) self.ffn = GDFN(dim, channels, mlp_ratio) - if emb_dim > 0: - self.norm1.append(FiLM(channels=channels, channels_emb=emb_dim)) - self.norm2.append(FiLM(channels=channels, channels_emb=emb_dim)) + if cond_dim > 0: + self.norm1.append(FiLM(channels=channels, cond_dim=cond_dim)) + self.norm2.append(FiLM(channels=channels, cond_dim=cond_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.norm1(x)) @@ -78,7 +81,7 @@ def __init__( n_heads: Sequence[int] = (1, 2, 4, 8), n_channels_per_head: int = 48, mlp_ratio: float = 2.66, - emb_dim: int = 0, + cond_dim: int = 0, ): super().__init__() @@ -89,19 +92,44 @@ def blocks(n_heads: int, n_blocks: int): *(RestormerBlock(dim, n_channels_per_head, n_heads, mlp_ratio) for _ in range(n_blocks)) ) - if emb_dim > 0 and n_blocks > 1: - layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, channels_emb=emb_dim)) + if cond_dim > 0 and n_blocks > 1: + layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, cond_dim=cond_dim)) return layers - for n_block, n_heads in zip(n_blocks, n_heads): + for n_block, n_heads in zip(n_blocks, n_heads, strict=False): self.input_blocks.append(blocks(n_heads, n_block)) self.output_blocks.append(blocks(n_heads, n_block)) self.skip_blocks.append(Identity()) self.output_blocks = self.output_blocks[::-1] - for n_head_current, n_head_next in pairwise(n_heads): - self.down_blocks.append(Sequential(ConvND(dim)(n_head_current*n_channels_per_head, n_head_next*n_channels_per_head // 2**dim, kernel_size=3, padding=1, bias=False), PixelUnshuffle(2))) - self.up_blocks.append(Sequential(ConvND(dim)(n_head_next*n_channels_per_head, n_head_current*n_channels_per_head*2**dim, kernel_size=3, padding=1, bias=False), PixelShuffle(2))) + for n_head_current, n_head_next in pairwise(n_heads): + self.down_blocks.append( + Sequential( + ConvND(dim)( + n_head_current * n_channels_per_head, + n_head_next * n_channels_per_head // 2**dim, + kernel_size=3, + padding=1, + bias=False, + ), + PixelUnshuffle(2), + ) + ) + self.up_blocks.append( + Sequential( + ConvND(dim)( + n_head_next * n_channels_per_head, + n_head_current * n_channels_per_head * 2**dim, + kernel_size=3, + padding=1, + bias=False, + ), + PixelShuffle(2), + ) + ) self.middle_block = blocks(n_heads, n_blocks) - self.last = Sequential(*blocks(n_heads[0],n_refinement_blocks), ConvND(dim)(n_channels_per_head*n_heads[0], channels_out, kernel_size=3, stride=1, padding=1)) + self.last = Sequential( + *blocks(n_heads[0], n_refinement_blocks), + ConvND(dim)(n_channels_per_head * n_heads[0], channels_out, kernel_size=3, stride=1, padding=1), + ) diff --git a/src/mrpro/nn/nets/SwinIR.py b/src/mrpro/nn/nets/SwinIR.py new file mode 100644 index 000000000..90fd8f508 --- /dev/null +++ b/src/mrpro/nn/nets/SwinIR.py @@ -0,0 +1,189 @@ +import torch +from torch.nn import GELU, Module + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.nn.NDModules import ConvND +from mrpro.nn.Sequential import Sequential +from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention + + +class SwinTransformerLayer(Module): + """Swin Transformer layer. + + As used in the SwinIR network. + """ + + def __init__( + self, + dim: int, + channels: int, + n_heads: int, + window_size: int, + shifted: bool, + mlp_ratio: int = 4, + cond_dim: int = 0, + ): + """Initialize the Swin Transformer layer. + + Parameters + ---------- + dim + Number of spatial dimensions (1D, 2D, or 3D) + channels + Number of channels in the input tensor + n_heads + Number of attention heads + window_size + Size of the local window for computing windowed self-attention + shifted + Whether to use shifted window attention + mlp_ratio + Expansion ratio for the MLP + cond_dim + Dimension of optional tensor for FiLM conditioning. If 0, no conditioning is used + """ + super().__init__() + self.norm1 = LayerNorm(channels) + self.attn = ShiftedWindowAttention(dim, channels, n_heads, window_size, shifted) + if cond_dim > 0: + self.norm2 = Sequential(LayerNorm(None), FiLM(cond_dim)) + else: + self.norm2 = Sequential(LayerNorm(channels)) + self.mlp = Sequential( + ConvND(dim)(channels, channels * mlp_ratio, 1), + GELU(), + ConvND(dim)(channels * mlp_ratio, channels, 1), + ) + + def __call__(self, x, cond: torch.Tensor | None = None): + """Apply the Swin Transformer layer. + + Parameters + ---------- + x + Input tensor of shape (batch_size, channels, *spatial_dims) + cond + Optional conditioning tensor of shape (batch_size, cond_dim) + + Returns + ------- + Output tensor of shape (batch_size, channels, *spatial_dims) + """ + return super().__call__(x, cond) + + def forward(self, x, cond: torch.Tensor | None = None): + """Apply the Swin Transformer layer.""" + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x, cond)) + return x + + +class RSTB(Module): + """Residual Swin Transformer block. + + As used in the SwinIR network. + """ + + def __init__(self, dim: int, channels: int, n_heads: int, window_size: int, depth: int, cond_dim: int = 0): + super().__init__() + self.layers = Sequential( + *[ + SwinTransformerLayer(dim, channels, n_heads, window_size, shifted=(i % 2 == 1), cond_dim=cond_dim) + for i in range(depth) + ] + ) + self.conv = ConvND(dim)(channels, channels, 3, padding=1) + + def __call__(self, x, cond: torch.Tensor | None = None): + """Apply the residual Swin Transformer block. + + Parameters + ---------- + x + Input tensor of shape (batch_size, channels, *spatial_dims) + cond + Optional conditioning tensor of shape (batch_size, cond_dim) + + Returns + ------- + Output tensor of shape (batch_size, channels, *spatial_dims) + """ + return super().__call__(x, cond) + + def forward(self, x, cond: torch.Tensor | None = None): + """Apply the residual Swin Transformer block.""" + return x + self.conv(self.layers(x, cond)) + + +class SwinIR(Module): + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + channels_per_head: int = 16, + n_heads: int = 6, + window_size: int = 64, + n_blocks: int = 6, + n_attn_per_block: int = 6, + cond_dim: int = 0, + ): + """Initialize the SwinIR model. + + Parameters + ---------- + dim + Number of spatial dimensions (1D, 2D, or 3D) + channels_in + Number of input channels + channels_out + Number of output channels + channels_per_head + Number of channels per attention head + n_heads + Number of attention heads + window_size + Size of the local window for computing windowed self-attention + n_blocks + Number of residual Swin Transformer blocks (RSTB) + n_attn_per_block + Number of windowed attention layers per RSTB block + cond_dim + Dimension of optional tensor for FiLM conditioning. If 0, no conditioning is used + """ + super().__init__() + self.shallow = ConvND(dim)(channels_in, channels_per_head * n_heads, 3, padding=1) + self.body = Sequential( + *[ + RSTB(dim, channels_per_head * n_heads, n_heads, window_size, n_attn_per_block, cond_dim) + for _ in range(n_blocks) + ] + ) + self.body.append(ConvND(dim)(channels_per_head * n_heads, channels_per_head * n_heads, 3, padding=1)) + self.final = ConvND(dim)(channels_per_head, channels_out, 3, padding=1) + self.skip = ConvND(dim)(channels_in, channels_out, 1, padding=1) + + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SwinIR model. + + Parameters + ---------- + x + Input tensor of shape (batch_size, channels_in, *spatial_dims) + cond + Optional conditioning tensor of shape (batch_size, cond_dim) + + Returns + ------- + out + Output tensor of shape (batch_size, channels_out, *spatial_dims) + """ + return super().__call__(x, cond) + + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SwinIR model.""" + h = self.shallow(x) + h = self.body(h, cond) + self.skip(x) + out = self.final(h) + return out diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 6b8eebe1d..b1cedfac5 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -1,9 +1,10 @@ +from collections.abc import Sequence from functools import partial import torch from torch.nn import Identity, Module, ModuleList -from mrpro.nn.EmbMixin import call_with_emb +from mrpro.nn.CondMixin import call_with_cond class UNetBase(Module): @@ -38,9 +39,9 @@ def __init__(self) -> None: self.first = Identity() """The first block""" - def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: """Apply to Network.""" - call = partial(call_with_emb, emb=emb) + call = partial(call_with_cond, cond=cond) x = call(self.first, x) xs = [] for block, down, skip in zip(self.input_blocks, self.down_blocks, self.skip_blocks, strict=True): @@ -54,21 +55,21 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: x = call(block, x) return call(self.last, x) - def __call__(self, x: torch.Tensor, emb: torch.Tensor | None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None) -> torch.Tensor: """Apply to Network. Parameters ---------- x The input tensor. - emb - The embedding tensor. + cond + The conditioning tensor. Returns ------- The output tensor. """ - return self(x, emb) + return self(x, cond) class AttentionUNet(UNet): @@ -80,6 +81,7 @@ class AttentionUNet(UNet): https://arxiv.org/abs/1804.03999 """ + class SeparableUNet(UNetBase): """UNet where blocks apply separable convolutions in different dimensions @@ -89,9 +91,11 @@ class SeparableUNet(UNetBase): ---------- .. [QUI] Qiu, Z., Yao, T., & Mei, T. Learning spatio-temporal representation with pseudo-3d residual networks. ICCV 2017. https://arxiv.org/abs/1711.10305 - .. [ZIM] Zimmermann, F. F., & Kofler, A. (2023, October). NoSENSE: Learned unrolled cardiac MRI reconstruction without explicit sensitivity maps. - STACOM MICCAI 2023. https://arxiv.org/abs/2309.15608 + .. [ZIM] Zimmermann, F. F., & Kofler, A. (2023, October). NoSENSE: Learned unrolled cardiac MRI reconstruction + without explicit sensitivity maps. STACOM MICCAI 2023. https://arxiv.org/abs/2309.15608 """ + + class UNet(UNetBase): """UNet. @@ -107,16 +111,13 @@ class UNet(UNetBase): def __init__( self, - dim:int, - + dim: int, in_channels: int, out_channels: int, n_features: Sequence[int], - n_heads:Sequence[int] - n_blocks:int|Sequence[int] - channels_emb: int, - dim: int, + n_heads: Sequence[int], + n_blocks: int | Sequence[int], + cond_dim: int, num_blocks: int, - padding_modes:str|Sequence[str] - + padding_modes: str | Sequence[str], ) -> None: ... diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index e351e32a6..6b4f6fa9f 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -136,7 +136,7 @@ def __init__( n_channels_per_head: int = 32, n_heads: Sequence[int] = (1, 2, 4, 8), n_blocks: int = 2, - emb_dim: int = 0, + cond_dim: int = 0, window_size: int = 8, mlp_ratio: float = 4.0, max_droppath_rate: float = 0.1, @@ -158,8 +158,8 @@ def __init__( Number of attention heads at each resolution level. n_blocks : int, optional Number of transformer blocks at each resolution level in the input and output path - emb_dim : int, optional - Dimension of the embedding. If `0`, no FiLM layers are added. + cond_dim : int, optional + Dimension of a conditioning tensor. If `0`, no FiLM layers are added. window_size : int, optional Size of the attention windows in the (shifted) window attention layers. mlp_ratio : float, optional @@ -187,8 +187,8 @@ def blocks(n_heads: int, p_droppath: float = 0.0): ) ) - if emb_dim > 0 and n_blocks > 1: - layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, channels_emb=emb_dim)) + if cond_dim > 0 and n_blocks > 1: + layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, cond_dim=cond_dim)) return layers drop_path_rates = torch.linspace(0, max_droppath_rate, len(n_heads)).tolist() diff --git a/src/mrpro/nn/nets/VAE.py b/src/mrpro/nn/nets/VAE.py new file mode 100644 index 000000000..3fcb2b417 --- /dev/null +++ b/src/mrpro/nn/nets/VAE.py @@ -0,0 +1,55 @@ +import torch +from torch.nn import Module + + +class VAE(Module): + """Basic Variational Autoencoder. + + Consists of an encoder to transform the input into a latent space and a decoder to transform the latent space back + into the original space. The encoder should return twice the number of channels as the decoder needs to reconstruct + the input: half of the channels are the mean and the other half the log variance of the latent space. + The reparameterization trick is used to sample from the latent space. + The forward pass returns the reconstructed image and the KL divergence between the latent space and the standard normal distribution. + """ + + def __init__(self, encoder: Module, decoder: Module): + """Initialize the VAE. + + Parameters + ---------- + encoder : Module + Encoder module. Should return double the number of channels of the latent space. + decoder : Module + Decoder module + """ + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass of the VAE. + + Calculates the reconstruction as well as the KL divergence between the latent space and the + standard normal distribution. + + Parameters + ---------- + x : torch.Tensor + Input tensor + + Returns + ------- + tuple of the reconstructed image and + the KL divergence between the latent space and the standard normal distribution. + """ + return self.forward(x) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass of the VAE.""" + z = self.encoder(x) + mean, logvar = z.chunk(2, dim=1) + std = torch.exp(0.5 * logvar) + sample = mean + torch.randn_like(std) * std + reconstruction = self.decoder(sample) + kl = -0.5 * torch.sum(1 + logvar - mean.square() - std.square()) + return reconstruction, kl diff --git a/tests/nn/test_film.py b/tests/nn/test_film.py index e069d4529..535b76055 100644 --- a/tests/nn/test_film.py +++ b/tests/nn/test_film.py @@ -24,7 +24,7 @@ def test_film(channels, channels_emb, input_shape, emb_shape, device): rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) emb = rng.float32_tensor(emb_shape).to(device).requires_grad_(True) - film = FiLM(channels=channels, channels_emb=channels_emb).to(device) + film = FiLM(channels=channels, cond_dim=channels_emb).to(device) output = film(x, emb) assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() diff --git a/tests/nn/test_groupnorm32.py b/tests/nn/test_groupnorm32.py index 389b8ca85..0c936dca7 100644 --- a/tests/nn/test_groupnorm32.py +++ b/tests/nn/test_groupnorm32.py @@ -1,7 +1,7 @@ """Tests for GroupNorm32 module.""" import pytest -from mrpro.nn import GroupNorm32 +from mrpro.nn import GroupNorm from mrpro.utils import RandomGenerator @@ -23,7 +23,7 @@ def test_groupnorm32(channels, groups, input_shape, device): """Test GroupNorm32 output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) - norm = GroupNorm32(channels=channels, groups=groups).to(device) + norm = GroupNorm(channels=channels, groups=groups).to(device) output = norm(x) assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() diff --git a/tests/nn/test_sequential.py b/tests/nn/test_sequential.py index da3297cdb..c52b0ed2a 100644 --- a/tests/nn/test_sequential.py +++ b/tests/nn/test_sequential.py @@ -29,7 +29,7 @@ def test_sequential(input_shape, emb_shape, device): seq = Sequential( Linear(input_shape[1], 64), FastFourierOp(), - FiLM(channels=64, channels_emb=16), + FiLM(channels=64, cond_dim=16), ).to(device) output = seq(x, emb) assert output.shape == (input_shape[0], 32), f'Output shape {output.shape} != expected {(input_shape[0], 32)}' From 97115e77559e36e4adf02c09cba5b96cc53c4100 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 19 May 2025 02:23:57 +0200 Subject: [PATCH 034/205] update --- src/mrpro/nn/FiLM.py | 8 +- src/mrpro/nn/LayerNorm.py | 9 +- src/mrpro/nn/LinearSelfAttention.py | 2 +- src/mrpro/nn/MultiHeadAttention.py | 6 +- src/mrpro/nn/Residual.py | 6 +- src/mrpro/nn/RoPE.py | 132 +++++++++++----- src/mrpro/nn/activations.py | 21 ++- src/mrpro/nn/encoding.py | 53 +++++++ src/mrpro/nn/nets/CNN.py | 2 +- src/mrpro/nn/nets/DCAE.py | 54 +++---- src/mrpro/nn/nets/Restormer.py | 126 ++++++++++++---- src/mrpro/nn/nets/SwinIR.py | 218 +++++++++++++++------------ src/mrpro/nn/nets/UNet.py | 48 +++--- src/mrpro/nn/nets/Uformer.py | 20 ++- src/mrpro/nn/nets/VAE.py | 5 +- src/mrpro/nn/test.ipynb | 149 ------------------ tests/nn/test_film.py | 2 +- tests/nn/test_transposedattention.py | 2 +- 18 files changed, 479 insertions(+), 384 deletions(-) delete mode 100644 src/mrpro/nn/test.ipynb diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py index 9cac169e9..014aa5835 100644 --- a/src/mrpro/nn/FiLM.py +++ b/src/mrpro/nn/FiLM.py @@ -15,7 +15,7 @@ class FiLM(CondMixin, Module): References ---------- - ..[FiLM] Perez, L., Strub, F., de Vries, H., Dumoulin, V., & Courville, A. "FiLM: Visual reasoning with a general + ..[FiLM] Perez, L., Strub, F., de Vries, H., Dumoulin, V., & Courville, A. "FiLM : Visual reasoning with a general conditioning layer." AAAI (2018). https://arxiv.org/abs/1709.07871 """ @@ -31,7 +31,7 @@ def __init__(self, channels: int, cond_dim: int) -> None: """ super().__init__() if cond_dim > 0: - self.project = Sequential( + self.project: Module = Sequential( SiLU(), Linear(cond_dim, 2 * channels), ) @@ -54,8 +54,6 @@ def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Te """Apply FiLM.""" if cond is None: return x - - cond = self.project(cond) - scale, shift = cond.chunk(2, dim=1) + scale, shift = self.project(cond).chunk(2, dim=1) scale, shift = unsqueeze_tensors_right(scale, shift, ndim=x.ndim) return x * (1 + scale) + shift diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py index 1bea902d1..863e81e13 100644 --- a/src/mrpro/nn/LayerNorm.py +++ b/src/mrpro/nn/LayerNorm.py @@ -1,6 +1,7 @@ -from torch.nn import Module, Parameter import torch -from mrpro.nn.utils import unsqueeze_right +from torch.nn import Module, Parameter + +from mrpro.utils.reshape import unsqueeze_right class LayerNorm(Module): @@ -21,8 +22,8 @@ def __init__(self, channels: int | None, channel_last: bool = False, bias: bool """ super().__init__() if channels is not None: - self.weight = Parameter(torch.ones(channels)) - self.bias = Parameter(torch.zeros(channels)) if bias else None + self.weight: Parameter | None = Parameter(torch.ones(channels)) + self.bias: Parameter | None = Parameter(torch.zeros(channels)) if bias else None else: self.weight = None self.bias = None diff --git a/src/mrpro/nn/LinearSelfAttention.py b/src/mrpro/nn/LinearSelfAttention.py index 4adb77e88..2626fb262 100644 --- a/src/mrpro/nn/LinearSelfAttention.py +++ b/src/mrpro/nn/LinearSelfAttention.py @@ -21,7 +21,7 @@ class LinearSelfAttention(Module): ---------- channels Input and output channel dimension. - num_heads + n_heads Number of attention heads. bias Whether to use bias in the QKV projection. diff --git a/src/mrpro/nn/MultiHeadAttention.py b/src/mrpro/nn/MultiHeadAttention.py index 953f0a1e7..cc94e3d3d 100644 --- a/src/mrpro/nn/MultiHeadAttention.py +++ b/src/mrpro/nn/MultiHeadAttention.py @@ -15,7 +15,7 @@ def __init__( self, channels_in: int, channels_out: int, - num_heads: int, + n_heads: int, features_last: bool = False, p_dropout: float = 0.0, ): @@ -29,7 +29,7 @@ def __init__( Number of input channels. channels_out Number of output channels. - num_heads + n_heads number of attention heads features_last Whether the features dimension is the last dimension, as common in transformer models, @@ -39,7 +39,7 @@ def __init__( """ super().__init__() self.mha = torch.nn.MultiheadAttention( - conded_dim=channels_in, num_heads=num_heads, batch_first=True, dropout=p_dropout + embed_dim=channels_in, n_heads=n_heads, batch_first=True, dropout=p_dropout ) self.features_last = features_last self.to_out = Linear(channels_in, channels_out) diff --git a/src/mrpro/nn/Residual.py b/src/mrpro/nn/Residual.py index 86698ce6b..19b047405 100644 --- a/src/mrpro/nn/Residual.py +++ b/src/mrpro/nn/Residual.py @@ -1,7 +1,7 @@ -from mrpro.nn.CondMixin import CondMixin, call_with_cond - import torch -from torch.nn import Module, Identity +from torch.nn import Identity, Module + +from mrpro.nn.CondMixin import CondMixin, call_with_cond class Residual(CondMixin, Module): diff --git a/src/mrpro/nn/RoPE.py b/src/mrpro/nn/RoPE.py index 4f71d1165..06f63b7b7 100644 --- a/src/mrpro/nn/RoPE.py +++ b/src/mrpro/nn/RoPE.py @@ -1,77 +1,139 @@ -from math import log +"""Rotary Position Embeddings (RoPE) implementation.""" import torch +from torch.nn import Module + +from mrpro.nn.NDModules import ConvND -# Rotary position embeddings @torch.compile -def apply_rotary_emb_(x: torch.Tensor, theta: torch.Tensor, conjugated: bool): - """Adds the rotary embedding to the input tensor (inplace). +def apply_rotary_emb_(x: torch.Tensor, theta: torch.Tensor, conjugated: bool) -> None: + """Add rotary embedding to the input tensor (inplace). This is a helper function for the `AxialRoPE` class. + + Parameters + ---------- + x : torch.Tensor + Input tensor to modify + theta : torch.Tensor + Rotation angles + conjugated : bool + Whether to use conjugated rotation """ n_emb = theta.shape[-1] * 2 if n_emb > x.shape[-1]: - raise ValueError('More theta values then channels//2 in the input tensor.') + raise ValueError(f'Embedding dimension {n_emb} is larger than input dimension {x.shape[-1]}') x1, x2 = x[..., :n_emb].chunk(2, dim=-1) - dtype = torch.promote_type(torch.result_type(x, theta), torch.float32) - x1_, x2_, theta = x1.to(dtype), x2.to(dtype), theta.to(dtype) - cos, sin = torch.cos(theta), torch.sin(theta) - sin = -sin if conjugated else sin - y1 = x1_ * cos - x2_ * sin - y2 = x2_ * cos + x1_ * sin - x1.copy_(y1) - x2.copy_(y2) + if conjugated: + x1, x2 = x2, x1 + x[..., :n_emb] = torch.cat([x1 * theta.cos() - x2 * theta.sin(), x2 * theta.cos() + x1 * theta.sin()], dim=-1) class RotaryEmbedding_(torch.autograd.Function): - """Adds the rotary embedding to the input tensor (inplace). - - This is a autograd helper class for the `AxialRoPE` class. - """ + """Custom autograd function for rotary embeddings.""" @staticmethod - def forward(x: torch.Tensor, theta: torch.Tensor, conjugated: bool) -> torch.Tensor: - apply_rotary_emb_(x, theta, conj=conj) + def forward( + ctx: torch.autograd.function.FunctionCtx, x: torch.Tensor, theta: torch.Tensor, conjugated: bool + ) -> torch.Tensor: + """Apply rotary embedding in forward pass.""" + apply_rotary_emb_(x, theta, conjugated) return x @staticmethod - def setup_context(ctx, inputs: tuple[torch.Tensor, torch.Tensor, bool], output: torch.Tensor): + def setup_context( + ctx: torch.autograd.function.FunctionCtx, inputs: tuple[torch.Tensor, torch.Tensor, bool], output: torch.Tensor + ) -> None: + """Save tensors for backward pass.""" _, theta, conjugated = inputs ctx.save_for_backward(theta) ctx.conjugated = conjugated @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]: + def backward( + ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor + ) -> tuple[torch.Tensor, None, None]: + """Apply backward pass.""" (theta,) = ctx.saved_tensors - apply_rotary_emb_(grad_output, theta, conjugated=not ctx.conjugated) + apply_rotary_emb_(grad_output, theta, ctx.conjugated) return grad_output, None, None class AxialRoPE(Module): + """Axial Rotary Position Embedding. + + Applies rotary position embeddings along each axis independently. + """ + def __init__(self, dim: int, d_head: int, n_heads: int, headpos: int = -2, non_embed_fraction: float = 0.5): + """Initialize AxialRoPE. + + Parameters + ---------- + dim : int + Dimension of the input space + d_head : int + Dimension of each attention head + n_heads : int + Number of attention heads + headpos : int, optional + Position of the head dimension, by default -2 + non_embed_fraction : float, optional + Fraction of dimensions to not embed, by default 0.5 + """ super().__init__() - log_min = log(torch.pi) - log_max = log(100 ** (1 / dim) * torch.pi) - d_per_head = int(d_head / dim * (1 - non_embed_fraction)) - freqs = torch.linspace(log_min, log_max, n_heads * d_per_head).exp() - freqs = freqs.view(-1, n_heads).T - freqs = freqs.unsqueeze(-2).repeat(1, dim, 1).contiguous() - self.freqs = torch.nn.Parameter(freqs) + log_min = torch.log(torch.tensor(torch.pi)) + log_max = torch.log(torch.tensor(10000.0)) + freqs = torch.exp(torch.linspace(log_min, log_max, d_head // 2)) + self.register_buffer('freqs', freqs) self.headpos = headpos - def get_theta(self, pos): + def get_theta(self, pos: torch.Tensor) -> torch.Tensor: + """Get rotation angles for given positions. + + Parameters + ---------- + pos : torch.Tensor + Position tensor + + Returns + ------- + torch.Tensor + Rotation angles + """ return (self.freqs * pos[..., None, :, None]).flatten(start_dim=-2).movedim(-2, self.headpos) - def forward(self, pos, *tensors): + def forward(self, pos: torch.Tensor, *tensors: torch.Tensor) -> None: + """Apply rotary embeddings to input tensors. + + Parameters + ---------- + pos : torch.Tensor + Position tensor + *tensors : torch.Tensor + Tensors to apply rotary embeddings to + """ theta = self.get_theta(pos) tuple(RotaryEmbedding_.apply(x, theta, False) for x in tensors) @staticmethod - def make_axial_positions(*shape): - shape = torch.as_tensor(shape) - m = shape.max() + def make_axial_positions(*shape: int) -> torch.Tensor: + """Create axial position tensors. + + Parameters + ---------- + *shape : int + Shape of the position tensor + + Returns + ------- + torch.Tensor + Position tensor + """ + m = torch.as_tensor(shape).max() pos = torch.stack( - torch.meshgrid([torch.linspace(-1 + 1 / s, 1 - 1 / s, s) * (s / m) for s in shape], indexing='ij'), -1 + [torch.arange(s, device=m.device) - s // 2 for s in shape], + dim=-1, ) return pos diff --git a/src/mrpro/nn/activations.py b/src/mrpro/nn/activations.py index 4f757476a..cc5dcc1ba 100644 --- a/src/mrpro/nn/activations.py +++ b/src/mrpro/nn/activations.py @@ -10,7 +10,7 @@ class GEGLU(Module): ..[GLU] Shazeer, N. (2020). GLU variants improve transformer. https://arxiv.org/abs/2002.05202 """ - def __init__(self, in_features: int, out_features: int | None = None): + def __init__(self, in_features: int, out_features: int | None = None, channels_last: bool = False): """Initialize the GEGLU activation function. Parameters @@ -18,12 +18,23 @@ def __init__(self, in_features: int, out_features: int | None = None): in_features : int The number of input features. out_features : int - The number of output features. If None, the number of output features is the same as the number of input features. + The number of output features. If None, the number of + output features is the same as the number of input features. + channels_last + If True, the channel dimension is the last dimension, else in the second dimension. """ super().__init__() - self.proj = Linear(in_features, out_features * 2) + out_features_ = in_features if out_features is None else out_features + self.proj = Linear(in_features, out_features_ * 2) # gate and output stacked + self.channels_last = channels_last - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the GEGLU activation.""" + if not self.channels_last: + x = x.moveaxis(1, -1) h, gate = self.proj(x).chunk(2, dim=-1) gate = torch.nn.functional.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) - return h * gate + out = h * gate + if not self.channels_last: + out = out.moveaxis(out, -1, 1) + return out diff --git a/src/mrpro/nn/encoding.py b/src/mrpro/nn/encoding.py index 32e239adc..f562bf7f0 100644 --- a/src/mrpro/nn/encoding.py +++ b/src/mrpro/nn/encoding.py @@ -1,3 +1,5 @@ +"""Encoding modules for neural networks.""" + from itertools import combinations from math import ceil @@ -8,18 +10,69 @@ class FourierFeatures(Module): + """Fourier feature encoding layer. + + Projects input features into a higher dimensional space using random Fourier features. + This is useful for encoding positional information in neural networks. + """ + + weight: torch.Tensor + def __init__(self, in_features: int, out_features: int, std: float = 1.0): + """Initialize Fourier feature encoding layer. + + Parameters + ---------- + in_features : int + Number of input features + out_features : int + Number of output features (must be even) + std : float, optional + Standard deviation for random initialization, by default 1.0 + """ super().__init__() assert out_features % 2 == 0 self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply Fourier feature encoding. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (..., in_features) + + Returns + ------- + torch.Tensor + Encoded features of shape (..., out_features) + """ f = 2 * torch.pi * x @ self.weight.T return torch.cat([f.cos(), f.sin()], dim=-1) class AbsolutePositionEncoding(Module): + """Absolute position encoding layer. + + Encodes absolute positions in a grid using learned embeddings. + """ + + encoding: torch.Tensor + def __init__(self, dim: int, features: int, include_radii: bool = True, base_resolution: int = 128): + """Initialize absolute position encoding layer. + + Parameters + ---------- + dim : int + Dimension of the input space (1, 2, or 3) + features : int + Number of output features + include_radii : bool, optional + Whether to include radius features, by default True + base_resolution : int, optional + Base resolution for position encoding, by default 128 + """ super().__init__() coords = [unsqueeze_right(torch.linspace(-1, 1, base_resolution), i) for i in range(dim)] diff --git a/src/mrpro/nn/nets/CNN.py b/src/mrpro/nn/nets/CNN.py index 39e5d48b0..f278759b6 100644 --- a/src/mrpro/nn/nets/CNN.py +++ b/src/mrpro/nn/nets/CNN.py @@ -3,11 +3,11 @@ from torch.nn import SiLU +from mrpro.nn.FiLM import FiLM from mrpro.nn.GroupNorm import GroupNorm from mrpro.nn.NDModules import ConvND from mrpro.nn.Residual import Residual from mrpro.nn.Sequential import Sequential -from mrpro.nn.FiLM import FiLM class CNN(Sequential): diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py index 31266ede3..49be43589 100644 --- a/src/mrpro/nn/nets/DCAE.py +++ b/src/mrpro/nn/nets/DCAE.py @@ -1,9 +1,9 @@ from collections.abc import Sequence import torch -from torch.nn import Module +from torch.nn import Module, Sequential -from mrpro.nn import Sequential, SiLU +from mrpro.nn import SiLU from mrpro.nn.GluMBConvResBlock import GluMBConvResBlock from mrpro.nn.LinearSelfAttention import LinearSelfAttention from mrpro.nn.MultiHeadAttention import MultiHeadAttention @@ -106,9 +106,9 @@ def __init__( # "decoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU] " # "decoder.width_list=[128,256,512,512,1024,1024] decoder.depth_list=[0,5,10,2,2,2] " # "decoder.norm=[bn2d,bn2d,bn2d,trms2d,trms2d,trms2d] decoder.act=[relu,relu,relu,silu,silu,silu]" - self.append(PixelShuffleUpsampe(dim, channels_in, widths[0], upscale_factor=1, residual=True)) + self.append(PixelShuffleUpsample(dim, channels_in, widths[0], upscale_factor=1, residual=True)) - self.stages: list[OpSequential] = [] + self.stages: list[Sequential] = [] for block_type, width, depth in zip(block_types, widths, depths, strict=False): match block_type: case 'ResBlock': @@ -125,29 +125,29 @@ def __init__( if len(self) < len(widths): self.append(PixelShuffleUpsample(dim, width, width, upscale_factor=2, residual=True)) - stage.extend( - build_stage_main( - width=width, - depth=depth, - block_type=block_type, - norm=norm, - act=act, - input_width=( - width if cfg.upsample_match_channel else cfg.width_list[min(stage_id + 1, num_stages - 1)] - ), - ) - ) - self.stages.insert(0, OpSequential(stage)) - self.stages = nn.ModuleList(self.stages) - - self.project_out = build_decoder_project_out_block( - in_channels=cfg.width_list[0] if cfg.depth_list[0] > 0 else cfg.width_list[1], - out_channels=cfg.in_channels, - factor=1 if cfg.depth_list[0] > 0 else 2, - upsample_block_type=cfg.upsample_block_type, - norm=cfg.out_norm, - act=cfg.out_act, - ) + # stage.extend( + # build_stage_main( + # width=width, + # depth=depth, + # block_type=block_type, + # norm=norm, + # act=act, + # input_width=( + # width if cfg.upsample_match_channel else cfg.width_list[min(stage_id + 1, num_stages - 1)] + # ), + # ) + # ) + # self.stages.insert(0, OpSequential(stage)) + # self.stages = nn.ModuleList(self.stages) + + # self.project_out = build_decoder_project_out_block( + # in_channels=cfg.width_list[0] if cfg.depth_list[0] > 0 else cfg.width_list[1], + # out_channels=cfg.in_channels, + # factor=1 if cfg.depth_list[0] > 0 else 2, + # upsample_block_type=cfg.upsample_block_type, + # norm=cfg.out_norm, + # act=cfg.out_act, + # ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.project_in(x) diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index 49f437d15..9f9824cbd 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -3,14 +3,15 @@ from collections.abc import Sequence import torch -from torch.nn import Module +from torch.nn import Module, Identity from mrpro.nn.FiLM import FiLM -from mrpro.nn.NDModules import ConvND, ConvNd, InstanceNormNd +from mrpro.nn.NDModules import ConvND, InstanceNormND from mrpro.nn.nets.UNet import UNetBase from mrpro.nn.PixelShuffle import PixelShuffle, PixelUnshuffle from mrpro.nn.Sequential import Sequential from mrpro.nn.TransposedAttention import TransposedAttention +from mrpro.utils import pairwise class GDFN(Module): @@ -20,11 +21,22 @@ class GDFN(Module): """ def __init__(self, dim: int, channels: int, mlp_ratio: float): + """Initialize GDFN. + + Parameters + ---------- + dim : int + Dimension of the input space + channels : int + Number of input/output channels + mlp_ratio : float + Ratio for hidden dimension expansion + """ super().__init__() hidden_features = int(channels * mlp_ratio) - self.project_in = ConvNd(dim)(channels, hidden_features * 2, kernel_size=1) - self.depthwise_conv = ConvNd(dim)( + self.project_in = ConvND(dim)(channels, hidden_features * 2, kernel_size=1) + self.depthwise_conv = ConvND(dim)( hidden_features * 2, hidden_features * 2, kernel_size=3, @@ -32,28 +44,68 @@ def __init__(self, dim: int, channels: int, mlp_ratio: float): padding=1, groups=hidden_features * 2, ) - self.project_out = ConvNd(dim)(hidden_features, channels, kernel_size=1) + self.project_out = ConvND(dim)(hidden_features, channels, kernel_size=1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply gated depthwise feed forward network. + + Parameters + ---------- + x : torch.Tensor + Input tensor + + Returns + ------- + torch.Tensor + Output tensor + """ x = self.project_in(x) x1, x2 = self.depthwise_conv(x).chunk(2, dim=1) - return self.project_out(torch.nn.functional.gelu(x1) * x2) + x = x1 * torch.sigmoid(x2) + x = self.project_out(x) + return x class RestormerBlock(Module): """Transformer block with transposed attention and gated depthwise feed forward network.""" - def __init__(self, dim: int, channels: int, num_heads: int, mlp_ratio: float, cond_dim: int = 0): + def __init__(self, dim: int, channels: int, n_heads: int, mlp_ratio: float, cond_dim: int = 0): + """Initialize RestormerBlock. + + Parameters + ---------- + dim : int + Dimension of the input space + channels : int + Number of input/output channels + n_heads : int + Number of attention heads + mlp_ratio : float + Ratio for hidden dimension expansion + cond_dim : int, optional + Dimension of conditioning input, by default 0 + """ super().__init__() - self.norm1 = Sequential(InstanceNormNd(dim)(channels)) - self.attn = TransposedAttention(dim, channels, num_heads) - self.norm2 = Sequential(InstanceNormNd(dim)(channels)) + self.norm1 = Sequential(InstanceNormND(dim)(channels)) + self.attn = TransposedAttention(dim, channels, n_heads) + self.norm2 = Sequential(InstanceNormND(dim)(channels)) self.ffn = GDFN(dim, channels, mlp_ratio) if cond_dim > 0: - self.norm1.append(FiLM(channels=channels, cond_dim=cond_dim)) self.norm2.append(FiLM(channels=channels, cond_dim=cond_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply Restormer block. + + Parameters + ---------- + x : torch.Tensor + Input tensor + + Returns + ------- + torch.Tensor + Output tensor + """ x = x + self.attn(self.norm1(x)) x = x + self.ffn(self.norm2(x)) return x @@ -83,9 +135,32 @@ def __init__( mlp_ratio: float = 2.66, cond_dim: int = 0, ): + """Initialize Restormer. + + Parameters + ---------- + dim : int + Dimension of the input space + channels_in : int + Number of input channels + channels_out : int + Number of output channels + n_blocks : Sequence[int], optional + Number of blocks in each stage, by default (4, 6, 6, 8) + n_refinement_blocks : int, optional + Number of refinement blocks, by default 4 + n_heads : Sequence[int], optional + Number of attention heads in each stage, by default (1, 2, 4, 8) + n_channels_per_head : int, optional + Number of channels per attention head, by default 48 + mlp_ratio : float, optional + Ratio for hidden dimension expansion, by default 2.66 + cond_dim : int, optional + Dimension of conditioning input, by default 0 + """ super().__init__() - self.first = ConvNd(dim)(channels_in, n_channels_per_head, kernel_size=3, stride=1, padding=1, bias=False) + self.first = ConvND(dim)(channels_in, n_channels_per_head, kernel_size=3, stride=1, padding=1, bias=False) def blocks(n_heads: int, n_blocks: int): layers = Sequential( @@ -106,30 +181,27 @@ def blocks(n_heads: int, n_blocks: int): self.down_blocks.append( Sequential( ConvND(dim)( - n_head_current * n_channels_per_head, - n_head_next * n_channels_per_head // 2**dim, + n_channels_per_head * n_head_current, + n_channels_per_head * n_head_next, kernel_size=3, + stride=2, padding=1, - bias=False, - ), - PixelUnshuffle(2), + ) ) ) self.up_blocks.append( Sequential( ConvND(dim)( - n_head_next * n_channels_per_head, - n_head_current * n_channels_per_head * 2**dim, + n_channels_per_head * n_head_next, + n_channels_per_head * n_head_current, kernel_size=3, + stride=1, padding=1, - bias=False, - ), - PixelShuffle(2), + ) ) ) - self.middle_block = blocks(n_heads, n_blocks) - self.last = Sequential( - *blocks(n_heads[0], n_refinement_blocks), - ConvND(dim)(n_channels_per_head * n_heads[0], channels_out, kernel_size=3, stride=1, padding=1), + self.refinement_blocks = Sequential( + *(RestormerBlock(dim, n_channels_per_head, n_heads[0], mlp_ratio) for _ in range(n_refinement_blocks)) ) + self.last = ConvND(dim)(n_channels_per_head, channels_out, kernel_size=3, stride=1, padding=1) diff --git a/src/mrpro/nn/nets/SwinIR.py b/src/mrpro/nn/nets/SwinIR.py index 90fd8f508..aff4037ad 100644 --- a/src/mrpro/nn/nets/SwinIR.py +++ b/src/mrpro/nn/nets/SwinIR.py @@ -1,9 +1,10 @@ +"""SwinIR implementation.""" + import torch -from torch.nn import GELU, Module +from torch.nn import Module from mrpro.nn.FiLM import FiLM -from mrpro.nn.LayerNorm import LayerNorm -from mrpro.nn.NDModules import ConvND +from mrpro.nn.NDModules import ConvND, InstanceNormND from mrpro.nn.Sequential import Sequential from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention @@ -11,7 +12,7 @@ class SwinTransformerLayer(Module): """Swin Transformer layer. - As used in the SwinIR network. + Implements a single layer of the Swin Transformer architecture. """ def __init__( @@ -20,103 +21,129 @@ def __init__( channels: int, n_heads: int, window_size: int, - shifted: bool, mlp_ratio: int = 4, - cond_dim: int = 0, + emb_dim: int = 0, ): - """Initialize the Swin Transformer layer. + """Initialize SwinTransformerLayer. Parameters ---------- - dim - Number of spatial dimensions (1D, 2D, or 3D) - channels - Number of channels in the input tensor - n_heads + dim : int + Dimension of the input space + channels : int + Number of input/output channels + n_heads : int Number of attention heads - window_size - Size of the local window for computing windowed self-attention - shifted - Whether to use shifted window attention - mlp_ratio - Expansion ratio for the MLP - cond_dim - Dimension of optional tensor for FiLM conditioning. If 0, no conditioning is used + window_size : int + Size of the attention window + mlp_ratio : int, optional + Ratio for hidden dimension expansion, by default 4 + emb_dim : int, optional + Dimension of conditioning input, by default 0 """ super().__init__() - self.norm1 = LayerNorm(channels) - self.attn = ShiftedWindowAttention(dim, channels, n_heads, window_size, shifted) - if cond_dim > 0: - self.norm2 = Sequential(LayerNorm(None), FiLM(cond_dim)) - else: - self.norm2 = Sequential(LayerNorm(channels)) - self.mlp = Sequential( - ConvND(dim)(channels, channels * mlp_ratio, 1), - GELU(), - ConvND(dim)(channels * mlp_ratio, channels, 1), - ) + self.norm1 = Sequential(InstanceNormND(dim)(channels)) + self.attn = ShiftedWindowAttention(dim, channels, n_heads, window_size) + self.norm2 = Sequential(InstanceNormND(dim)(channels)) + if emb_dim > 0: + self.norm2.append(FiLM(channels=channels, cond_dim=emb_dim)) - def __call__(self, x, cond: torch.Tensor | None = None): + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the Swin Transformer layer. Parameters ---------- - x - Input tensor of shape (batch_size, channels, *spatial_dims) - cond - Optional conditioning tensor of shape (batch_size, cond_dim) + x : torch.Tensor + Input tensor + cond : torch.Tensor | None, optional + Conditioning input, by default None Returns ------- - Output tensor of shape (batch_size, channels, *spatial_dims) + torch.Tensor + Output tensor """ return super().__call__(x, cond) - def forward(self, x, cond: torch.Tensor | None = None): + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the Swin Transformer layer.""" x = x + self.attn(self.norm1(x)) - x = x + self.mlp(self.norm2(x, cond)) + x = x + self.norm2(x) return x -class RSTB(Module): +class ResidualSwinTransformerBlock(Module): """Residual Swin Transformer block. - As used in the SwinIR network. + Combines a Swin Transformer layer with a residual connection. """ - def __init__(self, dim: int, channels: int, n_heads: int, window_size: int, depth: int, cond_dim: int = 0): + def __init__( + self, + dim: int, + channels: int, + n_heads: int, + window_size: int, + depth: int, + emb_dim: int = 0, + ): + """Initialize ResidualSwinTransformerBlock. + + Parameters + ---------- + dim : int + Dimension of the input space + channels : int + Number of input/output channels + n_heads : int + Number of attention heads + window_size : int + Size of the attention window + depth : int + Number of Swin Transformer layers + emb_dim : int, optional + Dimension of conditioning input, by default 0 + """ super().__init__() self.layers = Sequential( - *[ - SwinTransformerLayer(dim, channels, n_heads, window_size, shifted=(i % 2 == 1), cond_dim=cond_dim) - for i in range(depth) - ] + *(SwinTransformerLayer(dim, channels, n_heads, window_size, emb_dim=emb_dim) for _ in range(depth)) ) self.conv = ConvND(dim)(channels, channels, 3, padding=1) - def __call__(self, x, cond: torch.Tensor | None = None): + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the residual Swin Transformer block. Parameters ---------- - x - Input tensor of shape (batch_size, channels, *spatial_dims) - cond - Optional conditioning tensor of shape (batch_size, cond_dim) + x : torch.Tensor + Input tensor + cond : torch.Tensor | None, optional + Conditioning input, by default None Returns ------- - Output tensor of shape (batch_size, channels, *spatial_dims) + torch.Tensor + Output tensor """ return super().__call__(x, cond) - def forward(self, x, cond: torch.Tensor | None = None): + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the residual Swin Transformer block.""" return x + self.conv(self.layers(x, cond)) class SwinIR(Module): + """SwinIR architecture. + + Implements the SwinIR [LZL21]_ network, which is a Swin Transformer based + image restoration network. + + References + ---------- + .. [LZL21] Liang, Jie, et al. "SwinIR: Image restoration using swin transformer." + ICCVW 2021, https://arxiv.org/pdf/2108.10257.pdf + """ + def __init__( self, dim: int, @@ -127,63 +154,64 @@ def __init__( window_size: int = 64, n_blocks: int = 6, n_attn_per_block: int = 6, - cond_dim: int = 0, + emb_dim: int = 0, ): - """Initialize the SwinIR model. + """Initialize SwinIR. Parameters ---------- - dim - Number of spatial dimensions (1D, 2D, or 3D) - channels_in + dim : int + Dimension of the input space + channels_in : int Number of input channels - channels_out + channels_out : int Number of output channels - channels_per_head - Number of channels per attention head - n_heads - Number of attention heads - window_size - Size of the local window for computing windowed self-attention - n_blocks - Number of residual Swin Transformer blocks (RSTB) - n_attn_per_block - Number of windowed attention layers per RSTB block - cond_dim - Dimension of optional tensor for FiLM conditioning. If 0, no conditioning is used + channels_per_head : int, optional + Number of channels per attention head, by default 16 + n_heads : int, optional + Number of attention heads, by default 6 + window_size : int, optional + Size of the attention window, by default 64 + n_blocks : int, optional + Number of residual blocks, by default 6 + n_attn_per_block : int, optional + Number of attention layers per block, by default 6 + emb_dim : int, optional + Dimension of conditioning input, by default 0 """ super().__init__() - self.shallow = ConvND(dim)(channels_in, channels_per_head * n_heads, 3, padding=1) - self.body = Sequential( - *[ - RSTB(dim, channels_per_head * n_heads, n_heads, window_size, n_attn_per_block, cond_dim) + self.first = ConvND(dim)(channels_in, channels_per_head * n_heads, kernel_size=3, padding=1) + self.blocks = Sequential( + *( + ResidualSwinTransformerBlock( + dim, + channels_per_head * n_heads, + n_heads, + window_size, + n_attn_per_block, + emb_dim, + ) for _ in range(n_blocks) - ] + ) ) - self.body.append(ConvND(dim)(channels_per_head * n_heads, channels_per_head * n_heads, 3, padding=1)) - self.final = ConvND(dim)(channels_per_head, channels_out, 3, padding=1) - self.skip = ConvND(dim)(channels_in, channels_out, 1, padding=1) + self.last = ConvND(dim)(channels_per_head * n_heads, channels_out, kernel_size=3, padding=1) - def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: - """Apply the SwinIR model. + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply SwinIR. Parameters ---------- - x - Input tensor of shape (batch_size, channels_in, *spatial_dims) - cond - Optional conditioning tensor of shape (batch_size, cond_dim) + x : torch.Tensor + Input tensor + cond : torch.Tensor | None, optional + Conditioning input, by default None Returns ------- - out - Output tensor of shape (batch_size, channels_out, *spatial_dims) + torch.Tensor + Output tensor """ - return super().__call__(x, cond) - - def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: - """Apply the SwinIR model.""" - h = self.shallow(x) - h = self.body(h, cond) + self.skip(x) - out = self.final(h) - return out + x = self.first(x) + x = self.blocks(x, cond) + x = self.last(x) + return x diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index b1cedfac5..2192db24f 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -72,30 +72,6 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None) -> torch.Tensor: return self(x, cond) -class AttentionUNet(UNet): - """UNet with attention gates. - - References - ---------- - .. [OKT18] Oktay, Ozan, et al. "Attention U-net: Learning where to look for the pancreas." MIDL (2018). - https://arxiv.org/abs/1804.03999 - """ - - -class SeparableUNet(UNetBase): - """UNet where blocks apply separable convolutions in different dimensions - - Based on the pseudo-3D residual network of [QUI]_ and the residual blocks of [ZIM]_. - - References - ---------- - .. [QUI] Qiu, Z., Yao, T., & Mei, T. Learning spatio-temporal representation with pseudo-3d residual networks. - ICCV 2017. https://arxiv.org/abs/1711.10305 - .. [ZIM] Zimmermann, F. F., & Kofler, A. (2023, October). NoSENSE: Learned unrolled cardiac MRI reconstruction - without explicit sensitivity maps. STACOM MICCAI 2023. https://arxiv.org/abs/2309.15608 - """ - - class UNet(UNetBase): """UNet. @@ -121,3 +97,27 @@ def __init__( num_blocks: int, padding_modes: str | Sequence[str], ) -> None: ... + + +class AttentionUNet(UNet): + """UNet with attention gates. + + References + ---------- + .. [OKT18] Oktay, Ozan, et al. "Attention U-net: Learning where to look for the pancreas." MIDL (2018). + https://arxiv.org/abs/1804.03999 + """ + + +class SeparableUNet(UNetBase): + """UNet where blocks apply separable convolutions in different dimensions. + + Based on the pseudo-3D residual network of [QUI]_ and the residual blocks of [ZIM]_. + + References + ---------- + .. [QUI] Qiu, Z., Yao, T., & Mei, T. Learning spatio-temporal representation with pseudo-3d residual networks. + ICCV 2017. https://arxiv.org/abs/1711.10305 + .. [ZIM] Zimmermann, F. F., & Kofler, A. (2023, October). NoSENSE: Learned unrolled cardiac MRI reconstruction + without explicit sensitivity maps. STACOM MICCAI 2023. https://arxiv.org/abs/2309.15608 + """ diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index 6b4f6fa9f..3b567350a 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -1,9 +1,10 @@ +"""Uformer: U-Net with window attention.""" + from collections.abc import Sequence from itertools import pairwise import torch -from sympy import Identity -from torch.nn import GELU, LeakyReLU, Module, Sequential +from torch.nn import GELU, Identity, LeakyReLU, Module from mrpro.nn.DropPath import DropPath from mrpro.nn.FiLM import FiLM @@ -49,7 +50,22 @@ def __init__( ConvND(dim)(hidden_dim, channels_out, 1), ) + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the LeFF module. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the LeFF module.""" return self.block(x) diff --git a/src/mrpro/nn/nets/VAE.py b/src/mrpro/nn/nets/VAE.py index 3fcb2b417..6d4a5bd6c 100644 --- a/src/mrpro/nn/nets/VAE.py +++ b/src/mrpro/nn/nets/VAE.py @@ -1,3 +1,5 @@ +"""Variational Autoencoder with a Gaussian latent space.""" + import torch from torch.nn import Module @@ -9,7 +11,8 @@ class VAE(Module): into the original space. The encoder should return twice the number of channels as the decoder needs to reconstruct the input: half of the channels are the mean and the other half the log variance of the latent space. The reparameterization trick is used to sample from the latent space. - The forward pass returns the reconstructed image and the KL divergence between the latent space and the standard normal distribution. + The forward pass returns the reconstructed image and the KL divergence between the latent space and the standard normal + distribution. """ def __init__(self, encoder: Module, decoder: Module): diff --git a/src/mrpro/nn/test.ipynb b/src/mrpro/nn/test.ipynb deleted file mode 100644 index ddfd6b8b3..000000000 --- a/src/mrpro/nn/test.ipynb +++ /dev/null @@ -1,149 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from mrpro.utils.sliding_window import sliding_window\n", - "from torch.nn import Module\n", - "from einops import rearrange\n", - "from mrpro.nn.NDModules import ConvND\n", - "class ShiftedWindowMSA(Module):\n", - " def __init__(self, dim, channels, n_heads, window_size=7, shifted=True):\n", - " super().__init__()\n", - " self.channels = channels\n", - " self.n_heads = n_heads\n", - " self.window_size = window_size\n", - " self.shifted = shifted\n", - " self.to_qkv = ConvND(dim)(channels, 3*channels, 1)\n", - " self.dim=dim\n", - " def forward(self, x):\n", - " if self.shifted:\n", - " x = torch.roll(x, (-(self.window_size//2),)*self.dim,dims=tuple(range(-self.dim, 0)))\n", - " qkv = self.to_qkv(x) \n", - " windowed = sliding_window(qkv, window_shape=self.window_size, stride=self.window_size, dim=range(-self.dim, 0))\n", - " flat = windowed.flatten(0,self.dim-1).flatten(-self.dim) \n", - " q,k,v = rearrange(flat, 'spatial batch (qkv heads channels) window->qkv spatial batch heads window channels', heads = self.n_heads, qkv=3)\n", - " result = torch.nn.functional.scaled_dot_product_attention(q,k,v, attn_mask=None)\n", - " result = rearrange(result, 'spatial batch head window channels->batch (head channels) spatial window')\n", - " result=result.unflatten(-2, windowed.shape[:self.dim]).unflatten(-1, (self.window_size,)*self.dim)\n", - " result=result.moveaxis(list(range(-self.dim, 0)), list(range(3, 3+2*self.dim, 2)))\n", - " result = result.reshape(x.shape)\n", - " if self.shifted:\n", - " result = torch.roll(result, (self.window_size//2,)*self.dim,dims=tuple(range(-self.dim, 0)))\n", - " return result" - ] - }, - { - "cell_type": "code", - "execution_count": 97, - "metadata": {}, - "outputs": [], - "source": [ - "m=ShiftedWindowMSA(dim=2, channels=16, n_heads=4, window_size=5, shifted=True).cuda()" - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "metadata": {}, - "outputs": [], - "source": [ - "x=torch.arange(2*16*20*30).reshape(2,16,20,30).float().cuda()" - ] - }, - { - "cell_type": "code", - "execution_count": 108, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 16, 20, 30])" - ] - }, - "execution_count": 108, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "m(x).shape" - ] - }, - { - "cell_type": "code", - "execution_count": 100, - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[100]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmatplotlib\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mpyplot\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mplt\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[43mplt\u001b[49m\u001b[43m.\u001b[49m\u001b[43mimshow\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3\u001b[39m plt.show()\n\u001b[32m 4\u001b[39m plt.imshow(m(x)[\u001b[32m0\u001b[39m,\u001b[32m0\u001b[39m])\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/mrpro/.venv/lib/python3.11/site-packages/matplotlib/pyplot.py:3590\u001b[39m, in \u001b[36mimshow\u001b[39m\u001b[34m(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, colorizer, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, data, **kwargs)\u001b[39m\n\u001b[32m 3568\u001b[39m \u001b[38;5;129m@_copy_docstring_and_deprecators\u001b[39m(Axes.imshow)\n\u001b[32m 3569\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mimshow\u001b[39m(\n\u001b[32m 3570\u001b[39m X: ArrayLike | PIL.Image.Image,\n\u001b[32m (...)\u001b[39m\u001b[32m 3588\u001b[39m **kwargs,\n\u001b[32m 3589\u001b[39m ) -> AxesImage:\n\u001b[32m-> \u001b[39m\u001b[32m3590\u001b[39m __ret = \u001b[43mgca\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mimshow\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 3591\u001b[39m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3592\u001b[39m \u001b[43m \u001b[49m\u001b[43mcmap\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcmap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3593\u001b[39m \u001b[43m \u001b[49m\u001b[43mnorm\u001b[49m\u001b[43m=\u001b[49m\u001b[43mnorm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3594\u001b[39m \u001b[43m \u001b[49m\u001b[43maspect\u001b[49m\u001b[43m=\u001b[49m\u001b[43maspect\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3595\u001b[39m \u001b[43m \u001b[49m\u001b[43minterpolation\u001b[49m\u001b[43m=\u001b[49m\u001b[43minterpolation\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3596\u001b[39m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m=\u001b[49m\u001b[43malpha\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3597\u001b[39m \u001b[43m \u001b[49m\u001b[43mvmin\u001b[49m\u001b[43m=\u001b[49m\u001b[43mvmin\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3598\u001b[39m \u001b[43m \u001b[49m\u001b[43mvmax\u001b[49m\u001b[43m=\u001b[49m\u001b[43mvmax\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3599\u001b[39m \u001b[43m \u001b[49m\u001b[43mcolorizer\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcolorizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3600\u001b[39m \u001b[43m \u001b[49m\u001b[43morigin\u001b[49m\u001b[43m=\u001b[49m\u001b[43morigin\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3601\u001b[39m \u001b[43m \u001b[49m\u001b[43mextent\u001b[49m\u001b[43m=\u001b[49m\u001b[43mextent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3602\u001b[39m \u001b[43m \u001b[49m\u001b[43minterpolation_stage\u001b[49m\u001b[43m=\u001b[49m\u001b[43minterpolation_stage\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3603\u001b[39m \u001b[43m \u001b[49m\u001b[43mfilternorm\u001b[49m\u001b[43m=\u001b[49m\u001b[43mfilternorm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3604\u001b[39m \u001b[43m \u001b[49m\u001b[43mfilterrad\u001b[49m\u001b[43m=\u001b[49m\u001b[43mfilterrad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3605\u001b[39m \u001b[43m \u001b[49m\u001b[43mresample\u001b[49m\u001b[43m=\u001b[49m\u001b[43mresample\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3606\u001b[39m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m=\u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3607\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mdata\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m}\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m{\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3608\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3609\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3610\u001b[39m sci(__ret)\n\u001b[32m 3611\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m __ret\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/mrpro/.venv/lib/python3.11/site-packages/matplotlib/__init__.py:1521\u001b[39m, in \u001b[36m_preprocess_data..inner\u001b[39m\u001b[34m(ax, data, *args, **kwargs)\u001b[39m\n\u001b[32m 1518\u001b[39m \u001b[38;5;129m@functools\u001b[39m.wraps(func)\n\u001b[32m 1519\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34minner\u001b[39m(ax, *args, data=\u001b[38;5;28;01mNone\u001b[39;00m, **kwargs):\n\u001b[32m 1520\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m data \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1521\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1522\u001b[39m \u001b[43m \u001b[49m\u001b[43max\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1523\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[38;5;28;43mmap\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcbook\u001b[49m\u001b[43m.\u001b[49m\u001b[43msanitize_sequence\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1524\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43m{\u001b[49m\u001b[43mk\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mcbook\u001b[49m\u001b[43m.\u001b[49m\u001b[43msanitize_sequence\u001b[49m\u001b[43m(\u001b[49m\u001b[43mv\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m.\u001b[49m\u001b[43mitems\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1526\u001b[39m bound = new_sig.bind(ax, *args, **kwargs)\n\u001b[32m 1527\u001b[39m auto_label = (bound.arguments.get(label_namer)\n\u001b[32m 1528\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m bound.kwargs.get(label_namer))\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/mrpro/.venv/lib/python3.11/site-packages/matplotlib/axes/_axes.py:5976\u001b[39m, in \u001b[36mAxes.imshow\u001b[39m\u001b[34m(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, colorizer, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, **kwargs)\u001b[39m\n\u001b[32m 5973\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m aspect \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 5974\u001b[39m \u001b[38;5;28mself\u001b[39m.set_aspect(aspect)\n\u001b[32m-> \u001b[39m\u001b[32m5976\u001b[39m \u001b[43mim\u001b[49m\u001b[43m.\u001b[49m\u001b[43mset_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 5977\u001b[39m im.set_alpha(alpha)\n\u001b[32m 5978\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m im.get_clip_path() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 5979\u001b[39m \u001b[38;5;66;03m# image does not already have clipping set, clip to Axes patch\u001b[39;00m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/mrpro/.venv/lib/python3.11/site-packages/matplotlib/image.py:685\u001b[39m, in \u001b[36m_ImageBase.set_data\u001b[39m\u001b[34m(self, A)\u001b[39m\n\u001b[32m 683\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(A, PIL.Image.Image):\n\u001b[32m 684\u001b[39m A = pil_to_array(A) \u001b[38;5;66;03m# Needed e.g. to apply png palette.\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m685\u001b[39m \u001b[38;5;28mself\u001b[39m._A = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_normalize_image_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43mA\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 686\u001b[39m \u001b[38;5;28mself\u001b[39m._imcache = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 687\u001b[39m \u001b[38;5;28mself\u001b[39m.stale = \u001b[38;5;28;01mTrue\u001b[39;00m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/mrpro/.venv/lib/python3.11/site-packages/matplotlib/image.py:646\u001b[39m, in \u001b[36m_ImageBase._normalize_image_array\u001b[39m\u001b[34m(A)\u001b[39m\n\u001b[32m 640\u001b[39m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[32m 641\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_normalize_image_array\u001b[39m(A):\n\u001b[32m 642\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 643\u001b[39m \u001b[33;03m Check validity of image-like input *A* and normalize it to a format suitable for\u001b[39;00m\n\u001b[32m 644\u001b[39m \u001b[33;03m Image subclasses.\u001b[39;00m\n\u001b[32m 645\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m646\u001b[39m A = \u001b[43mcbook\u001b[49m\u001b[43m.\u001b[49m\u001b[43msafe_masked_invalid\u001b[49m\u001b[43m(\u001b[49m\u001b[43mA\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 647\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m A.dtype != np.uint8 \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m np.can_cast(A.dtype, \u001b[38;5;28mfloat\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33msame_kind\u001b[39m\u001b[33m\"\u001b[39m):\n\u001b[32m 648\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mImage data of dtype \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mA.dtype\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m cannot be \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 649\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mconverted to float\u001b[39m\u001b[33m\"\u001b[39m)\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/mrpro/.venv/lib/python3.11/site-packages/matplotlib/cbook.py:684\u001b[39m, in \u001b[36msafe_masked_invalid\u001b[39m\u001b[34m(x, copy)\u001b[39m\n\u001b[32m 683\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34msafe_masked_invalid\u001b[39m(x, copy=\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[32m--> \u001b[39m\u001b[32m684\u001b[39m x = \u001b[43mnp\u001b[49m\u001b[43m.\u001b[49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msubok\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 685\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m x.dtype.isnative:\n\u001b[32m 686\u001b[39m \u001b[38;5;66;03m# If we have already made a copy, do the byteswap in place, else make a\u001b[39;00m\n\u001b[32m 687\u001b[39m \u001b[38;5;66;03m# copy with the byte order swapped.\u001b[39;00m\n\u001b[32m 688\u001b[39m \u001b[38;5;66;03m# Swap to native order.\u001b[39;00m\n\u001b[32m 689\u001b[39m x = x.byteswap(inplace=copy).view(x.dtype.newbyteorder(\u001b[33m'\u001b[39m\u001b[33mN\u001b[39m\u001b[33m'\u001b[39m))\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/mrpro/.venv/lib/python3.11/site-packages/torch/_tensor.py:1194\u001b[39m, in \u001b[36mTensor.__array__\u001b[39m\u001b[34m(self, dtype)\u001b[39m\n\u001b[32m 1192\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(Tensor.__array__, (\u001b[38;5;28mself\u001b[39m,), \u001b[38;5;28mself\u001b[39m, dtype=dtype)\n\u001b[32m 1193\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m dtype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1194\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mnumpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1195\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 1196\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.numpy().astype(dtype, copy=\u001b[38;5;28;01mFalse\u001b[39;00m)\n", - "\u001b[31mTypeError\u001b[39m: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first." - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAAGiCAYAAACGUJO6AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQAAGwdJREFUeJzt3X9M3dX9x/EX0HKpsdA6xoWyq6x1/ralgmVYG+dyJ4kG1z8WmTWFEX9MZUZ7s9liW1Crpau2I7NoY9XpHzqqRo2xBKdMYlSWRloSnW1NpRVmvLclrtyOKrTc8/1j316HBcsH+dG3PB/J5w/OPud+zj1h9+m9vfeS4JxzAgDAmMSJXgAAACNBwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmeQ7Y22+/reLiYs2aNUsJCQl65ZVXTjqnublZl1xyiXw+n84++2w9/fTTI1gqAABf8xywnp4ezZs3T3V1dcM6f9++fbrmmmt05ZVXqq2tTXfddZduuukmvf76654XCwDAcQnf5ct8ExIS9PLLL2vx4sVDnrN8+XJt27ZNH374YXzs17/+tQ4dOqTGxsaRXhoAMMlNGesLtLS0KBgMDhgrKirSXXfdNeSc3t5e9fb2xn+OxWL64osv9IMf/EAJCQljtVQAwBhwzunw4cOaNWuWEhNH760XYx6wcDgsv98/YMzv9ysajerLL7/UtGnTTphTU1Oj++67b6yXBgAYR52dnfrRj340arc35gEbicrKSoVCofjP3d3dOvPMM9XZ2anU1NQJXBkAwKtoNKpAIKDp06eP6u2OecAyMzMViUQGjEUiEaWmpg767EuSfD6ffD7fCeOpqakEDACMGu1/Ahrzz4EVFhaqqalpwNgbb7yhwsLCsb40AOB7zHPA/vOf/6itrU1tbW2S/vs2+ba2NnV0dEj678t/paWl8fNvvfVWtbe36+6779bu3bv16KOP6vnnn9eyZctG5x4AACYlzwF7//33NX/+fM2fP1+SFAqFNH/+fFVVVUmSPv/883jMJOnHP/6xtm3bpjfeeEPz5s3Thg0b9MQTT6ioqGiU7gIAYDL6Tp8DGy/RaFRpaWnq7u7m38AAwJixegznuxABACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGDSiAJWV1ennJwcpaSkqKCgQNu3b//W82tra3Xuuedq2rRpCgQCWrZsmb766qsRLRgAAGkEAdu6datCoZCqq6u1Y8cOzZs3T0VFRTpw4MCg5z/33HNasWKFqqurtWvXLj355JPaunWr7rnnnu+8eADA5OU5YBs3btTNN9+s8vJyXXDBBdq8ebNOO+00PfXUU4Oe/95772nhwoVasmSJcnJydNVVV+n6668/6bM2AAC+jaeA9fX1qbW1VcFg8OsbSExUMBhUS0vLoHMuu+wytba2xoPV3t6uhoYGXX311UNep7e3V9FodMABAMD/muLl5K6uLvX398vv9w8Y9/v92r1796BzlixZoq6uLl1++eVyzunYsWO69dZbv/UlxJqaGt13331elgYAmGTG/F2Izc3NWrt2rR599FHt2LFDL730krZt26Y1a9YMOaeyslLd3d3xo7Ozc6yXCQAwxtMzsPT0dCUlJSkSiQwYj0QiyszMHHTO6tWrtXTpUt10002SpIsvvlg9PT265ZZbtHLlSiUmnthQn88nn8/nZWkAgEnG0zOw5ORk5eXlqampKT4Wi8XU1NSkwsLCQeccOXLkhEglJSVJkpxzXtcLAIAkj8/AJCkUCqmsrEz5+flasGCBamtr1dPTo/LycklSaWmpsrOzVVNTI0kqLi7Wxo0bNX/+fBUUFGjv3r1avXq1iouL4yEDAMArzwErKSnRwYMHVVVVpXA4rNzcXDU2Nsbf2NHR0THgGdeqVauUkJCgVatW6bPPPtMPf/hDFRcX68EHHxy9ewEAmHQSnIHX8aLRqNLS0tTd3a3U1NSJXg4AwIOxegznuxABACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABgEgEDAJhEwAAAJhEwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGDSiAJWV1ennJwcpaSkqKCgQNu3b//W8w8dOqSKigplZWXJ5/PpnHPOUUNDw4gWDACAJE3xOmHr1q0KhULavHmzCgoKVFtbq6KiIu3Zs0cZGRknnN/X16df/OIXysjI0Isvvqjs7Gx9+umnmjFjxmisHwAwSSU455yXCQUFBbr00ku1adMmSVIsFlMgENAdd9yhFStWnHD+5s2b9dBDD2n37t2aOnXqiBYZjUaVlpam7u5upaamjug2AAATY6wewz29hNjX16fW1lYFg8GvbyAxUcFgUC0tLYPOefXVV1VYWKiKigr5/X5ddNFFWrt2rfr7+4e8Tm9vr6LR6IADAID/5SlgXV1d6u/vl9/vHzDu9/sVDocHndPe3q4XX3xR/f39amho0OrVq7VhwwY98MADQ16npqZGaWlp8SMQCHhZJgBgEhjzdyHGYjFlZGTo8ccfV15enkpKSrRy5Upt3rx5yDmVlZXq7u6OH52dnWO9TACAMZ7exJGenq6kpCRFIpEB45FIRJmZmYPOycrK0tSpU5WUlBQfO//88xUOh9XX16fk5OQT5vh8Pvl8Pi9LAwBMMp6egSUnJysvL09NTU3xsVgspqamJhUWFg46Z+HChdq7d69isVh87OOPP1ZWVtag8QIAYDg8v4QYCoW0ZcsWPfPMM9q1a5duu+029fT0qLy8XJJUWlqqysrK+Pm33XabvvjiC9155536+OOPtW3bNq1du1YVFRWjdy8AAJOO58+BlZSU6ODBg6qqqlI4HFZubq4aGxvjb+zo6OhQYuLXXQwEAnr99de1bNkyzZ07V9nZ2brzzju1fPny0bsXAIBJx/PnwCYCnwMDALtOic+BAQBwqiBgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwKQRBayurk45OTlKSUlRQUGBtm/fPqx59fX1SkhI0OLFi0dyWQAA4jwHbOvWrQqFQqqurtaOHTs0b948FRUV6cCBA986b//+/fr973+vRYsWjXixAAAc5zlgGzdu1M0336zy8nJdcMEF2rx5s0477TQ99dRTQ87p7+/XDTfcoPvuu0+zZ88+6TV6e3sVjUYHHAAA/C9PAevr61Nra6uCweDXN5CYqGAwqJaWliHn3X///crIyNCNN944rOvU1NQoLS0tfgQCAS/LBABMAp4C1tXVpf7+fvn9/gHjfr9f4XB40DnvvPOOnnzySW3ZsmXY16msrFR3d3f86Ozs9LJMAMAkMGUsb/zw4cNaunSptmzZovT09GHP8/l88vl8Y7gyAIB1ngKWnp6upKQkRSKRAeORSESZmZknnP/JJ59o//79Ki4ujo/FYrH/XnjKFO3Zs0dz5swZyboBAJOcp5cQk5OTlZeXp6ampvhYLBZTU1OTCgsLTzj/vPPO0wcffKC2trb4ce211+rKK69UW1sb/7YFABgxzy8hhkIhlZWVKT8/XwsWLFBtba16enpUXl4uSSotLVV2drZqamqUkpKiiy66aMD8GTNmSNIJ4wAAeOE5YCUlJTp48KCqqqoUDoeVm5urxsbG+Bs7Ojo6lJjIF3wAAMZWgnPOTfQiTiYajSotLU3d3d1KTU2d6OUAADwYq8dwnioBAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMCkEQWsrq5OOTk5SklJUUFBgbZv3z7kuVu2bNGiRYs0c+ZMzZw5U8Fg8FvPBwBgODwHbOvWrQqFQqqurtaOHTs0b948FRUV6cCBA4Oe39zcrOuvv15vvfWWWlpaFAgEdNVVV+mzzz77zosHAExeCc4552VCQUGBLr30Um3atEmSFIvFFAgEdMcdd2jFihUnnd/f36+ZM2dq06ZNKi0tHfSc3t5e9fb2xn+ORqMKBALq7u5Wamqql+UCACZYNBpVWlraqD+Ge3oG1tfXp9bWVgWDwa9vIDFRwWBQLS0tw7qNI0eO6OjRozrjjDOGPKempkZpaWnxIxAIeFkmAGAS8BSwrq4u9ff3y+/3Dxj3+/0Kh8PDuo3ly5dr1qxZAyL4TZWVleru7o4fnZ2dXpYJAJgEpoznxdatW6f6+no1NzcrJSVlyPN8Pp98Pt84rgwAYI2ngKWnpyspKUmRSGTAeCQSUWZm5rfOffjhh7Vu3Tq9+eabmjt3rveVAgDwPzy9hJicnKy8vDw1NTXFx2KxmJqamlRYWDjkvPXr12vNmjVqbGxUfn7+yFcLAMD/8/wSYigUUllZmfLz87VgwQLV1taqp6dH5eXlkqTS0lJlZ2erpqZGkvTHP/5RVVVVeu6555STkxP/t7LTTz9dp59++ijeFQDAZOI5YCUlJTp48KCqqqoUDoeVm5urxsbG+Bs7Ojo6lJj49RO7xx57TH19ffrVr3414Haqq6t17733frfVAwAmLc+fA5sIY/UZAgDA2DslPgcGAMCpgoABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAkwgYAMAkAgYAMImAAQBMImAAAJMIGADAJAIGADCJgAEATCJgAACTCBgAwCQCBgAwiYABAEwiYAAAk0YUsLq6OuXk5CglJUUFBQXavn37t57/wgsv6LzzzlNKSoouvvhiNTQ0jGixAAAc5zlgW7duVSgUUnV1tXbs2KF58+apqKhIBw4cGPT89957T9dff71uvPFG7dy5U4sXL9bixYv14YcffufFAwAmrwTnnPMyoaCgQJdeeqk2bdokSYrFYgoEArrjjju0YsWKE84vKSlRT0+PXnvttfjYT3/6U+Xm5mrz5s2DXqO3t1e9vb3xn7u7u3XmmWeqs7NTqampXpYLAJhg0WhUgUBAhw4dUlpa2ujdsPOgt7fXJSUluZdffnnAeGlpqbv22msHnRMIBNyf/vSnAWNVVVVu7ty5Q16nurraSeLg4ODg+B4dn3zyiZfknNQUedDV1aX+/n75/f4B436/X7t37x50TjgcHvT8cDg85HUqKysVCoXiPx86dEhnnXWWOjo6Rrfe3zPH/yuHZ6rfjn06OfZoeNin4Tn+KtoZZ5wxqrfrKWDjxefzyefznTCelpbGL8kwpKamsk/DwD6dHHs0POzT8CQmju4b3z3dWnp6upKSkhSJRAaMRyIRZWZmDjonMzPT0/kAAAyHp4AlJycrLy9PTU1N8bFYLKampiYVFhYOOqewsHDA+ZL0xhtvDHk+AADD4fklxFAopLKyMuXn52vBggWqra1VT0+PysvLJUmlpaXKzs5WTU2NJOnOO+/UFVdcoQ0bNuiaa65RfX293n//fT3++OPDvqbP51N1dfWgLyvia+zT8LBPJ8ceDQ/7NDxjtU+e30YvSZs2bdJDDz2kcDis3Nxc/fnPf1ZBQYEk6Wc/+5lycnL09NNPx89/4YUXtGrVKu3fv18/+clPtH79el199dWjdicAAJPPiAIGAMBE47sQAQAmETAAgEkEDABgEgEDAJh0ygSMP9EyPF72acuWLVq0aJFmzpypmTNnKhgMnnRfvw+8/i4dV19fr4SEBC1evHhsF3iK8LpPhw4dUkVFhbKysuTz+XTOOedMiv/fed2n2tpanXvuuZo2bZoCgYCWLVumr776apxWOzHefvttFRcXa9asWUpISNArr7xy0jnNzc265JJL5PP5dPbZZw945/qwjeo3K45QfX29S05Odk899ZT75z//6W6++WY3Y8YMF4lEBj3/3XffdUlJSW79+vXuo48+cqtWrXJTp051H3zwwTivfHx53aclS5a4uro6t3PnTrdr1y73m9/8xqWlpbl//etf47zy8eN1j47bt2+fy87OdosWLXK//OUvx2exE8jrPvX29rr8/Hx39dVXu3feecft27fPNTc3u7a2tnFe+fjyuk/PPvus8/l87tlnn3X79u1zr7/+usvKynLLli0b55WPr4aGBrdy5Ur30ksvOUknfOH7N7W3t7vTTjvNhUIh99FHH7lHHnnEJSUlucbGRk/XPSUCtmDBAldRURH/ub+/382aNcvV1NQMev51113nrrnmmgFjBQUF7re//e2YrnOied2nbzp27JibPn26e+aZZ8ZqiRNuJHt07Ngxd9lll7knnnjClZWVTYqAed2nxx57zM2ePdv19fWN1xJPCV73qaKiwv385z8fMBYKhdzChQvHdJ2nkuEE7O6773YXXnjhgLGSkhJXVFTk6VoT/hJiX1+fWltbFQwG42OJiYkKBoNqaWkZdE5LS8uA8yWpqKhoyPO/D0ayT9905MgRHT16dNS/EfpUMdI9uv/++5WRkaEbb7xxPJY54UayT6+++qoKCwtVUVEhv9+viy66SGvXrlV/f/94LXvcjWSfLrvsMrW2tsZfZmxvb1dDQwNf3PANo/UYPuHfRj9ef6LFupHs0zctX75cs2bNOuEX5/tiJHv0zjvv6Mknn1RbW9s4rPDUMJJ9am9v19///nfdcMMNamho0N69e3X77bfr6NGjqq6uHo9lj7uR7NOSJUvU1dWlyy+/XM45HTt2TLfeeqvuueee8ViyGUM9hkejUX355ZeaNm3asG5nwp+BYXysW7dO9fX1evnll5WSkjLRyzklHD58WEuXLtWWLVuUnp4+0cs5pcViMWVkZOjxxx9XXl6eSkpKtHLlyiH/qvpk1dzcrLVr1+rRRx/Vjh079NJLL2nbtm1as2bNRC/te2nCn4HxJ1qGZyT7dNzDDz+sdevW6c0339TcuXPHcpkTyuseffLJJ9q/f7+Ki4vjY7FYTJI0ZcoU7dmzR3PmzBnbRU+AkfwuZWVlaerUqUpKSoqPnX/++QqHw+rr61NycvKYrnkijGSfVq9eraVLl+qmm26SJF188cXq6enRLbfcopUrV47638OyaqjH8NTU1GE/+5JOgWdg/ImW4RnJPknS+vXrtWbNGjU2Nio/P388ljphvO7Reeedpw8++EBtbW3x49prr9WVV16ptrY2BQKB8Vz+uBnJ79LChQu1d+/eeOAl6eOPP1ZWVtb3Ml7SyPbpyJEjJ0TqePQdXzsbN2qP4d7eXzI26uvrnc/nc08//bT76KOP3C233OJmzJjhwuGwc865pUuXuhUrVsTPf/fdd92UKVPcww8/7Hbt2uWqq6snzdvovezTunXrXHJysnvxxRfd559/Hj8OHz48UXdhzHndo2+aLO9C9LpPHR0dbvr06e53v/ud27Nnj3vttddcRkaGe+CBBybqLowLr/tUXV3tpk+f7v7617+69vZ297e//c3NmTPHXXfddRN1F8bF4cOH3c6dO93OnTudJLdx40a3c+dO9+mnnzrnnFuxYoVbunRp/Pzjb6P/wx/+4Hbt2uXq6ursvo3eOeceeeQRd+aZZ7rk5GS3YMEC949//CP+v11xxRWurKxswPnPP/+8O+ecc1xycrK78MIL3bZt28Z5xRPDyz6dddZZTtIJR3V19fgvfBx5/V36X5MlYM5536f33nvPFRQUOJ/P52bPnu0efPBBd+zYsXFe9fjzsk9Hjx519957r5szZ45LSUlxgUDA3X777e7f//73+C98HL311luDPtYc35uysjJ3xRVXnDAnNzfXJScnu9mzZ7u//OUvnq/Ln1MBAJg04f8GBgDASBAwAIBJBAwAYBIBAwCYRMAAACYRMACASQQMAGASAQMAmETAAAAmETAAgEkEDABg0v8Bc0z++5j1+JwAAAAASUVORK5CYII=", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "plt.imshow(x[0,0])\n", - "plt.show()\n", - "plt.imshow(m(x)[0,0])\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.11" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tests/nn/test_film.py b/tests/nn/test_film.py index 535b76055..e3913cb5b 100644 --- a/tests/nn/test_film.py +++ b/tests/nn/test_film.py @@ -34,4 +34,4 @@ def test_film(channels, channels_emb, input_shape, emb_shape, device): assert not emb.isnan().any(), 'NaN values in embedding' assert not x.grad.isnan().any(), 'NaN values in input gradients' assert not emb.grad.isnan().any(), 'NaN values in embedding gradients' - assert film.project[1].weight.grad is not None, 'No gradient computed for Linear layer' + assert next(film.project.parameters()).grad is not None, 'No gradient computed for Linear layer' diff --git a/tests/nn/test_transposedattention.py b/tests/nn/test_transposedattention.py index 8768301df..0fdd80612 100644 --- a/tests/nn/test_transposedattention.py +++ b/tests/nn/test_transposedattention.py @@ -23,7 +23,7 @@ def test_transposed_attention(dim, channels, num_heads, input_shape, device): """Test TransposedAttention output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) - attn = TransposedAttention(dim=dim, channels=channels, num_heads=num_heads).to(device) + attn = TransposedAttention(dim=dim, channels_in=channels, channels_out=channels, n_heads=num_heads).to(device) output = attn(x) assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() From 33d95572ecb7e46cf5049d31a26ed5e4d6d69885 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 19 May 2025 14:41:48 +0200 Subject: [PATCH 035/205] update --- .vscode/settings.json | 4 +- src/mrpro/nn/ComplexAsChannel.py | 2 + src/mrpro/nn/CondMixin.py | 3 +- src/mrpro/nn/CoordConv.py | 0 src/mrpro/nn/{activations.py => GEGLU.py} | 2 + src/mrpro/nn/LayerNorm.py | 2 + src/mrpro/nn/LinearSelfAttention.py | 2 +- src/mrpro/nn/MultiHeadAttention.py | 2 +- src/mrpro/nn/NeighborhoodSelfAttention.py | 6 +- src/mrpro/nn/RMSNorm.py | 2 + src/mrpro/nn/Residual.py | 2 + src/mrpro/nn/RoPE.py | 16 ++--- src/mrpro/nn/Sequential.py | 2 + src/mrpro/nn/encoding.py | 6 +- src/mrpro/nn/nets/CNN.py | 6 +- src/mrpro/nn/nets/DCAE.py | 72 ++++++++++++++-------- src/mrpro/nn/nets/Restormer.py | 23 ++++--- src/mrpro/nn/nets/SwinIR.py | 74 ++++++++++++++++------- src/mrpro/nn/nets/UNet.py | 18 ++++-- src/mrpro/nn/nets/VAE.py | 4 +- src/mrpro/nn/nets/__init__.py | 4 ++ tests/nn/test_shiftedwindowattention.py | 2 +- 22 files changed, 167 insertions(+), 87 deletions(-) delete mode 100644 src/mrpro/nn/CoordConv.py rename src/mrpro/nn/{activations.py => GEGLU.py} (96%) create mode 100644 src/mrpro/nn/nets/__init__.py diff --git a/.vscode/settings.json b/.vscode/settings.json index a63f276cc..849d7c5a9 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -14,8 +14,8 @@ }, "python.testing.pytestArgs": [ "tests", - // "-m not cuda" + "-m not cuda" ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, -} +} \ No newline at end of file diff --git a/src/mrpro/nn/ComplexAsChannel.py b/src/mrpro/nn/ComplexAsChannel.py index bab6cf6cb..5ce5b02cb 100644 --- a/src/mrpro/nn/ComplexAsChannel.py +++ b/src/mrpro/nn/ComplexAsChannel.py @@ -1,3 +1,5 @@ +"""ComplexAsChannel: handling complex-valued tensors as channels.""" + import torch from einops import rearrange from torch.nn import Module diff --git a/src/mrpro/nn/CondMixin.py b/src/mrpro/nn/CondMixin.py index 2dccbd305..f5cb4a4c4 100644 --- a/src/mrpro/nn/CondMixin.py +++ b/src/mrpro/nn/CondMixin.py @@ -5,7 +5,8 @@ def call_with_cond(module: Module, x: torch.Tensor, cond: torch.Tensor | None) -> torch.Tensor: - if isinstance(CondMixin, Module): + """Call a module with conditioning if it is a CondMixin.""" + if isinstance(module, CondMixin): return module(x, cond) return module(x) diff --git a/src/mrpro/nn/CoordConv.py b/src/mrpro/nn/CoordConv.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/mrpro/nn/activations.py b/src/mrpro/nn/GEGLU.py similarity index 96% rename from src/mrpro/nn/activations.py rename to src/mrpro/nn/GEGLU.py index cc5dcc1ba..42605ea45 100644 --- a/src/mrpro/nn/activations.py +++ b/src/mrpro/nn/GEGLU.py @@ -1,3 +1,5 @@ +"""Gated linear unit activation function.""" + import torch from torch.nn import Linear, Module diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py index 863e81e13..ce7b60553 100644 --- a/src/mrpro/nn/LayerNorm.py +++ b/src/mrpro/nn/LayerNorm.py @@ -1,3 +1,5 @@ +"""Layer normalization.""" + import torch from torch.nn import Module, Parameter diff --git a/src/mrpro/nn/LinearSelfAttention.py b/src/mrpro/nn/LinearSelfAttention.py index 2626fb262..3645905a0 100644 --- a/src/mrpro/nn/LinearSelfAttention.py +++ b/src/mrpro/nn/LinearSelfAttention.py @@ -1,4 +1,4 @@ -"""Linear self-attention""" +"""Linear self-attention.""" import torch from einops import rearrange diff --git a/src/mrpro/nn/MultiHeadAttention.py b/src/mrpro/nn/MultiHeadAttention.py index cc94e3d3d..884bc8097 100644 --- a/src/mrpro/nn/MultiHeadAttention.py +++ b/src/mrpro/nn/MultiHeadAttention.py @@ -39,7 +39,7 @@ def __init__( """ super().__init__() self.mha = torch.nn.MultiheadAttention( - embed_dim=channels_in, n_heads=n_heads, batch_first=True, dropout=p_dropout + embed_dim=channels_in, num_heads=n_heads, batch_first=True, dropout=p_dropout ) self.features_last = features_last self.to_out = Linear(channels_in, channels_out) diff --git a/src/mrpro/nn/NeighborhoodSelfAttention.py b/src/mrpro/nn/NeighborhoodSelfAttention.py index fe7aeec74..3ff753ab7 100644 --- a/src/mrpro/nn/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/NeighborhoodSelfAttention.py @@ -134,9 +134,11 @@ def __init__( Parameters ---------- - channels + channels_in The number of channels in the input tensor. - n_head + channels_out + The number of channels in the output tensor. + n_heads The number of attention heads. kernel_size The size of the attention neighborhood window. diff --git a/src/mrpro/nn/RMSNorm.py b/src/mrpro/nn/RMSNorm.py index 89d32ee2b..7ffbcfeec 100644 --- a/src/mrpro/nn/RMSNorm.py +++ b/src/mrpro/nn/RMSNorm.py @@ -1,3 +1,5 @@ +"""RMSNorm module for root mean square normalization.""" + import torch from torch.nn import Module, Parameter diff --git a/src/mrpro/nn/Residual.py b/src/mrpro/nn/Residual.py index 19b047405..9a59c4016 100644 --- a/src/mrpro/nn/Residual.py +++ b/src/mrpro/nn/Residual.py @@ -1,3 +1,5 @@ +"""Residual connection.""" + import torch from torch.nn import Identity, Module diff --git a/src/mrpro/nn/RoPE.py b/src/mrpro/nn/RoPE.py index 06f63b7b7..e84ec06c4 100644 --- a/src/mrpro/nn/RoPE.py +++ b/src/mrpro/nn/RoPE.py @@ -1,10 +1,8 @@ -"""Rotary Position Embeddings (RoPE) implementation.""" +"""Rotary Position Embedding (RoPE).""" import torch from torch.nn import Module -from mrpro.nn.NDModules import ConvND - @torch.compile def apply_rotary_emb_(x: torch.Tensor, theta: torch.Tensor, conjugated: bool) -> None: @@ -30,12 +28,14 @@ def apply_rotary_emb_(x: torch.Tensor, theta: torch.Tensor, conjugated: bool) -> x[..., :n_emb] = torch.cat([x1 * theta.cos() - x2 * theta.sin(), x2 * theta.cos() + x1 * theta.sin()], dim=-1) -class RotaryEmbedding_(torch.autograd.Function): +class RotaryEmbedding(torch.autograd.Function): """Custom autograd function for rotary embeddings.""" @staticmethod def forward( - ctx: torch.autograd.function.FunctionCtx, x: torch.Tensor, theta: torch.Tensor, conjugated: bool + x: torch.Tensor, + theta: torch.Tensor, + conjugated: bool, ) -> torch.Tensor: """Apply rotary embedding in forward pass.""" apply_rotary_emb_(x, theta, conjugated) @@ -43,12 +43,12 @@ def forward( @staticmethod def setup_context( - ctx: torch.autograd.function.FunctionCtx, inputs: tuple[torch.Tensor, torch.Tensor, bool], output: torch.Tensor + ctx: torch.autograd.function.FunctionCtx, inputs: tuple[torch.Tensor, torch.Tensor, bool], _output: torch.Tensor ) -> None: """Save tensors for backward pass.""" _, theta, conjugated = inputs ctx.save_for_backward(theta) - ctx.conjugated = conjugated + ctx.conjugated = conjugated # type: ignore[attr-defined] @staticmethod def backward( @@ -115,7 +115,7 @@ def forward(self, pos: torch.Tensor, *tensors: torch.Tensor) -> None: Tensors to apply rotary embeddings to """ theta = self.get_theta(pos) - tuple(RotaryEmbedding_.apply(x, theta, False) for x in tensors) + tuple(RotaryEmbedding.apply(x, theta, False) for x in tensors) @staticmethod def make_axial_positions(*shape: int) -> torch.Tensor: diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py index 84c19dd46..77a375d39 100644 --- a/src/mrpro/nn/Sequential.py +++ b/src/mrpro/nn/Sequential.py @@ -1,3 +1,5 @@ +"""Sequential container with support for conditioning and Operators.""" + import torch from mrpro.nn.CondMixin import CondMixin diff --git a/src/mrpro/nn/encoding.py b/src/mrpro/nn/encoding.py index f562bf7f0..5a3501e5f 100644 --- a/src/mrpro/nn/encoding.py +++ b/src/mrpro/nn/encoding.py @@ -30,8 +30,9 @@ def __init__(self, in_features: int, out_features: int, std: float = 1.0): std : float, optional Standard deviation for random initialization, by default 1.0 """ + if out_features % 2 != 0: + raise ValueError('out_features must be even.') super().__init__() - assert out_features % 2 == 0 self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -89,7 +90,8 @@ def __init__(self, dim: int, features: int, include_radii: bool = True, base_res self.register_buffer('encoding', torch.cat(encoding, dim=1)[:, :features]) self.interpolation_mode = ['linear', 'bilinear', 'trilinear'][dim - 1] - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for encoding.""" features = self.encoding.shape[1] if features > x.shape[1]: raise ValueError(f'x has {x.shape[1]} features, but {features} are required') diff --git a/src/mrpro/nn/nets/CNN.py b/src/mrpro/nn/nets/CNN.py index f278759b6..6a4ade796 100644 --- a/src/mrpro/nn/nets/CNN.py +++ b/src/mrpro/nn/nets/CNN.py @@ -1,7 +1,9 @@ +"""Simple Convolutional Neural Network.""" + from collections.abc import Sequence from itertools import pairwise -from torch.nn import SiLU +from torch.nn import ReLU from mrpro.nn.FiLM import FiLM from mrpro.nn.GroupNorm import GroupNorm @@ -45,7 +47,7 @@ def __init__( super().__init__() channels = [channels_in, *features] for i, (channels_current, channels_next) in enumerate(pairwise(channels)): - block = Sequential(ConvND(dim)(channels_current, channels_next, 3, padding=1), SiLU(True)) + block = Sequential(ConvND(dim)(channels_current, channels_next, 3, padding=1), ReLU(True)) if norm: block.append(GroupNorm(1)) if cond_dim > 0 and i % 2 == 0: diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py index 49be43589..c3d3b586c 100644 --- a/src/mrpro/nn/nets/DCAE.py +++ b/src/mrpro/nn/nets/DCAE.py @@ -1,36 +1,51 @@ +"""Deep Compression Autoencoder.""" + from collections.abc import Sequence import torch -from torch.nn import Module, Sequential +from torch.nn import Module, Sequential, SiLU -from mrpro.nn import SiLU from mrpro.nn.GluMBConvResBlock import GluMBConvResBlock from mrpro.nn.LinearSelfAttention import LinearSelfAttention from mrpro.nn.MultiHeadAttention import MultiHeadAttention from mrpro.nn.NDModules import ConvND from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample +from mrpro.nn.Residual import Residual from mrpro.nn.RMSNorm import RMSNorm -class ResBlock(Module): +class ResBlock(Residual): + """Residual block with two convolutions and normalization.""" + def __init__( self, dim: int, channels: int, ): - super().__init__() - self.inner = Sequential( - ConvND(dim)(channels, channels, kernel_size=3, padding=1), - SiLU(), - ConvND(dim)(channels, channels, kernel_size=3, padding=1, bias=False), - RMSNorm(channels), + """Initialize the ResBlock. + + Parameters + ---------- + dim : int + The spatial dimension of the input tensor. + channels : int + The number of channels in the input tensor. + """ + super().__init__( + Residual( + Sequential( + ConvND(dim)(channels, channels, kernel_size=3, padding=1), + SiLU(), + ConvND(dim)(channels, channels, kernel_size=3, padding=1, bias=False), + RMSNorm(channels), + ) + ) ) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.inner(x) + x - class EfficientViTBlock(Module): + """Efficient Vision Transformer block with optional linear attention.""" + def __init__( self, dim: int, @@ -44,7 +59,7 @@ def __init__( attention = LinearSelfAttention(channels, channels, n_heads) # TODO: check heads and head dim else: attention = MultiHeadAttention(channels, channels, n_heads, features_last=False) - self.context_module = Sequential(attention, RMSNorm(channels)) + self.context_module = Residual(Sequential(attention, RMSNorm(channels))) self.local_module = GluMBConvResBlock( dim=dim, channels_in=channels, @@ -53,12 +68,15 @@ def __init__( ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.context_module(x) + x - x = self.local_module(x) # is already residual + """Forward pass for EfficientViTBlock.""" + x = self.context_module(x) + x = self.local_module(x) return x class Encoder(Sequential): + """Encoder for DCAE.""" + def __init__( self, dim: int = 2, @@ -84,13 +102,15 @@ def __init__( stage = [EfficientViTBlock(dim, width, n_heads=1, linear_attn=False) for _ in range(depth)] case _: raise ValueError(f'Block type {block_type} not supported') - self.append(Sequential(stage)) + self.append(Sequential(*stage)) if len(self) < len(widths): self.append(PixelUnshuffleDownsample(dim, width, width, downscale_factor=2, residual=True)) self.append(PixelUnshuffleDownsample(dim, widths[-1], channels_out, downscale_factor=1, residual=True)) class Decoder(Module): + """Decoder for DCAE.""" + def __init__( self, dim: int = 2, @@ -121,7 +141,8 @@ def __init__( stage = [EfficientViTBlock(dim, width, n_heads=1, linear_attn=False) for _ in range(depth)] case _: raise ValueError(f'Block type {block_type} not supported') - self.append(Sequential(stage)) + + self.stages.append(Sequential(*stage)) if len(self) < len(widths): self.append(PixelShuffleUpsample(dim, width, width, upscale_factor=2, residual=True)) @@ -149,11 +170,12 @@ def __init__( # act=cfg.out_act, # ) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.project_in(x) - for stage in reversed(self.stages): - if len(stage.op_list) == 0: - continue - x = stage(x) - x = self.project_out(x) - return x + self.project_out = PixelShuffleUpsample(dim, widths[-1], channels_out, upscale_factor=1, residual=True) + + # def forward(self, x: torch.Tensor) -> torch.Tensor: + # """Forward pass for Decoder.""" + # x = self.project_in(x) + # for stage in reversed(self.stages): + # x = stage(x) + # x = self.project_out(x) + # return x diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index 9f9824cbd..55b46fd94 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -1,17 +1,16 @@ """Restormer implementation.""" from collections.abc import Sequence +from itertools import pairwise import torch -from torch.nn import Module, Identity +from torch.nn import Identity, Module from mrpro.nn.FiLM import FiLM from mrpro.nn.NDModules import ConvND, InstanceNormND from mrpro.nn.nets.UNet import UNetBase -from mrpro.nn.PixelShuffle import PixelShuffle, PixelUnshuffle from mrpro.nn.Sequential import Sequential from mrpro.nn.TransposedAttention import TransposedAttention -from mrpro.utils import pairwise class GDFN(Module): @@ -87,7 +86,7 @@ def __init__(self, dim: int, channels: int, n_heads: int, mlp_ratio: float, cond """ super().__init__() self.norm1 = Sequential(InstanceNormND(dim)(channels)) - self.attn = TransposedAttention(dim, channels, n_heads) + self.attn = TransposedAttention(dim, channels, channels, n_heads) self.norm2 = Sequential(InstanceNormND(dim)(channels)) self.ffn = GDFN(dim, channels, mlp_ratio) if cond_dim > 0: @@ -171,18 +170,18 @@ def blocks(n_heads: int, n_blocks: int): layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, cond_dim=cond_dim)) return layers - for n_block, n_heads in zip(n_blocks, n_heads, strict=False): - self.input_blocks.append(blocks(n_heads, n_block)) - self.output_blocks.append(blocks(n_heads, n_block)) + for block, head in zip(n_blocks, n_heads, strict=False): + self.input_blocks.append(blocks(head, block)) + self.output_blocks.append(blocks(head, block)) self.skip_blocks.append(Identity()) self.output_blocks = self.output_blocks[::-1] - for n_head_current, n_head_next in pairwise(n_heads): + for head_current, head_next in pairwise(n_heads): self.down_blocks.append( Sequential( ConvND(dim)( - n_channels_per_head * n_head_current, - n_channels_per_head * n_head_next, + n_channels_per_head * head_current, + n_channels_per_head * head_next, kernel_size=3, stride=2, padding=1, @@ -192,8 +191,8 @@ def blocks(n_heads: int, n_blocks: int): self.up_blocks.append( Sequential( ConvND(dim)( - n_channels_per_head * n_head_next, - n_channels_per_head * n_head_current, + n_channels_per_head * head_next, + n_channels_per_head * head_current, kernel_size=3, stride=1, padding=1, diff --git a/src/mrpro/nn/nets/SwinIR.py b/src/mrpro/nn/nets/SwinIR.py index aff4037ad..ecfc356f5 100644 --- a/src/mrpro/nn/nets/SwinIR.py +++ b/src/mrpro/nn/nets/SwinIR.py @@ -1,8 +1,9 @@ """SwinIR implementation.""" import torch -from torch.nn import Module +from torch.nn import GELU, Module +from mrpro.nn.DropPath import DropPath from mrpro.nn.FiLM import FiLM from mrpro.nn.NDModules import ConvND, InstanceNormND from mrpro.nn.Sequential import Sequential @@ -23,6 +24,7 @@ def __init__( window_size: int, mlp_ratio: int = 4, emb_dim: int = 0, + p_droppath: float = 0.0, ): """Initialize SwinTransformerLayer. @@ -36,17 +38,25 @@ def __init__( Number of attention heads window_size : int Size of the attention window - mlp_ratio : int, optional - Ratio for hidden dimension expansion, by default 4 - emb_dim : int, optional - Dimension of conditioning input, by default 0 + mlp_ratio : int + Ratio for hidden dimension expansion in MLP + emb_dim : int + Dimension of conditioning input. If 0, no FiLM conditioning is used. + p_droppath : float + Droppath probability for MLP """ super().__init__() - self.norm1 = Sequential(InstanceNormND(dim)(channels)) + self.norm1 = InstanceNormND(dim)(channels) self.attn = ShiftedWindowAttention(dim, channels, n_heads, window_size) self.norm2 = Sequential(InstanceNormND(dim)(channels)) if emb_dim > 0: self.norm2.append(FiLM(channels=channels, cond_dim=emb_dim)) + self.mlp = Sequential( + ConvND(dim)(channels, channels * mlp_ratio, 1), + GELU(True), + ConvND(dim)(channels * mlp_ratio, channels, 1), + DropPath(p_droppath), + ) def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the Swin Transformer layer. @@ -68,14 +78,15 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the Swin Transformer layer.""" x = x + self.attn(self.norm1(x)) - x = x + self.norm2(x) + x = x + self.mlp(self.norm2(x, cond)) return x class ResidualSwinTransformerBlock(Module): - """Residual Swin Transformer block. + """Residual Swin Transformer block (RSTB). - Combines a Swin Transformer layer with a residual connection. + Combines a Swin Transformer layer with a residual connection, + as used in the SwinIR architecture. """ def __init__( @@ -86,6 +97,8 @@ def __init__( window_size: int, depth: int, emb_dim: int = 0, + p_droppath: float = 0.0, + mlp_ratio: int = 4, ): """Initialize ResidualSwinTransformerBlock. @@ -102,11 +115,20 @@ def __init__( depth : int Number of Swin Transformer layers emb_dim : int, optional - Dimension of conditioning input, by default 0 + Dimension of conditioning input. If 0, no FiLM conditioning is used. + p_droppath : float, optional + Droppath probability for MLP. + mlp_ratio : int, optional + Ratio for hidden dimension expansion in MLP """ super().__init__() self.layers = Sequential( - *(SwinTransformerLayer(dim, channels, n_heads, window_size, emb_dim=emb_dim) for _ in range(depth)) + *( + SwinTransformerLayer( + dim, channels, n_heads, window_size, emb_dim=emb_dim, p_droppath=p_droppath, mlp_ratio=mlp_ratio + ) + for _ in range(depth) + ) ) self.conv = ConvND(dim)(channels, channels, 3, padding=1) @@ -118,7 +140,7 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T x : torch.Tensor Input tensor cond : torch.Tensor | None, optional - Conditioning input, by default None + Conditioning input. If None, no FiLM conditioning is used. Returns ------- @@ -155,6 +177,8 @@ def __init__( n_blocks: int = 6, n_attn_per_block: int = 6, emb_dim: int = 0, + p_droppath: float = 0.0, + mlp_ratio: int = 4, ): """Initialize SwinIR. @@ -167,17 +191,21 @@ def __init__( channels_out : int Number of output channels channels_per_head : int, optional - Number of channels per attention head, by default 16 + Number of channels per attention head n_heads : int, optional - Number of attention heads, by default 6 - window_size : int, optional - Size of the attention window, by default 64 - n_blocks : int, optional - Number of residual blocks, by default 6 - n_attn_per_block : int, optional - Number of attention layers per block, by default 6 + Number of attention heads + window_size : int + Size of the attention window. Inputs sizes must be divisible by this value. + n_blocks : int + Number of residual blocks + n_attn_per_block : int + Number of attention layers per block emb_dim : int, optional - Dimension of conditioning input, by default 0 + Dimension of conditioning input. If 0, no FiLM conditioning is used. + p_droppath : float, optional + Droppath probability for MLP. + mlp_ratio : int, optional + Ratio for hidden dimension expansion in MLP. """ super().__init__() self.first = ConvND(dim)(channels_in, channels_per_head * n_heads, kernel_size=3, padding=1) @@ -190,6 +218,8 @@ def __init__( window_size, n_attn_per_block, emb_dim, + p_droppath, + mlp_ratio, ) for _ in range(n_blocks) ) @@ -204,7 +234,7 @@ def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Te x : torch.Tensor Input tensor cond : torch.Tensor | None, optional - Conditioning input, by default None + Conditioning input. If None, no FiLM conditioning is used. Returns ------- diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 2192db24f..3c9974245 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -1,3 +1,5 @@ +"""UNet variants.""" + from collections.abc import Sequence from functools import partial @@ -21,7 +23,7 @@ def __init__(self) -> None: self.skip_blocks = ModuleList() """Modifications to the skip connections""" - self.middle_block = Module() + self.middle_block: Module = Identity() """Also called bottleneck block""" self.output_blocks = ModuleList() @@ -33,11 +35,11 @@ def __init__(self) -> None: self.concat_blocks = ModuleList() """Joins the skip connections with the upsampled features from a lower resolution level""" - self.last = Identity() - """The last block""" + self.last: Module = Identity() + """The last block. Should reduce to the number of output channels.""" - self.first = Identity() - """The first block""" + self.first: Module = Identity() + """The first block. Should expand from the number of input channels.""" def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: """Apply to Network.""" @@ -98,6 +100,8 @@ def __init__( padding_modes: str | Sequence[str], ) -> None: ... + """Initialize the UNet.""" + class AttentionUNet(UNet): """UNet with attention gates. @@ -112,10 +116,12 @@ class AttentionUNet(UNet): class SeparableUNet(UNetBase): """UNet where blocks apply separable convolutions in different dimensions. - Based on the pseudo-3D residual network of [QUI]_ and the residual blocks of [ZIM]_. + Based on the pseudo-3D residual network of [QUI]_, [TRAN]_ and the residual blocks of [ZIM]_. References ---------- + .. [TRAN] Tran, D., Wang, H., Torresani, L., Ray, J., LeCun, Y., & Paluri, M. A closer look at spatiotemporal convolutions for action recognition. + CVPR 2018. https://arxiv.org/abs/1711.11248 .. [QUI] Qiu, Z., Yao, T., & Mei, T. Learning spatio-temporal representation with pseudo-3d residual networks. ICCV 2017. https://arxiv.org/abs/1711.10305 .. [ZIM] Zimmermann, F. F., & Kofler, A. (2023, October). NoSENSE: Learned unrolled cardiac MRI reconstruction diff --git a/src/mrpro/nn/nets/VAE.py b/src/mrpro/nn/nets/VAE.py index 6d4a5bd6c..9f98ac9c3 100644 --- a/src/mrpro/nn/nets/VAE.py +++ b/src/mrpro/nn/nets/VAE.py @@ -11,8 +11,8 @@ class VAE(Module): into the original space. The encoder should return twice the number of channels as the decoder needs to reconstruct the input: half of the channels are the mean and the other half the log variance of the latent space. The reparameterization trick is used to sample from the latent space. - The forward pass returns the reconstructed image and the KL divergence between the latent space and the standard normal - distribution. + The forward pass returns the reconstructed image and the KL divergence between the latent space and the standard + normal distribution. """ def __init__(self, encoder: Module, decoder: Module): diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py new file mode 100644 index 000000000..14908f57e --- /dev/null +++ b/src/mrpro/nn/nets/__init__.py @@ -0,0 +1,4 @@ +from mrpro.nn.nets.Restormer import Restormer +from mrpro.nn.nets.Uformer import Uformer + +__all__ = ["Restormer", "Uformer"] diff --git a/tests/nn/test_shiftedwindowattention.py b/tests/nn/test_shiftedwindowattention.py index 773e0daff..7ea8a4175 100644 --- a/tests/nn/test_shiftedwindowattention.py +++ b/tests/nn/test_shiftedwindowattention.py @@ -25,7 +25,7 @@ def test_shifted_window_attentio(dim: int, window_size: int, shifted: bool, devi rng = RandomGenerator(13) x = rng.float32_tensor((batch, channels, *spatial_shape)).to(device).requires_grad_(True) swin = ShiftedWindowAttention( - dim=dim, channels=channels, n_heads=n_heads, window_size=window_size, shifted=shifted + dim=dim, channels_in=channels, channels_out=channels, n_heads=n_heads, window_size=window_size, shifted=shifted ).to(device) out = swin(x) assert out.shape == x.shape, f'Output shape {out.shape} != input shape {x.shape}' From 62e04a1f4685464bda7a35f2d15ed35e7a1fc2ff Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 19 May 2025 17:20:20 +0200 Subject: [PATCH 036/205] update --- src/mrpro/nn/RoPE.py | 8 +++++--- src/mrpro/nn/encoding.py | 2 +- src/mrpro/nn/nets/DCAE.py | 22 +++++++++++----------- src/mrpro/nn/nets/SwinIR.py | 2 +- src/mrpro/nn/nets/UNet.py | 4 ++-- tests/nn/test_resblock.py | 6 +++--- tests/nn/test_transposedattention.py | 2 +- 7 files changed, 24 insertions(+), 22 deletions(-) diff --git a/src/mrpro/nn/RoPE.py b/src/mrpro/nn/RoPE.py index e84ec06c4..bd3cf5e79 100644 --- a/src/mrpro/nn/RoPE.py +++ b/src/mrpro/nn/RoPE.py @@ -51,12 +51,12 @@ def setup_context( ctx.conjugated = conjugated # type: ignore[attr-defined] @staticmethod - def backward( + def backward( # type: ignore[override] ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor ) -> tuple[torch.Tensor, None, None]: """Apply backward pass.""" - (theta,) = ctx.saved_tensors - apply_rotary_emb_(grad_output, theta, ctx.conjugated) + (theta,) = ctx.saved_tensors # type: ignore[attr-defined] + apply_rotary_emb_(grad_output, theta, ctx.conjugated) # type: ignore[attr-defined] return grad_output, None, None @@ -66,6 +66,8 @@ class AxialRoPE(Module): Applies rotary position embeddings along each axis independently. """ + freqs: torch.Tensor + def __init__(self, dim: int, d_head: int, n_heads: int, headpos: int = -2, non_embed_fraction: float = 0.5): """Initialize AxialRoPE. diff --git a/src/mrpro/nn/encoding.py b/src/mrpro/nn/encoding.py index 5a3501e5f..9828b2cf0 100644 --- a/src/mrpro/nn/encoding.py +++ b/src/mrpro/nn/encoding.py @@ -80,7 +80,7 @@ def __init__(self, dim: int, features: int, include_radii: bool = True, base_res if include_radii: for n in range(2, dim + 1): for combination in combinations(coords, n): - coords.append(2**0.5 * torch.sqrt(sum([c**2 for c in combination])) - 1) + coords.append((2 * sum([c**2 for c in combination])) ** 0.5 - 1) n_freqs = ceil(features / len(coords) / 2) freqs = unsqueeze_right((base_resolution) ** torch.linspace(0, 1, n_freqs), dim) encoding = [] diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py index c3d3b586c..6e169f5cc 100644 --- a/src/mrpro/nn/nets/DCAE.py +++ b/src/mrpro/nn/nets/DCAE.py @@ -14,15 +14,15 @@ from mrpro.nn.RMSNorm import RMSNorm -class ResBlock(Residual): - """Residual block with two convolutions and normalization.""" +class CNNBlock(Residual): + """Block with two convolutions and normalization.""" def __init__( self, dim: int, channels: int, ): - """Initialize the ResBlock. + """Initialize the CNNBlock. Parameters ---------- @@ -51,12 +51,12 @@ def __init__( dim: int, channels: int, n_heads: int, - expand_ratio: float = 4, + expand_ratio: int = 4, linear_attn: bool = False, ): super().__init__() if linear_attn: - attention = LinearSelfAttention(channels, channels, n_heads) # TODO: check heads and head dim + attention: Module = LinearSelfAttention(channels, channels, n_heads) # TODO: check heads and head dim else: attention = MultiHeadAttention(channels, channels, n_heads, features_last=False) self.context_module = Residual(Sequential(attention, RMSNorm(channels))) @@ -82,7 +82,7 @@ def __init__( dim: int = 2, channels_in: int = 3, channels_out: int = 32, - block_types: Sequence[str] = ('ResBlock', 'ResBlock', 'LinearViT', 'LinearViT', 'ViT'), + block_types: Sequence[str] = ('CNN', 'CNN', 'LinearViT', 'LinearViT', 'ViT'), widths: Sequence[int] = (256, 512, 512, 1024, 1024), depths: Sequence[int] = (4, 6, 2, 2, 2), ): @@ -92,8 +92,8 @@ def __init__( raise ValueError('block_types, widths, and depths must have the same length') for block_type, width, depth in zip(block_types, widths, depths, strict=False): match block_type: - case 'ResBlock': - stage = [ResBlock(dim, width) for _ in range(depth)] + case 'CNN': + stage: list[Module] = [CNNBlock(dim, width) for _ in range(depth)] case 'LinearViT': stage = [ EfficientViTBlock(dim, width, n_heads=1, linear_attn=True) for _ in range(depth) @@ -108,7 +108,7 @@ def __init__( self.append(PixelUnshuffleDownsample(dim, widths[-1], channels_out, downscale_factor=1, residual=True)) -class Decoder(Module): +class Decoder(Sequential): """Decoder for DCAE.""" def __init__( @@ -116,7 +116,7 @@ def __init__( dim: int = 2, channels_in: int = 32, channels_out: int = 3, - block_types: Sequence[str] = ('ViT', 'LinearViT', 'LinearViT', 'ResBlock', 'ResBlock'), + block_types: Sequence[str] = ('ViT', 'LinearViT', 'LinearViT', 'CNN', 'CNN'), widths: Sequence[int] = (1024, 1024, 512, 512, 256), depths: Sequence[int] = (2, 2, 2, 6, 4), ): @@ -132,7 +132,7 @@ def __init__( for block_type, width, depth in zip(block_types, widths, depths, strict=False): match block_type: case 'ResBlock': - stage = [ResBlock(dim, width) for _ in range(depth)] + stage: list[Module] = [CNNBlock(dim, width) for _ in range(depth)] case 'LinearViT': stage = [ EfficientViTBlock(dim, width, n_heads=1, linear_attn=True) for _ in range(depth) diff --git a/src/mrpro/nn/nets/SwinIR.py b/src/mrpro/nn/nets/SwinIR.py index ecfc356f5..58493cfc3 100644 --- a/src/mrpro/nn/nets/SwinIR.py +++ b/src/mrpro/nn/nets/SwinIR.py @@ -53,7 +53,7 @@ def __init__( self.norm2.append(FiLM(channels=channels, cond_dim=emb_dim)) self.mlp = Sequential( ConvND(dim)(channels, channels * mlp_ratio, 1), - GELU(True), + GELU('tanh'), ConvND(dim)(channels * mlp_ratio, channels, 1), DropPath(p_droppath), ) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 3c9974245..655a877da 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -120,8 +120,8 @@ class SeparableUNet(UNetBase): References ---------- - .. [TRAN] Tran, D., Wang, H., Torresani, L., Ray, J., LeCun, Y., & Paluri, M. A closer look at spatiotemporal convolutions for action recognition. - CVPR 2018. https://arxiv.org/abs/1711.11248 + .. [TRAN] Tran, D., Wang, H., Torresani, L., Ray, J., LeCun, Y., & Paluri, M. A closer look at spatiotemporal + convolutions for action recognition. CVPR 2018. https://arxiv.org/abs/1711.11248 .. [QUI] Qiu, Z., Yao, T., & Mei, T. Learning spatio-temporal representation with pseudo-3d residual networks. ICCV 2017. https://arxiv.org/abs/1711.10305 .. [ZIM] Zimmermann, F. F., & Kofler, A. (2023, October). NoSENSE: Learned unrolled cardiac MRI reconstruction diff --git a/tests/nn/test_resblock.py b/tests/nn/test_resblock.py index 957eb70c6..195b88f01 100644 --- a/tests/nn/test_resblock.py +++ b/tests/nn/test_resblock.py @@ -13,18 +13,18 @@ ], ) @pytest.mark.parametrize( - ('dim', 'channels_in', 'channels_out', 'channels_emb', 'input_shape', 'emb_shape'), + ('dim', 'channels_in', 'channels_out', 'cond_dim', 'input_shape', 'emb_shape'), [ (2, 32, 32, 16, (1, 32, 32, 32), (1, 16)), (3, 64, 32, 0, (2, 64, 16, 16, 16), None), ], ) -def test_resblock(dim, channels_in, channels_out, channels_emb, input_shape, emb_shape, device): +def test_resblock(dim, channels_in, channels_out, cond_dim, input_shape, emb_shape, device): """Test ResBlock output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) emb = rng.float32_tensor(emb_shape).to(device).requires_grad_(True) if emb_shape else None - res = ResBlock(dim=dim, channels_in=channels_in, channels_out=channels_out, channels_emb=channels_emb).to(device) + res = ResBlock(dim=dim, channels_in=channels_in, channels_out=channels_out, cond_dim=cond_dim).to(device) output = res(x, emb) assert output.shape == (input_shape[0], channels_out, *input_shape[2:]), ( f'Output shape {output.shape} != expected {(input_shape[0], channels_out, *input_shape[2:])}' diff --git a/tests/nn/test_transposedattention.py b/tests/nn/test_transposedattention.py index 0fdd80612..ea39781b3 100644 --- a/tests/nn/test_transposedattention.py +++ b/tests/nn/test_transposedattention.py @@ -32,5 +32,5 @@ def test_transposed_attention(dim, channels, num_heads, input_shape, device): assert not x.grad.isnan().any(), 'NaN values in input gradients' assert attn.to_qkv.weight.grad is not None, 'No gradient computed for qkv' assert attn.qkv_dwconv.weight.grad is not None, 'No gradient computed for qkv_dwconv' - assert attn.project_out.weight.grad is not None, 'No gradient computed for project_out' + assert attn.to_out.weight.grad is not None, 'No gradient computed for project_out' assert attn.temperature.grad is not None, 'No gradient computed for temperature' From 7e9d12183fae6d953967b3409a64e729848274bf Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 19 May 2025 21:52:03 +0200 Subject: [PATCH 037/205] update --- src/mrpro/nn/nets/DCAE.py | 163 ++++++++++++++++++++++++++++---------- 1 file changed, 120 insertions(+), 43 deletions(-) diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py index 6e169f5cc..bc2fdaeda 100644 --- a/src/mrpro/nn/nets/DCAE.py +++ b/src/mrpro/nn/nets/DCAE.py @@ -1,6 +1,7 @@ """Deep Compression Autoencoder.""" from collections.abc import Sequence +from typing import Literal import torch from torch.nn import Module, Sequential, SiLU @@ -9,13 +10,22 @@ from mrpro.nn.LinearSelfAttention import LinearSelfAttention from mrpro.nn.MultiHeadAttention import MultiHeadAttention from mrpro.nn.NDModules import ConvND +from mrpro.nn.nets.VAE import VAE from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample from mrpro.nn.Residual import Residual from mrpro.nn.RMSNorm import RMSNorm class CNNBlock(Residual): - """Block with two convolutions and normalization.""" + """Block with two convolutions and normalization. + + As used in the DCAE [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ def __init__( self, @@ -44,7 +54,15 @@ def __init__( class EfficientViTBlock(Module): - """Efficient Vision Transformer block with optional linear attention.""" + """Efficient Vision Transformer block with optional linear attention. + + As used in the DCAE [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ def __init__( self, @@ -54,6 +72,21 @@ def __init__( expand_ratio: int = 4, linear_attn: bool = False, ): + """Initialize the EfficientViTBlock. + + Parameters + ---------- + dim : int + The spatial dimension of the input tensor. + channels : int + The number of channels in the input tensor. + n_heads : int + The number of attention heads. + expand_ratio : int + The expansion ratio of the GluMBConvResBlock. + linear_attn : bool + Whether to use linear attention instead of softmax attention with quadratic complexity. + """ super().__init__() if linear_attn: attention: Module = LinearSelfAttention(channels, channels, n_heads) # TODO: check heads and head dim @@ -75,17 +108,45 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Encoder(Sequential): - """Encoder for DCAE.""" + """Encoder for DCAE. + + As used in the DC-Autoencoder [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ def __init__( self, dim: int = 2, channels_in: int = 3, channels_out: int = 32, - block_types: Sequence[str] = ('CNN', 'CNN', 'LinearViT', 'LinearViT', 'ViT'), + block_types: Sequence[Literal['CNN', 'LinearViT', 'ViT']] = ('CNN', 'CNN', 'LinearViT', 'LinearViT', 'ViT'), widths: Sequence[int] = (256, 512, 512, 1024, 1024), depths: Sequence[int] = (4, 6, 2, 2, 2), ): + """Initialize the Encoder. + + The length of the `block_types`, `widths`, and `depths` must be the same and determine + the number of stages in the encoder. Between the stages, downsampling is performed. + + Parameters + ---------- + dim : int + The spatial dimension of the input tensor. + channels_in : int + The number of channels in the input tensor, i.e. the latent space + channels_out : int + The number of channels in the output tensor, i.e. the original space + block_types : Sequence[str] + The types of blocks to use in the decoder. + widths : Sequence[int] + The widths of the blocks in the decoder, i.e. the number of channels in the blocks + depths : Sequence[int] + The depths of the blocks in the decoder, i.e. the number blocks in the stage + """ super().__init__() self.append(PixelUnshuffleDownsample(dim, channels_in, widths[0], downscale_factor=2, residual=False)) if len(block_types) != len(widths) or len(block_types) != len(depths): @@ -109,29 +170,54 @@ def __init__( class Decoder(Sequential): - """Decoder for DCAE.""" + """Decoder for DCAE. + + As used in the DC-Autoencoder [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ def __init__( self, dim: int = 2, channels_in: int = 32, channels_out: int = 3, - block_types: Sequence[str] = ('ViT', 'LinearViT', 'LinearViT', 'CNN', 'CNN'), + block_types: Sequence[Literal['ViT', 'LinearViT', 'CNN']] = ('ViT', 'LinearViT', 'LinearViT', 'CNN', 'CNN'), widths: Sequence[int] = (1024, 1024, 512, 512, 256), depths: Sequence[int] = (2, 2, 2, 6, 4), ): + """Initialize the Decoder. + + The length of the `block_types`, `widths`, and `depths` must be the same and determine + the number of stages in the decoder. Between the stages, upsampling is performed. + + Parameters + ---------- + dim : int + The spatial dimension of the input tensor. + channels_in : int + The number of channels in the input tensor, i.e. the latent space + channels_out : int + The number of channels in the output tensor, i.e. the original space + block_types : Sequence[str] + The types of blocks to use in the decoder. + widths : Sequence[int] + The widths of the blocks in the decoder, i.e. the number of channels in the blocks + depths : Sequence[int] + The depths of the blocks in the decoder, i.e. the number blocks in the stage + """ super().__init__() if not (len(block_types) == len(widths) == len(depths)): raise ValueError('block_types, widths, and depths must have the same length') - # "decoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU] " - # "decoder.width_list=[128,256,512,512,1024,1024] decoder.depth_list=[0,5,10,2,2,2] " - # "decoder.norm=[bn2d,bn2d,bn2d,trms2d,trms2d,trms2d] decoder.act=[relu,relu,relu,silu,silu,silu]" self.append(PixelShuffleUpsample(dim, channels_in, widths[0], upscale_factor=1, residual=True)) self.stages: list[Sequential] = [] for block_type, width, depth in zip(block_types, widths, depths, strict=False): match block_type: - case 'ResBlock': + case 'CNN': stage: list[Module] = [CNNBlock(dim, width) for _ in range(depth)] case 'LinearViT': stage = [ @@ -146,36 +232,27 @@ def __init__( if len(self) < len(widths): self.append(PixelShuffleUpsample(dim, width, width, upscale_factor=2, residual=True)) - # stage.extend( - # build_stage_main( - # width=width, - # depth=depth, - # block_type=block_type, - # norm=norm, - # act=act, - # input_width=( - # width if cfg.upsample_match_channel else cfg.width_list[min(stage_id + 1, num_stages - 1)] - # ), - # ) - # ) - # self.stages.insert(0, OpSequential(stage)) - # self.stages = nn.ModuleList(self.stages) - - # self.project_out = build_decoder_project_out_block( - # in_channels=cfg.width_list[0] if cfg.depth_list[0] > 0 else cfg.width_list[1], - # out_channels=cfg.in_channels, - # factor=1 if cfg.depth_list[0] > 0 else 2, - # upsample_block_type=cfg.upsample_block_type, - # norm=cfg.out_norm, - # act=cfg.out_act, - # ) - - self.project_out = PixelShuffleUpsample(dim, widths[-1], channels_out, upscale_factor=1, residual=True) - - # def forward(self, x: torch.Tensor) -> torch.Tensor: - # """Forward pass for Decoder.""" - # x = self.project_in(x) - # for stage in reversed(self.stages): - # x = stage(x) - # x = self.project_out(x) - # return x + self.append(PixelShuffleUpsample(dim, widths[-1], channels_out, upscale_factor=1, residual=True)) + + +class DCVAE(VAE): + """Variational Autoencoder based on DCAE. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + dim: int, + channels: int, + latent_dim: int = 32, + block_types: Sequence[Literal['CNN', 'LinearViT', 'ViT']] = ('CNN', 'CNN', 'LinearViT', 'LinearViT', 'ViT'), + widths: Sequence[int] = (256, 512, 512, 1024, 1024), + depths: Sequence[int] = (4, 6, 2, 2, 2), + ): + encoder = Encoder(dim, channels, latent_dim * 2, block_types, widths, depths) + decoder = Decoder(dim, latent_dim, channels, block_types[::-1], widths[::-1], depths[::-1]) + super().__init__(encoder, decoder) From 3d259bbc3122f3d312f9a0a628c994e452820d81 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 20 May 2025 02:09:41 +0200 Subject: [PATCH 038/205] update --- src/mrpro/nn/LinearSelfAttention.py | 7 ++--- src/mrpro/nn/MultiHeadAttention.py | 8 +++--- src/mrpro/nn/PixelShuffle.py | 4 ++- src/mrpro/nn/RMSNorm.py | 12 ++++++--- src/mrpro/nn/Sequential.py | 12 +++++++++ src/mrpro/nn/nets/DCAE.py | 42 ++++++++++++++++------------- src/mrpro/nn/nets/__init__.py | 5 +++- 7 files changed, 58 insertions(+), 32 deletions(-) diff --git a/src/mrpro/nn/LinearSelfAttention.py b/src/mrpro/nn/LinearSelfAttention.py index 3645905a0..3ad70f14f 100644 --- a/src/mrpro/nn/LinearSelfAttention.py +++ b/src/mrpro/nn/LinearSelfAttention.py @@ -83,11 +83,11 @@ def forward(self, x: Tensor) -> Tensor: x = x.float() if not self.channel_last: x = x.moveaxis(1, -1) - spatial_shape = x.shape[2:-1] + spatial_shape = x.shape[1:-1] qkv = self.to_qkv(x) query, key, value = rearrange( - qkv, 'batch ... (qkv head channels) -> qkv batch head (...) channel', qkv=3, head=self.n_heads + qkv, 'batch ... (qkv head channels) -> qkv batch head (...) channels', qkv=3, head=self.n_heads ) query = self.kernel_function(query) @@ -102,7 +102,8 @@ def forward(self, x: Tensor) -> Tensor: attn = value_key_query[..., :-1, :] / normalization out = self.to_out(attn) out = out.to(orig_dtype) - out.unflatten(-2, spatial_shape) + out = out.moveaxis(1, -1).flatten(-2) # join heads and channels + out = out.unflatten(-2, spatial_shape) if not self.channel_last: out = out.moveaxis(-1, 1) return out diff --git a/src/mrpro/nn/MultiHeadAttention.py b/src/mrpro/nn/MultiHeadAttention.py index 884bc8097..d7e146b6b 100644 --- a/src/mrpro/nn/MultiHeadAttention.py +++ b/src/mrpro/nn/MultiHeadAttention.py @@ -63,17 +63,17 @@ def __call__(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) def _reshape(self, x: torch.Tensor) -> torch.Tensor: if not self.features_last: x = x.moveaxis(1, -1) - return x.flatten(2, -2) + return x.flatten(1, -2) def forward(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: """Apply multi-head attention.""" reshaped_x = self._reshape(x) reshaped_cross_attention = self._reshape(cross_attention) if cross_attention is not None else reshaped_x - y = self.mha(reshaped_cross_attention, reshaped_cross_attention, reshaped_x) - out = self.to_out(y) + y = self.mha(reshaped_cross_attention, reshaped_cross_attention, reshaped_x, need_weights=False)[0] + out: torch.Tensor = self.to_out(y) if not self.features_last: - out = out.moveaxes(-1, 1) + out = out.moveaxis(-1, 1) return out.reshape(x.shape) diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py index fedacb9dd..9a474d3ff 100644 --- a/src/mrpro/nn/PixelShuffle.py +++ b/src/mrpro/nn/PixelShuffle.py @@ -112,9 +112,11 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply downsampling.""" - x = self.pixel_unshuffle(x) h = self.conv(x) + h = self.pixel_unshuffle(h) + if self.residual: + x = self.pixel_unshuffle(x) h = h + x.unflatten(1, (h.shape[1], -1)).mean(2) return h diff --git a/src/mrpro/nn/RMSNorm.py b/src/mrpro/nn/RMSNorm.py index 7ffbcfeec..52a2bce43 100644 --- a/src/mrpro/nn/RMSNorm.py +++ b/src/mrpro/nn/RMSNorm.py @@ -7,7 +7,7 @@ class RMSNorm(Module): """RMSNorm over the channel dimension.""" - def __init__(self, channels: int, eps: float = 1e-8): + def __init__(self, channels: int, eps: float = 1e-8, channel_last: bool = False): """Initialize RMSNorm. Includes a learnable weight and bias. @@ -18,11 +18,14 @@ def __init__(self, channels: int, eps: float = 1e-8): Number of channels. eps Epsilon value to avoid division by zero. + channel_last + If True, the channel dimension is the last dimension. """ super().__init__() self.weight = Parameter(torch.zeros(channels)) self.bias = Parameter(torch.zeros(channels)) self.eps = eps + self.channel_dim = -1 if channel_last else 1 def __call__(self, x: torch.Tensor) -> torch.Tensor: """Apply RMSNorm over the channel dimension. @@ -40,10 +43,11 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply RMSNorm over the channel dimension.""" - mean_square = x.pow(2).mean(dim=1, keepdim=True) + mean_square = x.square().mean(dim=self.channel_dim, keepdim=True) scale = (mean_square + self.eps).rsqrt() x = x * scale - shape = (1, -1, *([1] * (x.ndim - 2))) - weight = (1 + self.weight).view(shape) + shape = [1] * x.ndim + shape[self.channel_dim] = -1 + weight = (self.weight + 1).view(shape) bias = self.bias.view(shape) return x * weight + bias diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py index 77a375d39..99629b5b8 100644 --- a/src/mrpro/nn/Sequential.py +++ b/src/mrpro/nn/Sequential.py @@ -1,5 +1,7 @@ """Sequential container with support for conditioning and Operators.""" +from collections import OrderedDict + import torch from mrpro.nn.CondMixin import CondMixin @@ -35,3 +37,13 @@ def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Te else: x = module(x) return x + + def __getitem__(self, idx: slice | int) -> 'Sequential': + """Get a slice or item from the Sequential container. + + Subclasses will decompose to `Sequential` on indexing. + """ + if isinstance(idx, slice): + return Sequential(OrderedDict(list(self._modules.items())[idx])) + else: + return self._get_item_by_idx(self._modules.values(), idx) diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py index bc2fdaeda..c54cd838a 100644 --- a/src/mrpro/nn/nets/DCAE.py +++ b/src/mrpro/nn/nets/DCAE.py @@ -4,7 +4,7 @@ from typing import Literal import torch -from torch.nn import Module, Sequential, SiLU +from torch.nn import Module, ReLU, SiLU from mrpro.nn.GluMBConvResBlock import GluMBConvResBlock from mrpro.nn.LinearSelfAttention import LinearSelfAttention @@ -14,6 +14,7 @@ from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample from mrpro.nn.Residual import Residual from mrpro.nn.RMSNorm import RMSNorm +from mrpro.nn.Sequential import Sequential class CNNBlock(Residual): @@ -42,13 +43,11 @@ def __init__( The number of channels in the input tensor. """ super().__init__( - Residual( - Sequential( - ConvND(dim)(channels, channels, kernel_size=3, padding=1), - SiLU(), - ConvND(dim)(channels, channels, kernel_size=3, padding=1, bias=False), - RMSNorm(channels), - ) + Sequential( + ConvND(dim)(channels, channels, kernel_size=3, padding=1), + SiLU(True), + ConvND(dim)(channels, channels, kernel_size=3, padding=1, bias=False), + RMSNorm(channels), ) ) @@ -151,7 +150,7 @@ def __init__( self.append(PixelUnshuffleDownsample(dim, channels_in, widths[0], downscale_factor=2, residual=False)) if len(block_types) != len(widths) or len(block_types) != len(depths): raise ValueError('block_types, widths, and depths must have the same length') - for block_type, width, depth in zip(block_types, widths, depths, strict=False): + for block_type, width, next_width, depth in zip(block_types, widths, (*widths[1:], None), depths, strict=False): match block_type: case 'CNN': stage: list[Module] = [CNNBlock(dim, width) for _ in range(depth)] @@ -164,8 +163,9 @@ def __init__( case _: raise ValueError(f'Block type {block_type} not supported') self.append(Sequential(*stage)) - if len(self) < len(widths): - self.append(PixelUnshuffleDownsample(dim, width, width, downscale_factor=2, residual=True)) + if next_width: + self.append(PixelUnshuffleDownsample(dim, width, next_width, downscale_factor=2, residual=True)) + # Norm # relu self.append(PixelUnshuffleDownsample(dim, widths[-1], channels_out, downscale_factor=1, residual=True)) @@ -214,8 +214,7 @@ def __init__( raise ValueError('block_types, widths, and depths must have the same length') self.append(PixelShuffleUpsample(dim, channels_in, widths[0], upscale_factor=1, residual=True)) - self.stages: list[Sequential] = [] - for block_type, width, depth in zip(block_types, widths, depths, strict=False): + for block_type, width, next_width, depth in zip(block_types, widths, (*widths[1:], None), depths, strict=False): match block_type: case 'CNN': stage: list[Module] = [CNNBlock(dim, width) for _ in range(depth)] @@ -227,12 +226,17 @@ def __init__( stage = [EfficientViTBlock(dim, width, n_heads=1, linear_attn=False) for _ in range(depth)] case _: raise ValueError(f'Block type {block_type} not supported') - - self.stages.append(Sequential(*stage)) - if len(self) < len(widths): - self.append(PixelShuffleUpsample(dim, width, width, upscale_factor=2, residual=True)) - - self.append(PixelShuffleUpsample(dim, widths[-1], channels_out, upscale_factor=1, residual=True)) + self.append(Sequential(*stage)) + if next_width: + self.append(PixelShuffleUpsample(dim, width, next_width, upscale_factor=2, residual=True)) + + self.append( + Sequential( + RMSNorm(widths[-1]), + ReLU(), + PixelShuffleUpsample(dim, widths[-1], channels_out, upscale_factor=2), + ) + ) class DCVAE(VAE): diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 14908f57e..35c701a5d 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -1,4 +1,7 @@ from mrpro.nn.nets.Restormer import Restormer from mrpro.nn.nets.Uformer import Uformer +from mrpro.nn.nets.DCAE import DCVAE +from mrpro.nn.nets.VAE import VAE +from mrpro.nn.nets.UNet import UNet, AttentionUNet -__all__ = ["Restormer", "Uformer"] +__all__ = ["AttentionUNet", "DCVAE", "Restormer", "UNet", "Uformer", "VAE"] \ No newline at end of file From 52c8630c89ce48b82120ab4a5f22597389cb5778 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 20 May 2025 22:43:25 +0200 Subject: [PATCH 039/205] update --- src/mrpro/nn/RoPE.py | 4 ++-- src/mrpro/nn/nets/Restormer.py | 14 +++++++------- src/mrpro/nn/nets/SwinIR.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/mrpro/nn/RoPE.py b/src/mrpro/nn/RoPE.py index bd3cf5e79..90ecb8739 100644 --- a/src/mrpro/nn/RoPE.py +++ b/src/mrpro/nn/RoPE.py @@ -80,9 +80,9 @@ def __init__(self, dim: int, d_head: int, n_heads: int, headpos: int = -2, non_e n_heads : int Number of attention heads headpos : int, optional - Position of the head dimension, by default -2 + Position of the head dimension non_embed_fraction : float, optional - Fraction of dimensions to not embed, by default 0.5 + Fraction of dimensions to not embed """ super().__init__() log_min = torch.log(torch.tensor(torch.pi)) diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index 55b46fd94..831ac880e 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -82,7 +82,7 @@ def __init__(self, dim: int, channels: int, n_heads: int, mlp_ratio: float, cond mlp_ratio : float Ratio for hidden dimension expansion cond_dim : int, optional - Dimension of conditioning input, by default 0 + Dimension of conditioning input """ super().__init__() self.norm1 = Sequential(InstanceNormND(dim)(channels)) @@ -145,17 +145,17 @@ def __init__( channels_out : int Number of output channels n_blocks : Sequence[int], optional - Number of blocks in each stage, by default (4, 6, 6, 8) + Number of blocks in each stage n_refinement_blocks : int, optional - Number of refinement blocks, by default 4 + Number of refinement blocks n_heads : Sequence[int], optional - Number of attention heads in each stage, by default (1, 2, 4, 8) + Number of attention heads in each stage n_channels_per_head : int, optional - Number of channels per attention head, by default 48 + Number of channels per attention head mlp_ratio : float, optional - Ratio for hidden dimension expansion, by default 2.66 + Ratio for hidden dimension expansion cond_dim : int, optional - Dimension of conditioning input, by default 0 + Dimension of conditioning input """ super().__init__() diff --git a/src/mrpro/nn/nets/SwinIR.py b/src/mrpro/nn/nets/SwinIR.py index 58493cfc3..b86a02e9a 100644 --- a/src/mrpro/nn/nets/SwinIR.py +++ b/src/mrpro/nn/nets/SwinIR.py @@ -66,7 +66,7 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T x : torch.Tensor Input tensor cond : torch.Tensor | None, optional - Conditioning input, by default None + Conditioning input Returns ------- From d626bbbdd420c3b0a59c1deef62444afcf784170 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 21 May 2025 01:26:18 +0200 Subject: [PATCH 040/205] update --- src/mrpro/nn/TransposedAttention.py | 2 +- src/mrpro/nn/join.py | 91 +++++++++++++++++++++++++++++ src/mrpro/nn/nets/Restormer.py | 2 +- src/mrpro/nn/nets/UNet.py | 6 +- 4 files changed, 96 insertions(+), 5 deletions(-) create mode 100644 src/mrpro/nn/join.py diff --git a/src/mrpro/nn/TransposedAttention.py b/src/mrpro/nn/TransposedAttention.py index b9105285a..2cbfac17c 100644 --- a/src/mrpro/nn/TransposedAttention.py +++ b/src/mrpro/nn/TransposedAttention.py @@ -46,7 +46,7 @@ def __init__(self, dim: int, channels_in: int, channels_out: int, n_heads: int): padding=1, bias=False, ) - self.to_out = ConvND(dim)(channels_per_head * n_heads * 3, channels_out, kernel_size=1) + self.to_out = ConvND(dim)(channels_per_head * n_heads, channels_out, kernel_size=1) def __call__(self, x: torch.Tensor) -> torch.Tensor: """Apply transposed attention. diff --git a/src/mrpro/nn/join.py b/src/mrpro/nn/join.py new file mode 100644 index 000000000..22baf70b0 --- /dev/null +++ b/src/mrpro/nn/join.py @@ -0,0 +1,91 @@ +from typing import Literal, Sequence + +import torch +from torch.nn import Module + +from mrpro.utils.pad_or_crop import pad_or_crop + + +def fix_shapes( + xs: Sequence[torch.Tensor], mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'], dim: Sequence[int] +) -> tuple[torch.Tensor, ...]: + if mode == 'fail': + return tuple(xs) + + shapes = [[x.shape[d] for d in dim] for x in xs] + if mode == 'crop': + target = tuple(min(s) for s in zip(*shapes, strict=True)) + else: + target = tuple(max(s) for s in zip(*shapes, strict=True)) + if mode in ('crop', 'zero'): + mode = 'constant' + return tuple(pad_or_crop(x, target, dim=dim, mode=mode) for x in xs) + + # # def pad(x) -> torch.Tensor: + # # if x.shape[2:] == target: + # # return x + # # pad = [] + # # for cur, tgt in zip(reversed(x.shape[2:]), reversed(target), strict=True): + # # left = (tgt - cur) // 2 + # # right = tgt - cur - left + # # pad.extend([left, right]) + # # return torch.nn.functional.pad(x, pad, mode=mode) + + # return tuple(pad(x) for x in xs) + + +class Concat(Module): + """Concatenate tensors along the channel dimension""" + + def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'] = 'zero', dim: int = 1) -> None: + """Initialize Concat. + + Parameters + ---------- + mode : {'fail', 'crop', 'zero', 'replicate', 'circular'}, default='zero' + How to handle mismatched spatial dimensions: + - 'fail': do not align, raise error if shapes mismatch + - 'crop': center-crop to smallest spatial size + - 'zero': zero-pad to largest spatial size + - 'replicate': pad by edge value replication + - 'circular': circular padding + dim + Dimension along which to concatenate. + """ + super().__init__() + modes = {'fail', 'crop', 'zero', 'replicate', 'circular'} + if mode not in modes: + raise ValueError(f'mode must be one of {modes}') + self.mode = mode + self.dim = dim + + def forward(self, *xs: torch.Tensor) -> torch.Tensor: + xs = fix_shapes(xs, self.mode, dim=[i for i in range(max(x.ndim for x in xs)) if i != self.dim]) + return torch.cat(xs, dim=1) + + +class Add(Module): + """Add tensors""" + + def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'] = 'zero') -> None: + """Initialize Add. + + Parameters + ---------- + mode : {'fail', 'crop', 'zero', 'replicate', 'circular'}, default='zero' + How to handle mismatched spatial dimensions: + - 'fail': do not align, raise error if shapes mismatch + - 'crop': center-crop to smallest spatial size + - 'zero': zero-pad to largest spatial size + - 'replicate': pad by edge value replication + - 'circular': circular padding + """ + super().__init__() + modes = {'fail', 'crop', 'zero', 'replicate', 'circular'} + if mode not in modes: + raise ValueError(f'mode must be one of {modes}') + self.mode = mode + + def forward(self, *xs: torch.Tensor) -> torch.Tensor: + xs = fix_shapes(xs, self.mode, dim=range(max(x.ndim for x in xs))) + return sum(xs, start=torch.tensor(0.0)) diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index 831ac880e..067210054 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -163,7 +163,7 @@ def __init__( def blocks(n_heads: int, n_blocks: int): layers = Sequential( - *(RestormerBlock(dim, n_channels_per_head, n_heads, mlp_ratio) for _ in range(n_blocks)) + *(RestormerBlock(dim, n_channels_per_head * n_heads, n_heads, mlp_ratio) for _ in range(n_blocks)) ) if cond_dim > 0 and n_blocks > 1: diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 655a877da..a4e88198c 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -41,7 +41,7 @@ def __init__(self) -> None: self.first: Module = Identity() """The first block. Should expand from the number of input channels.""" - def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply to Network.""" call = partial(call_with_cond, cond=cond) x = call(self.first, x) @@ -57,7 +57,7 @@ def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: x = call(block, x) return call(self.last, x) - def __call__(self, x: torch.Tensor, cond: torch.Tensor | None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply to Network. Parameters @@ -71,7 +71,7 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None) -> torch.Tensor: ------- The output tensor. """ - return self(x, cond) + return super().__call__(x, cond) class UNet(UNetBase): From 01881fef9076635036b932e2a94c01e90145cbb9 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 21 May 2025 16:04:02 +0200 Subject: [PATCH 041/205] Refactor imports to use lowercase 'ndmodules' and update method signatures for better compatibility with multiple input tensors. Introduce conversion functions between Linear and Conv layers, along with corresponding tests. --- src/mrpro/nn/AttentionGate.py | 2 +- src/mrpro/nn/ComplexAsChannel.py | 33 +-- src/mrpro/nn/CondMixin.py | 10 +- src/mrpro/nn/FiLM.py | 7 +- src/mrpro/nn/PixelShuffle.py | 2 +- src/mrpro/nn/ResBlock.py | 2 +- src/mrpro/nn/Residual.py | 8 +- src/mrpro/nn/Sequential.py | 30 ++- src/mrpro/nn/ShiftedWindowAttention.py | 2 +- src/mrpro/nn/SqueezeExcitation.py | 2 +- src/mrpro/nn/TransposedAttention.py | 2 +- src/mrpro/nn/__init__.py | 2 +- src/mrpro/nn/convert_linear_conv.py | 97 +++++++++ src/mrpro/nn/join.py | 41 ++-- src/mrpro/nn/{NDModules.py => ndmodules.py} | 0 src/mrpro/nn/nets/CNN.py | 2 +- src/mrpro/nn/nets/DCAE.py | 2 +- src/mrpro/nn/nets/Restormer.py | 33 ++- src/mrpro/nn/nets/SwinIR.py | 8 +- src/mrpro/nn/nets/Uformer.py | 12 +- tests/nn/test_convert_linear_conv.py | 217 ++++++++++++++++++++ 21 files changed, 419 insertions(+), 95 deletions(-) create mode 100644 src/mrpro/nn/convert_linear_conv.py rename src/mrpro/nn/{NDModules.py => ndmodules.py} (100%) create mode 100644 tests/nn/test_convert_linear_conv.py diff --git a/src/mrpro/nn/AttentionGate.py b/src/mrpro/nn/AttentionGate.py index db1c4aac7..96ebe6cf9 100644 --- a/src/mrpro/nn/AttentionGate.py +++ b/src/mrpro/nn/AttentionGate.py @@ -3,7 +3,7 @@ import torch from torch.nn import Module, ReLU, Sequential, Sigmoid -from mrpro.nn.NDModules import ConvND +from mrpro.nn.ndmodules import ConvND class AttentionGate(Module): diff --git a/src/mrpro/nn/ComplexAsChannel.py b/src/mrpro/nn/ComplexAsChannel.py index 5ce5b02cb..7c1bec0fd 100644 --- a/src/mrpro/nn/ComplexAsChannel.py +++ b/src/mrpro/nn/ComplexAsChannel.py @@ -10,23 +10,27 @@ class ComplexAsChannel(CondMixin, Module): """Wrap module to treat complex numbers as a channel dimension.""" - def __init__(self, module: Module): + def __init__(self, module: Module, convert_back: bool = True): """Initialize the ComplexAsChannel module. Wraps a module to treat complex numbers as a channel dimension. - If called with a complex tensor, real and imaginary parts are concatenated along the channel dimension. - as ``(batch, (channel real/imaginary), ...)``. + For each complex tensor in the input, real and imaginary parts are concatenated along the channel dimension + before being passed to the wrapped module. Parameters ---------- module : Module - The module to wrap. + The module to wrap. Should output a single real tensor. + convert_back : bool + If True, the output is converted back to a complex tensor. + The output should have a number of channels that is a multiple of 2. """ super().__init__() self.module = module + self.convert_back = convert_back - def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the module. Parameters @@ -36,19 +40,20 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T cond : torch.Tensor | None The conditioning tensor (if used by the wrapped module) """ - return super().__call__(x, cond) + return super().__call__(*x, cond=cond) - def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the module.""" - if x.is_complex(): - x_real = torch.view_as_real(x) - x_real = rearrange(x_real, 'batch channel ... complex -> batch (channel complex) ...') - else: - x_real = x + x_real = [ + rearrange(torch.view_as_real(c), 'batch channel ... complex -> batch (channel complex) ...') + if c.is_complex() + else c + for c in x + ] - y = call_with_cond(self.module, x_real, cond) + y = call_with_cond(self.module, *x_real, cond=cond) - if x.is_complex(): + if self.convert_back: y = rearrange(y, 'b (channel complex) ... -> b channel ... complex', complex=2).contiguous() y = torch.view_as_complex(y) return y diff --git a/src/mrpro/nn/CondMixin.py b/src/mrpro/nn/CondMixin.py index f5cb4a4c4..bd87b2a4c 100644 --- a/src/mrpro/nn/CondMixin.py +++ b/src/mrpro/nn/CondMixin.py @@ -4,11 +4,11 @@ from torch.nn import Module -def call_with_cond(module: Module, x: torch.Tensor, cond: torch.Tensor | None) -> torch.Tensor: +def call_with_cond(module: Module, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Call a module with conditioning if it is a CondMixin.""" if isinstance(module, CondMixin): - return module(x, cond) - return module(x) + return module(*x, cond=cond) + return module(*x) class CondMixin(Module): @@ -17,6 +17,6 @@ class CondMixin(Module): Used to determine if a module uses a conditioning within a Sequential container. """ - def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the module to the input.""" - return super().__call__(x, cond) + return super().__call__(*x, cond=cond) diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py index 014aa5835..1c2e3587b 100644 --- a/src/mrpro/nn/FiLM.py +++ b/src/mrpro/nn/FiLM.py @@ -38,7 +38,7 @@ def __init__(self, channels: int, cond_dim: int) -> None: else: self.project = Identity() - def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply FiLM. Parameters @@ -48,12 +48,15 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T cond The conditioning tensor. """ - return super().__call__(x, cond) + if len(x) != 1: + raise ValueError('FiLM expects a single input tensor') + return super().__call__(x[0], cond=cond) def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply FiLM.""" if cond is None: return x scale, shift = self.project(cond).chunk(2, dim=1) + scale, shift = unsqueeze_tensors_right(scale, shift, ndim=x.ndim) return x * (1 + scale) + shift diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py index 9a474d3ff..09f6d7ab6 100644 --- a/src/mrpro/nn/PixelShuffle.py +++ b/src/mrpro/nn/PixelShuffle.py @@ -3,7 +3,7 @@ import torch from torch.nn import Module -from mrpro.nn.NDModules import ConvND +from mrpro.nn.ndmodules import ConvND class PixelUnshuffle(Module): diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py index bc0f4f9cf..878e03553 100644 --- a/src/mrpro/nn/ResBlock.py +++ b/src/mrpro/nn/ResBlock.py @@ -6,7 +6,7 @@ from mrpro.nn.CondMixin import CondMixin from mrpro.nn.FiLM import FiLM from mrpro.nn.GroupNorm import GroupNorm -from mrpro.nn.NDModules import ConvND +from mrpro.nn.ndmodules import ConvND from mrpro.nn.Sequential import Sequential diff --git a/src/mrpro/nn/Residual.py b/src/mrpro/nn/Residual.py index 9a59c4016..e524fe169 100644 --- a/src/mrpro/nn/Residual.py +++ b/src/mrpro/nn/Residual.py @@ -23,7 +23,7 @@ def __init__(self, module: Module, skip: Module | None = None): self.module = module self.skip = Identity() if skip is None else skip - def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the module. Parameters @@ -38,8 +38,8 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T ------- The output tensor. """ - return super().__call__(x, cond) + return super().__call__(*x, cond=cond) - def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the module.""" - return call_with_cond(self.module, x, cond) + call_with_cond(self.skip, x, cond) + return call_with_cond(self.module, *x, cond=cond) + call_with_cond(self.skip, *x, cond=cond) diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py index 99629b5b8..95eb9b555 100644 --- a/src/mrpro/nn/Sequential.py +++ b/src/mrpro/nn/Sequential.py @@ -9,9 +9,13 @@ class Sequential(torch.nn.Sequential): - """Sequential container with support for conditioning and Operators.""" + """Sequential container with support for conditioning and Operators - def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + Allows multiple input tensors and a single output tensor of the sequential block. + + """ + + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply all modules in series to the input. Parameters @@ -25,18 +29,24 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T ------- The output tensor. """ - return super().__call__(x, cond) + return super().__call__(*x, cond=cond) - def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply all modules in series to the input.""" for module in self: - if isinstance(module, CondMixin): - x = module(x, cond) - elif isinstance(module, Operator): - (x,) = module(x) + if isinstance(module, Operator): + x = module(*x) else: - x = module(x) - return x + ret: torch.Tensor | tuple[torch.Tensor, ...] + if isinstance(module, CondMixin): + ret = module(*x, cond=cond) + else: + ret = module(*x) + if isinstance(ret, tuple): + x = ret + else: + x = (ret,) + return x[0] def __getitem__(self, idx: slice | int) -> 'Sequential': """Get a slice or item from the Sequential container. diff --git a/src/mrpro/nn/ShiftedWindowAttention.py b/src/mrpro/nn/ShiftedWindowAttention.py index 61b40351b..f66e2277e 100644 --- a/src/mrpro/nn/ShiftedWindowAttention.py +++ b/src/mrpro/nn/ShiftedWindowAttention.py @@ -4,7 +4,7 @@ from einops import rearrange from torch.nn import Module -from mrpro.nn.NDModules import ConvND +from mrpro.nn.ndmodules import ConvND from mrpro.utils.reshape import ravel_multi_index from mrpro.utils.sliding_window import sliding_window diff --git a/src/mrpro/nn/SqueezeExcitation.py b/src/mrpro/nn/SqueezeExcitation.py index 8dcb87b65..787817173 100644 --- a/src/mrpro/nn/SqueezeExcitation.py +++ b/src/mrpro/nn/SqueezeExcitation.py @@ -3,7 +3,7 @@ import torch from torch.nn import Module, ReLU, Sigmoid -from mrpro.nn.NDModules import AdaptiveAvgPoolND, ConvND +from mrpro.nn.ndmodules import AdaptiveAvgPoolND, ConvND from mrpro.nn.Sequential import Sequential diff --git a/src/mrpro/nn/TransposedAttention.py b/src/mrpro/nn/TransposedAttention.py index 2cbfac17c..043afa750 100644 --- a/src/mrpro/nn/TransposedAttention.py +++ b/src/mrpro/nn/TransposedAttention.py @@ -4,7 +4,7 @@ from einops import rearrange from torch.nn import Module, Parameter -from mrpro.nn.NDModules import ConvND +from mrpro.nn.ndmodules import ConvND class TransposedAttention(Module): diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index 93f243237..368fe925d 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -4,7 +4,7 @@ from mrpro.nn.CondMixin import CondMixin from mrpro.nn.FiLM import FiLM from mrpro.nn.GroupNorm import GroupNorm -from mrpro.nn.NDModules import ( +from mrpro.nn.ndmodules import ( AdaptiveAvgPoolND, AvgPoolND, BatchNormND, diff --git a/src/mrpro/nn/convert_linear_conv.py b/src/mrpro/nn/convert_linear_conv.py new file mode 100644 index 000000000..4a45a9ddf --- /dev/null +++ b/src/mrpro/nn/convert_linear_conv.py @@ -0,0 +1,97 @@ +"""Convert Linear layers to kernel size 1 ConvNd layers and vice versa.""" + +import torch +import torch.nn as nn +from torch.nn import Module, Conv1d, Conv2d, Conv3d, Linear +from mrpro.nn.ndmodules import ConvND +from typing import Literal, overload + + +@overload +def linear_to_conv(linear_layer: Linear, dim: Literal[1]) -> Conv1d: ... + + +@overload +def linear_to_conv(linear_layer: Linear, dim: Literal[2]) -> Conv2d: ... + + +@overload +def linear_to_conv(linear_layer: Linear, dim: Literal[3]) -> Conv3d: ... + + +@overload +def linear_to_conv(linear_layer: Linear, dim: int) -> Conv1d | Conv2d | Conv3d: ... + + +def linear_to_conv(linear_layer: Linear, dim: int) -> Conv1d | Conv2d | Conv3d: + """Convert a Linear layer to a ConvNd layer with kernel size 1. + + Rearranging the spatial dimensions to the batch dimension, applying the linear layer and rearranging the spatial dimensions back + it equivalent to applying the a kernel size 1 ConvNd layer. + This function will create the ConvNd with the correct weights and bias. + + See :func:`conv_to_linear` for the reverse operation. + + + + Parameters + ---------- + linear_layer : nn.Linear + The linear layer to convert. + dim : int + The convolution dimension (1, 2, or 3). + + Returns + ------- + A Conv layer with equivalent weights and bias. + """ + conv = ConvND(dim)( + in_channels=linear_layer.in_features, + out_channels=linear_layer.out_features, + kernel_size=1, + bias=linear_layer.bias is not None, + device=linear_layer.weight.device, + dtype=linear_layer.weight.dtype, + ) + + with torch.no_grad(): + conv.weight.copy_(linear_layer.weight.view_as(conv.weight)) + if conv.bias is not None and linear_layer.bias is not None: + conv.bias.copy_(linear_layer.bias) + + return conv + + +def conv_to_linear(conv_layer: Conv1d | Conv2d | Conv3d) -> Linear: + """ + Convert a Conv1d, Conv2d, or Conv3d layer with kernel size 1 to a Linear layer. + + Applying a kernel size 1 ConvNd layer is equivalent to applying a Linear layer to each voxel. + This function will create the Linear layer with the correct weights and bias. + + See :func:`linear_to_conv` for the reverse operation. + + Parameters + ---------- + conv_layer : nn.Module + The convolutional layer to convert. Must have kernel size 1. + + Returns + ------- + A linear layer with equivalent weights and bias. + """ + if not all(k == 1 for k in conv_layer.kernel_size): + raise ValueError('Kernel size must be 1 for conversion.') + linear = Linear( + conv_layer.weight.shape[0], + conv_layer.weight.shape[1], + bias=conv_layer.bias is not None, + device=conv_layer.weight.device, + dtype=conv_layer.weight.dtype, + ) + with torch.no_grad(): + linear.weight.copy_(conv_layer.weight.view_as(linear.weight)) + if linear.bias is not None and conv_layer.bias is not None: + linear.bias.copy_(conv_layer.bias) + + return linear diff --git a/src/mrpro/nn/join.py b/src/mrpro/nn/join.py index 22baf70b0..5802902a6 100644 --- a/src/mrpro/nn/join.py +++ b/src/mrpro/nn/join.py @@ -1,4 +1,7 @@ -from typing import Literal, Sequence +"""Modules for concatenating or adding tensors.""" + +from collections.abc import Sequence +from typing import Literal import torch from torch.nn import Module @@ -6,9 +9,10 @@ from mrpro.utils.pad_or_crop import pad_or_crop -def fix_shapes( +def _fix_shapes( xs: Sequence[torch.Tensor], mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'], dim: Sequence[int] ) -> tuple[torch.Tensor, ...]: + """Fix shapes of input tensors by padding or cropping.""" if mode == 'fail': return tuple(xs) @@ -17,27 +21,16 @@ def fix_shapes( target = tuple(min(s) for s in zip(*shapes, strict=True)) else: target = tuple(max(s) for s in zip(*shapes, strict=True)) - if mode in ('crop', 'zero'): - mode = 'constant' - return tuple(pad_or_crop(x, target, dim=dim, mode=mode) for x in xs) - - # # def pad(x) -> torch.Tensor: - # # if x.shape[2:] == target: - # # return x - # # pad = [] - # # for cur, tgt in zip(reversed(x.shape[2:]), reversed(target), strict=True): - # # left = (tgt - cur) // 2 - # # right = tgt - cur - left - # # pad.extend([left, right]) - # # return torch.nn.functional.pad(x, pad, mode=mode) - - # return tuple(pad(x) for x in xs) + if mode == 'zero' or mode == 'crop': + return tuple(pad_or_crop(x, target, dim=dim, mode='constant', value=0.0) for x in xs) + else: + return tuple(pad_or_crop(x, target, dim=dim, mode=mode) for x in xs) class Concat(Module): - """Concatenate tensors along the channel dimension""" + """Concatenate tensors along the channel dimension.""" - def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'] = 'zero', dim: int = 1) -> None: + def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'] = 'fail', dim: int = 1) -> None: """Initialize Concat. Parameters @@ -60,14 +53,15 @@ def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular' self.dim = dim def forward(self, *xs: torch.Tensor) -> torch.Tensor: - xs = fix_shapes(xs, self.mode, dim=[i for i in range(max(x.ndim for x in xs)) if i != self.dim]) + """Concatenate input tensors.""" + xs = _fix_shapes(xs, self.mode, dim=[i for i in range(max(x.ndim for x in xs)) if i != self.dim]) return torch.cat(xs, dim=1) class Add(Module): - """Add tensors""" + """Add tensors.""" - def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'] = 'zero') -> None: + def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'] = 'fail') -> None: """Initialize Add. Parameters @@ -87,5 +81,6 @@ def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular' self.mode = mode def forward(self, *xs: torch.Tensor) -> torch.Tensor: - xs = fix_shapes(xs, self.mode, dim=range(max(x.ndim for x in xs))) + """Add input tensors.""" + xs = _fix_shapes(xs, self.mode, dim=range(max(x.ndim for x in xs))) return sum(xs, start=torch.tensor(0.0)) diff --git a/src/mrpro/nn/NDModules.py b/src/mrpro/nn/ndmodules.py similarity index 100% rename from src/mrpro/nn/NDModules.py rename to src/mrpro/nn/ndmodules.py diff --git a/src/mrpro/nn/nets/CNN.py b/src/mrpro/nn/nets/CNN.py index 6a4ade796..3fbabcc4f 100644 --- a/src/mrpro/nn/nets/CNN.py +++ b/src/mrpro/nn/nets/CNN.py @@ -7,7 +7,7 @@ from mrpro.nn.FiLM import FiLM from mrpro.nn.GroupNorm import GroupNorm -from mrpro.nn.NDModules import ConvND +from mrpro.nn.ndmodules import ConvND from mrpro.nn.Residual import Residual from mrpro.nn.Sequential import Sequential diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py index c54cd838a..a66a26bb0 100644 --- a/src/mrpro/nn/nets/DCAE.py +++ b/src/mrpro/nn/nets/DCAE.py @@ -9,7 +9,7 @@ from mrpro.nn.GluMBConvResBlock import GluMBConvResBlock from mrpro.nn.LinearSelfAttention import LinearSelfAttention from mrpro.nn.MultiHeadAttention import MultiHeadAttention -from mrpro.nn.NDModules import ConvND +from mrpro.nn.ndmodules import ConvND from mrpro.nn.nets.VAE import VAE from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample from mrpro.nn.Residual import Residual diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index 067210054..26eea7dc4 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -7,8 +7,10 @@ from torch.nn import Identity, Module from mrpro.nn.FiLM import FiLM -from mrpro.nn.NDModules import ConvND, InstanceNormND +from mrpro.nn.join import Concat +from mrpro.nn.ndmodules import ConvND, InstanceNormND from mrpro.nn.nets.UNet import UNetBase +from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample from mrpro.nn.Sequential import Sequential from mrpro.nn.TransposedAttention import TransposedAttention @@ -170,36 +172,27 @@ def blocks(n_heads: int, n_blocks: int): layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, cond_dim=cond_dim)) return layers - for block, head in zip(n_blocks, n_heads, strict=False): + for block, head in zip(n_blocks[:-1], n_heads[:-1], strict=True): self.input_blocks.append(blocks(head, block)) self.output_blocks.append(blocks(head, block)) + self.skip_blocks.append(Identity()) + self.concat_blocks.append(Concat()) + self.middle_block = blocks(n_heads[-1], n_blocks[-1]) self.output_blocks = self.output_blocks[::-1] for head_current, head_next in pairwise(n_heads): self.down_blocks.append( - Sequential( - ConvND(dim)( - n_channels_per_head * head_current, - n_channels_per_head * head_next, - kernel_size=3, - stride=2, - padding=1, - ) - ) + PixelUnshuffleDownsample(dim, n_channels_per_head * head_current, n_channels_per_head * head_next) ) + self.up_blocks.append( - Sequential( - ConvND(dim)( - n_channels_per_head * head_next, - n_channels_per_head * head_current, - kernel_size=3, - stride=1, - padding=1, - ) - ) + PixelShuffleUpsample(dim, n_channels_per_head * head_next, n_channels_per_head * head_current) ) + self.output_blocks = self.input_blocks[::-1] + self.up_blocks = self.up_blocks[::-1] + self.concat_blocks = self.concat_blocks[::-1] self.refinement_blocks = Sequential( *(RestormerBlock(dim, n_channels_per_head, n_heads[0], mlp_ratio) for _ in range(n_refinement_blocks)) ) diff --git a/src/mrpro/nn/nets/SwinIR.py b/src/mrpro/nn/nets/SwinIR.py index b86a02e9a..2269ecfb5 100644 --- a/src/mrpro/nn/nets/SwinIR.py +++ b/src/mrpro/nn/nets/SwinIR.py @@ -5,7 +5,7 @@ from mrpro.nn.DropPath import DropPath from mrpro.nn.FiLM import FiLM -from mrpro.nn.NDModules import ConvND, InstanceNormND +from mrpro.nn.ndmodules import ConvND, InstanceNormND from mrpro.nn.Sequential import Sequential from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention @@ -78,7 +78,7 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the Swin Transformer layer.""" x = x + self.attn(self.norm1(x)) - x = x + self.mlp(self.norm2(x, cond)) + x = x + self.mlp(self.norm2(x, cond=cond)) return x @@ -151,7 +151,7 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the residual Swin Transformer block.""" - return x + self.conv(self.layers(x, cond)) + return x + self.conv(self.layers(x, cond=cond)) class SwinIR(Module): @@ -242,6 +242,6 @@ def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Te Output tensor """ x = self.first(x) - x = self.blocks(x, cond) + x = self.blocks(x, cond=cond) x = self.last(x) return x diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index 3b567350a..5f53754c4 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -8,7 +8,8 @@ from mrpro.nn.DropPath import DropPath from mrpro.nn.FiLM import FiLM -from mrpro.nn.NDModules import ConvND, ConvTransposeND, InstanceNormND +from mrpro.nn.join import Concat +from mrpro.nn.ndmodules import ConvND, ConvTransposeND, InstanceNormND from mrpro.nn.nets.UNet import UNetBase from mrpro.nn.Sequential import Sequential from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention @@ -135,7 +136,7 @@ class Uformer(UNetBase): """Uformer: U-Net with window attention. Implements the Uformer network proposed in [WANG21]_ - It is SWIN/U-Net hybrid consisting of (shifted) windows attention transformer layers at different + It is SWin-Transformer/U-Net hybrid consisting of (shifted) windows attention transformer layers at different resolution levels, extended by FiLM layers for conditioning. References @@ -208,11 +209,11 @@ def blocks(n_heads: int, p_droppath: float = 0.0): return layers drop_path_rates = torch.linspace(0, max_droppath_rate, len(n_heads)).tolist() - for n_head, p_droppath_input in zip(n_heads, drop_path_rates, strict=True): + for n_head, p_droppath_input in zip(n_heads[:-1], drop_path_rates[:-1], strict=True): self.input_blocks.append(blocks(n_heads=n_head, p_droppath=p_droppath_input)) self.output_blocks.append(blocks(n_heads=n_head, p_droppath=max_droppath_rate)) self.skip_blocks.append(Identity()) - self.output_blocks = self.output_blocks[::-1] + self.concat_blocks.append(Concat()) self.middle_block = blocks(n_heads=n_heads[-1], p_droppath=max_droppath_rate) for n_head_current, n_head_next in pairwise(n_heads): @@ -230,6 +231,9 @@ def blocks(n_heads: int, p_droppath: float = 0.0): n_channels_per_head * n_head_next, n_channels_per_head * n_head_current, kernel_size=2, stride=2 ) ) + self.output_blocks = self.output_blocks[::-1] + self.up_blocks = self.up_blocks[::-1] + self.first = torch.nn.Sequential( ConvND(dim)(channels_in, n_channels_per_head * n_heads[0], kernel_size=3, stride=1, padding='same'), LeakyReLU(), diff --git a/tests/nn/test_convert_linear_conv.py b/tests/nn/test_convert_linear_conv.py new file mode 100644 index 000000000..adcab5711 --- /dev/null +++ b/tests/nn/test_convert_linear_conv.py @@ -0,0 +1,217 @@ +"""Tests for converting between Linear and Conv layers.""" + +from typing import Literal + +import pytest +import torch +from mrpro.nn.convert_linear_conv import conv_to_linear, linear_to_conv +from mrpro.utils import RandomGenerator +from torch.nn import Conv1d, Conv2d, Conv3d, Linear + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'in_channels', 'out_channels', 'bias'), + [ + (1, 32, 64, True), + (2, 16, 32, True), + (3, 8, 16, True), + (3, 1, 1, False), + ], +) +def test_linear_to_conv(device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test converting Linear to Conv layer.""" + rng = RandomGenerator(seed=42) + linear = Linear(channels_in, channels_out, bias=bias).to(device) + linear.weight.data = rng.rand_like(linear.weight) + if bias: + linear.bias.data = rng.rand_like(linear.bias) + + conv = linear_to_conv(linear, dim) + assert isinstance(conv, (Conv1d, Conv2d, Conv3d)[dim - 1]) + + assert conv.in_channels == channels_in + assert conv.out_channels == channels_out + assert conv.kernel_size == (1,) * dim + assert conv.bias is not None if bias else conv.bias is None + + assert conv.weight.device == device + if conv.bias is not None: + assert conv.bias.device == device + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'in_channels', 'out_channels', 'bias'), + [ + (1, 32, 64, True), + (2, 16, 32, True), + (3, 8, 16, True), + (3, 1, 1, False), + ], +) +def test_linear_to_conv_functional( + device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool +) -> None: + """Test functional equivalence of Linear to Conv conversion.""" + rng = RandomGenerator(seed=42) + linear = Linear(channels_in, channels_out, bias=bias).to(device) + linear.weight.data = rng.rand_like(linear.weight) + if bias: + linear.bias.data = rng.rand_like(linear.bias) + + conv = linear_to_conv(linear, dim) + spatial_shape = (4,) * dim + x = rng.randn_tensor((2, channels_in, *spatial_shape), torch.float32).to(device) + + y_conv = conv(x) + y_conv = y_conv.moveaxis(1, -1).flatten(0, -2) + + x_reshaped = x.moveaxis(1, -1).flatten(0, -2) + y_linear = linear(x_reshaped) + + assert torch.allclose(y_conv, y_linear) + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'in_channels', 'out_channels', 'bias'), + [ + (1, 32, 64, True), + (2, 16, 32, True), + (3, 8, 16, True), + (3, 1, 1, False), + ], +) +def test_conv_to_linear(device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test converting Conv layer to Linear.""" + rng = RandomGenerator(seed=42) + conv_class = (Conv1d, Conv2d, Conv3d)[dim - 1] + conv = conv_class(channels_in, channels_out, kernel_size=1, bias=bias).to(device) + conv.weight.data = rng.rand_like(conv.weight) + if conv.bias is not None: + conv.bias.data = rng.rand_like(conv.bias) + + linear = conv_to_linear(conv) + + assert isinstance(linear, Linear) + assert linear.in_features == channels_in + assert linear.out_features == channels_out + assert linear.bias is not None if bias else linear.bias is None + + assert linear.weight.device == device + if bias: + assert linear.bias.device == device + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'channels_in', 'channels_out', 'bias'), + [ + (1, 32, 64, True), + (2, 16, 32, True), + (3, 8, 16, True), + (3, 1, 1, False), + ], +) +def test_conv_to_linear_functional( + device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool +) -> None: + """Test functional equivalence of Conv to Linear conversion.""" + rng = RandomGenerator(seed=42) + conv_class = (Conv1d, Conv2d, Conv3d)[dim - 1] + conv = conv_class(channels_in, channels_out, kernel_size=1, bias=bias).to(device) + conv.weight.data = rng.rand_like(conv.weight) + if conv.bias is not None: + conv.bias.data = rng.rand_like(conv.bias) + + linear = conv_to_linear(conv) + + # Create input tensor with spatial dimensions + spatial_shape = (4,) * dim + x = rng.randn_tensor((2, channels_in, *spatial_shape), torch.float32).to(device) + + # Apply conv layer + y_conv = conv(x) + + # Reshape input for linear layer + x_reshaped = x.flatten(0, -2) # Flatten all dimensions except last + y_linear = linear(x_reshaped) + y_linear = y_linear.view(2, channels_out, *spatial_shape) + + # Compare outputs + assert torch.allclose(y_conv, y_linear) + + +def test_conv_to_linear_invalid_kernel(): + """Test conv_to_linear with invalid kernel size.""" + conv = Conv2d(32, 64, kernel_size=3, bias=True) + with pytest.raises(ValueError, match='Kernel size must be 1'): + conv_to_linear(conv) + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'channels_in', 'channels_out', 'bias'), + [ + (1, 32, 64, True), + (2, 16, 32, True), + (3, 8, 16, True), + (3, 1, 1, False), + ], +) +def test_round_trip_conversion( + device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool +) -> None: + """Test round-trip conversion between Linear and Conv layers.""" + rng = RandomGenerator(seed=42) + + linear1 = Linear(channels_in, channels_out, bias=bias).to(device) + linear1.weight.data = rng.rand_like(linear1.weight) + if bias: + linear1.bias.data = rng.rand_like(linear1.bias) + + conv = linear_to_conv(linear1, dim) + linear2 = conv_to_linear(conv) + + assert linear2.in_features == channels_in + assert linear2.out_features == channels_out + assert linear2.bias is not None if bias else linear2.bias is None + + assert torch.allclose(linear2.weight, linear1.weight) + if bias: + assert torch.allclose(linear2.bias, linear1.bias) + + assert linear2.weight.device == device + if bias: + assert linear2.bias.device == device From 4f6a603a53e05323be4219c049c84a479e139f6d Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 21 May 2025 18:19:46 +0200 Subject: [PATCH 042/205] Enhance conversion functions between Linear and Conv layers by refining method signatures and improving test structure. Update parameter names for clarity and ensure compatibility with multiple input tensors. --- src/mrpro/nn/GluMBConvResBlock.py | 2 +- src/mrpro/nn/Sequential.py | 2 +- src/mrpro/nn/convert_linear_conv.py | 19 ++-- src/mrpro/nn/encoding.py | 6 +- src/mrpro/nn/nets/DCAE.py | 1 + src/mrpro/nn/nets/UNet.py | 7 +- tests/nn/test_convert_linear_conv.py | 127 +++++++-------------------- tests/nn/test_sequential.py | 8 +- 8 files changed, 55 insertions(+), 117 deletions(-) diff --git a/src/mrpro/nn/GluMBConvResBlock.py b/src/mrpro/nn/GluMBConvResBlock.py index f48883125..4940a966e 100644 --- a/src/mrpro/nn/GluMBConvResBlock.py +++ b/src/mrpro/nn/GluMBConvResBlock.py @@ -5,7 +5,7 @@ from mrpro.nn.CondMixin import CondMixin from mrpro.nn.FiLM import FiLM -from mrpro.nn.NDModules import ConvND +from mrpro.nn.ndmodules import ConvND from mrpro.nn.RMSNorm import RMSNorm diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py index 95eb9b555..aaad42b52 100644 --- a/src/mrpro/nn/Sequential.py +++ b/src/mrpro/nn/Sequential.py @@ -9,7 +9,7 @@ class Sequential(torch.nn.Sequential): - """Sequential container with support for conditioning and Operators + """Sequential container with support for conditioning and Operators. Allows multiple input tensors and a single output tensor of the sequential block. diff --git a/src/mrpro/nn/convert_linear_conv.py b/src/mrpro/nn/convert_linear_conv.py index 4a45a9ddf..a6dac5f33 100644 --- a/src/mrpro/nn/convert_linear_conv.py +++ b/src/mrpro/nn/convert_linear_conv.py @@ -1,10 +1,11 @@ """Convert Linear layers to kernel size 1 ConvNd layers and vice versa.""" +from typing import Literal, overload + import torch -import torch.nn as nn -from torch.nn import Module, Conv1d, Conv2d, Conv3d, Linear +from torch.nn import Conv1d, Conv2d, Conv3d, Linear + from mrpro.nn.ndmodules import ConvND -from typing import Literal, overload @overload @@ -26,9 +27,11 @@ def linear_to_conv(linear_layer: Linear, dim: int) -> Conv1d | Conv2d | Conv3d: def linear_to_conv(linear_layer: Linear, dim: int) -> Conv1d | Conv2d | Conv3d: """Convert a Linear layer to a ConvNd layer with kernel size 1. - Rearranging the spatial dimensions to the batch dimension, applying the linear layer and rearranging the spatial dimensions back - it equivalent to applying the a kernel size 1 ConvNd layer. - This function will create the ConvNd with the correct weights and bias. + Rearranging the spatial dimensions to the batch dimension, + applying the linear layer and rearranging the spatial dimensions back + is equivalent to applying a kernel size 1 ConvNd layer. + + This function will create the Conv1d, Conv2d, or Conv3d with the correct weights and bias. See :func:`conv_to_linear` for the reverse operation. @@ -83,8 +86,8 @@ def conv_to_linear(conv_layer: Conv1d | Conv2d | Conv3d) -> Linear: if not all(k == 1 for k in conv_layer.kernel_size): raise ValueError('Kernel size must be 1 for conversion.') linear = Linear( - conv_layer.weight.shape[0], - conv_layer.weight.shape[1], + conv_layer.in_channels, + conv_layer.out_channels, bias=conv_layer.bias is not None, device=conv_layer.weight.device, dtype=conv_layer.weight.dtype, diff --git a/src/mrpro/nn/encoding.py b/src/mrpro/nn/encoding.py index 9828b2cf0..679f38685 100644 --- a/src/mrpro/nn/encoding.py +++ b/src/mrpro/nn/encoding.py @@ -28,7 +28,7 @@ def __init__(self, in_features: int, out_features: int, std: float = 1.0): out_features : int Number of output features (must be even) std : float, optional - Standard deviation for random initialization, by default 1.0 + Standard deviation for random initialization """ if out_features % 2 != 0: raise ValueError('out_features must be even.') @@ -70,9 +70,9 @@ def __init__(self, dim: int, features: int, include_radii: bool = True, base_res features : int Number of output features include_radii : bool, optional - Whether to include radius features, by default True + Whether to include radius features base_resolution : int, optional - Base resolution for position encoding, by default 128 + Base resolution for position encoding """ super().__init__() diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py index a66a26bb0..d6fdbfb00 100644 --- a/src/mrpro/nn/nets/DCAE.py +++ b/src/mrpro/nn/nets/DCAE.py @@ -257,6 +257,7 @@ def __init__( widths: Sequence[int] = (256, 512, 512, 1024, 1024), depths: Sequence[int] = (4, 6, 2, 2, 2), ): + """Initialize the DCVAE.""" encoder = Encoder(dim, channels, latent_dim * 2, block_types, widths, depths) decoder = Decoder(dim, latent_dim, channels, block_types[::-1], widths[::-1], depths[::-1]) super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index a4e88198c..599b9b4da 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -13,6 +13,7 @@ class UNetBase(Module): """Base class for U-shaped networks.""" def __init__(self) -> None: + """Initialize the UNetBase.""" super().__init__() self.input_blocks = ModuleList() """The encoder blocks. Order is highest resolution to lowest resolution.""" @@ -98,9 +99,9 @@ def __init__( cond_dim: int, num_blocks: int, padding_modes: str | Sequence[str], - ) -> None: ... - - """Initialize the UNet.""" + ) -> None: + """Initialize the UNet.""" + super().__init__() class AttentionUNet(UNet): diff --git a/tests/nn/test_convert_linear_conv.py b/tests/nn/test_convert_linear_conv.py index adcab5711..c977f0936 100644 --- a/tests/nn/test_convert_linear_conv.py +++ b/tests/nn/test_convert_linear_conv.py @@ -8,23 +8,27 @@ from mrpro.utils import RandomGenerator from torch.nn import Conv1d, Conv2d, Conv3d, Linear - -@pytest.mark.parametrize( +DEVICES = pytest.mark.parametrize( 'device', [ pytest.param('cpu', id='cpu'), pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), ], ) -@pytest.mark.parametrize( - ('dim', 'in_channels', 'out_channels', 'bias'), +SHAPES = pytest.mark.parametrize( + ('dim', 'channels_in', 'channels_out', 'bias'), [ (1, 32, 64, True), (2, 16, 32, True), (3, 8, 16, True), (3, 1, 1, False), ], + ids=['1d', '2d', '3d', '3d_no_bias'], ) + + +@SHAPES +@DEVICES def test_linear_to_conv(device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: """Test converting Linear to Conv layer.""" rng = RandomGenerator(seed=42) @@ -41,40 +45,23 @@ def test_linear_to_conv(device: str, dim: Literal[1, 2, 3], channels_in: int, ch assert conv.kernel_size == (1,) * dim assert conv.bias is not None if bias else conv.bias is None - assert conv.weight.device == device + assert conv.weight.device.type == device if conv.bias is not None: - assert conv.bias.device == device + assert conv.bias.device.type == device -@pytest.mark.parametrize( - 'device', - [ - pytest.param('cpu', id='cpu'), - pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), - ], -) -@pytest.mark.parametrize( - ('dim', 'in_channels', 'out_channels', 'bias'), - [ - (1, 32, 64, True), - (2, 16, 32, True), - (3, 8, 16, True), - (3, 1, 1, False), - ], -) -def test_linear_to_conv_functional( - device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool -) -> None: +@SHAPES +def test_linear_to_conv_functional(dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: """Test functional equivalence of Linear to Conv conversion.""" rng = RandomGenerator(seed=42) - linear = Linear(channels_in, channels_out, bias=bias).to(device) + linear = Linear(channels_in, channels_out, bias=bias) linear.weight.data = rng.rand_like(linear.weight) if bias: linear.bias.data = rng.rand_like(linear.bias) conv = linear_to_conv(linear, dim) spatial_shape = (4,) * dim - x = rng.randn_tensor((2, channels_in, *spatial_shape), torch.float32).to(device) + x = rng.randn_tensor((2, channels_in, *spatial_shape), torch.float32) y_conv = conv(x) y_conv = y_conv.moveaxis(1, -1).flatten(0, -2) @@ -82,25 +69,11 @@ def test_linear_to_conv_functional( x_reshaped = x.moveaxis(1, -1).flatten(0, -2) y_linear = linear(x_reshaped) - assert torch.allclose(y_conv, y_linear) + torch.testing.assert_close(y_conv, y_linear) -@pytest.mark.parametrize( - 'device', - [ - pytest.param('cpu', id='cpu'), - pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), - ], -) -@pytest.mark.parametrize( - ('dim', 'in_channels', 'out_channels', 'bias'), - [ - (1, 32, 64, True), - (2, 16, 32, True), - (3, 8, 16, True), - (3, 1, 1, False), - ], -) +@SHAPES +@DEVICES def test_conv_to_linear(device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: """Test converting Conv layer to Linear.""" rng = RandomGenerator(seed=42) @@ -117,54 +90,32 @@ def test_conv_to_linear(device: str, dim: Literal[1, 2, 3], channels_in: int, ch assert linear.out_features == channels_out assert linear.bias is not None if bias else linear.bias is None - assert linear.weight.device == device + assert linear.weight.device.type == device if bias: - assert linear.bias.device == device + assert linear.bias.device.type == device -@pytest.mark.parametrize( - 'device', - [ - pytest.param('cpu', id='cpu'), - pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), - ], -) -@pytest.mark.parametrize( - ('dim', 'channels_in', 'channels_out', 'bias'), - [ - (1, 32, 64, True), - (2, 16, 32, True), - (3, 8, 16, True), - (3, 1, 1, False), - ], -) -def test_conv_to_linear_functional( - device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool -) -> None: +@SHAPES +def test_conv_to_linear_functional(dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: """Test functional equivalence of Conv to Linear conversion.""" rng = RandomGenerator(seed=42) conv_class = (Conv1d, Conv2d, Conv3d)[dim - 1] - conv = conv_class(channels_in, channels_out, kernel_size=1, bias=bias).to(device) + conv = conv_class(channels_in, channels_out, kernel_size=1, bias=bias) conv.weight.data = rng.rand_like(conv.weight) if conv.bias is not None: conv.bias.data = rng.rand_like(conv.bias) linear = conv_to_linear(conv) - - # Create input tensor with spatial dimensions spatial_shape = (4,) * dim - x = rng.randn_tensor((2, channels_in, *spatial_shape), torch.float32).to(device) - # Apply conv layer + x = rng.randn_tensor((2, channels_in, *spatial_shape), torch.float32) y_conv = conv(x) + y_conv = y_conv.moveaxis(1, -1).flatten(0, -2) - # Reshape input for linear layer - x_reshaped = x.flatten(0, -2) # Flatten all dimensions except last + x_reshaped = x.moveaxis(1, -1).flatten(0, -2) y_linear = linear(x_reshaped) - y_linear = y_linear.view(2, channels_out, *spatial_shape) - # Compare outputs - assert torch.allclose(y_conv, y_linear) + torch.testing.assert_close(y_conv, y_linear) def test_conv_to_linear_invalid_kernel(): @@ -174,22 +125,8 @@ def test_conv_to_linear_invalid_kernel(): conv_to_linear(conv) -@pytest.mark.parametrize( - 'device', - [ - pytest.param('cpu', id='cpu'), - pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), - ], -) -@pytest.mark.parametrize( - ('dim', 'channels_in', 'channels_out', 'bias'), - [ - (1, 32, 64, True), - (2, 16, 32, True), - (3, 8, 16, True), - (3, 1, 1, False), - ], -) +@SHAPES +@DEVICES def test_round_trip_conversion( device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool ) -> None: @@ -208,10 +145,6 @@ def test_round_trip_conversion( assert linear2.out_features == channels_out assert linear2.bias is not None if bias else linear2.bias is None - assert torch.allclose(linear2.weight, linear1.weight) - if bias: - assert torch.allclose(linear2.bias, linear1.bias) - - assert linear2.weight.device == device + torch.testing.assert_close(linear2.weight, linear1.weight) if bias: - assert linear2.bias.device == device + torch.testing.assert_close(linear2.bias, linear1.bias) diff --git a/tests/nn/test_sequential.py b/tests/nn/test_sequential.py index c52b0ed2a..59e7dade9 100644 --- a/tests/nn/test_sequential.py +++ b/tests/nn/test_sequential.py @@ -15,23 +15,23 @@ ], ) @pytest.mark.parametrize( - ('input_shape', 'emb_shape'), + ('input_shape', 'cond_dim'), [ ((1, 32), (1, 16)), ((2, 64), None), ], ) -def test_sequential(input_shape, emb_shape, device): +def test_sequential(input_shape, cond_dim, device): """Test Sequential output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) - emb = rng.float32_tensor(emb_shape).to(device).requires_grad_(True) if emb_shape else None + cond = rng.float32_tensor(cond_dim).to(device).requires_grad_(True) if cond_dim else None seq = Sequential( Linear(input_shape[1], 64), FastFourierOp(), FiLM(channels=64, cond_dim=16), ).to(device) - output = seq(x, emb) + output = seq(x, cond) assert output.shape == (input_shape[0], 32), f'Output shape {output.shape} != expected {(input_shape[0], 32)}' output.sum().backward() assert x.grad is not None, 'No gradient computed for input' From 7d608b5de0efaf063a36b137ac112106146f02cd Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 22 May 2025 00:04:31 +0200 Subject: [PATCH 043/205] Refactor method signatures in neural network modules to use keyword-only arguments for conditioning tensors. Update import statements for clarity and add new SwinIR network to the module. Enhance VAE with a mode method for improved functionality. --- src/mrpro/nn/CondMixin.py | 4 ++-- src/mrpro/nn/FiLM.py | 6 +++--- src/mrpro/nn/GluMBConvResBlock.py | 2 +- src/mrpro/nn/ResBlock.py | 8 ++++---- src/mrpro/nn/__init__.py | 2 +- src/mrpro/nn/nets/SwinIR.py | 2 +- src/mrpro/nn/nets/Uformer.py | 21 ++++++++++++++------- src/mrpro/nn/nets/VAE.py | 6 ++++++ src/mrpro/nn/nets/__init__.py | 3 ++- tests/nn/test_film.py | 16 ++++++++-------- tests/nn/test_resblock.py | 16 ++++++++-------- tests/nn/test_sequential.py | 2 +- 12 files changed, 51 insertions(+), 37 deletions(-) diff --git a/src/mrpro/nn/CondMixin.py b/src/mrpro/nn/CondMixin.py index bd87b2a4c..6a902c413 100644 --- a/src/mrpro/nn/CondMixin.py +++ b/src/mrpro/nn/CondMixin.py @@ -17,6 +17,6 @@ class CondMixin(Module): Used to determine if a module uses a conditioning within a Sequential container. """ - def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the module to the input.""" - return super().__call__(*x, cond=cond) + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py index 1c2e3587b..e6d101260 100644 --- a/src/mrpro/nn/FiLM.py +++ b/src/mrpro/nn/FiLM.py @@ -38,7 +38,7 @@ def __init__(self, channels: int, cond_dim: int) -> None: else: self.project = Identity() - def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply FiLM. Parameters @@ -50,9 +50,9 @@ def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch. """ if len(x) != 1: raise ValueError('FiLM expects a single input tensor') - return super().__call__(x[0], cond=cond) + return super().__call__(x, cond=cond) - def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply FiLM.""" if cond is None: return x diff --git a/src/mrpro/nn/GluMBConvResBlock.py b/src/mrpro/nn/GluMBConvResBlock.py index 4940a966e..02b623dfa 100644 --- a/src/mrpro/nn/GluMBConvResBlock.py +++ b/src/mrpro/nn/GluMBConvResBlock.py @@ -97,6 +97,6 @@ def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Te h, gate = torch.chunk(h, 2, dim=1) h = h * torch.nn.functional.silu(gate) if self.film is not None: - h = self.film(h, cond) + h = self.film(h, cond=cond) h = self.point_conv(h) return self.skip(x) + h diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py index 878e03553..0897eb52c 100644 --- a/src/mrpro/nn/ResBlock.py +++ b/src/mrpro/nn/ResBlock.py @@ -47,7 +47,7 @@ def __init__(self, dim: int, channels_in: int, channels_out: int, cond_dim: int) else: self.skip_connection = ConvND(dim)(channels_in, channels_out, kernel_size=1) - def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the ResBlock. Parameters @@ -61,10 +61,10 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T ------- The output tensor. """ - return super().__call__(x, cond) + return super().__call__(x, cond=cond) - def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the ResBlock.""" - h = self.block(x, cond) + h = self.block(x, cond=cond) x = self.skip_connection(x) + h return x diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index 368fe925d..e59e4efde 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -20,7 +20,7 @@ from mrpro.nn.SqueezeExcitation import SqueezeExcitation from mrpro.nn.TransposedAttention import TransposedAttention from mrpro.nn.DropPath import DropPath -import mrpro.nn.nets +from mrpro.nn import nets __all__ = [ "AdaptiveAvgPoolND", "AttentionGate", diff --git a/src/mrpro/nn/nets/SwinIR.py b/src/mrpro/nn/nets/SwinIR.py index 2269ecfb5..e3e8a440a 100644 --- a/src/mrpro/nn/nets/SwinIR.py +++ b/src/mrpro/nn/nets/SwinIR.py @@ -47,7 +47,7 @@ def __init__( """ super().__init__() self.norm1 = InstanceNormND(dim)(channels) - self.attn = ShiftedWindowAttention(dim, channels, n_heads, window_size) + self.attn = ShiftedWindowAttention(dim, channels, channels, n_heads, window_size) self.norm2 = Sequential(InstanceNormND(dim)(channels)) if emb_dim > 0: self.norm2.append(FiLM(channels=channels, cond_dim=emb_dim)) diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index 5f53754c4..12e60a8cf 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -42,7 +42,7 @@ def __init__( Expansion ratio of the hidden dimension """ super().__init__() - hidden_dim = int(dim * expand_ratio) + hidden_dim = int(channels_in * expand_ratio) self.block = Sequential( ConvND(dim)(channels_in, hidden_dim, 1), GELU(), @@ -211,10 +211,10 @@ def blocks(n_heads: int, p_droppath: float = 0.0): drop_path_rates = torch.linspace(0, max_droppath_rate, len(n_heads)).tolist() for n_head, p_droppath_input in zip(n_heads[:-1], drop_path_rates[:-1], strict=True): self.input_blocks.append(blocks(n_heads=n_head, p_droppath=p_droppath_input)) - self.output_blocks.append(blocks(n_heads=n_head, p_droppath=max_droppath_rate)) + self.output_blocks.append(blocks(n_heads=2 * n_head, p_droppath=max_droppath_rate)) self.skip_blocks.append(Identity()) self.concat_blocks.append(Concat()) - self.middle_block = blocks(n_heads=n_heads[-1], p_droppath=max_droppath_rate) + self.output_blocks = self.output_blocks[::-1] for n_head_current, n_head_next in pairwise(n_heads): self.down_blocks.append( @@ -226,18 +226,25 @@ def blocks(n_heads: int, p_droppath: float = 0.0): padding=1, ) ) + + self.middle_block = blocks(n_heads=n_heads[-1], p_droppath=max_droppath_rate) + + self.up_blocks.append( + ConvTransposeND(dim)( + n_channels_per_head * n_heads[-1], n_channels_per_head * n_heads[-2], kernel_size=2, stride=2 + ) + ) + for n_head_current, n_head_next in pairwise(n_heads[-2::-1]): self.up_blocks.append( ConvTransposeND(dim)( - n_channels_per_head * n_head_next, n_channels_per_head * n_head_current, kernel_size=2, stride=2 + 2 * n_channels_per_head * n_head_current, n_channels_per_head * n_head_next, kernel_size=2, stride=2 ) ) - self.output_blocks = self.output_blocks[::-1] - self.up_blocks = self.up_blocks[::-1] self.first = torch.nn.Sequential( ConvND(dim)(channels_in, n_channels_per_head * n_heads[0], kernel_size=3, stride=1, padding='same'), LeakyReLU(), ) self.last = ConvND(dim)( - n_channels_per_head * n_heads[-1], channels_out, kernel_size=3, stride=1, padding='same' + 2 * n_channels_per_head * n_heads[0], channels_out, kernel_size=3, stride=1, padding='same' ) diff --git a/src/mrpro/nn/nets/VAE.py b/src/mrpro/nn/nets/VAE.py index 9f98ac9c3..cd4a1260a 100644 --- a/src/mrpro/nn/nets/VAE.py +++ b/src/mrpro/nn/nets/VAE.py @@ -47,6 +47,12 @@ def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ return self.forward(x) + def mode(self, x: torch.Tensor) -> torch.Tensor: + """Mode of the VAE.""" + z = self.encoder(x) + mean, _ = z.chunk(2, dim=1) + return self.decoder(mean) + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass of the VAE.""" z = self.encoder(x) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 35c701a5d..d6951b4a8 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -3,5 +3,6 @@ from mrpro.nn.nets.DCAE import DCVAE from mrpro.nn.nets.VAE import VAE from mrpro.nn.nets.UNet import UNet, AttentionUNet +from mrpro.nn.nets.SwinIR import SwinIR -__all__ = ["AttentionUNet", "DCVAE", "Restormer", "UNet", "Uformer", "VAE"] \ No newline at end of file +__all__ = ["AttentionUNet", "DCVAE", "Restormer", "UNet", "Uformer", "VAE", "SwinIR"] \ No newline at end of file diff --git a/tests/nn/test_film.py b/tests/nn/test_film.py index e3913cb5b..0106aa1cb 100644 --- a/tests/nn/test_film.py +++ b/tests/nn/test_film.py @@ -13,25 +13,25 @@ ], ) @pytest.mark.parametrize( - ('channels', 'channels_emb', 'input_shape', 'emb_shape'), + ('channels', 'channels_cond', 'input_shape', 'cond_shape'), [ (64, 32, (1, 64, 32, 32), (1, 32)), (32, 16, (2, 32, 16, 16), (2, 16)), ], ) -def test_film(channels, channels_emb, input_shape, emb_shape, device): +def test_film(channels, channels_cond, input_shape, cond_shape, device): """Test FiLM output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) - emb = rng.float32_tensor(emb_shape).to(device).requires_grad_(True) - film = FiLM(channels=channels, cond_dim=channels_emb).to(device) - output = film(x, emb) + cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) + film = FiLM(channels=channels, cond_dim=channels_cond).to(device) + output = film(x, cond=cond) assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() assert x.grad is not None, 'No gradient computed for input' - assert emb.grad is not None, 'No gradient computed for embedding' + assert cond.grad is not None, 'No gradient computed for condedding' assert not x.isnan().any(), 'NaN values in input' - assert not emb.isnan().any(), 'NaN values in embedding' + assert not cond.isnan().any(), 'NaN values in condedding' assert not x.grad.isnan().any(), 'NaN values in input gradients' - assert not emb.grad.isnan().any(), 'NaN values in embedding gradients' + assert not cond.grad.isnan().any(), 'NaN values in condedding gradients' assert next(film.project.parameters()).grad is not None, 'No gradient computed for Linear layer' diff --git a/tests/nn/test_resblock.py b/tests/nn/test_resblock.py index 195b88f01..3787257ce 100644 --- a/tests/nn/test_resblock.py +++ b/tests/nn/test_resblock.py @@ -13,19 +13,19 @@ ], ) @pytest.mark.parametrize( - ('dim', 'channels_in', 'channels_out', 'cond_dim', 'input_shape', 'emb_shape'), + ('dim', 'channels_in', 'channels_out', 'cond_dim', 'input_shape', 'cond_shape'), [ (2, 32, 32, 16, (1, 32, 32, 32), (1, 16)), (3, 64, 32, 0, (2, 64, 16, 16, 16), None), ], ) -def test_resblock(dim, channels_in, channels_out, cond_dim, input_shape, emb_shape, device): +def test_resblock(dim, channels_in, channels_out, cond_dim, input_shape, cond_shape, device): """Test ResBlock output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) - emb = rng.float32_tensor(emb_shape).to(device).requires_grad_(True) if emb_shape else None + cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) if cond_shape else None res = ResBlock(dim=dim, channels_in=channels_in, channels_out=channels_out, cond_dim=cond_dim).to(device) - output = res(x, emb) + output = res(x, cond=cond) assert output.shape == (input_shape[0], channels_out, *input_shape[2:]), ( f'Output shape {output.shape} != expected {(input_shape[0], channels_out, *input_shape[2:])}' ) @@ -34,7 +34,7 @@ def test_resblock(dim, channels_in, channels_out, cond_dim, input_shape, emb_sha assert not x.isnan().any(), 'NaN values in input' assert not x.grad.isnan().any(), 'NaN values in input gradients' assert res.block[2].weight.grad is not None, 'No gradient computed for first Conv' - if emb is not None: - assert emb.grad is not None, 'No gradient computed for embedding' - assert not emb.isnan().any(), 'NaN values in embedding' - assert not emb.grad.isnan().any(), 'NaN values in embedding gradients' + if cond is not None: + assert cond.grad is not None, 'No gradient computed for condedding' + assert not cond.isnan().any(), 'NaN values in condedding' + assert not cond.grad.isnan().any(), 'NaN values in condedding gradients' diff --git a/tests/nn/test_sequential.py b/tests/nn/test_sequential.py index 59e7dade9..83e585498 100644 --- a/tests/nn/test_sequential.py +++ b/tests/nn/test_sequential.py @@ -31,7 +31,7 @@ def test_sequential(input_shape, cond_dim, device): FastFourierOp(), FiLM(channels=64, cond_dim=16), ).to(device) - output = seq(x, cond) + output = seq(x, cond=cond) assert output.shape == (input_shape[0], 32), f'Output shape {output.shape} != expected {(input_shape[0], 32)}' output.sum().backward() assert x.grad is not None, 'No gradient computed for input' From 3757b1fa3a883c5413f1f94c24c6068bbb30335b Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 22 May 2025 02:11:50 +0200 Subject: [PATCH 044/205] Add EMADict class for Exponential Moving Average functionality and update imports - Introduced EMADict class to maintain exponential moving averages for various data types. - Updated __all__ lists in utils and nn modules to include new EMADict class. - Added tests for EMADict to ensure correct functionality and error handling. --- src/mrpro/nn/nets/__init__.py | 10 +++- src/mrpro/utils/__init__.py | 2 + src/mrpro/utils/ema.py | 92 +++++++++++++++++++++++++++++++++++ tests/utils/test_ema.py | 89 +++++++++++++++++++++++++++++++++ 4 files changed, 192 insertions(+), 1 deletion(-) create mode 100644 src/mrpro/utils/ema.py create mode 100644 tests/utils/test_ema.py diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index d6951b4a8..02d5a449f 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -5,4 +5,12 @@ from mrpro.nn.nets.UNet import UNet, AttentionUNet from mrpro.nn.nets.SwinIR import SwinIR -__all__ = ["AttentionUNet", "DCVAE", "Restormer", "UNet", "Uformer", "VAE", "SwinIR"] \ No newline at end of file +__all__ = [ + "AttentionUNet", + "DCVAE", + "Restormer", + "SwinIR", + "UNet", + "Uformer", + "VAE" +] \ No newline at end of file diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index 92465470d..cff56dd67 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -16,7 +16,9 @@ from mrpro.utils.interpolate import interpolate, apply_lowres from mrpro.utils.RandomGenerator import RandomGenerator from mrpro.utils.to_tuple import to_tuple +from mrpro.utils.ema import EMADict __all__ = [ + "EMADict", "Indexer", "RandomGenerator", "TensorAttributeMixin", diff --git a/src/mrpro/utils/ema.py b/src/mrpro/utils/ema.py new file mode 100644 index 000000000..b45cc4d27 --- /dev/null +++ b/src/mrpro/utils/ema.py @@ -0,0 +1,92 @@ +from collections.abc import ItemsView, KeysView, Mapping, ValuesView +from typing import Any + +import torch + + +class EMADict: + """ + Exponential Moving Average (EMA) dictionary. + + Maintains an EMA of values for each key. On update, existing keys are + updated with EMA, and new keys are added directly. + + Detaches the values from the autograd graph. + + + """ + + def __init__( + self, + decay: float, + ): + """Initialize the EMA dictionary. + + Parameters + ---------- + decay : float + Decay rate for EMA (between 0 and 1). + """ + self.decay: float = decay + if not 0 <= decay <= 1: + raise ValueError(f'Decay must be between 0 and 1, got {decay}') + self._data: dict[str, Any] = dict() + + def __getitem__(self, key: str) -> Any: + """Get the value of the EMA dict for a given key.""" + return self._data[key] + + def __setitem__(self, key: str, value: Any) -> None: + """Set the value of the EMA dict for a given key.""" + if key in self._data: + old_v = self._data[key] + if isinstance(value, torch.Tensor): + if isinstance(old_v, torch.Tensor) and isinstance(value, torch.Tensor): + if torch.is_floating_point(old_v) or torch.is_complex(old_v): + old_v.mul_(self.decay).add_(value.detach().to(old_v.device), alpha=1.0 - self.decay) + else: + old_v.copy_(value) + return + elif isinstance(old_v, float) and isinstance(value, float): # noqa: SIM114 + self._data[key] = self.decay * old_v + (1.0 - self.decay) * value + return + elif isinstance(old_v, complex) and isinstance(value, complex): + self._data[key] = self.decay * old_v + (1.0 - self.decay) * value + return + if isinstance(value, torch.Tensor): + self._data[key] = value.detach().clone() + else: + self._data[key] = value + + def __delitem__(self, key: str) -> None: + """Delete a key from the EMA dict.""" + del self._data[key] + + def __contains__(self, key: str) -> bool: + """Check if a key is in the EMA dict.""" + return key in self._data + + def values(self) -> ValuesView[Any]: + """Get the values of the EMA dict.""" + return self._data.values() + + def keys(self) -> KeysView[str]: + """Get the keys of the EMA dict.""" + return self._data.keys() + + def items(self) -> ItemsView[str, Any]: + """Get the items of the EMA dict.""" + return self._data.items() + + def update(self, other: Mapping[Any, Any]) -> None: + """Update the EMA dict with another dictionary. + + For existing keys, performs EMA update. For new keys, adds them directly. + + Parameters + ---------- + other : dict + Dictionary to update from. + """ + for k, v in other.items(): + self.__setitem__(k, v) diff --git a/tests/utils/test_ema.py b/tests/utils/test_ema.py new file mode 100644 index 000000000..3f9ea5ca2 --- /dev/null +++ b/tests/utils/test_ema.py @@ -0,0 +1,89 @@ +"""Tests for EMADict.""" + +from typing import Any + +import pytest +import torch +from mrpro.utils import RandomGenerator +from mrpro.utils.ema import EMADict + + +@pytest.mark.parametrize( + ('key', 'value'), + [ + ('float', 1.0), + ('complex', 1.0 + 1.0j), + ('tensor', torch.ones(2, 3)), + ], +) +def test_ema_dict_numerical( + key: str, + value: Any, +) -> None: + """Test that EMA calculation is numerically correct.""" + decay = 0.8 + ema = EMADict(decay=decay) + + ema[key] = value + new_value = RandomGenerator(seed=42).float32() * value + ema.update({key: new_value}) + + expected = decay * value + (1 - decay) * new_value + if isinstance(value, torch.Tensor): + torch.testing.assert_close(ema[key], expected) + else: + assert ema[key] == pytest.approx(expected) + + +def test_ema_dict_invalid_decay() -> None: + """Test EMADict with invalid decay values.""" + with pytest.raises(ValueError, match='Decay must be between 0 and 1'): + EMADict(decay=-0.1) + with pytest.raises(ValueError, match='Decay must be between 0 and 1'): + EMADict(decay=1.1) + + +def test_ema_dict_update() -> None: + """Test EMADict update method.""" + rng = RandomGenerator(seed=42) + ema = EMADict(decay=0.9) + + new_dict: dict[str, Any] = { + 'float': rng.float32(), + 'complex': rng.complex64(), + 'tensor': rng.float32_tensor((2, 3)), + 'string': 'test', + } + ema.update(new_dict) + + for key, value in new_dict.items(): + assert key in ema + if isinstance(value, torch.Tensor): + torch.testing.assert_close(ema[key], value) + else: + assert ema[key] == value + + +def test_ema_dict_deletion() -> None: + """Test EMADict deletion.""" + rng = RandomGenerator(seed=42) + ema = EMADict(decay=0.9) + + ema['test'] = rng.float32() + assert 'test' in ema + + del ema['test'] + assert 'test' not in ema + + with pytest.raises(KeyError): + del ema['nonexistent'] + + +def test_ema_dict_tensor_detach() -> None: + """Test that tensors are detached from autograd graph.""" + rng = RandomGenerator(seed=42) + ema = EMADict(decay=0.9) + + tensor = rng.float32_tensor((2, 3)).requires_grad_(True) + ema['test'] = tensor + assert not ema['test'].requires_grad From f3aaa6ab527e765a0f89081ef27aad2d3fe890f5 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 22 May 2025 10:00:41 +0200 Subject: [PATCH 045/205] Refactor EfficientViTBlock and Encoder/Decoder stages to use dynamic head counts based on width; improve sequential structure in PixelUnshuffleDownsample. --- src/mrpro/nn/nets/DCAE.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py index d6fdbfb00..c12829322 100644 --- a/src/mrpro/nn/nets/DCAE.py +++ b/src/mrpro/nn/nets/DCAE.py @@ -88,7 +88,7 @@ def __init__( """ super().__init__() if linear_attn: - attention: Module = LinearSelfAttention(channels, channels, n_heads) # TODO: check heads and head dim + attention: Module = LinearSelfAttention(channels, channels, n_heads) else: attention = MultiHeadAttention(channels, channels, n_heads, features_last=False) self.context_module = Residual(Sequential(attention, RMSNorm(channels))) @@ -155,18 +155,23 @@ def __init__( case 'CNN': stage: list[Module] = [CNNBlock(dim, width) for _ in range(depth)] case 'LinearViT': - stage = [ - EfficientViTBlock(dim, width, n_heads=1, linear_attn=True) for _ in range(depth) - ] # TODO: heads + stage = [EfficientViTBlock(dim, width, n_heads=width // 32, linear_attn=True) for _ in range(depth)] case 'ViT': - stage = [EfficientViTBlock(dim, width, n_heads=1, linear_attn=False) for _ in range(depth)] + stage = [ + EfficientViTBlock(dim, width, n_heads=width // 32, linear_attn=False) for _ in range(depth) + ] case _: raise ValueError(f'Block type {block_type} not supported') self.append(Sequential(*stage)) if next_width: self.append(PixelUnshuffleDownsample(dim, width, next_width, downscale_factor=2, residual=True)) - # Norm # relu - self.append(PixelUnshuffleDownsample(dim, widths[-1], channels_out, downscale_factor=1, residual=True)) + self.append( + Sequential( + RMSNorm(widths[-1]), + ReLU(), + PixelUnshuffleDownsample(dim, widths[-1], channels_out, downscale_factor=1, residual=True), + ) + ) class Decoder(Sequential): @@ -219,11 +224,11 @@ def __init__( case 'CNN': stage: list[Module] = [CNNBlock(dim, width) for _ in range(depth)] case 'LinearViT': - stage = [ - EfficientViTBlock(dim, width, n_heads=1, linear_attn=True) for _ in range(depth) - ] # TODO: heads + stage = [EfficientViTBlock(dim, width, n_heads=width // 32, linear_attn=True) for _ in range(depth)] case 'ViT': - stage = [EfficientViTBlock(dim, width, n_heads=1, linear_attn=False) for _ in range(depth)] + stage = [ + EfficientViTBlock(dim, width, n_heads=width // 32, linear_attn=False) for _ in range(depth) + ] case _: raise ValueError(f'Block type {block_type} not supported') self.append(Sequential(*stage)) From 4c497345b111d92e4b239391abafb712a8212f01 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 22 May 2025 16:48:14 +0200 Subject: [PATCH 046/205] Refactor Restormer and Uformer networks to utilize UNetEncoder and UNetDecoder classes for improved modularity and clarity. Update import statements and streamline block initialization for better readability and maintainability. --- src/mrpro/nn/nets/Restormer.py | 60 ++++++------ src/mrpro/nn/nets/UNet.py | 174 +++++++++++++++++++++++++++------ src/mrpro/nn/nets/Uformer.py | 72 ++++++++------ 3 files changed, 217 insertions(+), 89 deletions(-) diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index 26eea7dc4..8b209406b 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -4,12 +4,12 @@ from itertools import pairwise import torch -from torch.nn import Identity, Module +from torch.nn import Module from mrpro.nn.FiLM import FiLM from mrpro.nn.join import Concat from mrpro.nn.ndmodules import ConvND, InstanceNormND -from mrpro.nn.nets.UNet import UNetBase +from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample from mrpro.nn.Sequential import Sequential from mrpro.nn.TransposedAttention import TransposedAttention @@ -159,9 +159,6 @@ def __init__( cond_dim : int, optional Dimension of conditioning input """ - super().__init__() - - self.first = ConvND(dim)(channels_in, n_channels_per_head, kernel_size=3, stride=1, padding=1, bias=False) def blocks(n_heads: int, n_blocks: int): layers = Sequential( @@ -172,28 +169,35 @@ def blocks(n_heads: int, n_blocks: int): layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, cond_dim=cond_dim)) return layers - for block, head in zip(n_blocks[:-1], n_heads[:-1], strict=True): - self.input_blocks.append(blocks(head, block)) - self.output_blocks.append(blocks(head, block)) - - self.skip_blocks.append(Identity()) - self.concat_blocks.append(Concat()) - - self.middle_block = blocks(n_heads[-1], n_blocks[-1]) - self.output_blocks = self.output_blocks[::-1] - for head_current, head_next in pairwise(n_heads): - self.down_blocks.append( - PixelUnshuffleDownsample(dim, n_channels_per_head * head_current, n_channels_per_head * head_next) - ) - - self.up_blocks.append( - PixelShuffleUpsample(dim, n_channels_per_head * head_next, n_channels_per_head * head_current) - ) + first_block = ConvND(dim)(channels_in, n_channels_per_head, kernel_size=3, stride=1, padding=1, bias=False) + encoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)] + down_blocks = [ + PixelUnshuffleDownsample(dim, n_channels_per_head * head_current, n_channels_per_head * head_next) + for head_current, head_next in pairwise(n_heads) + ] + middle_block = blocks(n_heads[-1], n_blocks[-1]) + encoder = UNetEncoder( + first_block=first_block, + encoder_blocks=encoder_blocks, + down_blocks=down_blocks, + middle_block=middle_block, + ) - self.output_blocks = self.input_blocks[::-1] - self.up_blocks = self.up_blocks[::-1] - self.concat_blocks = self.concat_blocks[::-1] - self.refinement_blocks = Sequential( - *(RestormerBlock(dim, n_channels_per_head, n_heads[0], mlp_ratio) for _ in range(n_refinement_blocks)) + up_blocks = [ + PixelShuffleUpsample(dim, n_channels_per_head * head_next, n_channels_per_head * head_current) + for head_current, head_next in pairwise(n_heads) + ][::-1] + concat_blocks = [Concat() for _ in range(len(encoder_blocks))] + decoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)][::-1] + last_block = Sequential( + *(RestormerBlock(dim, n_channels_per_head, n_heads[0], mlp_ratio) for _ in range(n_refinement_blocks)), + ConvND(dim)(n_channels_per_head, channels_out, kernel_size=3, stride=1, padding=1), ) - self.last = ConvND(dim)(n_channels_per_head, channels_out, kernel_size=3, stride=1, padding=1) + decoder = UNetDecoder( + decoder_blocks=decoder_blocks, + up_blocks=up_blocks, + concat_blocks=concat_blocks, + last_block=last_block, + ) + + super().__init__(encoder=encoder, decoder=decoder) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 599b9b4da..f50ff1177 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -4,59 +4,173 @@ from functools import partial import torch -from torch.nn import Identity, Module, ModuleList +from sympy import Identity +from torch.nn import Module, ModuleList from mrpro.nn.CondMixin import call_with_cond -class UNetBase(Module): - """Base class for U-shaped networks.""" +class UNetEncoder(Module): + """Encoder.""" - def __init__(self) -> None: - """Initialize the UNetBase.""" + def __init__( + self, + first_block: Module, + encoder_blocks: Sequence[Module], + down_blocks: Sequence[Module], + middle_block: Module, + ) -> None: + """Initialize the UNetEncoder.""" super().__init__() - self.input_blocks = ModuleList() + self.first = first_block + """The first block. Should expand from the number of input channels.""" + + self.encoder_blocks = ModuleList(encoder_blocks) """The encoder blocks. Order is highest resolution to lowest resolution.""" - self.down_blocks = ModuleList() + self.down_blocks = ModuleList(down_blocks) """The downsampling blocks""" - self.skip_blocks = ModuleList() - """Modifications to the skip connections""" - - self.middle_block: Module = Identity() + self.middle_block = middle_block """Also called bottleneck block""" - self.output_blocks = ModuleList() - """Also called decoder blocks. Order is lowest resolution to highest resolution.""" + def __len__(self): + """Get the number of resolutions levels.""" + return len(self.down_blocks) + 1 + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple[torch.Tensor, ...]: + """Apply to Network.""" + call = partial(call_with_cond, cond=cond) + + x = call(self.first, x) + + xs = [] + for block, down in zip(self.encoder_blocks, self.down_blocks, strict=True): + x = call(block, x) + xs.append(x) + x = call(down, x) + + x = call(self.middle_block, x) + + return (*xs, x) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple[torch.Tensor, ...]: + """Apply to Network. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor. + + Returns + ------- + The tensors at the different resolutions, highest resolution first. + """ + return super().__call__(x, cond) + + +class UNetDecoder(Module): + """Decoder.""" + + def __init__( + self, + decoder_blocks: Sequence[Module], + up_blocks: Sequence[Module], + concat_blocks: Sequence[Module], + last_block: Module, + ) -> None: + """Initialize the UNetDecoder.""" + super().__init__() + self.decoder_blocks = ModuleList(decoder_blocks) + """The decoder blocks. Order is lowest resolution to highest resolution.""" - self.up_blocks = ModuleList() + self.up_blocks = ModuleList(up_blocks) """The upsampling blocks""" - self.concat_blocks = ModuleList() + self.concat_blocks = ModuleList(concat_blocks) """Joins the skip connections with the upsampled features from a lower resolution level""" - self.last: Module = Identity() + self.last_block = last_block """The last block. Should reduce to the number of output channels.""" - self.first: Module = Identity() - """The first block. Should expand from the number of input channels.""" + def __len__(self): + """Get the number of resolutions levels.""" + return len(self.up_blocks) + 1 - def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def forward(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply to Network.""" call = partial(call_with_cond, cond=cond) - x = call(self.first, x) - xs = [] - for block, down, skip in zip(self.input_blocks, self.down_blocks, self.skip_blocks, strict=True): - x = call(block, x) - xs.append(call(skip, x)) - x = call(down, x) - x = call(self.middle_block, x) - for block, up, concat in zip(self.output_blocks, self.up_blocks, self.concat_blocks, strict=True): + + x = hs[-1] # lowest resolution, from middle block + for block, up, concat, h in zip( + self.decoder_blocks, self.up_blocks, self.concat_blocks, hs[-2::-1], strict=True + ): x = call(up, x) - x = concat(x, xs.pop()) + x = concat(x, h) x = call(block, x) - return call(self.last, x) + + x = call(self.last_block, x) + return x + + def __call__(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network. + + Parameters + ---------- + hs + The tensors at the different resolutions, highest resolution first. + cond + The conditioning tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(hs, cond=cond) + + +class UNetBase(Module): + """Base class for U-shaped networks.""" + + def __init__(self, encoder: UNetEncoder, decoder: UNetDecoder, skip_blocks: Sequence[Module] | None = None) -> None: + """Initialize the UNetBase.""" + super().__init__() + self.encoder = encoder + """The encoder.""" + + self.decoder = decoder + """The decoder.""" + + self.skip_blocks = ModuleList() + """Modifications of the skip connections.""" + + if len(decoder) != len(encoder): + raise ValueError( + 'The number of resolutions in the encoder and decoder must be the same, ' + f'got {len(decoder)} and {len(encoder)}' + ) + + if skip_blocks is None: + self.skip_blocks.extend(Identity() for _ in range(len(decoder))) + elif len(skip_blocks) != len(decoder): + raise ValueError( + f'The number of skip blocks must be the same as the number of resolutions, ' + f'got {len(skip_blocks)} and {len(encoder)}' + ) + else: + self.skip_blocks.extend(skip_blocks) + + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network.""" + xs = self.encoder(x, cond=cond) + xs = tuple( + call_with_cond(self.skip_blocks[i], x, cond=cond) if i < len(self.skip_blocks) else x + for i, x in enumerate(xs) + ) + x = self.decoder(xs, cond=cond) + return x def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply to Network. @@ -72,7 +186,7 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T ------- The output tensor. """ - return super().__call__(x, cond) + return super().__call__(x, cond=cond) class UNet(UNetBase): diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index 12e60a8cf..287cb98ea 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -4,13 +4,13 @@ from itertools import pairwise import torch -from torch.nn import GELU, Identity, LeakyReLU, Module +from torch.nn import GELU, LeakyReLU, Module from mrpro.nn.DropPath import DropPath from mrpro.nn.FiLM import FiLM from mrpro.nn.join import Concat from mrpro.nn.ndmodules import ConvND, ConvTransposeND, InstanceNormND -from mrpro.nn.nets.UNet import UNetBase +from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder from mrpro.nn.Sequential import Sequential from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention @@ -186,7 +186,6 @@ def __init__( is linearly increased from `0` to `max_droppath_rate` with decreasing resolution. The rate in output blocks is fixed to `max_droppath_rate`. """ - super().__init__() def blocks(n_heads: int, p_droppath: float = 0.0): layers = Sequential( @@ -208,43 +207,54 @@ def blocks(n_heads: int, p_droppath: float = 0.0): layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, cond_dim=cond_dim)) return layers + first_block = torch.nn.Sequential( + ConvND(dim)(channels_in, n_channels_per_head * n_heads[0], kernel_size=3, stride=1, padding='same'), + LeakyReLU(), + ) drop_path_rates = torch.linspace(0, max_droppath_rate, len(n_heads)).tolist() - for n_head, p_droppath_input in zip(n_heads[:-1], drop_path_rates[:-1], strict=True): - self.input_blocks.append(blocks(n_heads=n_head, p_droppath=p_droppath_input)) - self.output_blocks.append(blocks(n_heads=2 * n_head, p_droppath=max_droppath_rate)) - self.skip_blocks.append(Identity()) - self.concat_blocks.append(Concat()) - self.output_blocks = self.output_blocks[::-1] - - for n_head_current, n_head_next in pairwise(n_heads): - self.down_blocks.append( - ConvND(dim)( - n_channels_per_head * n_head_current, - n_channels_per_head * n_head_next, - kernel_size=4, - stride=2, - padding=1, - ) + encoder_blocks = [ + blocks(n_heads=n_head, p_droppath=p_droppath_input) + for n_head, p_droppath_input in zip(n_heads[:-1], drop_path_rates[:-1], strict=True) + ] + down_blocks = [ + ConvND(dim)( + n_channels_per_head * n_head_current, + n_channels_per_head * n_head_next, + kernel_size=4, + stride=2, + padding=1, ) + for n_head_current, n_head_next in pairwise(n_heads) + ] + middle_block = blocks(n_heads=n_heads[-1], p_droppath=max_droppath_rate) + encoder = UNetEncoder( + first_block=first_block, + encoder_blocks=encoder_blocks, + down_blocks=down_blocks, + middle_block=middle_block, + ) - self.middle_block = blocks(n_heads=n_heads[-1], p_droppath=max_droppath_rate) - - self.up_blocks.append( + decoder_blocks = [blocks(n_heads=2 * n_head, p_droppath=max_droppath_rate) for n_head in reversed(n_heads[:-1])] + concat_blocks = [Concat() for _ in range(len(decoder_blocks))] + up_blocks = [ ConvTransposeND(dim)( n_channels_per_head * n_heads[-1], n_channels_per_head * n_heads[-2], kernel_size=2, stride=2 ) - ) - for n_head_current, n_head_next in pairwise(n_heads[-2::-1]): - self.up_blocks.append( + ] + for n_head_current, n_head_next in pairwise(reversed(n_heads[:-1])): + up_blocks.append( ConvTransposeND(dim)( 2 * n_channels_per_head * n_head_current, n_channels_per_head * n_head_next, kernel_size=2, stride=2 ) ) - - self.first = torch.nn.Sequential( - ConvND(dim)(channels_in, n_channels_per_head * n_heads[0], kernel_size=3, stride=1, padding='same'), - LeakyReLU(), - ) - self.last = ConvND(dim)( + last_block = ConvND(dim)( 2 * n_channels_per_head * n_heads[0], channels_out, kernel_size=3, stride=1, padding='same' ) + decoder = UNetDecoder( + decoder_blocks=decoder_blocks, + concat_blocks=concat_blocks, + up_blocks=up_blocks, + last_block=last_block, + ) + + super().__init__(encoder=encoder, decoder=decoder) From eafbfc62a6432066e01c4cd8dcb03fbcecf7a7ab Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 23 May 2025 01:45:24 +0200 Subject: [PATCH 047/205] Add SpatialTransformerBlock and integrate into UNet architecture - Introduced SpatialTransformerBlock for enhanced attention mechanisms. - Updated MultiHeadAttention to support cross-attention channels. - Modified UNet to include SpatialTransformerBlock in encoder and decoder stages based on specified attention depths. - Improved modularity and flexibility of the UNet structure. --- src/mrpro/nn/MultiHeadAttention.py | 10 ++- src/mrpro/nn/SpatialTransformerBlock.py | 88 +++++++++++++++++++++++++ src/mrpro/nn/nets/UNet.py | 29 ++++++-- 3 files changed, 121 insertions(+), 6 deletions(-) create mode 100644 src/mrpro/nn/SpatialTransformerBlock.py diff --git a/src/mrpro/nn/MultiHeadAttention.py b/src/mrpro/nn/MultiHeadAttention.py index d7e146b6b..c7b5e1cde 100644 --- a/src/mrpro/nn/MultiHeadAttention.py +++ b/src/mrpro/nn/MultiHeadAttention.py @@ -18,6 +18,7 @@ def __init__( n_heads: int, features_last: bool = False, p_dropout: float = 0.0, + channels_cross: int | None = None, ): """Initialize the Multi-head Attention. @@ -36,10 +37,17 @@ def __init__( or the second dimension, as common in image models. p_dropout Dropout probability. + channels_cross + Number of channels for cross-attention. If `None`, use `channels_in`. """ super().__init__() self.mha = torch.nn.MultiheadAttention( - embed_dim=channels_in, num_heads=n_heads, batch_first=True, dropout=p_dropout + embed_dim=channels_in, + num_heads=n_heads, + batch_first=True, + dropout=p_dropout, + kdim=channels_cross, + vdim=channels_cross, ) self.features_last = features_last self.to_out = Linear(channels_in, channels_out) diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py new file mode 100644 index 000000000..414d579dc --- /dev/null +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -0,0 +1,88 @@ +"""Spatial transformer block.""" + +import torch +from torch.nn import Dropout, Linear, Module + +from mrpro.nn.GEGLU import GEGLU +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.nn.MultiHeadAttention import MultiHeadAttention +from mrpro.nn.ndmodules import ConvND +from mrpro.nn.Sequential import Sequential + + +def zero_init(m: Module) -> Module: + """Initialize module weights and bias to zero.""" + if hasattr(m, 'weight') and isinstance(m.weight, torch.Tensor): + torch.nn.init.zeros_(m.weight) + if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor): + torch.nn.init.zeros_(m.bias) + return m + + +class BasicTransformerBlock(Module): + def __init__(self, channels: int, n_heads: int, p_dropout: float = 0.0, cond_dim: int = 0, mlp_ratio: float = 4): + super().__init__() + self.selfattention = Sequential( + LayerNorm(channels), + MultiHeadAttention(channels_in=channels, channels_out=channels, n_heads=n_heads, p_dropout=p_dropout), + ) + hidden_dim = int(channels * mlp_ratio) + self.ff = Sequential( + LayerNorm(channels), GEGLU(channels, hidden_dim), Dropout(p_dropout), Linear(hidden_dim, channels) + ) + self.crossattention = ( + Sequential( + LayerNorm(channels), + MultiHeadAttention( + channels_in=channels, + channels_out=channels, + n_heads=n_heads, + p_dropout=p_dropout, + channels_cross=cond_dim, + ), + ) + if cond_dim > 0 + else None + ) + self.norm2 = LayerNorm(channels) + self.cond_dim = cond_dim + + def forward(self, x, cond: torch.Tensor | None = None): + x = self.selfattention(x) + x + if cond is not None and self.crossattention is not None: + cond = cond.unflatten(-1, (-1, self.cond_dim)) + x = self.crossattention(x, cond=cond) + x + x = self.ff(x) + x + return x + + +class SpatialTransformerBlock(Module): + def __init__( + self, + dim: int, + channels: int, + n_heads: int, + channels_per_head: int, + depth: int = 1, + dropout: float = 0.0, + cond_dim: int = 0, + ): + super().__init__() + self.in_channels = channels + hidden_dim = n_heads * channels_per_head + self.norm = GroupNorm(channels) + + self.proj_in = ConvND(dim)(channels, hidden_dim, kernel_size=1, stride=1, padding=0) + blocks = [BasicTransformerBlock(channels, n_heads, p_dropout=dropout, cond_dim=cond_dim) for _ in range(depth)] + self.transformer_blocks = Sequential(*blocks) + + self.proj_out = zero_init(ConvND(dim)(hidden_dim, channels, kernel_size=1, stride=1, padding=0)) + + def forward(self, x, cond: torch.Tensor | None = None): + skip = x + x = self.norm(x) + x = self.proj_in(x) + x = self.transformer_blocks(x, cond=cond) + x = self.proj_out(x) + return x + skip diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index f50ff1177..2487b6b0c 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -8,7 +8,11 @@ from torch.nn import Module, ModuleList from mrpro.nn.CondMixin import call_with_cond - +from mrpro.nn.SpatialTransformerBlock import SpatialTransformerBlock +from mrpro.nn.ndmodules import ConvND +from mrpro.nn.Sequential import Sequential +from mrpro.nn.SpatialTransformerBlock import SpatialTransformerBlock +from mrpro.nn.ResBlock import ResBlock class UNetEncoder(Module): """Encoder.""" @@ -207,15 +211,30 @@ def __init__( dim: int, in_channels: int, out_channels: int, + attention_depths: Sequence[int], n_features: Sequence[int], - n_heads: Sequence[int], - n_blocks: int | Sequence[int], + n_heads: int, cond_dim: int, - num_blocks: int, + n_resblocks: int padding_modes: str | Sequence[str], ) -> None: """Initialize the UNet.""" - super().__init__() + + encoder_blocks = [] + decoder_blocks = [] + skip_blocks = [] + for i, (n_feat, n_heads, depth) in enumerate(zip(n_features, n_heads, n_resblocks, strict=True): + enc_block = Sequential(*[ResBlock(dim, n_feat, n_heads, cond_dim) for _ in range(depth)]) + dec_block = Sequential(*[ResBlock(dim, n_feat, n_heads, cond_dim) for _ in range(depth)]) + if i in attention_depths: + enc_block.append(SpatialTransformerBlock(dim, n_feat, n_heads, cond_dim)) + dec_block.append(SpatialTransformerBlock(dim, n_feat, n_heads, cond_dim)) + decoder_blocks.append(dec_block) + skip_blocks.append(enc_block) + + encoder = UNetEncoder(encoder_blocks, down_blocks, middle_block) + decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + super().__init__(encoder, decoder) class AttentionUNet(UNet): From afd7a4573c26567da47d61f92fba77e32ca657c9 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 23 May 2025 01:45:39 +0200 Subject: [PATCH 048/205] Refactor FiLM and Uformer modules for improved clarity and functionality - Simplified the FiLM class by removing unnecessary Sequential and Identity layers. - Updated the Uformer architecture to conditionally include FiLM layers based on the provided conditioning dimension. - Enhanced the forward method of LeWinTransformerBlock to accept conditioning tensors, improving modularity and flexibility. --- src/mrpro/nn/FiLM.py | 12 ++---- src/mrpro/nn/GroupNorm.py | 6 +-- src/mrpro/nn/nets/Uformer.py | 81 ++++++++---------------------------- 3 files changed, 23 insertions(+), 76 deletions(-) diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py index e6d101260..31108918d 100644 --- a/src/mrpro/nn/FiLM.py +++ b/src/mrpro/nn/FiLM.py @@ -1,7 +1,7 @@ """Feature-wise Linear Modulation.""" import torch -from torch.nn import Identity, Linear, Module, Sequential, SiLU +from torch.nn import Linear, Module from mrpro.nn.CondMixin import CondMixin from mrpro.utils.reshape import unsqueeze_tensors_right @@ -30,13 +30,7 @@ def __init__(self, channels: int, cond_dim: int) -> None: The dimension of the conditioning tensor. """ super().__init__() - if cond_dim > 0: - self.project: Module = Sequential( - SiLU(), - Linear(cond_dim, 2 * channels), - ) - else: - self.project = Identity() + self.project = Linear(cond_dim, 2 * channels) if cond_dim > 0 else None def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply FiLM. @@ -54,7 +48,7 @@ def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torc def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply FiLM.""" - if cond is None: + if cond is None or self.project is None: return x scale, shift = self.project(cond).chunk(2, dim=1) diff --git a/src/mrpro/nn/GroupNorm.py b/src/mrpro/nn/GroupNorm.py index 9a50a6319..09e91cf11 100644 --- a/src/mrpro/nn/GroupNorm.py +++ b/src/mrpro/nn/GroupNorm.py @@ -21,9 +21,9 @@ def __init__(self, channels: int, groups: int | None = None): a power of 2 that is less than or equal to 32 and leaves at least 4 channels per group. """ if groups is None: - groups_ = channels & -channels - while (groups_ >= channels // 4) or groups_ > 32: - groups_ //= 2 + groups_, candidate = 1, 2 + while (candidate <= min(32, channels // 4)) and (channels % candidate == 0): + groups_, candidate = candidate, groups_ * 2 else: groups_ = groups super().__init__(groups_, channels) diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index 287cb98ea..ae915d2b7 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -15,61 +15,6 @@ from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention -class LeFF(Module): - """Locally-enhanced Feed-Forward Network. - - Part of the Uformer architecture. - """ - - def __init__( - self, - dim: int, - channels_in: int = 32, - channels_out: int = 32, - expand_ratio: float = 4, - ) -> None: - """Initialize the LeFF module. - - Parameters - ---------- - dim : int - 2 or 3, for 2D or 3D input - channels_in : int - Input feature dimension - channels_out : int - Output feature dimension - expand_ratio : float - Expansion ratio of the hidden dimension - """ - super().__init__() - hidden_dim = int(channels_in * expand_ratio) - self.block = Sequential( - ConvND(dim)(channels_in, hidden_dim, 1), - GELU(), - ConvND(dim)(hidden_dim, hidden_dim, kernel_size=3, groups=hidden_dim, stride=1, padding=1), - GELU(), - ConvND(dim)(hidden_dim, channels_out, 1), - ) - - def __call__(self, x: torch.Tensor) -> torch.Tensor: - """Apply the LeFF module. - - Parameters - ---------- - x : torch.Tensor - The input tensor. - - Returns - ------- - The output tensor. - """ - return super().__call__(x) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Apply the LeFF module.""" - return self.block(x) - - class LeWinTransformerBlock(Module): """Locally-enhanced windowed attention transformer block. @@ -85,6 +30,7 @@ def __init__( shifted: bool = False, mlp_ratio: float = 4.0, p_droppath: float = 0.0, + cond_dim: int = 0, ) -> None: """Initialize the LeWinTransformerBlock module. @@ -104,9 +50,12 @@ def __init__( Ratio of the hidden dimension to the input dimension p_droppath : float, optional Dropout probability for the drop path. + cond_dim : int, optional + Dimension of a conditioning tensor. If `0`, no FiLM layers are added. """ super().__init__() channels = n_channels_per_head * n_heads + hidden_dim = int(channels * mlp_ratio) self.norm1 = InstanceNormND(dim)(channels) self.attn = ShiftedWindowAttention( dim=dim, @@ -116,19 +65,26 @@ def __init__( window_size=window_size, shifted=shifted, ) - self.norm2 = InstanceNormND(dim)(channels) - self.ff = LeFF(dim=dim, channels_in=channels, channels_out=channels, expand_ratio=mlp_ratio) + self.ff = Sequential( + ConvND(dim)(channels, hidden_dim, 1), + GELU(), + ConvND(dim)(hidden_dim, hidden_dim, kernel_size=3, groups=hidden_dim, stride=1, padding=1), + GELU(), + ConvND(dim)(hidden_dim, channels, 1), + ) + if cond_dim > 0: + self.ff.append(FiLM(channels, cond_dim)) self.modulator = torch.nn.Parameter(torch.empty(channels, *((window_size,) * dim))) torch.nn.init.trunc_normal_(self.modulator) self.drop_path = DropPath(droprate=p_droppath) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the transformer block.""" modulator = self.modulator.tile([t // s for t, s in zip(x.shape[1:], self.modulator.shape, strict=False)]) x_mod = self.norm1(x) + modulator x_attn = self.attn(x_mod) - x_ff = self.ff(self.norm2(x_attn)) + x_ff = self.ff(self.norm2(x_attn), cond=cond) return x + self.drop_path(x_ff) @@ -188,7 +144,7 @@ def __init__( """ def blocks(n_heads: int, p_droppath: float = 0.0): - layers = Sequential( + return Sequential( *( LeWinTransformerBlock( dim=dim, @@ -198,15 +154,12 @@ def blocks(n_heads: int, p_droppath: float = 0.0): mlp_ratio=mlp_ratio, shifted=bool(i % 2), p_droppath=p_droppath, + cond_dim=cond_dim, ) for i in range(n_blocks) ) ) - if cond_dim > 0 and n_blocks > 1: - layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, cond_dim=cond_dim)) - return layers - first_block = torch.nn.Sequential( ConvND(dim)(channels_in, n_channels_per_head * n_heads[0], kernel_size=3, stride=1, padding='same'), LeakyReLU(), From 24bfbc9306db963ca80c7e6b0ea931dc4e1fed04 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 23 May 2025 01:46:07 +0200 Subject: [PATCH 049/205] Refactor ZeroPadOp and pad_or_crop utility for improved functionality and clarity - Updated import statement in ZeroPadOp to directly import pad_or_crop function. - Enhanced pad_or_crop function to include a new 'mode' parameter for padding options, improving flexibility in data manipulation. --- src/mrpro/operators/ZeroPadOp.py | 2 +- src/mrpro/utils/pad_or_crop.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/mrpro/operators/ZeroPadOp.py b/src/mrpro/operators/ZeroPadOp.py index c4adfc831..19f19b23e 100644 --- a/src/mrpro/operators/ZeroPadOp.py +++ b/src/mrpro/operators/ZeroPadOp.py @@ -5,7 +5,7 @@ import torch from mrpro.operators.LinearOperator import LinearOperator -from mrpro.utils import pad_or_crop +from mrpro.utils.pad_or_crop import pad_or_crop class ZeroPadOp(LinearOperator): diff --git a/src/mrpro/utils/pad_or_crop.py b/src/mrpro/utils/pad_or_crop.py index d45320e88..d0ed946ef 100644 --- a/src/mrpro/utils/pad_or_crop.py +++ b/src/mrpro/utils/pad_or_crop.py @@ -2,9 +2,9 @@ import math from collections.abc import Sequence +from typing import Literal import torch -import torch.nn.functional as F # noqa: N812 def normalize_index(ndim: int, index: int) -> int: @@ -34,6 +34,7 @@ def pad_or_crop( data: torch.Tensor, new_shape: Sequence[int] | torch.Size, dim: None | Sequence[int] = None, + mode: Literal['constant', 'reflect', 'replicate', 'circular'] = 'constant', value: float = 0.0, ) -> torch.Tensor: """Change shape of data by center cropping or symmetric padding. @@ -47,8 +48,10 @@ def pad_or_crop( dim Dimensions the `new_shape` corresponds to. `None` is interpreted as last ``len(new_shape)`` dimensions. + mode + Mode to use for padding. value - value to use for padding. + value to use for constant padding. Returns ------- @@ -78,5 +81,5 @@ def pad_or_crop( if any(npad): # F.pad expects paddings in reversed order - data = F.pad(data, npad[::-1], value=value) + data = torch.nn.functional.pad(data, npad[::-1], value=value, mode=mode) return data From 376be37273f3e3369f89871d164dc4c8cc820b2b Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 2 Jun 2025 02:24:51 +0200 Subject: [PATCH 050/205] Add Upsample module for tensor resizing functionality - Introduced the Upsample class to facilitate tensor upsampling with configurable scale factors and modes (nearest, linear). - Implemented the forward method to compute new tensor sizes based on the specified dimensions and scale factor, enhancing flexibility in tensor manipulation. --- src/mrpro/nn/Upsample.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 src/mrpro/nn/Upsample.py diff --git a/src/mrpro/nn/Upsample.py b/src/mrpro/nn/Upsample.py new file mode 100644 index 000000000..e17384f13 --- /dev/null +++ b/src/mrpro/nn/Upsample.py @@ -0,0 +1,18 @@ +from typing import Literal + +import torch +from torch.nn import Module + +from mrpro.utils.interpolate import interpolate + + +class Upsample(Module): + def __init__(self, dim: int, scale_factor: int = 2, mode: Literal['nearest', 'linear'] = 'linear'): + super().__init__() + self.scale_factor = scale_factor + self.mode = mode + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + new_size = [d * self.scale_factor for d in x.shape[self.dim :]] + return interpolate(x, size=new_size, dim=range(-self.dim, 0)) From b1ff7f84991d2dfbcd1baf0c4eb9314464dfb7fe Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 2 Jun 2025 02:25:10 +0200 Subject: [PATCH 051/205] Refactor neural network modules to standardize feature dimension handling - Updated parameter names from 'channel_last' to 'features_last' across multiple modules for consistency. - Adjusted related logic in GEGLU, LayerNorm, LinearSelfAttention, NeighborhoodSelfAttention, RMSNorm, and BasicTransformerBlock to reflect the new parameter naming. - Enhanced clarity in the handling of feature dimensions, improving modularity and maintainability of the codebase. --- src/mrpro/nn/GEGLU.py | 12 ++++---- src/mrpro/nn/LayerNorm.py | 10 +++---- src/mrpro/nn/LinearSelfAttention.py | 12 ++++---- src/mrpro/nn/NeighborhoodSelfAttention.py | 10 +++---- src/mrpro/nn/RMSNorm.py | 6 ++-- src/mrpro/nn/SpatialTransformerBlock.py | 35 +++++++++++++++++++---- 6 files changed, 54 insertions(+), 31 deletions(-) diff --git a/src/mrpro/nn/GEGLU.py b/src/mrpro/nn/GEGLU.py index 42605ea45..787659e10 100644 --- a/src/mrpro/nn/GEGLU.py +++ b/src/mrpro/nn/GEGLU.py @@ -12,7 +12,7 @@ class GEGLU(Module): ..[GLU] Shazeer, N. (2020). GLU variants improve transformer. https://arxiv.org/abs/2002.05202 """ - def __init__(self, in_features: int, out_features: int | None = None, channels_last: bool = False): + def __init__(self, in_features: int, out_features: int | None = None, features_last: bool = False): """Initialize the GEGLU activation function. Parameters @@ -22,21 +22,21 @@ def __init__(self, in_features: int, out_features: int | None = None, channels_l out_features : int The number of output features. If None, the number of output features is the same as the number of input features. - channels_last + features_last If True, the channel dimension is the last dimension, else in the second dimension. """ super().__init__() out_features_ = in_features if out_features is None else out_features self.proj = Linear(in_features, out_features_ * 2) # gate and output stacked - self.channels_last = channels_last + self.features_last = features_last def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the GEGLU activation.""" - if not self.channels_last: + if not self.features_last: x = x.moveaxis(1, -1) h, gate = self.proj(x).chunk(2, dim=-1) gate = torch.nn.functional.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) out = h * gate - if not self.channels_last: - out = out.moveaxis(out, -1, 1) + if not self.features_last: + out = out.moveaxis(-1, 1) return out diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py index ce7b60553..07f9ba593 100644 --- a/src/mrpro/nn/LayerNorm.py +++ b/src/mrpro/nn/LayerNorm.py @@ -9,7 +9,7 @@ class LayerNorm(Module): """Layer normalization.""" - def __init__(self, channels: int | None, channel_last: bool = False, bias: bool = True) -> None: + def __init__(self, channels: int | None, features_last: bool = False, bias: bool = True) -> None: """Initialize the layer normalization. Parameters @@ -17,7 +17,7 @@ def __init__(self, channels: int | None, channel_last: bool = False, bias: bool channels Number of channels in the input tensor. If `None`, the layer normalization does not do an elementwise affine transformation. - channel_last + features_last If `True`, the channel dimension is the last dimension. bias If `False`, only a scaling is applied without an offset if an affine transformation is used. @@ -29,7 +29,7 @@ def __init__(self, channels: int | None, channel_last: bool = False, bias: bool else: self.weight = None self.bias = None - self.channel_last = channel_last + self.features_last = features_last def __call__(self, x: torch.Tensor) -> torch.Tensor: """Apply layer normalization to the input tensor. @@ -53,13 +53,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = (x - mean) / (std + 1e-5) if self.weight is not None: - if self.channel_last: + if self.features_last: x = x * self.weight else: x = x * unsqueeze_right(self.weight, x.ndim - 2) if self.bias is not None: - if self.channel_last: + if self.features_last: x = x + self.bias else: x = x + unsqueeze_right(self.bias, x.ndim - 2) diff --git a/src/mrpro/nn/LinearSelfAttention.py b/src/mrpro/nn/LinearSelfAttention.py index 3ad70f14f..12dbe1718 100644 --- a/src/mrpro/nn/LinearSelfAttention.py +++ b/src/mrpro/nn/LinearSelfAttention.py @@ -35,7 +35,7 @@ def __init__( channels_out: int, n_heads: int, eps: float = 1e-6, - channel_last: bool = False, + features_last: bool = False, ): """Initialize linear self-attention layer. @@ -49,12 +49,12 @@ def __init__( Number of attention heads. eps Small epsilon for numerical stability in normalization. - channel_last + features_last Whether the channel dimension is the last dimension, as common in transformer models, or the second dimension, as common in image models. """ super().__init__() - self.channel_last = channel_last + self.features_last = features_last self.eps = eps self.n_heads = n_heads channels_per_head = channels_in // n_heads @@ -68,7 +68,7 @@ def __call__(self, x: Tensor) -> Tensor: Parameters ---------- x - Tensor of shape `batch, channels, *spatial_dims` or (`batch, *spatial_dims, channels` if `channel_last`) + Tensor of shape `batch, channels, *spatial_dims` or (`batch, *spatial_dims, channels` if `features_last`) Returns ------- @@ -81,7 +81,7 @@ def forward(self, x: Tensor) -> Tensor: orig_dtype = x.dtype if x.dtype == torch.float16: x = x.float() - if not self.channel_last: + if not self.features_last: x = x.moveaxis(1, -1) spatial_shape = x.shape[1:-1] @@ -104,6 +104,6 @@ def forward(self, x: Tensor) -> Tensor: out = out.to(orig_dtype) out = out.moveaxis(1, -1).flatten(-2) # join heads and channels out = out.unflatten(-2, spatial_shape) - if not self.channel_last: + if not self.features_last: out = out.moveaxis(-1, 1) return out diff --git a/src/mrpro/nn/NeighborhoodSelfAttention.py b/src/mrpro/nn/NeighborhoodSelfAttention.py index 3ff753ab7..3c58bcfc1 100644 --- a/src/mrpro/nn/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/NeighborhoodSelfAttention.py @@ -125,7 +125,7 @@ def __init__( kernel_size: int | Sequence[int], dilation: int | Sequence[int] = 1, circular: bool | Sequence[bool] = False, - channel_last: bool = False, + features_last: bool = False, ) -> None: """Initialize a neighborhood attention module. @@ -146,7 +146,7 @@ def __init__( The dilation factor for the neighborhood. circular Whether the neighborhood wraps around the edges (circular padding) - channel_last + features_last Whether the channels are in the last dimension of the tensor, as common in visíon transformers. Otherwise, assume the channels are in the second dimension, as common in CNN models. """ @@ -155,7 +155,7 @@ def __init__( self.kernel_size = kernel_size self.dilation = dilation self.circular = circular - self.channel_last = channel_last + self.features_last = features_last channels_per_head = channels_in // n_heads self.to_qkv = Linear(channels_in, 3 * channels_per_head * n_heads) self.to_out = Linear(channels_per_head * n_heads, channels_out) @@ -172,7 +172,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ------- The output tensor after attention, with the same shape as the input tensor. """ - if not self.channel_last: + if not self.features_last: x = x.moveaxis(1, -1) spatial_shape = x.shape[2:-1] qkv = self.to_qkv(x) @@ -187,6 +187,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out: torch.Tensor = flex_attention(query.contiguous(), key.contiguous(), value.contiguous(), block_mask=mask) # type: ignore[assignment] # wrong type hints out = self.to_out(out) out = out.unflatten(-2, spatial_shape) - if not self.channel_last: + if not self.features_last: out = out.moveaxis(-1, 1) return out diff --git a/src/mrpro/nn/RMSNorm.py b/src/mrpro/nn/RMSNorm.py index 52a2bce43..0b184ebac 100644 --- a/src/mrpro/nn/RMSNorm.py +++ b/src/mrpro/nn/RMSNorm.py @@ -7,7 +7,7 @@ class RMSNorm(Module): """RMSNorm over the channel dimension.""" - def __init__(self, channels: int, eps: float = 1e-8, channel_last: bool = False): + def __init__(self, channels: int, eps: float = 1e-8, features_last: bool = False): """Initialize RMSNorm. Includes a learnable weight and bias. @@ -18,14 +18,14 @@ def __init__(self, channels: int, eps: float = 1e-8, channel_last: bool = False) Number of channels. eps Epsilon value to avoid division by zero. - channel_last + features_last If True, the channel dimension is the last dimension. """ super().__init__() self.weight = Parameter(torch.zeros(channels)) self.bias = Parameter(torch.zeros(channels)) self.eps = eps - self.channel_dim = -1 if channel_last else 1 + self.channel_dim = -1 if features_last else 1 def __call__(self, x: torch.Tensor) -> torch.Tensor: """Apply RMSNorm over the channel dimension. diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py index 414d579dc..6751db45a 100644 --- a/src/mrpro/nn/SpatialTransformerBlock.py +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -21,39 +21,62 @@ def zero_init(m: Module) -> Module: class BasicTransformerBlock(Module): - def __init__(self, channels: int, n_heads: int, p_dropout: float = 0.0, cond_dim: int = 0, mlp_ratio: float = 4): + def __init__( + self, + channels: int, + n_heads: int, + p_dropout: float = 0.0, + cond_dim: int = 0, + mlp_ratio: float = 4, + features_last: bool = False, + ): super().__init__() + self.features_last = features_last self.selfattention = Sequential( - LayerNorm(channels), - MultiHeadAttention(channels_in=channels, channels_out=channels, n_heads=n_heads, p_dropout=p_dropout), + LayerNorm(channels, features_last=True), + MultiHeadAttention( + channels_in=channels, + channels_out=channels, + n_heads=n_heads, + p_dropout=p_dropout, + features_last=True, + ), ) hidden_dim = int(channels * mlp_ratio) self.ff = Sequential( - LayerNorm(channels), GEGLU(channels, hidden_dim), Dropout(p_dropout), Linear(hidden_dim, channels) + LayerNorm(channels, features_last=True), + GEGLU(channels, hidden_dim, features_last=True), + Dropout(p_dropout), + Linear(hidden_dim, channels), ) self.crossattention = ( Sequential( - LayerNorm(channels), + LayerNorm(channels, features_last=True), MultiHeadAttention( channels_in=channels, channels_out=channels, n_heads=n_heads, p_dropout=p_dropout, channels_cross=cond_dim, + features_last=True, ), ) if cond_dim > 0 else None ) - self.norm2 = LayerNorm(channels) + self.norm2 = LayerNorm(channels, features_last=True) self.cond_dim = cond_dim def forward(self, x, cond: torch.Tensor | None = None): + if not self.features_last: + x = x.moveaxis(1, -1) x = self.selfattention(x) + x if cond is not None and self.crossattention is not None: cond = cond.unflatten(-1, (-1, self.cond_dim)) x = self.crossattention(x, cond=cond) + x x = self.ff(x) + x + if not self.features_last: + x = x.moveaxis(-1, 1) return x From dca726b02843484b9f6dcc5bad22c78b4d6fcc41 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 2 Jun 2025 02:25:23 +0200 Subject: [PATCH 052/205] Enhance UNet architecture with improved attention handling and modularity - Updated UNet class to include configurable attention depths and encoder blocks per scale, enhancing flexibility in model design. - Introduced new attention block functionality and refined block initialization for better clarity and maintainability. - Adjusted forward method to accept conditioning tensors explicitly, improving modularity in the encoder-decoder structure. - Integrated GroupNorm and SiLU into the final layers for improved performance and consistency. --- src/mrpro/nn/nets/UNet.py | 104 +++++++++++++++++++++++++++++--------- 1 file changed, 80 insertions(+), 24 deletions(-) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 2487b6b0c..7ff5b5179 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -4,15 +4,17 @@ from functools import partial import torch -from sympy import Identity -from torch.nn import Module, ModuleList +from torch.nn import Identity, Module, ModuleList, SiLU from mrpro.nn.CondMixin import call_with_cond -from mrpro.nn.SpatialTransformerBlock import SpatialTransformerBlock +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.join import Concat from mrpro.nn.ndmodules import ConvND +from mrpro.nn.ResBlock import ResBlock from mrpro.nn.Sequential import Sequential from mrpro.nn.SpatialTransformerBlock import SpatialTransformerBlock -from mrpro.nn.ResBlock import ResBlock +from mrpro.nn.Upsample import Upsample + class UNetEncoder(Module): """Encoder.""" @@ -72,7 +74,7 @@ def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tupl ------- The tensors at the different resolutions, highest resolution first. """ - return super().__call__(x, cond) + return super().__call__(x, cond=cond) class UNetDecoder(Module): @@ -211,28 +213,82 @@ def __init__( dim: int, in_channels: int, out_channels: int, - attention_depths: Sequence[int], - n_features: Sequence[int], - n_heads: int, - cond_dim: int, - n_resblocks: int - padding_modes: str | Sequence[str], + attention_depths: Sequence[int] = (-1, -2), + n_features: Sequence[int] = (64, 128, 192, 256), + n_heads: int = 4, + cond_dim: int = 0, + encoder_blocks_per_scale: int = 2, ) -> None: """Initialize the UNet.""" + depth = len(n_features) + if not all(-depth <= d < depth for d in attention_depths): + raise ValueError( + f'attention_depths must be in the range [-depth, depth], got {attention_depths=} for {depth=}' + ) + attention_depths = tuple(d % depth for d in attention_depths) + if len(attention_depths) != len(set(attention_depths)): + raise ValueError(f'attention_depths must be unique, got {attention_depths=}') + + def attention_block(channels: int) -> Module: + return SpatialTransformerBlock( + dim, channels, n_heads, channels_per_head=channels // n_heads, cond_dim=cond_dim + ) + + def block(channels_in: int, channels_out: int, attention: bool) -> Module: + if not attention: + return ResBlock(dim, channels_in, channels_out, cond_dim) + return Sequential(ResBlock(dim, channels_in, channels_out, cond_dim), attention_block(channels_out)) + + first_block = ConvND(dim)(in_channels, n_features[0], 3, padding=1) + + encoder_blocks: list[Module] = [] + down_blocks: list[Module] = [] + skip_features = [] + n_feat_old = n_features[0] + + for i_level, n_feat in enumerate(n_features): + encoder_blocks.append(Identity()) + skip_features.append(n_feat_old) + for _ in range(encoder_blocks_per_scale): + encoder_blocks.append(block(n_feat_old, n_feat, attention=i_level in attention_depths)) + n_feat_old = n_feat + down_blocks.append(Identity()) + skip_features.append(n_feat_old) + down_blocks.append(ConvND(dim)(n_feat, n_feat, 3, stride=2, padding=1)) + down_blocks[-1] = Identity() + + middle_block = Sequential( + ResBlock(dim, n_features[-1], n_features[-1], cond_dim), + ResBlock(dim, n_features[-1], n_features[-1], cond_dim), + ) + if i_level in attention_depths: + middle_block.insert(1, attention_block(n_features[-1])) + + encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) + + decoder_blocks: list[Module] = [] + up_blocks: list[Module] = [Identity()] + for i_level, n_feat in reversed(list(enumerate(n_features))): + decoder_blocks.append( + block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) + ) + n_feat_old = n_feat + for _ in range(encoder_blocks_per_scale): + decoder_blocks.append( + block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) + ) + n_feat_old = n_feat + + up_blocks.append(Identity()) + n_feat_old = n_feat + up_blocks.append(Upsample(dim, scale_factor=2)) + up_blocks.pop() + + concat_blocks = [Concat()] * len(decoder_blocks) + last_block = Sequential( + GroupNorm(n_features[0]), SiLU(), ConvND(dim)(n_features[0], out_channels, 3, padding=1) + ) - encoder_blocks = [] - decoder_blocks = [] - skip_blocks = [] - for i, (n_feat, n_heads, depth) in enumerate(zip(n_features, n_heads, n_resblocks, strict=True): - enc_block = Sequential(*[ResBlock(dim, n_feat, n_heads, cond_dim) for _ in range(depth)]) - dec_block = Sequential(*[ResBlock(dim, n_feat, n_heads, cond_dim) for _ in range(depth)]) - if i in attention_depths: - enc_block.append(SpatialTransformerBlock(dim, n_feat, n_heads, cond_dim)) - dec_block.append(SpatialTransformerBlock(dim, n_feat, n_heads, cond_dim)) - decoder_blocks.append(dec_block) - skip_blocks.append(enc_block) - - encoder = UNetEncoder(encoder_blocks, down_blocks, middle_block) decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) super().__init__(encoder, decoder) From 59569dc72c6c567fcdef479f4ce721c1a33c39d4 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 2 Jun 2025 16:25:15 +0200 Subject: [PATCH 053/205] - Updated AttentionGate to include a new 'concatenate' parameter, allowing for optional concatenation of gated and gate signals in the channel dimension. - Adjusted ResBlock to modify the rezero parameter for better stability during training. - Refactored forward methods in both AttentionGate and ResBlock to ensure compatibility with the new features and maintain clarity in tensor operations. - Updated import statements and class references in UNet and related modules to reflect the new AttentionGatedUNet class. --- src/mrpro/nn/AttentionGate.py | 10 ++- src/mrpro/nn/ResBlock.py | 4 +- src/mrpro/nn/SpatialTransformerBlock.py | 3 +- src/mrpro/nn/nets/Restormer.py | 4 +- src/mrpro/nn/nets/UNet.py | 113 +++++++++++++++++++----- src/mrpro/nn/nets/Uformer.py | 4 +- src/mrpro/nn/nets/__init__.py | 4 +- 7 files changed, 109 insertions(+), 33 deletions(-) diff --git a/src/mrpro/nn/AttentionGate.py b/src/mrpro/nn/AttentionGate.py index 96ebe6cf9..e7dd40d6e 100644 --- a/src/mrpro/nn/AttentionGate.py +++ b/src/mrpro/nn/AttentionGate.py @@ -17,7 +17,7 @@ class AttentionGate(Module): https://arxiv.org/abs/1804.03999 """ - def __init__(self, dim: int, channels_gate: int, channels_in: int, channels_hidden: int): + def __init__(self, dim: int, channels_gate: int, channels_in: int, channels_hidden: int, concatenate: bool = False): """Initialize the attention gate. Parameters @@ -30,6 +30,8 @@ def __init__(self, dim: int, channels_gate: int, channels_in: int, channels_hidd The number of channels in the input tensor. channels_hidden The number of internal, hidden channels. + concatenate + Whether to concatenate the gated signal with the gate signal in the channel dimension (1) """ super().__init__() self.project_gate = ConvND(dim)(channels_gate, channels_hidden, kernel_size=1) @@ -39,6 +41,7 @@ def __init__(self, dim: int, channels_gate: int, channels_in: int, channels_hidd ConvND(dim)(channels_hidden, 1, kernel_size=1), Sigmoid(), ) + self.concatenate = concatenate def __call__(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: """Apply the attention gate. @@ -63,4 +66,7 @@ def forward(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: if gate.shape[2:] != x.shape[2:]: projected_gate = torch.nn.functional.interpolate(projected_gate, size=x.shape[2:], mode='nearest') alpha = self.psi(projected_gate + projected_x) - return x * alpha + x = x * alpha + if self.concatenate: + x = torch.cat([x, gate], dim=1) + return x diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py index 0897eb52c..bd09101ac 100644 --- a/src/mrpro/nn/ResBlock.py +++ b/src/mrpro/nn/ResBlock.py @@ -30,7 +30,7 @@ def __init__(self, dim: int, channels_in: int, channels_out: int, cond_dim: int) """ super().__init__() - self.rezero = torch.nn.Parameter(torch.tensor(1e-6)) + self.rezero = torch.nn.Parameter(torch.tensor(1e-2)) self.block = Sequential( GroupNorm(channels_in), SiLU(), @@ -66,5 +66,5 @@ def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torc def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the ResBlock.""" h = self.block(x, cond=cond) - x = self.skip_connection(x) + h + x = self.skip_connection(x) + self.rezero * h return x diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py index 6751db45a..2c482a239 100644 --- a/src/mrpro/nn/SpatialTransformerBlock.py +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -64,10 +64,9 @@ def __init__( if cond_dim > 0 else None ) - self.norm2 = LayerNorm(channels, features_last=True) self.cond_dim = cond_dim - def forward(self, x, cond: torch.Tensor | None = None): + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: if not self.features_last: x = x.moveaxis(1, -1) x = self.selfattention(x) + x diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index 8b209406b..c56b2c4a0 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -178,7 +178,7 @@ def blocks(n_heads: int, n_blocks: int): middle_block = blocks(n_heads[-1], n_blocks[-1]) encoder = UNetEncoder( first_block=first_block, - encoder_blocks=encoder_blocks, + blocks=encoder_blocks, down_blocks=down_blocks, middle_block=middle_block, ) @@ -194,7 +194,7 @@ def blocks(n_heads: int, n_blocks: int): ConvND(dim)(n_channels_per_head, channels_out, kernel_size=3, stride=1, padding=1), ) decoder = UNetDecoder( - decoder_blocks=decoder_blocks, + blocks=decoder_blocks, up_blocks=up_blocks, concat_blocks=concat_blocks, last_block=last_block, diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 7ff5b5179..0b1c74e03 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -2,14 +2,17 @@ from collections.abc import Sequence from functools import partial +from itertools import pairwise import torch -from torch.nn import Identity, Module, ModuleList, SiLU +from torch.nn import Identity, Module, ModuleList, ReLU, SiLU +from mrpro.nn.AttentionGate import AttentionGate from mrpro.nn.CondMixin import call_with_cond +from mrpro.nn.FiLM import FiLM from mrpro.nn.GroupNorm import GroupNorm from mrpro.nn.join import Concat -from mrpro.nn.ndmodules import ConvND +from mrpro.nn.ndmodules import ConvND, MaxPoolND from mrpro.nn.ResBlock import ResBlock from mrpro.nn.Sequential import Sequential from mrpro.nn.SpatialTransformerBlock import SpatialTransformerBlock @@ -22,7 +25,7 @@ class UNetEncoder(Module): def __init__( self, first_block: Module, - encoder_blocks: Sequence[Module], + blocks: Sequence[Module], down_blocks: Sequence[Module], middle_block: Module, ) -> None: @@ -31,7 +34,7 @@ def __init__( self.first = first_block """The first block. Should expand from the number of input channels.""" - self.encoder_blocks = ModuleList(encoder_blocks) + self.blocks = ModuleList(blocks) """The encoder blocks. Order is highest resolution to lowest resolution.""" self.down_blocks = ModuleList(down_blocks) @@ -51,7 +54,7 @@ def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple x = call(self.first, x) xs = [] - for block, down in zip(self.encoder_blocks, self.down_blocks, strict=True): + for block, down in zip(self.blocks, self.down_blocks, strict=True): x = call(block, x) xs.append(x) x = call(down, x) @@ -82,14 +85,14 @@ class UNetDecoder(Module): def __init__( self, - decoder_blocks: Sequence[Module], + blocks: Sequence[Module], up_blocks: Sequence[Module], concat_blocks: Sequence[Module], last_block: Module, ) -> None: """Initialize the UNetDecoder.""" super().__init__() - self.decoder_blocks = ModuleList(decoder_blocks) + self.blocks = ModuleList(blocks) """The decoder blocks. Order is lowest resolution to highest resolution.""" self.up_blocks = ModuleList(up_blocks) @@ -110,13 +113,10 @@ def forward(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = N call = partial(call_with_cond, cond=cond) x = hs[-1] # lowest resolution, from middle block - for block, up, concat, h in zip( - self.decoder_blocks, self.up_blocks, self.concat_blocks, hs[-2::-1], strict=True - ): + for block, up, concat, h in zip(self.blocks, self.up_blocks, self.concat_blocks, hs[-2::-1], strict=True): x = call(up, x) - x = concat(x, h) + x = concat(h, x) x = call(block, x) - x = call(self.last_block, x) return x @@ -195,11 +195,51 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T return super().__call__(x, cond=cond) +class BasicUNet(UNetBase): + """Basic UNet. + + A Basic UNet with residual blocks, convolutional downsampling, and nearest neighbor upsampling. + + + """ + + def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Sequence[int], cond_dim: int): + """Initialize the BasicUNet.""" + encoder_blocks: list[Module] = [] + decoder_blocks: list[Module] = [] + down_blocks: list[Module] = [] + up_blocks: list[Module] = [] + concat_blocks: list[Module] = [] + for n_feat, n_feat_next in pairwise(n_features): + encoder_blocks.append(ResBlock(dim, n_feat, n_feat, cond_dim)) + decoder_blocks.append(ResBlock(dim, 2 * n_feat, n_feat, cond_dim)) + down_blocks.append(ConvND(dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) + up_blocks.append(Sequential(Upsample(dim, scale_factor=2), ConvND(dim)(n_feat_next, n_feat, 3, padding=1))) + concat_blocks.append(Concat()) + up_blocks = up_blocks[::-1] + decoder_blocks = decoder_blocks[::-1] + first_block = ConvND(dim)(channels_in, n_features[0], 3, padding=1) + last_block = Sequential( + GroupNorm(n_features[0]), SiLU(), ConvND(dim)(n_features[0], channels_out, 3, padding=1) + ) + middle_block = ResBlock(dim, n_features[-1], n_features[-1], cond_dim) + encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) + decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + super().__init__(encoder, decoder) + + class UNet(UNetBase): """UNet. - U-shaped convolutional network [UNET]_ with optional patch attention. - Inspired by the OpenAi DDPM UNet/Latent Diffusion UNet [LDM]_. + U-shaped convolutional network with optional patch attention. + Inspired by the OpenAi DDPM UNet/Latent Diffusion UNet [LDM]_, + significant differences to the vanilla UNet [UNET]_ include: + - Spatial attention + - Multiple skip connections per resolution + - Convolutional downsampling, nearest neighbor upsampling + - Residual convolution blocks + - Group normalization + - SiLU activation References ---------- @@ -240,12 +280,10 @@ def block(channels_in: int, channels_out: int, attention: bool) -> Module: return Sequential(ResBlock(dim, channels_in, channels_out, cond_dim), attention_block(channels_out)) first_block = ConvND(dim)(in_channels, n_features[0], 3, padding=1) - encoder_blocks: list[Module] = [] down_blocks: list[Module] = [] skip_features = [] n_feat_old = n_features[0] - for i_level, n_feat in enumerate(n_features): encoder_blocks.append(Identity()) skip_features.append(n_feat_old) @@ -256,14 +294,12 @@ def block(channels_in: int, channels_out: int, attention: bool) -> Module: skip_features.append(n_feat_old) down_blocks.append(ConvND(dim)(n_feat, n_feat, 3, stride=2, padding=1)) down_blocks[-1] = Identity() - middle_block = Sequential( ResBlock(dim, n_features[-1], n_features[-1], cond_dim), ResBlock(dim, n_features[-1], n_features[-1], cond_dim), ) if i_level in attention_depths: middle_block.insert(1, attention_block(n_features[-1])) - encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) decoder_blocks: list[Module] = [] @@ -283,25 +319,60 @@ def block(channels_in: int, channels_out: int, attention: bool) -> Module: n_feat_old = n_feat up_blocks.append(Upsample(dim, scale_factor=2)) up_blocks.pop() - concat_blocks = [Concat()] * len(decoder_blocks) last_block = Sequential( GroupNorm(n_features[0]), SiLU(), ConvND(dim)(n_features[0], out_channels, 3, padding=1) ) - decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + super().__init__(encoder, decoder) -class AttentionUNet(UNet): +class AttentionGatedUNet(UNetBase): """UNet with attention gates. + Basic UNet with attention gating of the skip signals by the lower resolution features [OKT18]_. + References ---------- .. [OKT18] Oktay, Ozan, et al. "Attention U-net: Learning where to look for the pancreas." MIDL (2018). https://arxiv.org/abs/1804.03999 """ + def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Sequence[int], cond_dim: int = 0): + def block(channels_in: int, channels_out: int) -> Module: + block = Sequential( + ConvND(dim)(channels_in, channels_out, 3, padding=1), + ReLU(True), + ConvND(dim)(channels_out, channels_out, 3, padding=1), + ReLU(True), + ) + if cond_dim > 0: + block.insert(2, FiLM(cond_dim)) + return block + + encoder_blocks: list[Module] = [] + down_blocks: list[Module] = [] + n_feat_old = channels_in + for n_feat in n_features[:-1]: + encoder_blocks.append(block(n_feat_old, n_feat)) + down_blocks.append(MaxPoolND(dim)(2)) + n_feat_old = n_feat + middle_block = block(n_features[-2], n_features[-1]) + encoder = UNetEncoder(Identity(), encoder_blocks, down_blocks, middle_block) + + concat_blocks = [] + decoder_blocks: list[Module] = [] + up_blocks: list[Module] = [] + for n_feat, n_feat_skip in pairwise(n_features[::-1]): + concat_blocks.append(AttentionGate(dim, n_feat, n_feat_skip, n_feat_skip, concatenate=True)) + decoder_blocks.append(block(n_feat + n_feat_skip, n_feat_skip)) + up_blocks.append(Upsample(dim, scale_factor=2)) + last_block = ConvND(dim)(n_features[0], channels_out, 1) + decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + + super().__init__(encoder, decoder) + class SeparableUNet(UNetBase): """UNet where blocks apply separable convolutions in different dimensions. diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index ae915d2b7..eaa6cc089 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -182,7 +182,7 @@ def blocks(n_heads: int, p_droppath: float = 0.0): middle_block = blocks(n_heads=n_heads[-1], p_droppath=max_droppath_rate) encoder = UNetEncoder( first_block=first_block, - encoder_blocks=encoder_blocks, + blocks=encoder_blocks, down_blocks=down_blocks, middle_block=middle_block, ) @@ -204,7 +204,7 @@ def blocks(n_heads: int, p_droppath: float = 0.0): 2 * n_channels_per_head * n_heads[0], channels_out, kernel_size=3, stride=1, padding='same' ) decoder = UNetDecoder( - decoder_blocks=decoder_blocks, + blocks=decoder_blocks, concat_blocks=concat_blocks, up_blocks=up_blocks, last_block=last_block, diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 02d5a449f..6f540e118 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -2,11 +2,11 @@ from mrpro.nn.nets.Uformer import Uformer from mrpro.nn.nets.DCAE import DCVAE from mrpro.nn.nets.VAE import VAE -from mrpro.nn.nets.UNet import UNet, AttentionUNet +from mrpro.nn.nets.UNet import UNet, AttentionGatedUNet from mrpro.nn.nets.SwinIR import SwinIR __all__ = [ - "AttentionUNet", + "AttentionGatedUNet", "DCVAE", "Restormer", "SwinIR", From 6b942ca1dc336193a61f04ffe4619eae434292d9 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 2 Jun 2025 17:22:00 +0200 Subject: [PATCH 054/205] wip --- src/mrpro/nn/GEGLU.py | 14 +++++++ src/mrpro/nn/SpatialTransformerBlock.py | 17 ++++++++ src/mrpro/nn/Upsample.py | 51 +++++++++++++++++++---- src/mrpro/nn/encoding.py | 26 +++++++++--- src/mrpro/nn/join.py | 30 ++++++++++++++ src/mrpro/nn/nets/DCAE.py | 14 +++++++ src/mrpro/nn/nets/Restormer.py | 55 ++++++++++++++----------- src/mrpro/nn/nets/Uformer.py | 22 ++++++---- 8 files changed, 184 insertions(+), 45 deletions(-) diff --git a/src/mrpro/nn/GEGLU.py b/src/mrpro/nn/GEGLU.py index 787659e10..d2fb64354 100644 --- a/src/mrpro/nn/GEGLU.py +++ b/src/mrpro/nn/GEGLU.py @@ -40,3 +40,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.features_last: out = out.moveaxis(-1, 1) return out + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the GEGLU activation. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Activated tensor + """ + return super().__call__(x) diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py index 2c482a239..a69e5c464 100644 --- a/src/mrpro/nn/SpatialTransformerBlock.py +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -102,9 +102,26 @@ def __init__( self.proj_out = zero_init(ConvND(dim)(hidden_dim, channels, kernel_size=1, stride=1, padding=0)) def forward(self, x, cond: torch.Tensor | None = None): + """Apply the spatial transformer block.""" skip = x x = self.norm(x) x = self.proj_in(x) x = self.transformer_blocks(x, cond=cond) x = self.proj_out(x) return x + skip + + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the spatial transformer block. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor. If None, no conditioning is applied. + + Returns + ------- + Output tensor after spatial transformer + """ + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/Upsample.py b/src/mrpro/nn/Upsample.py index e17384f13..0c2169b0b 100644 --- a/src/mrpro/nn/Upsample.py +++ b/src/mrpro/nn/Upsample.py @@ -3,16 +3,53 @@ import torch from torch.nn import Module -from mrpro.utils.interpolate import interpolate - class Upsample(Module): - def __init__(self, dim: int, scale_factor: int = 2, mode: Literal['nearest', 'linear'] = 'linear'): + def __init__(self, dim: int, scale_factor: int = 2, mode: Literal['nearest', 'linear', 'cubic'] = 'linear'): + """Initialize the upsampling layer. + + Parameters + ---------- + dim + Spatial dimensions of the input tensor, i.e. 2 for 2D, 3 for 3D, etc. + scale_factor + Factor by which to upsample + mode + Interpolation mode. See `torch.nn.functional.interpolate` for details. + """ super().__init__() self.scale_factor = scale_factor - self.mode = mode - self.dim = dim + if mode == 'nearest': + self.mode = 'nearest' + elif dim == 1 and mode == 'linear': + self.mode = 'linear' + elif dim == 2 and mode == 'cubic': + self.mode = 'bicubic' + elif dim == 2 and mode == 'linear': + self.mode = 'bilinear' + elif dim == 3 and mode == 'linear': + self.mode = 'trilinear' + else: + raise ValueError(f'Invalid mode for dimension {dim}: {mode}') def forward(self, x: torch.Tensor) -> torch.Tensor: - new_size = [d * self.scale_factor for d in x.shape[self.dim :]] - return interpolate(x, size=new_size, dim=range(-self.dim, 0)) + """Upsample the input tensor.""" + return torch.nn.functional.interpolate( + x, + mode=self.mode, + scale_factor=self.scale_factor, + ) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Upsample the input tensor. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Upsampled tensor + """ + return super().__call__(x) diff --git a/src/mrpro/nn/encoding.py b/src/mrpro/nn/encoding.py index 679f38685..c65b86c61 100644 --- a/src/mrpro/nn/encoding.py +++ b/src/mrpro/nn/encoding.py @@ -35,19 +35,22 @@ def __init__(self, in_features: int, out_features: int, std: float = 1.0): super().__init__() self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def __call__(self, x: torch.Tensor) -> torch.Tensor: """Apply Fourier feature encoding. Parameters ---------- - x : torch.Tensor + x Input tensor of shape (..., in_features) Returns ------- - torch.Tensor - Encoded features of shape (..., out_features) + Encoded features of shape (..., out_features) """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply Fourier feature encoding.""" f = 2 * torch.pi * x @ self.weight.T return torch.cat([f.cos(), f.sin()], dim=-1) @@ -90,8 +93,19 @@ def __init__(self, dim: int, features: int, include_radii: bool = True, base_res self.register_buffer('encoding', torch.cat(encoding, dim=1)[:, :features]) self.interpolation_mode = ['linear', 'bilinear', 'trilinear'][dim - 1] - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass for encoding.""" + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for encoding. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Encoded tensor with absolute position information + """ features = self.encoding.shape[1] if features > x.shape[1]: raise ValueError(f'x has {x.shape[1]} features, but {features} are required') diff --git a/src/mrpro/nn/join.py b/src/mrpro/nn/join.py index 5802902a6..5f08bb371 100644 --- a/src/mrpro/nn/join.py +++ b/src/mrpro/nn/join.py @@ -57,6 +57,21 @@ def forward(self, *xs: torch.Tensor) -> torch.Tensor: xs = _fix_shapes(xs, self.mode, dim=[i for i in range(max(x.ndim for x in xs)) if i != self.dim]) return torch.cat(xs, dim=1) + def __call__(self, *xs: torch.Tensor) -> torch.Tensor: + """ + Concatenate input tensors. + + Parameters + ---------- + xs + Input tensors + + Returns + ------- + Concatenated tensor + """ + return super().__call__(*xs) + class Add(Module): """Add tensors.""" @@ -84,3 +99,18 @@ def forward(self, *xs: torch.Tensor) -> torch.Tensor: """Add input tensors.""" xs = _fix_shapes(xs, self.mode, dim=range(max(x.ndim for x in xs))) return sum(xs, start=torch.tensor(0.0)) + + def __call__(self, *xs: torch.Tensor) -> torch.Tensor: + """ + Add input tensors. + + Parameters + ---------- + xs + Input tensors + + Returns + ------- + Summed tensor + """ + return super().__call__(*xs) diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py index c12829322..036e8a022 100644 --- a/src/mrpro/nn/nets/DCAE.py +++ b/src/mrpro/nn/nets/DCAE.py @@ -99,6 +99,20 @@ def __init__( expand_ratio=expand_ratio, ) + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the EfficientViTBlock. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Output tensor + """ + return super().__call__(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass for EfficientViTBlock.""" x = self.context_module(x) diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index c56b2c4a0..f95fd2f98 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -6,6 +6,7 @@ import torch from torch.nn import Module +from mrpro.nn.CondMixin import CondMixin from mrpro.nn.FiLM import FiLM from mrpro.nn.join import Concat from mrpro.nn.ndmodules import ConvND, InstanceNormND @@ -47,18 +48,17 @@ def __init__(self, dim: int, channels: int, mlp_ratio: float): ) self.project_out = ConvND(dim)(hidden_features, channels, kernel_size=1) - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Apply gated depthwise feed forward network. + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the gated depthwise feed forward network. Parameters ---------- - x : torch.Tensor + x Input tensor Returns ------- - torch.Tensor - Output tensor + Output tensor """ x = self.project_in(x) x1, x2 = self.depthwise_conv(x).chunk(2, dim=1) @@ -67,7 +67,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class RestormerBlock(Module): +class RestormerBlock(CondMixin, Module): """Transformer block with transposed attention and gated depthwise feed forward network.""" def __init__(self, dim: int, channels: int, n_heads: int, mlp_ratio: float, cond_dim: int = 0): @@ -75,16 +75,16 @@ def __init__(self, dim: int, channels: int, n_heads: int, mlp_ratio: float, cond Parameters ---------- - dim : int + dim Dimension of the input space channels : int Number of input/output channels - n_heads : int + n_heads Number of attention heads - mlp_ratio : float + mlp_ratio Ratio for hidden dimension expansion - cond_dim : int, optional - Dimension of conditioning input + cond_dim + Dimension of conditioning input. If 0, no conditioning is applied. """ super().__init__() self.norm1 = Sequential(InstanceNormND(dim)(channels)) @@ -94,21 +94,26 @@ def __init__(self, dim: int, channels: int, n_heads: int, mlp_ratio: float, cond if cond_dim > 0: self.norm2.append(FiLM(channels=channels, cond_dim=cond_dim)) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply Restormer block. Parameters ---------- - x : torch.Tensor + x Input tensor + cond + Conditioning tensor. If None, no conditioning is applied. Returns ------- - torch.Tensor Output tensor """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Forward pass for RestormerBlock.""" x = x + self.attn(self.norm1(x)) - x = x + self.ffn(self.norm2(x)) + x = x + self.ffn(self.norm2(x, cond=cond)) return x @@ -140,24 +145,24 @@ def __init__( Parameters ---------- - dim : int + dim Dimension of the input space - channels_in : int + channels_in Number of input channels - channels_out : int + channels_out Number of output channels - n_blocks : Sequence[int], optional + n_blocks Number of blocks in each stage - n_refinement_blocks : int, optional + n_refinement_blocks Number of refinement blocks - n_heads : Sequence[int], optional + n_heads Number of attention heads in each stage - n_channels_per_head : int, optional + n_channels_per_head Number of channels per attention head - mlp_ratio : float, optional + mlp_ratio Ratio for hidden dimension expansion - cond_dim : int, optional - Dimension of conditioning input + cond_dim + Dimension of conditioning input. If 0, no conditioning is applied. """ def blocks(n_heads: int, n_blocks: int): diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index eaa6cc089..40ec3680d 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -79,13 +79,21 @@ def __init__( torch.nn.init.trunc_normal_(self.modulator) self.drop_path = DropPath(droprate=p_droppath) - def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: - """Apply the transformer block.""" - modulator = self.modulator.tile([t // s for t, s in zip(x.shape[1:], self.modulator.shape, strict=False)]) - x_mod = self.norm1(x) + modulator - x_attn = self.attn(x_mod) - x_ff = self.ff(self.norm2(x_attn), cond=cond) - return x + self.drop_path(x_ff) + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the transformer block. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor + + Returns + ------- + Output tensor + """ + return super().__call__(x, cond=cond) class Uformer(UNetBase): From ea3109ef8caff466e4d4cb680cc65f992e2489f8 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 2 Jun 2025 22:56:20 +0200 Subject: [PATCH 055/205] update --- src/mrpro/nn/SpatialTransformerBlock.py | 75 +++++++++++++++++++++---- src/mrpro/nn/Upsample.py | 4 ++ src/mrpro/nn/nets/UNet.py | 18 +++++- src/mrpro/utils/ema.py | 6 +- tests/nn/test_film.py | 1 + 5 files changed, 89 insertions(+), 15 deletions(-) diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py index a69e5c464..7be73de05 100644 --- a/src/mrpro/nn/SpatialTransformerBlock.py +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -21,6 +21,8 @@ def zero_init(m: Module) -> Module: class BasicTransformerBlock(Module): + """Basic vision transformer block.""" + def __init__( self, channels: int, @@ -30,6 +32,23 @@ def __init__( mlp_ratio: float = 4, features_last: bool = False, ): + """Initialize the basic transformer block. + + Parameters + ---------- + channels + Number of channels in the input and output. + n_heads + Number of attention heads. + p_dropout + Dropout probability. + cond_dim + Number of channels in the conditioning tensor. + mlp_ratio + Ratio of the hidden dimension to the input dimension. + features_last + Whether the features are last in the input tensor. + """ super().__init__() self.features_last = features_last self.selfattention = Sequential( @@ -66,7 +85,20 @@ def __init__( ) self.cond_dim = cond_dim - def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the basic transformer block. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. If None, no conditioning is applied. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the basic transformer block.""" if not self.features_last: x = x.moveaxis(1, -1) x = self.selfattention(x) + x @@ -80,6 +112,8 @@ def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Te class SpatialTransformerBlock(Module): + """Spatial transformer block.""" + def __init__( self, dim: int, @@ -90,6 +124,25 @@ def __init__( dropout: float = 0.0, cond_dim: int = 0, ): + """Initialize the spatial transformer block. + + Parameters + ---------- + dim + Spatial dimension of the input tensor. + channels + Number of channels in the input and output. + n_heads + Number of attention heads. + channels_per_head + Number of channels per attention head. + depth + Number of transformer blocks. + dropout + Dropout probability. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + """ super().__init__() self.in_channels = channels hidden_dim = n_heads * channels_per_head @@ -101,16 +154,7 @@ def __init__( self.proj_out = zero_init(ConvND(dim)(hidden_dim, channels, kernel_size=1, stride=1, padding=0)) - def forward(self, x, cond: torch.Tensor | None = None): - """Apply the spatial transformer block.""" - skip = x - x = self.norm(x) - x = self.proj_in(x) - x = self.transformer_blocks(x, cond=cond) - x = self.proj_out(x) - return x + skip - - def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the spatial transformer block. Parameters @@ -125,3 +169,12 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T Output tensor after spatial transformer """ return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the spatial transformer block.""" + skip = x + x = self.norm(x) + x = self.proj_in(x) + x = self.transformer_blocks(x, cond=cond) + x = self.proj_out(x) + return x + skip diff --git a/src/mrpro/nn/Upsample.py b/src/mrpro/nn/Upsample.py index 0c2169b0b..74462fbff 100644 --- a/src/mrpro/nn/Upsample.py +++ b/src/mrpro/nn/Upsample.py @@ -1,3 +1,5 @@ +"""Upsampling by interpolation.""" + from typing import Literal import torch @@ -5,6 +7,8 @@ class Upsample(Module): + """Upsampling by interpolation.""" + def __init__(self, dim: int, scale_factor: int = 2, mode: Literal['nearest', 'linear', 'cubic'] = 'linear'): """Initialize the upsampling layer. diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 0b1c74e03..e5af1e16b 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -340,6 +340,22 @@ class AttentionGatedUNet(UNetBase): """ def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Sequence[int], cond_dim: int = 0): + """Initialize the AttentionGatedUNet. + + Parameters + ---------- + dim + Spatial dimension of the input tensor. + channels_in + Number of channels in the input tensor. + channels_out + Number of channels in the output tensor. + n_features + Number of features at each resolution level. The length determines the number of resolution levels. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + """ + def block(channels_in: int, channels_out: int) -> Module: block = Sequential( ConvND(dim)(channels_in, channels_out, 3, padding=1), @@ -348,7 +364,7 @@ def block(channels_in: int, channels_out: int) -> Module: ReLU(True), ) if cond_dim > 0: - block.insert(2, FiLM(cond_dim)) + block.insert(2, FiLM(channels_out, cond_dim)) return block encoder_blocks: list[Module] = [] diff --git a/src/mrpro/utils/ema.py b/src/mrpro/utils/ema.py index b45cc4d27..28e840016 100644 --- a/src/mrpro/utils/ema.py +++ b/src/mrpro/utils/ema.py @@ -30,13 +30,13 @@ def __init__( self.decay: float = decay if not 0 <= decay <= 1: raise ValueError(f'Decay must be between 0 and 1, got {decay}') - self._data: dict[str, Any] = dict() + self._data: dict[str, Any] = {} - def __getitem__(self, key: str) -> Any: + def __getitem__(self, key: str) -> Any: # noqa: ANN401 """Get the value of the EMA dict for a given key.""" return self._data[key] - def __setitem__(self, key: str, value: Any) -> None: + def __setitem__(self, key: str, value: Any) -> None: # noqa: ANN401 """Set the value of the EMA dict for a given key.""" if key in self._data: old_v = self._data[key] diff --git a/tests/nn/test_film.py b/tests/nn/test_film.py index 0106aa1cb..d49cd476b 100644 --- a/tests/nn/test_film.py +++ b/tests/nn/test_film.py @@ -34,4 +34,5 @@ def test_film(channels, channels_cond, input_shape, cond_shape, device): assert not cond.isnan().any(), 'NaN values in condedding' assert not x.grad.isnan().any(), 'NaN values in input gradients' assert not cond.grad.isnan().any(), 'NaN values in condedding gradients' + assert film.project is not None, 'Linear layer is not initialized' assert next(film.project.parameters()).grad is not None, 'No gradient computed for Linear layer' From 4c1ec0feaa67ca540f08126c35e756b88dfb0dba Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 3 Jun 2025 17:20:25 +0200 Subject: [PATCH 056/205] Refactor parameter documentation in encoding and normalization modules - Updated parameter documentation in FourierFeatures, AbsolutePositionEncoding, GEGLU, and LayerNorm classes to remove type hints from docstrings for consistency. - Enhanced clarity in parameter descriptions while maintaining the overall structure of the documentation. --- src/mrpro/nn/GEGLU.py | 4 ++-- src/mrpro/nn/LayerNorm.py | 2 +- src/mrpro/nn/encoding.py | 14 +++++++------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/mrpro/nn/GEGLU.py b/src/mrpro/nn/GEGLU.py index d2fb64354..0310c6a76 100644 --- a/src/mrpro/nn/GEGLU.py +++ b/src/mrpro/nn/GEGLU.py @@ -17,9 +17,9 @@ def __init__(self, in_features: int, out_features: int | None = None, features_l Parameters ---------- - in_features : int + in_features The number of input features. - out_features : int + out_features The number of output features. If None, the number of output features is the same as the number of input features. features_last diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py index 07f9ba593..75ada98d3 100644 --- a/src/mrpro/nn/LayerNorm.py +++ b/src/mrpro/nn/LayerNorm.py @@ -36,7 +36,7 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: Parameters ---------- - x : torch.Tensor + x Input tensor Returns diff --git a/src/mrpro/nn/encoding.py b/src/mrpro/nn/encoding.py index c65b86c61..39f48c51e 100644 --- a/src/mrpro/nn/encoding.py +++ b/src/mrpro/nn/encoding.py @@ -23,11 +23,11 @@ def __init__(self, in_features: int, out_features: int, std: float = 1.0): Parameters ---------- - in_features : int + in_features Number of input features - out_features : int + out_features Number of output features (must be even) - std : float, optional + std Standard deviation for random initialization """ if out_features % 2 != 0: @@ -68,13 +68,13 @@ def __init__(self, dim: int, features: int, include_radii: bool = True, base_res Parameters ---------- - dim : int + dim Dimension of the input space (1, 2, or 3) - features : int + features Number of output features - include_radii : bool, optional + include_radii Whether to include radius features - base_resolution : int, optional + base_resolution Base resolution for position encoding """ super().__init__() From e20d6f7edaa50ef6de13877fea649eac940cf777 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 3 Jun 2025 23:31:04 +0200 Subject: [PATCH 057/205] wip --- src/mrpro/nn/GluMBConvResBlock.py | 18 +- src/mrpro/nn/SpatialTransformerBlock.py | 6 +- src/mrpro/nn/nets/UNet.py | 208 +++++++++++++++++++++--- 3 files changed, 203 insertions(+), 29 deletions(-) diff --git a/src/mrpro/nn/GluMBConvResBlock.py b/src/mrpro/nn/GluMBConvResBlock.py index 02b623dfa..3eaf3b9d4 100644 --- a/src/mrpro/nn/GluMBConvResBlock.py +++ b/src/mrpro/nn/GluMBConvResBlock.py @@ -90,7 +90,23 @@ def __init__( else: self.film = None - def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply MBConv block. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. If None, no conditioning is applied. + + Returns + ------- + Output tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply MBConv block.""" h = self.inverted_conv(x) h = self.depth_conv(h) diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py index 7be73de05..eb37daaff 100644 --- a/src/mrpro/nn/SpatialTransformerBlock.py +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -3,12 +3,14 @@ import torch from torch.nn import Dropout, Linear, Module +from mrpro.nn.CondMixin import CondMixin from mrpro.nn.GEGLU import GEGLU from mrpro.nn.GroupNorm import GroupNorm from mrpro.nn.LayerNorm import LayerNorm from mrpro.nn.MultiHeadAttention import MultiHeadAttention from mrpro.nn.ndmodules import ConvND from mrpro.nn.Sequential import Sequential +from mrpro.nn.CondMixin import CondMixin def zero_init(m: Module) -> Module: @@ -20,7 +22,7 @@ def zero_init(m: Module) -> Module: return m -class BasicTransformerBlock(Module): +class BasicTransformerBlock(CondMixin, Module): """Basic vision transformer block.""" def __init__( @@ -111,7 +113,7 @@ def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch return x -class SpatialTransformerBlock(Module): +class SpatialTransformerBlock(CondMixin, Module): """Spatial transformer block.""" def __init__( diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index e5af1e16b..7419e71ae 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -200,7 +200,10 @@ class BasicUNet(UNetBase): A Basic UNet with residual blocks, convolutional downsampling, and nearest neighbor upsampling. - + References + ---------- + .. [UNET] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image + segmentation MICCAI 2015. https://arxiv.org/abs/1505.04597 """ def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Sequence[int], cond_dim: int): @@ -234,12 +237,10 @@ class UNet(UNetBase): U-shaped convolutional network with optional patch attention. Inspired by the OpenAi DDPM UNet/Latent Diffusion UNet [LDM]_, significant differences to the vanilla UNet [UNET]_ include: - - Spatial attention + - Spatial transformer blocks - Multiple skip connections per resolution - Convolutional downsampling, nearest neighbor upsampling - - Residual convolution blocks - - Group normalization - - SiLU activation + - Residual convolution blocks with group normalization and SiLU activation References ---------- @@ -251,15 +252,35 @@ class UNet(UNetBase): def __init__( self, dim: int, - in_channels: int, - out_channels: int, + channels_in: int, + channels_out: int, attention_depths: Sequence[int] = (-1, -2), n_features: Sequence[int] = (64, 128, 192, 256), n_heads: int = 4, cond_dim: int = 0, encoder_blocks_per_scale: int = 2, ) -> None: - """Initialize the UNet.""" + """Initialize the UNet. + + Parameters + ---------- + dim + Spatial dimension of the input tensor. + channels_in + Number of channels in the input tensor. + channels_out + Number of channels in the output tensor. + attention_depths + The depths at which to apply attention. + n_features + Number of features at each resolution level. The length determines the number of resolution levels. + n_heads + Number of attention heads. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + encoder_blocks_per_scale + Number of encoder blocks per resolution level. The number of decoder blocks is one more. + """ depth = len(n_features) if not all(-depth <= d < depth for d in attention_depths): raise ValueError( @@ -279,7 +300,7 @@ def block(channels_in: int, channels_out: int, attention: bool) -> Module: return ResBlock(dim, channels_in, channels_out, cond_dim) return Sequential(ResBlock(dim, channels_in, channels_out, cond_dim), attention_block(channels_out)) - first_block = ConvND(dim)(in_channels, n_features[0], 3, padding=1) + first_block = ConvND(dim)(channels_in, n_features[0], 3, padding=1) encoder_blocks: list[Module] = [] down_blocks: list[Module] = [] skip_features = [] @@ -293,7 +314,7 @@ def block(channels_in: int, channels_out: int, attention: bool) -> Module: down_blocks.append(Identity()) skip_features.append(n_feat_old) down_blocks.append(ConvND(dim)(n_feat, n_feat, 3, stride=2, padding=1)) - down_blocks[-1] = Identity() + down_blocks[-1] = Identity() # no downsampling after the last resolution level middle_block = Sequential( ResBlock(dim, n_features[-1], n_features[-1], cond_dim), ResBlock(dim, n_features[-1], n_features[-1], cond_dim), @@ -313,15 +334,14 @@ def block(channels_in: int, channels_out: int, attention: bool) -> Module: decoder_blocks.append( block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) ) - n_feat_old = n_feat - up_blocks.append(Identity()) - n_feat_old = n_feat up_blocks.append(Upsample(dim, scale_factor=2)) - up_blocks.pop() + up_blocks.pop() # no upsampling after the last resolution level concat_blocks = [Concat()] * len(decoder_blocks) last_block = Sequential( - GroupNorm(n_features[0]), SiLU(), ConvND(dim)(n_features[0], out_channels, 3, padding=1) + GroupNorm(n_features[0]), + SiLU(), + ConvND(dim)(n_features[0], channels_out, 3, padding=1), ) decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) @@ -390,17 +410,153 @@ def block(channels_in: int, channels_out: int) -> Module: super().__init__(encoder, decoder) -class SeparableUNet(UNetBase): - """UNet where blocks apply separable convolutions in different dimensions. +from einops import rearrange - Based on the pseudo-3D residual network of [QUI]_, [TRAN]_ and the residual blocks of [ZIM]_. - References - ---------- - .. [TRAN] Tran, D., Wang, H., Torresani, L., Ray, J., LeCun, Y., & Paluri, M. A closer look at spatiotemporal - convolutions for action recognition. CVPR 2018. https://arxiv.org/abs/1711.11248 - .. [QUI] Qiu, Z., Yao, T., & Mei, T. Learning spatio-temporal representation with pseudo-3d residual networks. - ICCV 2017. https://arxiv.org/abs/1711.10305 - .. [ZIM] Zimmermann, F. F., & Kofler, A. (2023, October). NoSENSE: Learned unrolled cardiac MRI reconstruction - without explicit sensitivity maps. STACOM MICCAI 2023. https://arxiv.org/abs/2309.15608 +class SpatioTemporalBlock(Module): + """Spatio-temporal block. + + Applies first a spatial block then a temporal block. + In the spatial block, the time dimension is a batch dimension, + in the temporal block, the spatial dimensions are a batch dimension. """ + + def __init__(self, spatial_block: Module, temporal_block: Module): + """Initialize the SpatioTemporalBlock.""" + super().__init__() + self.spatial_block = spatial_block + self.temporal_block = temporal_block + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + batchsize = x.shape[0] + x = rearrange(x, 'batch channel time ... -> (batch time) channel ...') + x = call_with_cond(self.spatial_block, x, cond=cond) + spatial_shape = x.shape[2:] + x = rearrange(x, '(batch time) channel ... -> (batch ...) channel time', batch=batchsize) + x = call_with_cond(self.temporal_block, x, cond=cond) + x = rearrange(x, '(batch spatial) channel time -> batch channel time spatial').unflatten(-1, spatial_shape) + return x + + +# class SpatioTemporalUNet(UNetBase): +# """UNet where blocks apply separable convolutions in different dimensions. +# U-shaped convolutional network with optional patch attention. +# Inspired by the OpenAi DDPM UNet/Latent Diffusion UNet [UNET]_, [LDM]_, +# Based on the pseudo-3D residual network of [QUI]_, [TRAN]_, [HO]_, and the residual blocks of [ZIM]_. + +# References +# ---------- +# .. [UNET] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image +# segmentation MICCAI 2015. https://arxiv.org/abs/1505.04597 +# .. [LDM] https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py +# .. [TRAN] Tran, D., Wang, H., Torresani, L., Ray, J., LeCun, Y., & Paluri, M. A closer look at spatiotemporal +# convolutions for action recognition. CVPR 2018. https://arxiv.org/abs/1711.11248 +# .. [QUI] Qiu, Z., Yao, T., & Mei, T. Learning spatio-temporal representation with pseudo-3d residual networks. +# ICCV 2017. https://arxiv.org/abs/1711.10305 +# .. [HO] Ho, J., Salimans, T., Gritsenko, A., Chan, W., Norouzi, M., & Fleet, D. J. Video diffusion models. +# NeurIPS 2022. https://arxiv.org/abs/2209.11168 +# .. [ZIM] Zimmermann, F. F., & Kofler, A. (2023, October). NoSENSE: Learned unrolled cardiac MRI reconstruction +# without explicit sensitivity maps. STACOM MICCAI 2023. https://arxiv.org/abs/2309.15608 +# """ + + +# def __init__( +# self, +# dim: int, +# in_channels: int, +# out_channels: int, +# attention_depths: Sequence[int] = (-1, -2), +# n_features: Sequence[int] = (64, 128, 192, 256), +# n_heads: int = 4, +# cond_dim: int = 0, +# encoder_blocks_per_scale: int = 2, +# temporal_downsampling: bool = False, +# ) -> None: +# """Initialize the UNet. + +# Parameters +# ---------- +# dim +# Spatial dimension of the input tensor. +# channels_in +# Number of channels in the input tensor. +# channels_out +# Number of channels in the output tensor. +# attention_depths +# The depths at which to apply attention. +# n_features +# Number of features at each resolution level. The length determines the number of resolution levels. +# n_heads +# Number of attention heads. +# cond_dim +# Number of channels in the conditioning tensor. If 0, no conditioning is applied. +# encoder_blocks_per_scale +# Number of encoder blocks per resolution level. The number of decoder blocks is one more. +# temporal_downsampling +# Whether to downsample the temporal dimension. +# """ +# depth = len(n_features) +# if not all(-depth <= d < depth for d in attention_depths): +# raise ValueError( +# f'attention_depths must be in the range [-depth, depth], got {attention_depths=} for {depth=}' +# ) +# attention_depths = tuple(d % depth for d in attention_depths) +# if len(attention_depths) != len(set(attention_depths)): +# raise ValueError(f'attention_depths must be unique, got {attention_depths=}') + +# def attention_block(channels: int) -> Module: +# SpatioTemporalBlock(SpatialTransformerBlock( +# dim, channels, n_heads, channels_per_head=channels // n_heads, cond_dim=cond_dim +# ) + +# def block(channels_in: int, channels_out: int, attention: bool) -> Module: +# if not attention: +# return ResBlock(dim, channels_in, channels_out, cond_dim) +# return Sequential(ResBlock(dim, channels_in, channels_out, cond_dim), attention_block(channels_out)) + +# first_block = ConvND(dim)(in_channels, n_features[0], 3, padding=1) +# encoder_blocks: list[Module] = [] +# down_blocks: list[Module] = [] +# skip_features = [] +# n_feat_old = n_features[0] +# for i_level, n_feat in enumerate(n_features): +# encoder_blocks.append(Identity()) +# skip_features.append(n_feat_old) +# for _ in range(encoder_blocks_per_scale): +# encoder_blocks.append(block(n_feat_old, n_feat, attention=i_level in attention_depths)) +# n_feat_old = n_feat +# down_blocks.append(Identity()) +# skip_features.append(n_feat_old) +# down_blocks.append(ConvND(dim)(n_feat, n_feat, 3, stride=2, padding=1)) +# down_blocks[-1] = Identity() # no downsampling after the last resolution level +# middle_block = Sequential( +# ResBlock(dim, n_features[-1], n_features[-1], cond_dim), +# ResBlock(dim, n_features[-1], n_features[-1], cond_dim), +# ) +# if i_level in attention_depths: +# middle_block.insert(1, attention_block(n_features[-1])) +# encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) + +# decoder_blocks: list[Module] = [] +# up_blocks: list[Module] = [Identity()] +# for i_level, n_feat in reversed(list(enumerate(n_features))): +# decoder_blocks.append( +# block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) +# ) +# n_feat_old = n_feat +# for _ in range(encoder_blocks_per_scale): +# decoder_blocks.append( +# block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) +# ) +# up_blocks.append(Identity()) +# up_blocks.append(Upsample(dim, scale_factor=2)) +# up_blocks.pop() # no upsampling after the last resolution level +# concat_blocks = [Concat()] * len(decoder_blocks) +# last_block = Sequential( +# GroupNorm(n_features[0]), +# SiLU(), +# ConvND(dim)(n_features[0], out_channels, 3, padding=1), +# ) +# decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + +# super().__init__(encoder, decoder) From c7e588ec223cca3ebf2f7aec3ac148ac498b6b17 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 4 Jun 2025 02:13:22 +0200 Subject: [PATCH 058/205] Refactor AttentionGate and DropPath modules for improved functionality - Updated AttentionGate to ensure consistent interpolation of gate tensors regardless of shape. - Modified DropPath to conditionally scale the mask based on the keep probability, enhancing flexibility in dropout behavior. - Enhanced _fix_shapes function in join.py to support new interpolation modes (linear and nearest) for better tensor shape handling. - Improved documentation in Concat class to reflect new interpolation options and ensure clarity in parameter descriptions. --- src/mrpro/nn/AttentionGate.py | 4 ++-- src/mrpro/nn/DropPath.py | 6 ++--- src/mrpro/nn/LinearSelfAttention.py | 2 +- src/mrpro/nn/NeighborhoodSelfAttention.py | 14 +++++------ src/mrpro/nn/PixelShuffle.py | 2 ++ src/mrpro/nn/RMSNorm.py | 2 +- src/mrpro/nn/Sequential.py | 3 ++- src/mrpro/nn/SpatialTransformerBlock.py | 4 +++- src/mrpro/nn/join.py | 23 ++++++++++++------ src/mrpro/nn/nets/DCAE.py | 6 ++--- src/mrpro/nn/nets/UNet.py | 2 +- src/mrpro/nn/nets/Uformer.py | 29 +++++++++++++++-------- tests/nn/test_resblock.py | 6 ++--- 13 files changed, 62 insertions(+), 41 deletions(-) diff --git a/src/mrpro/nn/AttentionGate.py b/src/mrpro/nn/AttentionGate.py index e7dd40d6e..1d57fe5ee 100644 --- a/src/mrpro/nn/AttentionGate.py +++ b/src/mrpro/nn/AttentionGate.py @@ -63,10 +63,10 @@ def forward(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: """Apply the attention gate.""" projected_gate = self.project_gate(gate) projected_x = self.project_x(x) - if gate.shape[2:] != x.shape[2:]: - projected_gate = torch.nn.functional.interpolate(projected_gate, size=x.shape[2:], mode='nearest') + projected_gate = torch.nn.functional.interpolate(projected_gate, size=x.shape[2:], mode='nearest') alpha = self.psi(projected_gate + projected_x) x = x * alpha if self.concatenate: + gate = torch.nn.functional.interpolate(gate, size=x.shape[2:], mode='nearest') x = torch.cat([x, gate], dim=1) return x diff --git a/src/mrpro/nn/DropPath.py b/src/mrpro/nn/DropPath.py index 2d0abba1e..b1314904e 100644 --- a/src/mrpro/nn/DropPath.py +++ b/src/mrpro/nn/DropPath.py @@ -49,7 +49,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.droprate == 0 or not self.training: return x shape = (x.shape[0],) + (1,) * (x.ndim - 1) - mask = ( - ((1 - self.droprate) + torch.rand(shape, dtype=x.dtype, device=x.device)).floor_().div_(1 - self.droprate) - ) + mask = ((1 - self.droprate) + torch.rand(shape, dtype=x.dtype, device=x.device)).floor_() + if self.scale_by_keep: + mask = mask.div_(1 - self.droprate) return x * mask diff --git a/src/mrpro/nn/LinearSelfAttention.py b/src/mrpro/nn/LinearSelfAttention.py index 12dbe1718..612f0ffe2 100644 --- a/src/mrpro/nn/LinearSelfAttention.py +++ b/src/mrpro/nn/LinearSelfAttention.py @@ -100,9 +100,9 @@ def forward(self, x: Tensor) -> Tensor: value_key_query = value_key @ query normalization = value_key_query[..., -1:, :] + self.eps attn = value_key_query[..., :-1, :] / normalization + attn = attn.moveaxis(1, -1).flatten(-2) # join heads and channels out = self.to_out(attn) out = out.to(orig_dtype) - out = out.moveaxis(1, -1).flatten(-2) # join heads and channels out = out.unflatten(-2, spatial_shape) if not self.features_last: out = out.moveaxis(-1, 1) diff --git a/src/mrpro/nn/NeighborhoodSelfAttention.py b/src/mrpro/nn/NeighborhoodSelfAttention.py index 3c58bcfc1..91151e4bb 100644 --- a/src/mrpro/nn/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/NeighborhoodSelfAttention.py @@ -17,9 +17,9 @@ @cache def neighborhood_mask( input_size: torch.Size, - kernel_size: int | Sequence[int], - dilation: int | Sequence[int] = 1, - circular: bool | Sequence[bool] = False, + kernel_size: int | tuple[int, ...], # tuples instead of Sequence for cache + dilation: int | tuple[int, ...] = 1, + circular: bool | tuple[bool, ...] = False, ) -> BlockMask: """Create a flex attention block mask for neighborhood attention. @@ -152,9 +152,9 @@ def __init__( """ super().__init__() self.n_head = n_heads - self.kernel_size = kernel_size - self.dilation = dilation - self.circular = circular + self.kernel_size = kernel_size if isinstance(kernel_size, int) else tuple(kernel_size) + self.dilation = dilation if isinstance(dilation, int) else tuple(dilation) + self.circular = circular if isinstance(circular, bool) else tuple(circular) self.features_last = features_last channels_per_head = channels_in // n_heads self.to_qkv = Linear(channels_in, 3 * channels_per_head * n_heads) @@ -174,7 +174,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ if not self.features_last: x = x.moveaxis(1, -1) - spatial_shape = x.shape[2:-1] + spatial_shape = x.shape[1:-1] qkv = self.to_qkv(x) query, key, value = rearrange( qkv, 'batch ... (qkv head channels) -> qkv batch head (...) channel', qkv=3, head=self.n_head diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py index 09f6d7ab6..70b0270df 100644 --- a/src/mrpro/nn/PixelShuffle.py +++ b/src/mrpro/nn/PixelShuffle.py @@ -49,6 +49,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: new_shape = list(x.shape[:2]) source_positions = [] for i, old in enumerate(x.shape[2:]): + if old % self.downscale_factor: + raise ValueError('Spatial size must be divisible by downscale_factor.') new_shape.append(old // self.downscale_factor) new_shape.append(self.downscale_factor) source_positions.append(2 + 2 * i) diff --git a/src/mrpro/nn/RMSNorm.py b/src/mrpro/nn/RMSNorm.py index 0b184ebac..28cecbf9f 100644 --- a/src/mrpro/nn/RMSNorm.py +++ b/src/mrpro/nn/RMSNorm.py @@ -39,7 +39,7 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: ------- Normalized tensor. """ - return self.forward(x) + return super().__call__(x) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply RMSNorm over the channel dimension.""" diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py index aaad42b52..33842884c 100644 --- a/src/mrpro/nn/Sequential.py +++ b/src/mrpro/nn/Sequential.py @@ -1,6 +1,7 @@ """Sequential container with support for conditioning and Operators.""" from collections import OrderedDict +from typing import cast import torch @@ -35,7 +36,7 @@ def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T """Apply all modules in series to the input.""" for module in self: if isinstance(module, Operator): - x = module(*x) + x = cast(tuple[torch.Tensor, ...], module(*x)) # always tuple else: ret: torch.Tensor | tuple[torch.Tensor, ...] if isinstance(module, CondMixin): diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py index eb37daaff..40501dc2f 100644 --- a/src/mrpro/nn/SpatialTransformerBlock.py +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -151,7 +151,9 @@ def __init__( self.norm = GroupNorm(channels) self.proj_in = ConvND(dim)(channels, hidden_dim, kernel_size=1, stride=1, padding=0) - blocks = [BasicTransformerBlock(channels, n_heads, p_dropout=dropout, cond_dim=cond_dim) for _ in range(depth)] + blocks = [ + BasicTransformerBlock(hidden_dim, n_heads, p_dropout=dropout, cond_dim=cond_dim) for _ in range(depth) + ] self.transformer_blocks = Sequential(*blocks) self.proj_out = zero_init(ConvND(dim)(hidden_dim, channels, kernel_size=1, stride=1, padding=0)) diff --git a/src/mrpro/nn/join.py b/src/mrpro/nn/join.py index 5f08bb371..2749a7397 100644 --- a/src/mrpro/nn/join.py +++ b/src/mrpro/nn/join.py @@ -7,20 +7,25 @@ from torch.nn import Module from mrpro.utils.pad_or_crop import pad_or_crop +from mrpro.utils.interpolate import interpolate def _fix_shapes( - xs: Sequence[torch.Tensor], mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'], dim: Sequence[int] + xs: Sequence[torch.Tensor], + mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'], + dim: Sequence[int], ) -> tuple[torch.Tensor, ...]: """Fix shapes of input tensors by padding or cropping.""" if mode == 'fail': return tuple(xs) shapes = [[x.shape[d] for d in dim] for x in xs] - if mode == 'crop': + if mode == 'crop': # smallest as target target = tuple(min(s) for s in zip(*shapes, strict=True)) - else: + else: # largest as target target = tuple(max(s) for s in zip(*shapes, strict=True)) + if mode == 'linear' or mode == 'nearest': + return tuple(interpolate(x, target, dim=dim, mode=mode) for x in xs) if mode == 'zero' or mode == 'crop': return tuple(pad_or_crop(x, target, dim=dim, mode='constant', value=0.0) for x in xs) else: @@ -30,23 +35,27 @@ def _fix_shapes( class Concat(Module): """Concatenate tensors along the channel dimension.""" - def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'] = 'fail', dim: int = 1) -> None: + def __init__( + self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'] = 'fail', dim: int = 1 + ) -> None: """Initialize Concat. Parameters ---------- - mode : {'fail', 'crop', 'zero', 'replicate', 'circular'}, default='zero' + mode How to handle mismatched spatial dimensions: - 'fail': do not align, raise error if shapes mismatch - 'crop': center-crop to smallest spatial size - 'zero': zero-pad to largest spatial size - 'replicate': pad by edge value replication - 'circular': circular padding + - 'linear': linear interpolation to largest spatial size + - 'nearest': nearest neighbor interpolation to largest spatial size dim Dimension along which to concatenate. """ super().__init__() - modes = {'fail', 'crop', 'zero', 'replicate', 'circular'} + modes = {'fail', 'crop', 'zero', 'replicate', 'circular', 'interpolate'} if mode not in modes: raise ValueError(f'mode must be one of {modes}') self.mode = mode @@ -55,7 +64,7 @@ def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular' def forward(self, *xs: torch.Tensor) -> torch.Tensor: """Concatenate input tensors.""" xs = _fix_shapes(xs, self.mode, dim=[i for i in range(max(x.ndim for x in xs)) if i != self.dim]) - return torch.cat(xs, dim=1) + return torch.cat(xs, dim=self.dim) def __call__(self, *xs: torch.Tensor) -> torch.Tensor: """ diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py index 036e8a022..1f49a0297 100644 --- a/src/mrpro/nn/nets/DCAE.py +++ b/src/mrpro/nn/nets/DCAE.py @@ -169,11 +169,9 @@ def __init__( case 'CNN': stage: list[Module] = [CNNBlock(dim, width) for _ in range(depth)] case 'LinearViT': - stage = [EfficientViTBlock(dim, width, n_heads=width // 32, linear_attn=True) for _ in range(depth)] + stage = [EfficientViTBlock(dim, width, max(1, width // 32), linear_attn=True) for _ in range(depth)] case 'ViT': - stage = [ - EfficientViTBlock(dim, width, n_heads=width // 32, linear_attn=False) for _ in range(depth) - ] + stage = [EfficientViTBlock(dim, width, max(1, width // 32)) for _ in range(depth)] case _: raise ValueError(f'Block type {block_type} not supported') self.append(Sequential(*stage)) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 7419e71ae..3e602e074 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -337,7 +337,7 @@ def block(channels_in: int, channels_out: int, attention: bool) -> Module: up_blocks.append(Identity()) up_blocks.append(Upsample(dim, scale_factor=2)) up_blocks.pop() # no upsampling after the last resolution level - concat_blocks = [Concat()] * len(decoder_blocks) + concat_blocks = [Concat() for _ in range(len(decoder_blocks))] last_block = Sequential( GroupNorm(n_features[0]), SiLU(), diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index 40ec3680d..5bf717477 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -13,9 +13,10 @@ from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder from mrpro.nn.Sequential import Sequential from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention +from mrpro.nn.CondMixin import CondMixin -class LeWinTransformerBlock(Module): +class LeWinTransformerBlock(CondMixin, Module): """Locally-enhanced windowed attention transformer block. Part of the Uformer architecture. @@ -36,21 +37,21 @@ def __init__( Parameters ---------- - dim : int + dim Dimension of the input, e.g. 2 or 3 - n_channels_per_head : int + n_channels_per_head Number of features per head - n_heads : int + n_heads Number of attention heads - window_size : int, optional + window_size Size of the attention window - shifted : bool, optional + shifted Whether to use shifted variant of the attention - mlp_ratio : float, optional + mlp_ratio Ratio of the hidden dimension to the input dimension - p_droppath : float, optional + p_droppath Dropout probability for the drop path. - cond_dim : int, optional + cond_dim Dimension of a conditioning tensor. If `0`, no FiLM layers are added. """ super().__init__() @@ -79,7 +80,7 @@ def __init__( torch.nn.init.trunc_normal_(self.modulator) self.drop_path = DropPath(droprate=p_droppath) - def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the transformer block. Parameters @@ -95,6 +96,14 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T """ return super().__call__(x, cond=cond) + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the transformer block.""" + modulator = self.modulator.tile([t // s for t, s in zip(x.shape[1:], self.modulator.shape, strict=False)]) + x_mod = self.norm1(x) + modulator + x_attn = self.attn(x_mod) + x_ff = self.ff(self.norm2(x_attn), cond=cond) + return x + self.drop_path(x_ff) + class Uformer(UNetBase): """Uformer: U-Net with window attention. diff --git a/tests/nn/test_resblock.py b/tests/nn/test_resblock.py index 3787257ce..6df1fce7f 100644 --- a/tests/nn/test_resblock.py +++ b/tests/nn/test_resblock.py @@ -35,6 +35,6 @@ def test_resblock(dim, channels_in, channels_out, cond_dim, input_shape, cond_sh assert not x.grad.isnan().any(), 'NaN values in input gradients' assert res.block[2].weight.grad is not None, 'No gradient computed for first Conv' if cond is not None: - assert cond.grad is not None, 'No gradient computed for condedding' - assert not cond.isnan().any(), 'NaN values in condedding' - assert not cond.grad.isnan().any(), 'NaN values in condedding gradients' + assert cond.grad is not None, 'No gradient computed for conditioning' + assert not cond.isnan().any(), 'NaN values in conditioning' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' From 10f19949ec3ec9b5f2619e3015260dc687e28900 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 4 Jun 2025 02:14:06 +0200 Subject: [PATCH 059/205] Refactor import statements and enhance EMA documentation - Updated import statements in join.py and Uformer.py to ensure consistent usage of the CondMixin class. - Removed unnecessary import of CondMixin in SpatialTransformerBlock.py for cleaner code. - Added a docstring to ema.py to clarify the purpose of the Exponential Moving Average (EMA) dictionary. --- src/mrpro/nn/SpatialTransformerBlock.py | 1 - src/mrpro/nn/join.py | 2 +- src/mrpro/nn/nets/Uformer.py | 2 +- src/mrpro/utils/ema.py | 2 ++ 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py index 40501dc2f..d7601a9c2 100644 --- a/src/mrpro/nn/SpatialTransformerBlock.py +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -10,7 +10,6 @@ from mrpro.nn.MultiHeadAttention import MultiHeadAttention from mrpro.nn.ndmodules import ConvND from mrpro.nn.Sequential import Sequential -from mrpro.nn.CondMixin import CondMixin def zero_init(m: Module) -> Module: diff --git a/src/mrpro/nn/join.py b/src/mrpro/nn/join.py index 2749a7397..0aed41b8d 100644 --- a/src/mrpro/nn/join.py +++ b/src/mrpro/nn/join.py @@ -6,8 +6,8 @@ import torch from torch.nn import Module -from mrpro.utils.pad_or_crop import pad_or_crop from mrpro.utils.interpolate import interpolate +from mrpro.utils.pad_or_crop import pad_or_crop def _fix_shapes( diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index 5bf717477..dd97efe59 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -6,6 +6,7 @@ import torch from torch.nn import GELU, LeakyReLU, Module +from mrpro.nn.CondMixin import CondMixin from mrpro.nn.DropPath import DropPath from mrpro.nn.FiLM import FiLM from mrpro.nn.join import Concat @@ -13,7 +14,6 @@ from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder from mrpro.nn.Sequential import Sequential from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention -from mrpro.nn.CondMixin import CondMixin class LeWinTransformerBlock(CondMixin, Module): diff --git a/src/mrpro/utils/ema.py b/src/mrpro/utils/ema.py index 28e840016..23b7d6cf1 100644 --- a/src/mrpro/utils/ema.py +++ b/src/mrpro/utils/ema.py @@ -1,3 +1,5 @@ +"""Exponential Moving Average (EMA) dictionary.""" + from collections.abc import ItemsView, KeysView, Mapping, ValuesView from typing import Any From 07f9f7afb7dfba8d7f4cb2168e8edf1b7c6aad73 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 10 Jun 2025 15:11:51 +0200 Subject: [PATCH 060/205] Update src/mrpro/operators/ConjugateGradientOp.py Co-authored-by: Andreas Kofler --- src/mrpro/operators/ConjugateGradientOp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/operators/ConjugateGradientOp.py b/src/mrpro/operators/ConjugateGradientOp.py index 27d6daee9..6ff032013 100644 --- a/src/mrpro/operators/ConjugateGradientOp.py +++ b/src/mrpro/operators/ConjugateGradientOp.py @@ -86,7 +86,7 @@ def backward(ctx: ConjugateGradientCTX, *grad_output: torch.Tensor) -> tuple[tor with torch.enable_grad(): rhs = ctx.rhs_factory(*inputs) operator = ctx.operator_factory(*inputs) - inputs_with_grad = tuple(i for i, need_grad in zip(inputs, ctx.needs_input_grad[2:], strict=True) if need_grad) + inputs_with_grad = tuple(x for x, need_grad in zip(inputs, ctx.needs_input_grad[2:], strict=True) if need_grad) if inputs_with_grad: rhs_norm = sum((r.abs().square().sum() for r in grad_output), torch.tensor(0.0)).sqrt().item() bwd_tol = ctx.tolerance * max(rhs_norm, 1e-6) # clip in case rhs is 0 From 44c8ac05433721127e631fc49019ac7a48779d38 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 10 Jun 2025 15:32:59 +0200 Subject: [PATCH 061/205] Apply suggestions from code review Co-authored-by: Andreas Kofler --- src/mrpro/operators/ConjugateGradientOp.py | 17 ++++++++--------- src/mrpro/operators/OptimizerOp.py | 19 ++++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/mrpro/operators/ConjugateGradientOp.py b/src/mrpro/operators/ConjugateGradientOp.py index 6ff032013..af9f21b8a 100644 --- a/src/mrpro/operators/ConjugateGradientOp.py +++ b/src/mrpro/operators/ConjugateGradientOp.py @@ -28,7 +28,7 @@ class ConjugateGradientCTX(torch.autograd.function.FunctionCtx): class ConjugateGradientFunction(torch.autograd.Function): - """Autograd function for the regularized least squares operator.""" + """Autograd function for the CG operator.""" if TYPE_CHECKING: @@ -111,10 +111,10 @@ def backward(ctx: ConjugateGradientCTX, *grad_output: torch.Tensor) -> tuple[tor class ConjugateGradientOp(torch.nn.Module): r"""Solves a linear positive semidefinite system with the conjugate gradient method. - Solves :math: `A x = b` where :math:`A` is a linear operator , :math:`b` is a tensor or a tuple of tensors. + Solves :math: `A x = b` where :math:`A` is a linear operator or a matrix of linear operators , :math:`b` is a tensor or a tuple of tensors. - The operator is autograd differentiable using implicit differentiation. - If this is not needed, consider using `mrpro.algorithms.optimizers.cg` directly. + The operator is autograd differentiable using implicit differentiation, which can be helpfpul for including CG as a method to increase data-consistency within a neural network based on algorithm unrolling. + If the latter property is not needed for your application, consider using `mrpro.algorithms.optimizers.cg` directly. """ def __init__( @@ -133,9 +133,9 @@ def __init__( **Example: Regularized Least Squares** Consider the regularized least squares problem: - :math:`\min_x \|A x - b\|_2^2 + \alpha \|x - x_0\|_2^2`. + :math:`\min_x \|A x - y\|_2^2 + \alpha \|x - x_0\|_2^2`. - The normal equations are :math:`(A^H A + \alpha I) x = A^H b + \alpha x_0`. + The normal equations are :math:`(A^H A + \alpha I) x = A^H y + \alpha x_0`. This can be solved using the ConjugateGradientOp as follows: .. code-block:: python operator_factory = lambda alpha, x0, b: A.gram + alpha @@ -164,8 +164,7 @@ def __init__( implicit differentiation. .. warning:: - If implicit_backward is `True`, the problem has to converge, otherwise the backward - will be wrong. `tolerance` and `max_iter` should be chosen accordingly. + If implicit_backward is `True`, `tolerance` and `max_iter` should be chosen such that the cg algorithm converges, otherwise the backward will be wrong. """ super().__init__() self.operator_factory = operator_factory @@ -195,7 +194,7 @@ def forward( op = self.operator_factory(*parameters) rhs = self.rhs_factory(*parameters) rhs_norm = sum((r.abs().square().sum() for r in rhs), torch.tensor(0.0)).sqrt().item() - fwd_tol = self.tolerance * rhs_norm + forward_tolerance = self.tolerance * rhs_norm if isinstance(op, LinearOperator): if len(rhs) != 1: raise ValueError('LinearOperator requires a single right-hand side tensor.') diff --git a/src/mrpro/operators/OptimizerOp.py b/src/mrpro/operators/OptimizerOp.py index 19849c08f..1b5ea37c6 100644 --- a/src/mrpro/operators/OptimizerOp.py +++ b/src/mrpro/operators/OptimizerOp.py @@ -29,8 +29,8 @@ """LBFGS Optimizer""" -class OptimizeCtx(torch.autograd.function.FunctionCtx): - """Rype hinting the CTX object.""" +class OptimizerCtx(torch.autograd.function.FunctionCtx): + """Type hinting the CTX object.""" factory: Callable[ [Unpack[tuple[torch.Tensor, ...]]], Callable[[Unpack[tuple[torch.Tensor, ...]]], tuple[torch.Tensor]] @@ -40,7 +40,7 @@ class OptimizeCtx(torch.autograd.function.FunctionCtx): saved_tensors: tuple[torch.Tensor, ...] -class OptimizeFunction(torch.autograd.Function): +class OptimizerFunction(torch.autograd.Function): """Implicit Backward.""" if TYPE_CHECKING: @@ -64,7 +64,7 @@ def apply( @staticmethod def forward( - ctx: OptimizeCtx, + ctx: OptimizerCtx, factory: Callable[ [Unpack[tuple[torch.Tensor, ...]]], Callable[[Unpack[tuple[torch.Tensor, ...]]], tuple[torch.Tensor]] ], @@ -84,8 +84,8 @@ def forward( parameters_ = tuple(p.detach().clone() for p in parameters if isinstance(p, torch.Tensor)) initial_values_ = tuple(x.detach().requires_grad_(True) for x in initial_values if isinstance(x, torch.Tensor)) objective = factory(*parameters) - xprime = optimize(objective, initial_values) - ctx.save_for_backward(*xprime, *parameters_) + solution = optimize(objective, initial_values) + ctx.save_for_backward(*solution, *parameters_) ctx.len_x = len(initial_values_) return xprime @@ -108,7 +108,7 @@ def hvp(*v: torch.Tensor) -> tuple[torch.Tensor, ...]: hessian_inverse_grad = cg(hvp, grad_outputs, max_iterations=100, tolerance=1e-7) with torch.enable_grad(): dobjective_dxprime = torch.autograd.grad(objective(*xprime), xprime, create_graph=True) - # - d^2_obective / d_xprime d_params Hessian^-1_grad + # - d^2_obective / d_xprime d_params Hessian^-1 * grad grad_params = list(torch.autograd.grad(dobjective_dxprime, dparams, hessian_inverse_grad)) grad_inputs: list[torch.Tensor | None] = [None, None, None] # factory, x0, optimize for need_grad in ctx.needs_input_grad[3:]: @@ -124,7 +124,8 @@ class OptimizerOp(Operator[Unpack[ArgumentType], VariableType]): """Differentiable Optimization Operator. One of the building blocks of PINQI [ZIMM2024]_ - Finds :math:`x^*=argmin_x f_p(x) + Finds :math:`x^*=argmin_x f_p(x). + The solution :math:`x^*` will be differentiable with respect to some parameters :math:`p` for the functional :math:`f`. References ---------- @@ -146,7 +147,7 @@ def __init__( Function, that given the parameters of the problem returns an objective function. The objective function should be a callable that takes the variable(s) as input and returns a scalar. initializer - Function, that given the parameters of the problem returns a tuple of initial values for the variable(s) + Function that, given the parameters of the problem, returns a tuple of initial values for the variable(s) optimize Function used to perform the optimization, for example `lbfgs`. Use `functools.partial` to setup up all settings besides the objective function and the initial values. From edb3b0fd30c644276acd0609ae7e68cb31ef194f Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Jun 2025 16:09:13 +0200 Subject: [PATCH 062/205] review --- src/mrpro/operators/ConjugateGradientOp.py | 16 ++++++++------- src/mrpro/operators/OptimizerOp.py | 23 ++++++++++++---------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/mrpro/operators/ConjugateGradientOp.py b/src/mrpro/operators/ConjugateGradientOp.py index af9f21b8a..74b7eb535 100644 --- a/src/mrpro/operators/ConjugateGradientOp.py +++ b/src/mrpro/operators/ConjugateGradientOp.py @@ -111,10 +111,11 @@ def backward(ctx: ConjugateGradientCTX, *grad_output: torch.Tensor) -> tuple[tor class ConjugateGradientOp(torch.nn.Module): r"""Solves a linear positive semidefinite system with the conjugate gradient method. - Solves :math: `A x = b` where :math:`A` is a linear operator or a matrix of linear operators , :math:`b` is a tensor or a tuple of tensors. + Solves :math: `A x = b` where :math:`A` is a linear operator or a matrix of linear operators , + :math:`b` is a tensor or a tuple of tensors. - The operator is autograd differentiable using implicit differentiation, which can be helpfpul for including CG as a method to increase data-consistency within a neural network based on algorithm unrolling. - If the latter property is not needed for your application, consider using `mrpro.algorithms.optimizers.cg` directly. + The operator is autograd differentiable using implicit differentiation. This is useful for including CG within a + network. If this is not needed for your application, consider using `mrpro.algorithms.optimizers.cg` directly. """ def __init__( @@ -164,7 +165,8 @@ def __init__( implicit differentiation. .. warning:: - If implicit_backward is `True`, `tolerance` and `max_iter` should be chosen such that the cg algorithm converges, otherwise the backward will be wrong. + If implicit_backward is `True`, `tolerance` and `max_iterations` should be chosen such that the cg algorithm + converges, otherwise the backward will be wrong. """ super().__init__() self.operator_factory = operator_factory @@ -194,17 +196,17 @@ def forward( op = self.operator_factory(*parameters) rhs = self.rhs_factory(*parameters) rhs_norm = sum((r.abs().square().sum() for r in rhs), torch.tensor(0.0)).sqrt().item() - forward_tolerance = self.tolerance * rhs_norm + tolerance = self.tolerance * rhs_norm if isinstance(op, LinearOperator): if len(rhs) != 1: raise ValueError('LinearOperator requires a single right-hand side tensor.') if initial_value is not None and len(initial_value) != 1: raise ValueError('LinearOperator requires a single initial value tensor.') solution = cg( - op, rhs, initial_value=initial_value, tolerance=fwd_tol, max_iterations=self.max_iterations + op, rhs, initial_value=initial_value, tolerance=tolerance, max_iterations=self.max_iterations ) else: solution = cg( - op, rhs, initial_value=initial_value, tolerance=fwd_tol, max_iterations=self.max_iterations + op, rhs, initial_value=initial_value, tolerance=tolerance, max_iterations=self.max_iterations ) return solution diff --git a/src/mrpro/operators/OptimizerOp.py b/src/mrpro/operators/OptimizerOp.py index 1b5ea37c6..23f189e12 100644 --- a/src/mrpro/operators/OptimizerOp.py +++ b/src/mrpro/operators/OptimizerOp.py @@ -87,12 +87,12 @@ def forward( solution = optimize(objective, initial_values) ctx.save_for_backward(*solution, *parameters_) ctx.len_x = len(initial_values_) - return xprime + return solution @staticmethod - def backward(ctx: OptimizeCtx, *grad_outputs: torch.Tensor) -> tuple[torch.Tensor | None, ...]: + def backward(ctx: OptimizerCtx, *grad_outputs: torch.Tensor) -> tuple[torch.Tensor | None, ...]: """Calculate the backward pass using implicit differentiation.""" - xprime = tuple(xp.detach().clone().requires_grad_(True) for xp in ctx.saved_tensors[: ctx.len_x]) + solution = tuple(x.detach().clone().requires_grad_(True) for x in ctx.saved_tensors[: ctx.len_x]) parameters = ctx.saved_tensors[ctx.len_x :] parameters = tuple( p.detach().clone().requires_grad_(True) if ctx.needs_input_grad[i + 3] else p.detach() @@ -103,13 +103,13 @@ def backward(ctx: OptimizeCtx, *grad_outputs: torch.Tensor) -> tuple[torch.Tenso objective = ctx.factory(*parameters) def hvp(*v: torch.Tensor) -> tuple[torch.Tensor, ...]: - return torch.autograd.functional.vhp(lambda *x: objective(*x)[0], xprime, v=v)[1] + return torch.autograd.functional.vhp(lambda *x: objective(*x)[0], solution, v=v)[1] hessian_inverse_grad = cg(hvp, grad_outputs, max_iterations=100, tolerance=1e-7) with torch.enable_grad(): - dobjective_dxprime = torch.autograd.grad(objective(*xprime), xprime, create_graph=True) - # - d^2_obective / d_xprime d_params Hessian^-1 * grad - grad_params = list(torch.autograd.grad(dobjective_dxprime, dparams, hessian_inverse_grad)) + dobjective_dsolution = torch.autograd.grad(objective(*solution), solution, create_graph=True) + # - d^2_obective / d_solution d_params Hessian^-1 * grad + grad_params = list(torch.autograd.grad(dobjective_dsolution, dparams, hessian_inverse_grad)) grad_inputs: list[torch.Tensor | None] = [None, None, None] # factory, x0, optimize for need_grad in ctx.needs_input_grad[3:]: if need_grad: @@ -123,9 +123,12 @@ def hvp(*v: torch.Tensor) -> tuple[torch.Tensor, ...]: class OptimizerOp(Operator[Unpack[ArgumentType], VariableType]): """Differentiable Optimization Operator. + Finds :math:`x^*=argmin_x f_p(x). + The solution :math:`x^*` will be differentiable with respect to some parameters :math:`p` + of the functional :math:`f`. + One of the building blocks of PINQI [ZIMM2024]_ - Finds :math:`x^*=argmin_x f_p(x). - The solution :math:`x^*` will be differentiable with respect to some parameters :math:`p` for the functional :math:`f`. + References ---------- @@ -184,7 +187,7 @@ def forward(self, *parameters: Unpack[ArgumentType]) -> VariableType: """ initial_values = self.initializer(*parameters) initial_values_ = tuple(x.clone() if any(x is p for p in parameters) else x for x in initial_values) - result = OptimizeFunction.apply( + result = OptimizerFunction.apply( self.factory, initial_values_, self.optimize, *cast(tuple[torch.Tensor, ...], parameters) ) return cast(VariableType, result) From ca88e3a32865aaa8f30955dc64035d114c41c334 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Tue, 10 Jun 2025 18:24:15 +0200 Subject: [PATCH 063/205] review --- tests/operators/test_optimizer_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/operators/test_optimizer_op.py b/tests/operators/test_optimizer_op.py index 19c408c13..bf63b492e 100644 --- a/tests/operators/test_optimizer_op.py +++ b/tests/operators/test_optimizer_op.py @@ -10,8 +10,8 @@ def test_optimizer_op_gradcheck() -> None: rng = RandomGenerator(seed=42) constraints_op = ConstraintsOp( bounds=( - (-1, 1), # M0 is not constrained - (0.001, 4.0), # T1 is constrained between 1 ms and 3 s + (-1, 1), # M0 in [-1, 1] + (0.001, 4.0), # T1 is constrained between 1 ms and 4 s ) ).double() # everything is double, otherwise the numerical derivative used in gradcheck gives wrong values From 8c9943f1707c4453bc00bb82aedb21f41a548770 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 13 Jun 2025 13:26:46 +0200 Subject: [PATCH 064/205] docstring --- src/mrpro/operators/ConjugateGradientOp.py | 32 ++++++++++++---------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/mrpro/operators/ConjugateGradientOp.py b/src/mrpro/operators/ConjugateGradientOp.py index 74b7eb535..6823543a9 100644 --- a/src/mrpro/operators/ConjugateGradientOp.py +++ b/src/mrpro/operators/ConjugateGradientOp.py @@ -56,17 +56,17 @@ def forward( operator = operator_factory(*inputs) rhs = rhs_factory(*inputs) rhs_norm = sum((r.abs().square().sum() for r in rhs), torch.tensor(0.0)).sqrt().item() - fwd_tol = tolerance * max(rhs_norm, 1e-6) # clip in case rhs is 0 + tol_ = tolerance * max(rhs_norm, 1e-6) # clip in case rhs is 0 if isinstance(operator, LinearOperator): if len(rhs) != 1: raise ValueError('LinearOperator requires a single right-hand side tensor.') if initial_value is not None and len(initial_value) != 1: raise ValueError('LinearOperator requires a single initial value tensor.') solution: tuple[torch.Tensor, ...] = cg( - operator, rhs, initial_value=initial_value, tolerance=fwd_tol, max_iterations=max_iterations + operator, rhs, initial_value=initial_value, tolerance=tol_, max_iterations=max_iterations ) else: - solution = cg(operator, rhs, initial_value=initial_value, tolerance=fwd_tol, max_iterations=max_iterations) + solution = cg(operator, rhs, initial_value=initial_value, tolerance=tol_, max_iterations=max_iterations) ctx.save_for_backward(*solution, *inputs) ctx.len_solution = len(solution) ctx.tolerance = tolerance @@ -89,12 +89,12 @@ def backward(ctx: ConjugateGradientCTX, *grad_output: torch.Tensor) -> tuple[tor inputs_with_grad = tuple(x for x, need_grad in zip(inputs, ctx.needs_input_grad[2:], strict=True) if need_grad) if inputs_with_grad: rhs_norm = sum((r.abs().square().sum() for r in grad_output), torch.tensor(0.0)).sqrt().item() - bwd_tol = ctx.tolerance * max(rhs_norm, 1e-6) # clip in case rhs is 0 + tol_ = ctx.tolerance * max(rhs_norm, 1e-6) # clip in case rhs is 0 with torch.no_grad(): if isinstance(operator, LinearOperatorMatrix): - z = cg(operator.H, grad_output, tolerance=bwd_tol, max_iterations=ctx.max_iterations) + z = cg(operator.H, grad_output, tolerance=tol_, max_iterations=ctx.max_iterations) else: - z = cg(operator.H, grad_output[0], tolerance=bwd_tol, max_iterations=ctx.max_iterations) + z = cg(operator.H, grad_output[0], tolerance=tol_, max_iterations=ctx.max_iterations) if any(zi.isnan().any() for zi in z): raise RuntimeError('NaN in ConjugateGradientFunction.backward') with torch.enable_grad(): @@ -115,7 +115,15 @@ class ConjugateGradientOp(torch.nn.Module): :math:`b` is a tensor or a tuple of tensors. The operator is autograd differentiable using implicit differentiation. This is useful for including CG within a - network. If this is not needed for your application, consider using `mrpro.algorithms.optimizers.cg` directly. + network [MODL]_, [PINQI]_. + If this is not needed for your application, consider using `mrpro.algorithms.optimizers.cg` directly. + + References + ---------- + .. [MODL] Aggarwal, H. K., et al. MoDL: Model-based deep learning architecture for inverse problems. + (2018) IEEE TMI 2018, 38(2), 394-405. https://arxiv.org/abs/1712.02862 + .. [PINQI] Zimmermann, F. F., Kolbitsch, C., Schuenke, P., & Kofler, A. PINQI: an end-to-end physics-informed + approach to learned quantitative MRI reconstruction. IEEE TCI 2024, https://arxiv.org/abs/2306.11023 """ def __init__( @@ -196,17 +204,13 @@ def forward( op = self.operator_factory(*parameters) rhs = self.rhs_factory(*parameters) rhs_norm = sum((r.abs().square().sum() for r in rhs), torch.tensor(0.0)).sqrt().item() - tolerance = self.tolerance * rhs_norm + tol_ = self.tolerance * rhs_norm if isinstance(op, LinearOperator): if len(rhs) != 1: raise ValueError('LinearOperator requires a single right-hand side tensor.') if initial_value is not None and len(initial_value) != 1: raise ValueError('LinearOperator requires a single initial value tensor.') - solution = cg( - op, rhs, initial_value=initial_value, tolerance=tolerance, max_iterations=self.max_iterations - ) + solution = cg(op, rhs, initial_value=initial_value, tolerance=tol_, max_iterations=self.max_iterations) else: - solution = cg( - op, rhs, initial_value=initial_value, tolerance=tolerance, max_iterations=self.max_iterations - ) + solution = cg(op, rhs, initial_value=initial_value, tolerance=tol_, max_iterations=self.max_iterations) return solution From fb6eb419a6d42c2443ddd8b62ee778a2d76af81a Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 22 Jun 2025 12:51:23 +0200 Subject: [PATCH 065/205] update --- src/mrpro/nn/FiLM.py | 2 -- src/mrpro/nn/ResBlock.py | 2 +- src/mrpro/nn/Sequential.py | 2 +- src/mrpro/nn/SpatialTransformerBlock.py | 44 ++++++++++++------------- 4 files changed, 24 insertions(+), 26 deletions(-) diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py index 31108918d..92780aae3 100644 --- a/src/mrpro/nn/FiLM.py +++ b/src/mrpro/nn/FiLM.py @@ -42,8 +42,6 @@ def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torc cond The conditioning tensor. """ - if len(x) != 1: - raise ValueError('FiLM expects a single input tensor') return super().__call__(x, cond=cond) def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py index bd09101ac..8f61e6022 100644 --- a/src/mrpro/nn/ResBlock.py +++ b/src/mrpro/nn/ResBlock.py @@ -30,7 +30,7 @@ def __init__(self, dim: int, channels_in: int, channels_out: int, cond_dim: int) """ super().__init__() - self.rezero = torch.nn.Parameter(torch.tensor(1e-2)) + self.rezero = torch.nn.Parameter(torch.tensor(0.1)) self.block = Sequential( GroupNorm(channels_in), SiLU(), diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py index 33842884c..15b5d0152 100644 --- a/src/mrpro/nn/Sequential.py +++ b/src/mrpro/nn/Sequential.py @@ -9,7 +9,7 @@ from mrpro.operators import Operator -class Sequential(torch.nn.Sequential): +class Sequential(CondMixin,torch.nn.Sequential): """Sequential container with support for conditioning and Operators. Allows multiple input tensors and a single output tensor of the sequential block. diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py index d7601a9c2..433ae9f81 100644 --- a/src/mrpro/nn/SpatialTransformerBlock.py +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -69,22 +69,22 @@ def __init__( Dropout(p_dropout), Linear(hidden_dim, channels), ) - self.crossattention = ( - Sequential( - LayerNorm(channels, features_last=True), - MultiHeadAttention( - channels_in=channels, - channels_out=channels, - n_heads=n_heads, - p_dropout=p_dropout, - channels_cross=cond_dim, - features_last=True, - ), - ) - if cond_dim > 0 - else None - ) - self.cond_dim = cond_dim + # self.crossattention = ( + # Sequential( + # LayerNorm(channels, features_last=True), + # MultiHeadAttention( + # channels_in=channels, + # channels_out=channels, + # n_heads=n_heads, + # p_dropout=p_dropout, + # channels_cross=cond_dim, + # features_last=True, + # ), + # ) + # if cond_dim > 0 + # else None + # ) + # self.cond_dim = cond_dim def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the basic transformer block. @@ -101,14 +101,14 @@ def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torc def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the basic transformer block.""" if not self.features_last: - x = x.moveaxis(1, -1) + x = x.moveaxis(1, -1).contiguous() x = self.selfattention(x) + x - if cond is not None and self.crossattention is not None: - cond = cond.unflatten(-1, (-1, self.cond_dim)) - x = self.crossattention(x, cond=cond) + x + # if cond is not None and self.crossattention is not None: + # cond = cond.unflatten(-1, (-1, self.cond_dim)) + # x = self.crossattention(x, cond=cond) + x x = self.ff(x) + x if not self.features_last: - x = x.moveaxis(-1, 1) + x = x.moveaxis(-1, 1).contiguous() return x @@ -155,7 +155,7 @@ def __init__( ] self.transformer_blocks = Sequential(*blocks) - self.proj_out = zero_init(ConvND(dim)(hidden_dim, channels, kernel_size=1, stride=1, padding=0)) + self.proj_out = ConvND(dim)(hidden_dim, channels, kernel_size=1, stride=1, padding=0) def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the spatial transformer block. From aaa68e97317797944ba353b8db1a0bab6a46f649 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 23 Jun 2025 22:37:26 +0200 Subject: [PATCH 066/205] separable --- src/mrpro/nn/LayerNorm.py | 44 ++++-- src/mrpro/nn/SpatialTransformerBlock.py | 93 ++++-------- src/mrpro/nn/Upsample.py | 39 +++-- src/mrpro/nn/nets/UNet.py | 188 +++++++++++++++++++++--- 4 files changed, 248 insertions(+), 116 deletions(-) diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py index 75ada98d3..699de57f0 100644 --- a/src/mrpro/nn/LayerNorm.py +++ b/src/mrpro/nn/LayerNorm.py @@ -1,15 +1,16 @@ """Layer normalization.""" import torch -from torch.nn import Module, Parameter +from torch.nn import Linear, Module, Parameter -from mrpro.utils.reshape import unsqueeze_right +from mrpro.nn.CondMixin import CondMixin +from mrpro.utils.reshape import unsqueeze_at, unsqueeze_right -class LayerNorm(Module): +class LayerNorm(CondMixin, Module): """Layer normalization.""" - def __init__(self, channels: int | None, features_last: bool = False, bias: bool = True) -> None: + def __init__(self, channels: int | None, features_last: bool = False, cond_dim: int = 0) -> None: """Initialize the layer normalization. Parameters @@ -19,16 +20,25 @@ def __init__(self, channels: int | None, features_last: bool = False, bias: bool affine transformation. features_last If `True`, the channel dimension is the last dimension. - bias - If `False`, only a scaling is applied without an offset if an affine transformation is used. + cond_dim + Number of channels in the conditioning tensor. If `0`, no adaptive scaling is applied. """ super().__init__() - if channels is not None: - self.weight: Parameter | None = Parameter(torch.ones(channels)) - self.bias: Parameter | None = Parameter(torch.zeros(channels)) if bias else None + if channels is None and cond_dim == 0: + self.weight: Parameter | None = None + self.bias: Parameter | None = None + self.cond_proj: Linear | None = None + elif channels is None and cond_dim > 0: + raise ValueError('channels must be provided if cond_dim > 0') + elif channels is not None and cond_dim == 0: + self.weight = Parameter(torch.ones(channels)) + self.bias = Parameter(torch.zeros(channels)) + self.cond_proj = None else: self.weight = None self.bias = None + self.cond_proj = Linear(cond_dim, 2 * channels) + self.features_last = features_last def __call__(self, x: torch.Tensor) -> torch.Tensor: @@ -45,23 +55,25 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: """ return super().__call__(x) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply layer normalization to the input tensor.""" dims = tuple(range(1, x.ndim)) mean = x.mean(dim=dims, keepdim=True) std = x.std(dim=dims, keepdim=True, unbiased=False) x = (x - mean) / (std + 1e-5) - if self.weight is not None: + if self.weight is not None and self.bias is not None: if self.features_last: - x = x * self.weight + x = x * self.weight + self.bias else: - x = x * unsqueeze_right(self.weight, x.ndim - 2) + x = x * unsqueeze_right(self.weight, x.ndim - 2) + unsqueeze_right(self.bias, x.ndim - 2) - if self.bias is not None: + if self.cond_proj is not None and cond is not None: + scale, shift = self.cond_proj(cond).chunk(2, dim=-1) + scale = 1 + scale if self.features_last: - x = x + self.bias + x = x * unsqueeze_at(scale, 1, x.ndim - 2) + unsqueeze_at(shift, 1, x.ndim - 2) else: - x = x + unsqueeze_right(self.bias, x.ndim - 2) + x = x * unsqueeze_right(scale, x.ndim - 2) + unsqueeze_right(shift, x.ndim - 2) return x diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py index 433ae9f81..2b4c3e6e2 100644 --- a/src/mrpro/nn/SpatialTransformerBlock.py +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -1,5 +1,7 @@ """Spatial transformer block.""" +from collections.abc import Sequence + import torch from torch.nn import Dropout, Linear, Module @@ -8,7 +10,7 @@ from mrpro.nn.GroupNorm import GroupNorm from mrpro.nn.LayerNorm import LayerNorm from mrpro.nn.MultiHeadAttention import MultiHeadAttention -from mrpro.nn.ndmodules import ConvND +from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.Sequential import Sequential @@ -64,27 +66,11 @@ def __init__( ) hidden_dim = int(channels * mlp_ratio) self.ff = Sequential( - LayerNorm(channels, features_last=True), + LayerNorm(channels, features_last=True, cond_dim=cond_dim), GEGLU(channels, hidden_dim, features_last=True), Dropout(p_dropout), Linear(hidden_dim, channels), ) - # self.crossattention = ( - # Sequential( - # LayerNorm(channels, features_last=True), - # MultiHeadAttention( - # channels_in=channels, - # channels_out=channels, - # n_heads=n_heads, - # p_dropout=p_dropout, - # channels_cross=cond_dim, - # features_last=True, - # ), - # ) - # if cond_dim > 0 - # else None - # ) - # self.cond_dim = cond_dim def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the basic transformer block. @@ -103,10 +89,7 @@ def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch if not self.features_last: x = x.moveaxis(1, -1).contiguous() x = self.selfattention(x) + x - # if cond is not None and self.crossattention is not None: - # cond = cond.unflatten(-1, (-1, self.cond_dim)) - # x = self.crossattention(x, cond=cond) + x - x = self.ff(x) + x + x = self.ff(x, cond=cond) + x if not self.features_last: x = x.moveaxis(-1, 1).contiguous() return x @@ -117,67 +100,49 @@ class SpatialTransformerBlock(CondMixin, Module): def __init__( self, - dim: int, + dim_groups: Sequence[tuple[int, ...]], channels: int, n_heads: int, - channels_per_head: int, depth: int = 1, dropout: float = 0.0, cond_dim: int = 0, ): - """Initialize the spatial transformer block. - + """ Parameters ---------- - dim - Spatial dimension of the input tensor. + dim_groups + Groups of spatial dimensions for separate attention mechanisms. channels Number of channels in the input and output. n_heads - Number of attention heads. - channels_per_head - Number of channels per attention head. + Number of attention heads for each group. depth - Number of transformer blocks. + Number of transformer blocks for each group. dropout Dropout probability. cond_dim - Number of channels in the conditioning tensor. If 0, no conditioning is applied. + Dimension of the conditioning tensor. """ super().__init__() - self.in_channels = channels - hidden_dim = n_heads * channels_per_head + hidden_dim = n_heads * (channels // n_heads) self.norm = GroupNorm(channels) - - self.proj_in = ConvND(dim)(channels, hidden_dim, kernel_size=1, stride=1, padding=0) - blocks = [ - BasicTransformerBlock(hidden_dim, n_heads, p_dropout=dropout, cond_dim=cond_dim) for _ in range(depth) - ] - self.transformer_blocks = Sequential(*blocks) - - self.proj_out = ConvND(dim)(hidden_dim, channels, kernel_size=1, stride=1, padding=0) - - def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: - """Apply the spatial transformer block. - - Parameters - ---------- - x - Input tensor - cond - Conditioning tensor. If None, no conditioning is applied. - - Returns - ------- - Output tensor after spatial transformer - """ - return super().__call__(x, cond=cond) + self.proj_in = Linear(channels, hidden_dim) + self.transformer_blocks = Sequential() + for group in (g for _ in range(depth) for g in dim_groups): + block = BasicTransformerBlock(hidden_dim, n_heads, p_dropout=dropout, cond_dim=cond_dim, features_last=True) + self.transformer_blocks.append(PermutedBlock(group, block, features_last=True)) + self.proj_out = Linear(hidden_dim, channels) def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the spatial transformer block.""" skip = x - x = self.norm(x) - x = self.proj_in(x) - x = self.transformer_blocks(x, cond=cond) - x = self.proj_out(x) - return x + skip + h = self.norm(x) + h = h.movedim(1, -1) + h = self.proj_in(h) + h = self.transformer_blocks(h, cond=cond) + h = self.proj_out(h) + h = h.movedim(-1, 1) + return skip + h + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/Upsample.py b/src/mrpro/nn/Upsample.py index 74462fbff..ec9b0e032 100644 --- a/src/mrpro/nn/Upsample.py +++ b/src/mrpro/nn/Upsample.py @@ -1,21 +1,26 @@ """Upsampling by interpolation.""" +from collections.abc import Sequence from typing import Literal import torch -from torch.nn import Module +from torch.nn import Module, Sequential + +from mrpro.nn.PermutedBlock import PermutedBlock class Upsample(Module): """Upsampling by interpolation.""" - def __init__(self, dim: int, scale_factor: int = 2, mode: Literal['nearest', 'linear', 'cubic'] = 'linear'): + def __init__( + self, dim: Sequence[int], scale_factor: int = 2, mode: Literal['nearest', 'linear', 'cubic'] = 'linear' + ): """Initialize the upsampling layer. Parameters ---------- dim - Spatial dimensions of the input tensor, i.e. 2 for 2D, 3 for 3D, etc. + Dimensions which to upsample scale_factor Factor by which to upsample mode @@ -24,17 +29,23 @@ def __init__(self, dim: int, scale_factor: int = 2, mode: Literal['nearest', 'li super().__init__() self.scale_factor = scale_factor if mode == 'nearest': - self.mode = 'nearest' - elif dim == 1 and mode == 'linear': - self.mode = 'linear' - elif dim == 2 and mode == 'cubic': - self.mode = 'bicubic' - elif dim == 2 and mode == 'linear': - self.mode = 'bilinear' - elif dim == 3 and mode == 'linear': - self.mode = 'trilinear' - else: - raise ValueError(f'Invalid mode for dimension {dim}: {mode}') + dims = [tuple(d) for d in torch.tensor(dim).split(3)] + modes = ['nearest'] * len(self.dim) + elif mode == 'linear': + dims = [tuple(d) for d in torch.tensor(dim).split(3)] + modes = [{1: 'linear', 2: 'bilinear', 3: 'trilinear'}[len(d)] for d in dims] + elif mode == 'cubic': + if not len(dim) == 2: + raise ValueError('Cubic interpolation is only supported for 2D images.') + dims = [tuple(dim)] + modes = ['bicubic'] + + self.blocks = Sequential( + *[ + PermutedBlock(d, Upsample(d, scale_factor=scale_factor, mode=m)) + for d, m in zip(dims, modes, strict=False) + ] + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Upsample the input tensor.""" diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 3e602e074..00d1ea510 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -291,8 +291,9 @@ def __init__( raise ValueError(f'attention_depths must be unique, got {attention_depths=}') def attention_block(channels: int) -> Module: + dim_groups = (tuple(range(-dim, 0)),) return SpatialTransformerBlock( - dim, channels, n_heads, channels_per_head=channels // n_heads, cond_dim=cond_dim + dim_groups, channels, n_heads, channels_per_head=channels // n_heads, cond_dim=cond_dim ) def block(channels_in: int, channels_out: int, attention: bool) -> Module: @@ -410,33 +411,176 @@ def block(channels_in: int, channels_out: int) -> Module: super().__init__(encoder, decoder) -from einops import rearrange +from collections.abc import Sequence +from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.SeparableResBlock import SeparableResBlock # Assuming SeparableResBlock is here +from mrpro.nn.UNet import UNetBase, UNetDecoder, UNetEncoder -class SpatioTemporalBlock(Module): - """Spatio-temporal block. - Applies first a spatial block then a temporal block. - In the spatial block, the time dimension is a batch dimension, - in the temporal block, the spatial dimensions are a batch dimension. +class SeparableUNet(UNetBase): + """ + UNet with separable convolutions and controlled downsampling. """ - def __init__(self, spatial_block: Module, temporal_block: Module): - """Initialize the SpatioTemporalBlock.""" - super().__init__() - self.spatial_block = spatial_block - self.temporal_block = temporal_block - - def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: - batchsize = x.shape[0] - x = rearrange(x, 'batch channel time ... -> (batch time) channel ...') - x = call_with_cond(self.spatial_block, x, cond=cond) - spatial_shape = x.shape[2:] - x = rearrange(x, '(batch time) channel ... -> (batch ...) channel time', batch=batchsize) - x = call_with_cond(self.temporal_block, x, cond=cond) - x = rearrange(x, '(batch spatial) channel time -> batch channel time spatial').unflatten(-1, spatial_shape) - return x + def __init__( + self, + dim: int, # Total number of spatial dimensions (e.g., 2 for 2D, 3 for 3D) + dim_groups: Sequence[tuple[int, ...]], + channels_in: int, + channels_out: int, + n_features: Sequence[int], + cond_dim: int, + downsample_dims: Sequence[Sequence[int]] | None = None, + encoder_blocks_per_scale: int = 2, + ) -> None: + """ + Initialize the SeparableUNet. + + Parameters + ---------- + + """ + class SeparableUNet(UNetBase): + """ + UNet with separable convolutions and attention, and grouped downsampling. + """ + + def __init__( + self, + dim:int, + dim_groups: Sequence[tuple[int, ...]], + channels_in: int, + channels_out: int, + n_features: Sequence[int] = (64, 128, 256, 512), + cond_dim: int = 0, + encoder_blocks_per_scale: int = 2, + attention_depths: Sequence[int] = (-1,), + n_heads: int = 8, + downsample_dims: Sequence[Sequence[int]] | None = None, + ) -> None: + """ + Initialize the SeparableUNet. + + Parameters + ---------- + dim + Total number of non batch, non channel dimensions. + E.g., 2 for 2D images, 3 for 3D volumes or 2D+time for 2D+time images. + dim_groups + A list of tuples, where each tuple contains the spatial dimension + indices for one separable convolution. Each group must contain fewer than 3 dimensions. + channels_in + Number of channels in the input tensor. + channels_out + Number of channels in the output tensor. + n_features + Number of features at each resolution level. + cond_dim + Number of channels in the conditioning tensor. + encoder_blocks_per_scale + Number of encoder blocks per resolution level. + attention_depths + The depths at which to apply attention. + n_heads + Number of attention heads. + downsample_dims + Sequence specifying which absolute spatial dimensions to downsample + at each encoder level. If None, all dimensions in `dim_groups` are combined + and downsampled at each level. + If a downsampling step contains more than 3 dimensions, downsampling is performed separatly for each + dimension. If the length of the sequence is less than the number of resolution levels, the sequence is + repeated. E.g., ``((-1,-2), (-1,-2,-3))`` for 3D data: first level downsamples x,y; second level x,y,z; + third level x,y. + + + """ + depth = len(n_features) + for group in dim_groups: + if len(group)>3: + raise ValueError(f"dim_group {group} can at most contain 3 dimensions. Split it into multiple groups.") + if any(d>dim+2 or d<-dim for d in group): + raise ValueError(f"dim_group {group} contains dimensions that are out of range for dim={dim}") + + attention_depths = tuple(d % depth for d in attention_depths) + if downsample_dims is None: + all_spatial_dims = tuple( + sorted(list(set(d if d<0 else d-dim-2 for group in dim_groups for d in group))) + ) + downsample_dims = (all_spatial_dims,) * (depth - 1) + + + def downsampler(level_dims, c_in, c_out) -> Module: + if len(level_dims)>3: + sequence=Sequence(downsampler(d[0], c_in, c_out) for d in level_dims) + for d in level_dims[1:]: + sequence.append(downsampler(d, c_out, c_out)) + return sequence + return PermutedBlock( + level_dims, ConvND(len(level_dims))(c_in, c_out, 3, stride=2, padding=1)) + + def upsampler(level_dims, c_in, c_out) -> Module: + if len(level_dims)>3: + sequence=Sequence(upsampler(d[0], c_in, c_out) for d in level_dims) + for d in level_dims[1:]: + sequence.append(upsampler(d, c_out, c_out)) + return sequence + return PermutedBlock(level_dims, Upsample(len(level_dims), scale_factor=2, mode="nearest")) + + def block(c_in: int, c_out: int, apply_attention: bool) -> Module: + res_block = SeparableResBlock(dim_groups, c_in, c_out, cond_dim) + if not apply_attention: + return res_block + attn_block = SpatialTransformerBlock(dim_groups, c_out, n_heads, cond_dim=cond_dim) + return Sequential(res_block, attn_block) + + # --- Module Construction --- + first_block = PermutedBlock( + all_spatial_dims, ConvND(len(all_spatial_dims))(channels_in, n_features[0], 3, padding=1) + ) + # -- Encoder -- + encoder_blocks, down_blocks, skip_features = [], [], [] + c_feat = n_features[0] + for i_level, n_feat_level in enumerate(n_features): + for _ in range(encoder_blocks_per_scale): + encoder_blocks.append(block(c_feat, n_feat_level, i_level in attention_depths)) + c_feat = n_feat_level + skip_features.append(c_feat) + if i_level < depth - 1: + down_blocks.append(_create_downsampler(downsample_dims_per_level[i_level], c_feat, n_features[i_level + 1])) + c_feat = n_features[i_level + 1] + + # -- Middle & Encoder Finalization -- + middle_block = Sequential( + block(c_feat, c_feat, depth - 1 in attention_depths), + block(c_feat, c_feat, depth - 1 in attention_depths), + ) + encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) + + # -- Decoder -- + decoder_blocks, up_blocks = [], [] + for i_level in reversed(range(depth)): + n_feat_level = n_features[i_level] + if i_level > 0: + up_blocks.append(_create_upsampler(downsample_dims_per_level[i_level - 1], c_feat, n_feat_level)) + for _ in range(encoder_blocks_per_scale + 1): + skip_c = skip_features.pop() + decoder_blocks.append(block(c_feat + skip_c, n_feat_level, i_level in attention_depths)) + c_feat = n_feat_level + + decoder_blocks.reverse() + up_blocks.reverse() + + # -- Decoder Finalization -- + concat_blocks = [Concat()] * len(decoder_blocks) + last_block = Sequential( + GroupNorm(n_features[0]), SiLU(), + PermutedBlock(all_spatial_dims, ConvND(len(all_spatial_dims))(n_features[0], channels_out, 3, padding=1)) + ) + decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + + super().__init__(encoder, decoder) # class SpatioTemporalUNet(UNetBase): # """UNet where blocks apply separable convolutions in different dimensions. From 8209d117b5d21ae921fe4e276264a76132208009 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 27 Jun 2025 19:02:44 +0200 Subject: [PATCH 067/205] fix --- src/mrpro/phantoms/__init__.py | 1 + src/mrpro/phantoms/brainweb.py | 24 ++++++++++++++++++------ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/mrpro/phantoms/__init__.py b/src/mrpro/phantoms/__init__.py index 54e46cf29..61b1600ff 100644 --- a/src/mrpro/phantoms/__init__.py +++ b/src/mrpro/phantoms/__init__.py @@ -3,5 +3,6 @@ from mrpro.phantoms.EllipsePhantom import EllipsePhantom from mrpro.phantoms.phantom_elements import EllipseParameters from mrpro.phantoms import brainweb +from mrpro.phantoms import coils __all__ = ["EllipseParameters", "EllipsePhantom", "brainweb"] diff --git a/src/mrpro/phantoms/brainweb.py b/src/mrpro/phantoms/brainweb.py index cc9616dac..639d489bf 100644 --- a/src/mrpro/phantoms/brainweb.py +++ b/src/mrpro/phantoms/brainweb.py @@ -4,6 +4,7 @@ import gzip import io import re +import time from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass from os import PathLike @@ -242,7 +243,7 @@ def trim_indices(mask: torch.Tensor) -> tuple[slice, slice]: ) """Tissue values for 3T with wide randomization ranges.""" -DEFAULT_VALUES = {'r1': 0.0, 'm0': 0.0, 'r2': 0.0, 'mask': 0, 'tissueclass': -1} +DEFAULT_VALUES = {'r1': 0.0, 'm0': 0.0, 'r2': 0.0, 'mask': 0, 'tissueclass': -1, 't1': torch.inf, 't2': torch.inf} """Default values for masked out regions.""" @@ -264,11 +265,22 @@ def download_brainweb( depending on the system and access pattern. """ - def load_file(url: str, timeout: float = 60) -> bytes: - """Load url content.""" - response = requests.get(url, timeout=timeout) - response.raise_for_status() - return response.content + def load_file( + url: str, + timeout: float = 60, + max_retries: int = 3, + retry_delay: float = 30, + ) -> bytes: + """Load url content with retries for network errors.""" + for attempt in range(max_retries): + try: + response = requests.get(url, timeout=timeout) + response.raise_for_status() + return response.content + except requests.exceptions.RequestException: + if attempt == max_retries - 1: + raise + time.sleep(retry_delay) def unpack(data: bytes, dtype: np.typing.DTypeLike, shape: Sequence[int]) -> np.ndarray: """Unpack gzipped data.""" From 41e421668f93de1dcab60e105429f3afea96b541 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sat, 28 Jun 2025 17:30:48 +0200 Subject: [PATCH 068/205] wip --- examples/scripts/pinqi.py | 140 ++++++++++++++++++++++++++++++++++++ src/mrpro/data/Dataclass.py | 10 +++ src/mrpro/nn/nets/UNet.py | 78 +++++++------------- 3 files changed, 175 insertions(+), 53 deletions(-) create mode 100644 examples/scripts/pinqi.py diff --git a/examples/scripts/pinqi.py b/examples/scripts/pinqi.py new file mode 100644 index 000000000..73d57e3c9 --- /dev/null +++ b/examples/scripts/pinqi.py @@ -0,0 +1,140 @@ +# %% +import einops +import einops.layers +import mrpro +import torch + + +# %% +class Dataset(torch.utils.data.Dataset): + def __init__(self, size=64, acceleration=8, n_coils=8, random=True): + self.phantom = mrpro.phantoms.brainweb.BrainwebSlices( + what=('m0', 't1', 'mask'), + seed='index' if not random else 'random', + slice_preparation=mrpro.phantoms.brainweb.augment(size=size), + ) + self.signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2, 8)) + self.encoding_matrix = mrpro.data.SpatialDimension(1, size, size) + self.fov = mrpro.data.SpatialDimension(0.01, 0.25, 0.25) + self.acceleration = acceleration + self.n_coils = n_coils + self._random = random + + @property + def n_images(self): + return 5 + + @property + def n_parameters(self): + return 2 + + def __len__(self): + return len(self.phantom) + + def __getitem__(self, index): + phantom = self.phantom[index] + (images,) = self.signalmodel(phantom['m0'], phantom['t1']) + seed = torch.randint(0, 1000000, (1,)).item() if self._random else index + traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density( + encoding_matrix=self.encoding_matrix, + seed=seed, + fwhm_ratio=2, + ) + header = mrpro.data.KHeader( + encoding_matrix=self.encoding_matrix, + recon_matrix=self.encoding_matrix, + recon_fov=self.fov, + encoding_fov=self.fov, + ) + header.ti = self.signalmodel.saturation_time.tolist() + fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj) + csm = mrpro.data.CsmData(mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix), header) + images = einops.rearrange(images, 't y x -> t 1 1 y x') + (data,) = (fourier_op @ csm.as_operator())(images) + kdata = mrpro.data.KData(header, data, traj) + return {'kdata': kdata, 'csm': csm, **phantom} + + @staticmethod + def collate_fn(batch): + return torch.utils.data._utils.collate.collate( + batch, + collate_fn_map={ + mrpro.data.Dataclass: lambda batch, *, collate_fn_map: batch[0].stack(*batch[1:]), + **torch.utils.data._utils.collate.default_collate_fn_map, + }, + ) + + +# %% +ds = Dataset() +dl = torch.utils.data.DataLoader( + ds, batch_size=4, collate_fn=ds.collate_fn, num_workers=4, worker_init_fn=lambda *_: torch.set_num_threads(1) +) + +# %% + + +class PINQI(torch.nn.Module): + def __init__(self, signalmodel, n_parameters, n_images, n_iterations=2): + super().__init__() + self.signalmodel = mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ signalmodel + self._n_parameters = n_parameters + self._n_images = n_images + self.parameter_net = torch.nn.Conv2d(n_images * 2, n_parameters, kernel_size=1) + self.image_net = torch.nn.Conv3d(2, 2, kernel_size=1) + self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3)) + self.softplus = torch.nn.Softplus() + + def objective_factory(parameter_reg, lambda_parameters, image): + dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel + reg = mrpro.operators.functionals.L2NormSquared(parameter_reg) + return dc + lambda_parameters * reg + + self.nonlinear_solver = mrpro.operators.OptimizerOp(objective_factory, lambda parameter_reg, *_: parameter_reg) + self.linear_solver = mrpro.operators.ConjugateGradientOp( + operator_factory=lambda gram, lambda_image, lambda_q, *_: gram + lambda_image + lambda_q, + rhs_factory=lambda _gram, lambda_image, lambda_q, image_reg, signal, zero_filled_image: ( + zero_filled_image + lambda_image * image_reg + lambda_q * signal, + ), + ) + + def get_parameter_reg(self, image): + image = einops.rearrange(torch.view_as_real(image), 'batch t 1 1 y x complex-> batch (t complex) y x') + parameters = self.parameter_net(image) + parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x') + return tuple(parameters) + + def get_image_reg(self, image): + image = einops.rearrange(torch.view_as_real(image), 'batch t 1 1 y x complex-> batch complex t y x') + image = image + self.image_net(image) + image = einops.rearrange(image, 'batch complex t y x-> batch t 1 1 y x complex') + return torch.view_as_complex(image.contiguous()) + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): + csm_op = csm.as_operator() + fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) + aquisition_op = fourier_op @ csm_op + gram = aquisition_op.gram + (zero_filled_image,) = aquisition_op.H(kdata.data) + images = list(mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=3)) + parameters = [self.get_parameter_reg(images[-1])] + for lambda_image, lambda_q, lambda_parameter in self.softplus(self.lambdas_raw): + # subproblem 1 + image_reg = self.get_image_reg(images[-1]) + (signal,) = self.signalmodel(*parameters[-1]) + images.append(self.linear_solver(gram, lambda_image, lambda_q, image_reg, signal, zero_filled_image)) + # subproblem 2 + parameters_reg = self.get_parameter_reg(images[-1]) + parameters.append(self.nonlinear_solver(parameters_reg, lambda_parameter, images[-1])) + + return images, parameters + + +# %% +from tqdm import tqdm + +pinqi = PINQI(ds.signalmodel, ds.n_parameters, ds.n_images) + +for batch in tqdm(dl): + pred = pinqi(batch['kdata'], batch['csm']) +# %% diff --git a/src/mrpro/data/Dataclass.py b/src/mrpro/data/Dataclass.py index e11744b83..9db906359 100644 --- a/src/mrpro/data/Dataclass.py +++ b/src/mrpro/data/Dataclass.py @@ -823,6 +823,16 @@ def concatenate(self, *others: Self, dim: int) -> Self: new._reduce_repeats_(recurse=True) return new + def stack(self, *others: Self) -> Self: + """Stack other along new first dimension + + Parameters + ---------- + others + other instance to stack. + """ + return self[None].concatenate(*[o[None] for o in others], dim=0) + def __eq__(self, other: object) -> bool: """Check deep equality of two dataclasses. diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 00d1ea510..dd6917704 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -13,7 +13,9 @@ from mrpro.nn.GroupNorm import GroupNorm from mrpro.nn.join import Concat from mrpro.nn.ndmodules import ConvND, MaxPoolND +from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.ResBlock import ResBlock +from mrpro.nn.SeparableResBlock import SeparableResBlock # Assuming SeparableResBlock is here from mrpro.nn.Sequential import Sequential from mrpro.nn.SpatialTransformerBlock import SpatialTransformerBlock from mrpro.nn.Upsample import Upsample @@ -411,44 +413,12 @@ def block(channels_in: int, channels_out: int) -> Module: super().__init__(encoder, decoder) -from collections.abc import Sequence - -from mrpro.nn.PermutedBlock import PermutedBlock -from mrpro.nn.SeparableResBlock import SeparableResBlock # Assuming SeparableResBlock is here -from mrpro.nn.UNet import UNetBase, UNetDecoder, UNetEncoder - - class SeparableUNet(UNetBase): - """ - UNet with separable convolutions and controlled downsampling. - """ - - def __init__( - self, - dim: int, # Total number of spatial dimensions (e.g., 2 for 2D, 3 for 3D) - dim_groups: Sequence[tuple[int, ...]], - channels_in: int, - channels_out: int, - n_features: Sequence[int], - cond_dim: int, - downsample_dims: Sequence[Sequence[int]] | None = None, - encoder_blocks_per_scale: int = 2, - ) -> None: - """ - Initialize the SeparableUNet. - - Parameters - ---------- - - """ - class SeparableUNet(UNetBase): - """ - UNet with separable convolutions and attention, and grouped downsampling. - """ + """UNet with separable convolutions and attention, and grouped downsampling.""" def __init__( self, - dim:int, + dim: int, dim_groups: Sequence[tuple[int, ...]], channels_in: int, channels_out: int, @@ -497,35 +467,33 @@ def __init__( """ depth = len(n_features) for group in dim_groups: - if len(group)>3: - raise ValueError(f"dim_group {group} can at most contain 3 dimensions. Split it into multiple groups.") - if any(d>dim+2 or d<-dim for d in group): - raise ValueError(f"dim_group {group} contains dimensions that are out of range for dim={dim}") + if len(group) > 3: + raise ValueError(f'dim_group {group} can at most contain 3 dimensions. Split it into multiple groups.') + if any(d > dim + 2 or d < -dim for d in group): + raise ValueError(f'dim_group {group} contains dimensions that are out of range for dim={dim}') attention_depths = tuple(d % depth for d in attention_depths) if downsample_dims is None: all_spatial_dims = tuple( - sorted(list(set(d if d<0 else d-dim-2 for group in dim_groups for d in group))) + sorted(list(set(d if d < 0 else d - dim - 2 for group in dim_groups for d in group))) ) downsample_dims = (all_spatial_dims,) * (depth - 1) - def downsampler(level_dims, c_in, c_out) -> Module: - if len(level_dims)>3: - sequence=Sequence(downsampler(d[0], c_in, c_out) for d in level_dims) - for d in level_dims[1:]: - sequence.append(downsampler(d, c_out, c_out)) - return sequence - return PermutedBlock( - level_dims, ConvND(len(level_dims))(c_in, c_out, 3, stride=2, padding=1)) + if len(level_dims) > 3: + sequence = Sequence(downsampler(d[0], c_in, c_out) for d in level_dims) + for d in level_dims[1:]: + sequence.append(downsampler(d, c_out, c_out)) + return sequence + return PermutedBlock(level_dims, ConvND(len(level_dims))(c_in, c_out, 3, stride=2, padding=1)) def upsampler(level_dims, c_in, c_out) -> Module: - if len(level_dims)>3: - sequence=Sequence(upsampler(d[0], c_in, c_out) for d in level_dims) + if len(level_dims) > 3: + sequence = Sequence(upsampler(d[0], c_in, c_out) for d in level_dims) for d in level_dims[1:]: sequence.append(upsampler(d, c_out, c_out)) return sequence - return PermutedBlock(level_dims, Upsample(len(level_dims), scale_factor=2, mode="nearest")) + return PermutedBlock(level_dims, Upsample(len(level_dims), scale_factor=2, mode='nearest')) def block(c_in: int, c_out: int, apply_attention: bool) -> Module: res_block = SeparableResBlock(dim_groups, c_in, c_out, cond_dim) @@ -548,7 +516,9 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: c_feat = n_feat_level skip_features.append(c_feat) if i_level < depth - 1: - down_blocks.append(_create_downsampler(downsample_dims_per_level[i_level], c_feat, n_features[i_level + 1])) + down_blocks.append( + _create_downsampler(downsample_dims_per_level[i_level], c_feat, n_features[i_level + 1]) + ) c_feat = n_features[i_level + 1] # -- Middle & Encoder Finalization -- @@ -575,13 +545,15 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: # -- Decoder Finalization -- concat_blocks = [Concat()] * len(decoder_blocks) last_block = Sequential( - GroupNorm(n_features[0]), SiLU(), - PermutedBlock(all_spatial_dims, ConvND(len(all_spatial_dims))(n_features[0], channels_out, 3, padding=1)) + GroupNorm(n_features[0]), + SiLU(), + PermutedBlock(all_spatial_dims, ConvND(len(all_spatial_dims))(n_features[0], channels_out, 3, padding=1)), ) decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) super().__init__(encoder, decoder) + # class SpatioTemporalUNet(UNetBase): # """UNet where blocks apply separable convolutions in different dimensions. # U-shaped convolutional network with optional patch attention. From a242e55fe3a5a28cf2ff19ae3e0e7ab66f96cdf5 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sat, 28 Jun 2025 17:33:48 +0200 Subject: [PATCH 069/205] add --- src/mrpro/nn/PermutedBlock.py | 55 ++++++++++ src/mrpro/nn/SeparableResBlock.py | 170 ++++++++++++++++++++++++++++++ 2 files changed, 225 insertions(+) create mode 100644 src/mrpro/nn/PermutedBlock.py create mode 100644 src/mrpro/nn/SeparableResBlock.py diff --git a/src/mrpro/nn/PermutedBlock.py b/src/mrpro/nn/PermutedBlock.py new file mode 100644 index 000000000..80ab0aec7 --- /dev/null +++ b/src/mrpro/nn/PermutedBlock.py @@ -0,0 +1,55 @@ +from collections.abc import Sequence + +import torch +from torch import nn + +from mrpro.nn.CondMixin import CondMixin, call_with_cond + + +class PermutedBlock(CondMixin, nn.Module): + """Apply a submodule along selected spatial dimensions.""" + + apply_along_dim: tuple[int, ...] + module: nn.Module + + def __init__(self, apply_along_dim: Sequence[int], module: nn.Module, features_last: bool = False): + """Initialize the PermutedBlock. + + Parameters + ---------- + apply_along_dim + Spatial dimension indices to use when applying the module. + These will be moved to the last dimensions. + module + Module to apply on the selected dims. + features_last + If True, the features dimension is assumed to be the last dimension, as common in transformer models. + """ + super().__init__() + self.apply_along_dim = tuple(sorted(apply_along_dim)) + self.module = module + self.features_last = features_last + + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module along the selected dimensions.""" + keep = tuple(d % x.ndim for d in self.apply_along_dim) + if 0 in keep: + raise ValueError('Batch dimension should not be in apply_along_dim.') + if self.features_last: + if x.ndim - 1 in keep: + raise ValueError('Features dimension should not be in apply_along_dim.') + keep = tuple(d % (x.ndim - 1) for d in self.apply_along_dim) + batch_dim = tuple(d for d in range(x.ndim - 1) if d not in keep) + permute = (0, *batch_dim, *keep, -1) + else: + if 1 in keep: + raise ValueError('Features dimension should not be in apply_along_dim.') + batch_dim = tuple(d for d in range(2, x.ndim) if d not in keep) + permute = (0, *batch_dim, 1, *keep) + h = x.permute(permute) + batch_shape = h.shape[: 1 + len(batch_dim)] + h = h.flatten(0, len(batch_dim) + 1) + h = call_with_cond(self.module, h, cond=cond) + h = h.unflatten(0, batch_shape) + permute_back = torch.tensor(permute).argsort().tolist() + return h.permute(permute_back) diff --git a/src/mrpro/nn/SeparableResBlock.py b/src/mrpro/nn/SeparableResBlock.py new file mode 100644 index 000000000..c26a012cf --- /dev/null +++ b/src/mrpro/nn/SeparableResBlock.py @@ -0,0 +1,170 @@ +from collections.abc import Sequence + +import torch +from torch.nn import Module, SiLU + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import ConvND +from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.Sequential import Sequential + + +class SeparableResBlock(Module): + """Residual block with separable convolutions and ReZero.""" + + def __init__( + self, + dim_groups: Sequence[Sequence[int]], + channels_in: int, + channels_out: int, + cond_dim: int, + ) -> None: + """Initialize the SeparableResBlock. + + Applies convolutions as separable convolutions with SilU activation and group normalization. + For example, if ``dim_groups = ((-1,-2), (-3))`` then one 2D convolution is applied to the last two dimensions, + and one 1D convolution is applied to the last dimension. + The order within the block is Norm->Activation->Conv. + The whole sequence for all dimension groups is performed twice, with optional FiLM conditioning in between. + So for two `dim_groups`, a total of 4 convolutions are applied. + + Parameters + ---------- + dim_groups + Sequence of dimension groups to use in the convolutions. + channels_in + Number of input channels. + channels_out + Number of output channels. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + """ + super().__init__() + self.rezero = torch.nn.Parameter(torch.tensor(0.1)) + + def block(dims: Sequence[int], channels_in: int) -> Module: + return Sequential( + GroupNorm(channels_in), + SiLU(), + PermutedBlock(dims, ConvND(len(dims))(channels_in, channels_out, 3, padding=1)), + ) + + blocks = Sequential(*(block(d, channels_in if i == 0 else channels_out) for i, d in enumerate(dim_groups))) + if cond_dim > 0: + blocks.append(FiLM(channels_out, cond_dim)) + blocks.extend(block(d, channels_out) for d in dim_groups) + self.block = blocks + self.skip_connection = None + if channels_in != channels_out: + self.skip_connection = torch.nn.Linear(channels_in, channels_out) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. + + Returns + ------- + Output tensor with the same number and order of dimensions as the input. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock.""" + h = self.block(x, cond=cond) + if self.skip_connection is None: + skip = x + else: + skip = torch.moveaxis(x, 1, -1) + skip = self.skip_connection(skip) + skip = torch.moveaxis(skip, -1, 1) + return skip + self.rezero * h + + +from collections.abc import Sequence + +import torch +from torch.nn import Module + + +class SeparableResBlock(Module): + """Residual block with separable convolutions and ReZero.""" + + def __init__( + self, + dim_groups: Sequence[Sequence[int]], + channels_in: int, + channels_out: int, + cond_dim: int, + ) -> None: + """Initialize the SeparableResBlock. + + Applies convolutions as separable convolutions with SilU activation and group normalization. + For example, if ``dim_groups = ((-1,-2), (-3))`` then one 2D convolution is applied to the last two dimensions, + and one 1D convolution is applied to the last dimension. + The order within the block is Norm->Activation->Conv. + The whole sequence for all dimension groups is performed twice, with optional FiLM conditioning in between. + So for two `dim_groups`, a total of 4 convolutions are applied. + + Parameters + ---------- + dim_groups + Sequence of dimension groups to use in the convolutions. + channels_in + Number of input channels. + channels_out + Number of output channels. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + """ + super().__init__() + self.rezero = torch.nn.Parameter(torch.tensor(0.1)) + + def block(dims: Sequence[int], channels_in: int) -> Module: + return Sequential( + GroupNorm(channels_in), + SiLU(), + PermutedBlock(dims, ConvND(len(dims))(channels_in, channels_out, 3, padding=1)), + ) + + blocks = Sequential(*(block(d, channels_in if i == 0 else channels_out) for i, d in enumerate(dim_groups))) + if cond_dim > 0: + blocks.append(FiLM(channels_out, cond_dim)) + blocks.extend(block(d, channels_out) for d in dim_groups) + self.block = blocks + self.skip_connection = None + if channels_in != channels_out: + self.skip_connection = torch.nn.Linear(channels_in, channels_out) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. + + Returns + ------- + Output tensor with the same number and order of dimensions as the input. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock.""" + h = self.block(x, cond=cond) + if self.skip_connection is None: + skip = x + else: + skip = torch.moveaxis(x, 1, -1) + skip = self.skip_connection(skip) + skip = torch.moveaxis(skip, -1, 1) + return skip + self.rezero * h From 6861ca905258f5bb42549b39feb8bff065ba2cc4 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 1 Jul 2025 14:43:12 +0200 Subject: [PATCH 070/205] change tol --- src/mrpro/operators/ConjugateGradientOp.py | 4 ++-- src/mrpro/operators/OptimizerOp.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/mrpro/operators/ConjugateGradientOp.py b/src/mrpro/operators/ConjugateGradientOp.py index 6823543a9..9da2e025b 100644 --- a/src/mrpro/operators/ConjugateGradientOp.py +++ b/src/mrpro/operators/ConjugateGradientOp.py @@ -131,8 +131,8 @@ def __init__( operator_factory: Callable[..., LinearOperatorMatrix | LinearOperator], rhs_factory: Callable[..., tuple[torch.Tensor, ...]], implicit_backward: bool = True, - tolerance: float = 1e-8, - max_iterations: int = 10000, + tolerance: float = 1e-6, + max_iterations: int = 100, ): r"""Initialize a conjugate gradient operator. diff --git a/src/mrpro/operators/OptimizerOp.py b/src/mrpro/operators/OptimizerOp.py index 23f189e12..56bb23a6c 100644 --- a/src/mrpro/operators/OptimizerOp.py +++ b/src/mrpro/operators/OptimizerOp.py @@ -21,8 +21,8 @@ lbfgs, learning_rate=1.0, max_iterations=40, - tolerance_change=1e-8, - tolerance_grad=1e-7, + tolerance_change=1e-6, + tolerance_grad=1e-6, history_size=20, line_search_fn='strong_wolfe', ) @@ -81,9 +81,9 @@ def forward( """Optimize.""" ctx.factory = factory - parameters_ = tuple(p.detach().clone() for p in parameters if isinstance(p, torch.Tensor)) - initial_values_ = tuple(x.detach().requires_grad_(True) for x in initial_values if isinstance(x, torch.Tensor)) - objective = factory(*parameters) + parameters_ = tuple(p.detach().clone() for p in parameters) + initial_values_ = tuple(x.detach().requires_grad_(True) for x in initial_values) + objective = factory(*parameters_) solution = optimize(objective, initial_values) ctx.save_for_backward(*solution, *parameters_) ctx.len_x = len(initial_values_) @@ -105,7 +105,7 @@ def backward(ctx: OptimizerCtx, *grad_outputs: torch.Tensor) -> tuple[torch.Tens def hvp(*v: torch.Tensor) -> tuple[torch.Tensor, ...]: return torch.autograd.functional.vhp(lambda *x: objective(*x)[0], solution, v=v)[1] - hessian_inverse_grad = cg(hvp, grad_outputs, max_iterations=100, tolerance=1e-7) + hessian_inverse_grad = cg(hvp, grad_outputs, max_iterations=50, tolerance=1e-6) with torch.enable_grad(): dobjective_dsolution = torch.autograd.grad(objective(*solution), solution, create_graph=True) # - d^2_obective / d_solution d_params Hessian^-1 * grad From eadd3a6f86c1eccb74636094550390e717fb1743 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 1 Jul 2025 14:44:39 +0200 Subject: [PATCH 071/205] fix brainweb --- src/mrpro/phantoms/__init__.py | 4 +- src/mrpro/phantoms/brainweb.py | 77 +++++++++++++++++++++------------- 2 files changed, 49 insertions(+), 32 deletions(-) diff --git a/src/mrpro/phantoms/__init__.py b/src/mrpro/phantoms/__init__.py index 1c9ca5643..501689cfe 100644 --- a/src/mrpro/phantoms/__init__.py +++ b/src/mrpro/phantoms/__init__.py @@ -15,6 +15,6 @@ "FastMRIKDataDataset", "M4RawDataset", "brainweb", - "mdcnn", - "coils" + "coils", + "mdcnn" ] \ No newline at end of file diff --git a/src/mrpro/phantoms/brainweb.py b/src/mrpro/phantoms/brainweb.py index e507efde8..ab1e754b0 100644 --- a/src/mrpro/phantoms/brainweb.py +++ b/src/mrpro/phantoms/brainweb.py @@ -156,7 +156,7 @@ def augment_fn(data: torch.Tensor, rng: torch.Generator | None = None) -> torch. scale *= 1 + max_random_scaling_factor * rand[3] translate = rand[4:6] # subpixel translation for edge aliasing if trim: - data = data[trim_indices(data.sum(-1) > 0.1 * data.amax())] + data = data[[slice(None), *trim_indices(data.sum(0) > 0.1 * data.amax())]] data = torchvision.transforms.functional.affine( data, @@ -244,7 +244,7 @@ def trim_indices(mask: torch.Tensor) -> tuple[slice, slice]: ) """Tissue values for 3T with wide randomization ranges.""" -DEFAULT_VALUES = {'r1': 0.0, 'm0': 0.0, 'r2': 0.0, 'mask': 0, 'tissueclass': -1, 't1': torch.inf, 't2': torch.inf} +DEFAULT_VALUES = {'r1': 0.0, 'm0': 0.0, 'r2': 0.0, 'mask': 0, 'tissueclass': -1, 't1': 10.0, 't2': 0.0} """Default values for masked out regions.""" @@ -289,8 +289,9 @@ def load_file( return response.content except requests.exceptions.RequestException: if attempt == max_retries - 1: - raise + break time.sleep(retry_delay) + raise ConnectionError(f'Failed to download {url} after {max_retries} attempts.') def unpack(data: bytes, dtype: np.typing.DTypeLike, shape: Sequence[int]) -> np.ndarray: """Unpack gzipped data.""" @@ -553,7 +554,11 @@ def __init__( folder: str | Path = CACHE_DIR_BRAINWEB, what: Sequence[Literal['r1', 'r2', 'm0', 't1', 't2', 'mask', 'tissueclass'] | TClassNames] = ('m0', 'r1', 'r2'), parameters: Mapping[TClassNames, BrainwebTissue] = VALUES_3T_RANDOMIZED, - orientation: Literal['axial', 'coronal', 'sagittal'] = 'axial', + orientation: Literal['axial', 'coronal', 'sagittal'] | Sequence[Literal['axial', 'coronal', 'sagittal']] = ( + 'axial', + 'coronal', + 'sagittal', + ), skip_slices: tuple[tuple[int, int], tuple[int, int], tuple[int, int]] = ((80, 80), (100, 100), (100, 100)), step: int = 1, slice_preparation: Callable[[torch.Tensor, torch.Generator | None], torch.Tensor] = DEFAULT_AUGMENT_256, @@ -610,23 +615,36 @@ def __init__( self.what = what self.mask_values = mask_values + if isinstance(orientation, str): + orientation = [orientation] + elif len(orientation) != len(set(orientation)): + raise ValueError('Orientations must be unique.') try: - self._axis = {'axial': 0, 'coronal': 1, 'sagittal': 2}[orientation] + self.axes = [{'axial': 0, 'coronal': 1, 'sagittal': 2}[o] for o in orientation] except KeyError: raise ValueError(f'Invalid axis: {orientation}.') from None - self._skip_slices = skip_slices[self._axis] - - files = [] - ns_slices = [0] - for fn in Path(folder).glob('s??.h5'): - with h5py.File(fn) as f: - n_slices = f['classes'].shape[self._axis] - self._skip_slices[0] - self._skip_slices[1] - ns_slices.append(n_slices) - files.append(fn) - if not files: + + self.skip_slices = skip_slices + + files_and_axes = [] + ns_slices = [] + h5_files = sorted(Path(folder).glob('s??.h5')) + if not h5_files: raise FileNotFoundError(f'No files found in {folder}.') - self._files = tuple(files) - self._ns_slices = np.cumsum(ns_slices) + + for axis in self.axes: + skip_start, skip_end = self.skip_slices[axis] + for fn in h5_files: + with h5py.File(fn) as f: + n_slices = f['classes'].shape[axis] - skip_start - skip_end + if n_slices > 0: + files_and_axes.append((fn, axis)) + ns_slices.append(n_slices) + if not files_and_axes: + raise FileNotFoundError(f'After skipping {self.skip_slices} slices, no images are left.') + + self._files_and_axes = tuple(files_and_axes) + self._ns_slices = np.cumsum([0, *ns_slices]) self.slice_preparation = slice_preparation @@ -645,21 +663,19 @@ def __getitem__( self, index: int ) -> dict[Literal['r1', 'r2', 'm0', 't1', 't2', 'mask', 'tissueclass'] | TClassNames, torch.Tensor]: """Get a single slice.""" - if index * self.step >= self._ns_slices[-1]: + if index < 0: + index = len(self) + index + if not 0 <= index < len(self): raise IndexError - elif index < 0: - index = self._ns_slices[-1] + index * self.step - else: - index = index * self.step + index = index * self.step - file_id = np.searchsorted(self._ns_slices, index, 'right') - 1 - slice_id = index - self._ns_slices[file_id] + self._skip_slices[0] + chunk_id = np.searchsorted(self._ns_slices, index, 'right') - 1 + file_path, axis = self._files_and_axes[chunk_id] + slice_id = index - self._ns_slices[chunk_id] + self.skip_slices[axis][0] - with h5py.File(self._files[file_id]) as file: - where = [slice(self._skip_slices[0], file['classes'].shape[i] - self._skip_slices[1]) for i in range(3)] + [ - slice(None) - ] - where[self._axis] = slice_id + with h5py.File(file_path) as file: + where = [slice(None)] * 3 + where[axis] = slice_id data = torch.as_tensor(np.array(file['classes'][tuple(where)], dtype=np.uint8)) classnames = tuple(file.attrs['classnames']) @@ -691,9 +707,10 @@ def __getitem__( elif el in classnames: result[el] = data[..., classnames.index(el)] elif el == 'mask': - result[el] = ~( + result[el] = ( torch.nn.functional.conv2d((~mask)[None, None].float(), torch.ones(1, 1, 3, 3), padding=1)[0, 0] < 1 ) + else: raise NotImplementedError(f'what=({el},) is not implemented.') From 7d9dd14f13585d9949b602e49c109b0a98217d84 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 1 Jul 2025 14:44:54 +0200 Subject: [PATCH 072/205] fix sat rec --- src/mrpro/operators/models/SaturationRecovery.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/operators/models/SaturationRecovery.py b/src/mrpro/operators/models/SaturationRecovery.py index b86c97377..ebe278336 100644 --- a/src/mrpro/operators/models/SaturationRecovery.py +++ b/src/mrpro/operators/models/SaturationRecovery.py @@ -11,7 +11,7 @@ class SaturationRecovery(SignalModel[torch.Tensor, torch.Tensor]): """Signal model for saturation recovery.""" - def __init__(self, saturation_time: float | torch.Tensor | Sequence[int]) -> None: + def __init__(self, saturation_time: float | torch.Tensor | Sequence[float]) -> None: """Initialize saturation recovery signal model for T1 mapping. Parameters From 1322b515d9f13025ee5be3852743f057edb6ff77 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 1 Jul 2025 14:45:16 +0200 Subject: [PATCH 073/205] fix csmdata init typing --- src/mrpro/data/CsmData.py | 2 +- src/mrpro/data/Dataclass.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/mrpro/data/CsmData.py b/src/mrpro/data/CsmData.py index 5f9612522..0f3059663 100644 --- a/src/mrpro/data/CsmData.py +++ b/src/mrpro/data/CsmData.py @@ -48,7 +48,7 @@ def get_downsampled_size( ) -class CsmData(QData): +class CsmData(QData, init=False): """Coil sensitivity map class.""" @classmethod diff --git a/src/mrpro/data/Dataclass.py b/src/mrpro/data/Dataclass.py index 9db906359..3fb264b94 100644 --- a/src/mrpro/data/Dataclass.py +++ b/src/mrpro/data/Dataclass.py @@ -100,6 +100,7 @@ def __init_subclass__( # noqa: D417 cls, no_new_attributes: bool = True, auto_reduce_repeats: bool = True, + init: bool = True, *args, **kwargs, ) -> None: @@ -113,7 +114,7 @@ def __init_subclass__( # noqa: D417 If `True`, try to reduce dimensions only containing repeats to singleton. This will be done after init and post_init. """ - dataclasses.dataclass(cls, repr=False, eq=False) # type: ignore[call-overload] + dataclasses.dataclass(cls, repr=False, eq=False, init=init) # type: ignore[call-overload] super().__init_subclass__(**kwargs) child_post_init = vars(cls).get('__post_init__') From d0d333e5621a015cfd8f4afec456b972eb9530a6 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 1 Jul 2025 14:45:53 +0200 Subject: [PATCH 074/205] fix nn --- src/mrpro/nn/LayerNorm.py | 4 ++-- src/mrpro/nn/PermutedBlock.py | 13 +++++++------ src/mrpro/nn/Sequential.py | 2 +- src/mrpro/nn/SpatialTransformerBlock.py | 1 + src/mrpro/nn/Upsample.py | 14 +++++--------- src/mrpro/nn/nets/UNet.py | 16 +++++++++------- src/mrpro/nn/nets/__init__.py | 2 +- 7 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py index 699de57f0..a90ffa690 100644 --- a/src/mrpro/nn/LayerNorm.py +++ b/src/mrpro/nn/LayerNorm.py @@ -41,7 +41,7 @@ def __init__(self, channels: int | None, features_last: bool = False, cond_dim: self.features_last = features_last - def __call__(self, x: torch.Tensor) -> torch.Tensor: + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply layer normalization to the input tensor. Parameters @@ -53,7 +53,7 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: ------- Normalized output tensor """ - return super().__call__(x) + return super().__call__(x, cond=cond) def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply layer normalization to the input tensor.""" diff --git a/src/mrpro/nn/PermutedBlock.py b/src/mrpro/nn/PermutedBlock.py index 80ab0aec7..8b65ee62b 100644 --- a/src/mrpro/nn/PermutedBlock.py +++ b/src/mrpro/nn/PermutedBlock.py @@ -38,9 +38,8 @@ def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Te if self.features_last: if x.ndim - 1 in keep: raise ValueError('Features dimension should not be in apply_along_dim.') - keep = tuple(d % (x.ndim - 1) for d in self.apply_along_dim) - batch_dim = tuple(d for d in range(x.ndim - 1) if d not in keep) - permute = (0, *batch_dim, *keep, -1) + batch_dim = tuple(d for d in range(1, x.ndim - 1) if d not in keep) + permute = (0, *batch_dim, *keep, x.ndim - 1) else: if 1 in keep: raise ValueError('Features dimension should not be in apply_along_dim.') @@ -48,8 +47,10 @@ def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Te permute = (0, *batch_dim, 1, *keep) h = x.permute(permute) batch_shape = h.shape[: 1 + len(batch_dim)] - h = h.flatten(0, len(batch_dim) + 1) + h = h.flatten(0, len(batch_dim)) h = call_with_cond(self.module, h, cond=cond) h = h.unflatten(0, batch_shape) - permute_back = torch.tensor(permute).argsort().tolist() - return h.permute(permute_back) + permute_back = [0] * x.ndim + for i, p in enumerate(permute): + permute_back[p] = i + return h.permute(tuple(permute_back)) diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py index 15b5d0152..fb56bd43f 100644 --- a/src/mrpro/nn/Sequential.py +++ b/src/mrpro/nn/Sequential.py @@ -9,7 +9,7 @@ from mrpro.operators import Operator -class Sequential(CondMixin,torch.nn.Sequential): +class Sequential(CondMixin, torch.nn.Sequential): """Sequential container with support for conditioning and Operators. Allows multiple input tensors and a single output tensor of the sequential block. diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py index 2b4c3e6e2..906560c24 100644 --- a/src/mrpro/nn/SpatialTransformerBlock.py +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -129,6 +129,7 @@ def __init__( self.proj_in = Linear(channels, hidden_dim) self.transformer_blocks = Sequential() for group in (g for _ in range(depth) for g in dim_groups): + group = tuple(g - 1 if g < 0 else g for g in group) block = BasicTransformerBlock(hidden_dim, n_heads, p_dropout=dropout, cond_dim=cond_dim, features_last=True) self.transformer_blocks.append(PermutedBlock(group, block, features_last=True)) self.proj_out = Linear(hidden_dim, channels) diff --git a/src/mrpro/nn/Upsample.py b/src/mrpro/nn/Upsample.py index ec9b0e032..acced8d48 100644 --- a/src/mrpro/nn/Upsample.py +++ b/src/mrpro/nn/Upsample.py @@ -29,10 +29,10 @@ def __init__( super().__init__() self.scale_factor = scale_factor if mode == 'nearest': - dims = [tuple(d) for d in torch.tensor(dim).split(3)] - modes = ['nearest'] * len(self.dim) + dims = [d.tolist() for d in torch.tensor(dim).split(3)] + modes = ['nearest'] * len(dim) elif mode == 'linear': - dims = [tuple(d) for d in torch.tensor(dim).split(3)] + dims = [d.tolist() for d in torch.tensor(dim).split(3)] modes = [{1: 'linear', 2: 'bilinear', 3: 'trilinear'}[len(d)] for d in dims] elif mode == 'cubic': if not len(dim) == 2: @@ -42,18 +42,14 @@ def __init__( self.blocks = Sequential( *[ - PermutedBlock(d, Upsample(d, scale_factor=scale_factor, mode=m)) + PermutedBlock(d, torch.nn.Upsample(scale_factor=len(d) * (scale_factor,), mode=m)) for d, m in zip(dims, modes, strict=False) ] ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Upsample the input tensor.""" - return torch.nn.functional.interpolate( - x, - mode=self.mode, - scale_factor=self.scale_factor, - ) + return self.blocks(x) def __call__(self, x: torch.Tensor) -> torch.Tensor: """Upsample the input tensor. diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index dd6917704..5258e9cd2 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -219,7 +219,11 @@ def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Se encoder_blocks.append(ResBlock(dim, n_feat, n_feat, cond_dim)) decoder_blocks.append(ResBlock(dim, 2 * n_feat, n_feat, cond_dim)) down_blocks.append(ConvND(dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) - up_blocks.append(Sequential(Upsample(dim, scale_factor=2), ConvND(dim)(n_feat_next, n_feat, 3, padding=1))) + up_blocks.append( + Sequential( + Upsample(tuple(range(-dim, 0)), scale_factor=2), ConvND(dim)(n_feat_next, n_feat, 3, padding=1) + ) + ) concat_blocks.append(Concat()) up_blocks = up_blocks[::-1] decoder_blocks = decoder_blocks[::-1] @@ -256,9 +260,9 @@ def __init__( dim: int, channels_in: int, channels_out: int, - attention_depths: Sequence[int] = (-1, -2), + attention_depths: Sequence[int] = (-1,), n_features: Sequence[int] = (64, 128, 192, 256), - n_heads: int = 4, + n_heads: int = 8, cond_dim: int = 0, encoder_blocks_per_scale: int = 2, ) -> None: @@ -294,9 +298,7 @@ def __init__( def attention_block(channels: int) -> Module: dim_groups = (tuple(range(-dim, 0)),) - return SpatialTransformerBlock( - dim_groups, channels, n_heads, channels_per_head=channels // n_heads, cond_dim=cond_dim - ) + return SpatialTransformerBlock(dim_groups, channels, n_heads, cond_dim=cond_dim) def block(channels_in: int, channels_out: int, attention: bool) -> Module: if not attention: @@ -338,7 +340,7 @@ def block(channels_in: int, channels_out: int, attention: bool) -> Module: block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) ) up_blocks.append(Identity()) - up_blocks.append(Upsample(dim, scale_factor=2)) + up_blocks.append(Upsample(tuple(range(-dim, 0)), scale_factor=2)) up_blocks.pop() # no upsampling after the last resolution level concat_blocks = [Concat() for _ in range(len(decoder_blocks))] last_block = Sequential( diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 6f540e118..50baa2573 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -2,7 +2,7 @@ from mrpro.nn.nets.Uformer import Uformer from mrpro.nn.nets.DCAE import DCVAE from mrpro.nn.nets.VAE import VAE -from mrpro.nn.nets.UNet import UNet, AttentionGatedUNet +from mrpro.nn.nets.UNet import UNet, AttentionGatedUNet, BasicUNet, SeparableUNet from mrpro.nn.nets.SwinIR import SwinIR __all__ = [ From c8340650547f45536968d82ba9047537b0f2302c Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 1 Jul 2025 14:46:12 +0200 Subject: [PATCH 075/205] train_pinqi --- examples/scripts/train_pinqi.py | 483 ++++++++++++++++++++++++++++++++ 1 file changed, 483 insertions(+) create mode 100644 examples/scripts/train_pinqi.py diff --git a/examples/scripts/train_pinqi.py b/examples/scripts/train_pinqi.py new file mode 100644 index 000000000..453de5c37 --- /dev/null +++ b/examples/scripts/train_pinqi.py @@ -0,0 +1,483 @@ +# %% +import collections +from collections.abc import Sequence +from copy import deepcopy +from pathlib import Path +from typing import cast + +import einops +import matplotlib.pyplot as plt +import mrpro +import numpy as np +import pytorch_lightning as pl +import torch +import torch.utils.data._utils +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.loggers import NeptuneLogger +from pytorch_lightning.strategies import DDPStrategy + +# mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True) + + +class Dataset(torch.utils.data.Dataset): + def __init__( + self, + signalmodel: mrpro.operators.SignalModel, + n_images: int, + size: int = 192, + acceleration: int = 10, + n_coils: int = 8, + random: bool = True, + max_noise: float = 0.1, + ): + self.phantom = mrpro.phantoms.brainweb.BrainwebSlices( + what=('m0', 't1', 'mask'), + seed='index' if not random else 'random', + slice_preparation=mrpro.phantoms.brainweb.augment(size=size), + ) + self.signalmodel = signalmodel + self.encoding_matrix = mrpro.data.SpatialDimension(1, size, size) + self.fov = mrpro.data.SpatialDimension(0.01, 0.25, 0.25) + self.acceleration = acceleration + self.n_coils = n_coils + self._random = random + self.max_noise = max_noise + self._n_images = n_images + + def __len__(self) -> int: + return len(self.phantom) + + def __getitem__(self, index: int): + phantom = self.phantom[index] + (images,) = self.signalmodel(phantom['m0'], phantom['t1']) + seed = int(torch.randint(0, 1000000, (1,))) if self._random else index + + traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density( + encoding_matrix=self.encoding_matrix, + seed=seed, + acceleration=self.acceleration, + fwhm_ratio=2, + n_center=8, + n_other=(self._n_images,), + ) + header = mrpro.data.KHeader( + encoding_matrix=self.encoding_matrix, + recon_matrix=self.encoding_matrix, + recon_fov=self.fov, + encoding_fov=self.fov, + ) + + if isinstance(self.signalmodel, mrpro.operators.models.SaturationRecovery): + header.ti = self.signalmodel.saturation_time.tolist() + elif isinstance(self.signalmodel, mrpro.operators.models.InversionRecovery): + header.ti = self.signalmodel.ti.tolist() + + fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj) + csm = mrpro.data.CsmData( + mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix), + header, + ) + images = einops.rearrange(images, 't y x -> t 1 1 y x') + (data,) = (fourier_op @ csm.as_operator())(images) + data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std() + kdata = mrpro.data.KData(header, data, traj) + return {'kdata': kdata, 'csm': csm, **phantom} + + @staticmethod + def collate_fn(batch): + return torch.utils.data._utils.collate.collate( + batch, + collate_fn_map={ + mrpro.data.Dataclass: lambda batch, *, _: batch[0].stack(*batch[1:]), + **torch.utils.data._utils.collate.default_collate_fn_map, + }, + ) + + +class PINQI(torch.nn.Module): + def __init__( + self, + signalmodel: mrpro.operators.SignalModel, + constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp, + parameter_is_complex: Sequence[bool], + n_images: int, + n_iterations: int = 4, + n_features_parameter_net: Sequence[int] = (64, 128, 192, 256), + n_features_image_net: Sequence[int] = (16, 32, 48, 64), + ): + super().__init__() + self.signalmodel = ( + mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ deepcopy(signalmodel) @ constraints_op + ) + self.constraints_op = constraints_op + self._n_images = n_images + self._parameter_is_complex = parameter_is_complex + real_parameters = sum(1 for c in parameter_is_complex if c) + len(parameter_is_complex) + self.parameter_net = mrpro.nn.nets.UNet( + dim=2, + channels_in=n_images * 2, + channels_out=real_parameters, + attention_depths=(-1,), + n_features=n_features_parameter_net, + ) + self.image_net = mrpro.nn.nets.UNet( + 2, + channels_in=2, + channels_out=2, + attention_depths=(), + n_features=n_features_image_net, + ) + self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3)) + self.softplus = torch.nn.Softplus() + + def objective_factory(lambda_parameters: torch.Tensor, image: torch.Tensor, *parameter_reg: torch.Tensor): + dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel + reg = mrpro.operators.ProximableFunctionalSeparableSum( + *[mrpro.operators.functionals.L2NormSquared(r) for r in parameter_reg] + ) + return dc + lambda_parameters * reg + + self.nonlinear_solver = mrpro.operators.OptimizerOp( + objective_factory, + lambda _l, _i, *parameter_reg: parameter_reg, + ) + + def get_linear_solver(self, gram: mrpro.operators.LinearOperator): + def operator_factory( + lambda_image: torch.Tensor, + lambda_q: torch.Tensor, + *_, + ): + return gram + lambda_image + lambda_q + + def rhs_factory( + lambda_image: torch.Tensor, + lambda_q: torch.Tensor, + image_reg: torch.Tensor, + signal: torch.Tensor, + zero_filled_image: torch.Tensor, + ): + return (zero_filled_image + lambda_image * image_reg + lambda_q * signal,) + + return mrpro.operators.ConjugateGradientOp( + operator_factory=operator_factory, + rhs_factory=rhs_factory, + ) + + def get_parameter_reg(self, image: torch.Tensor) -> tuple[torch.Tensor, ...]: + image = einops.rearrange( + torch.view_as_real(image), + 'batch t 1 1 y x complex-> batch (t complex) y x', + ) + parameters = self.parameter_net(image.contiguous()) + parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x') + i = 0 + result = [] + for is_complex in self._parameter_is_complex: + if is_complex: + result.append(torch.complex(parameters[i], parameters[i + 1])) + i += 2 + else: + result.append(parameters[i]) + i += 1 + return tuple(result) + + def get_image_reg(self, image: torch.Tensor) -> torch.Tensor: + batch = image.shape[0] + image = einops.rearrange( + torch.view_as_real(image), + 'batch t 1 1 y x complex-> (batch t) complex y x', + ) + image = image + self.image_net(image.contiguous()) + image = einops.rearrange(image, '(batch t) complex y x-> batch t 1 1 y x complex', batch=batch) + return torch.view_as_complex(image.contiguous()) + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): + csm_op = csm.as_operator() + fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) + aquisition_op = fourier_op @ csm_op + gram = aquisition_op.gram + (zero_filled_image,) = aquisition_op.H(kdata.data) + images = list(mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2)) + parameters = [self.get_parameter_reg(images[-1])] + linear_solver = self.get_linear_solver(gram) + + for lambda_image, lambda_q, lambda_parameter in self.softplus(self.lambdas_raw): + image_reg = self.get_image_reg(images[-1]) + (signal,) = self.signalmodel(*parameters[-1]) + images.extend(linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image)) + parameters_reg = self.get_parameter_reg(images[-1]) + parameters.append(self.nonlinear_solver(lambda_parameter, images[-1], *parameters_reg)) + if self.constraints_op is not None: + parameters = [self.constraints_op(*p) for p in parameters] + return images, parameters + + +class DataModule(pl.LightningDataModule): + def __init__( + self, + batch_size: int = 8, + num_workers: int = 4, + signalmodel: mrpro.operators.SignalModel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 6.0)), + n_images: int = 5, + **kwargs, + ): + super().__init__() + self.batch_size = batch_size + self.num_workers = num_workers + self.train_dataset = Dataset(signalmodel=signalmodel, n_images=n_images, **kwargs, random=True) + self.val_dataset = Dataset(signalmodel=signalmodel, n_images=n_images, **kwargs, random=False) + + def train_dataloader(self): + return torch.utils.data.DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + persistent_workers=self.num_workers > 0, + collate_fn=self.train_dataset.collate_fn, + worker_init_fn=lambda *_: torch.set_num_threads(1), + ) + + def val_dataloader(self): + return torch.utils.data.DataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=True, + persistent_workers=self.num_workers > 0, + collate_fn=self.val_dataset.collate_fn, + ) + + +class Module(pl.LightningModule): + def __init__( + self, + signalmodel: mrpro.operators.SignalModel, + constraints_op: mrpro.operators.ConstraintsOp, + parameter_is_complex: Sequence[bool], + n_images: int, + n_iterations: int = 4, + n_features_parameter_net: Sequence[int] = (64, 128, 192, 256), + n_features_image_net: Sequence[int] = (16, 32, 48, 64), + lr: float = 3e-4, + weight_decay: float = 1e-4, + loss_weights: Sequence[float] = (0.1, 0.1, 0.1, 0.2, 0.5), + ): + super().__init__() + self.save_hyperparameters() + if len(loss_weights) != n_iterations + 1: + raise ValueError(f'loss_weights must be of length {n_iterations + 1} for {n_iterations} iterations') + + self.pinqi = PINQI( + signalmodel=signalmodel, + constraints_op=constraints_op, + parameter_is_complex=parameter_is_complex, + n_images=n_images, + n_iterations=n_iterations, + n_features_parameter_net=n_features_parameter_net, + n_features_image_net=n_features_image_net, + ) + + self.validation_step_outputs = collections.defaultdict(list) + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): + return self.pinqi(kdata, csm) + + def loss(self, predictions, batch): + loss = torch.tensor(0.0, device=self.device) + target_m0 = batch['m0'] + target_t1 = batch['t1'] + mask = batch['mask'] + for prediction, weight in zip(predictions, self.hparams.loss_weights, strict=False): + prediction_m0, prediction_t1 = prediction + loss_t1 = torch.nn.functional.mse_loss(prediction_t1.squeeze()[mask], target_t1[mask]) + loss_m0 = torch.nn.functional.mse_loss( + torch.view_as_real((prediction_m0).squeeze()[mask]), + torch.view_as_real(target_m0[mask]), + ) + loss_outside = prediction_m0[~mask].abs().mean() + loss = loss + weight * (loss_t1 + 0.5 * loss_m0 + 0.1 * loss_outside) + return loss + + def training_step(self, batch, batch_idx): + images, parameters = self(batch['kdata'], batch['csm']) + loss = self.loss(parameters, batch) + self.log( + 'train/loss', + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + return loss + + def validation_step(self, batch, batch_idx): + images, parameters = self(batch['kdata'], batch['csm']) + loss = self.loss(parameters, batch) + + pred_m0, pred_t1 = parameters[-1] + target_t1, target_m0 = batch['t1'], batch['m0'] + mask = batch['mask'] + (ssim_t1,) = mrpro.operators.functionals.SSIM(target_t1, mask)(pred_t1) + (l1_t1,) = mrpro.operators.functionals.L1Norm(target_t1, mask)(pred_t1) + (l1_m0,) = mrpro.operators.functionals.L1Norm(target_m0, mask)(pred_m0) + self.log('val/ssim_t1', ssim_t1, on_epoch=True, sync_dist=True) + self.log('val/l1_t1', l1_t1, on_epoch=True, sync_dist=True) + self.log('val/l1_m0', l1_m0, on_epoch=True, sync_dist=True) + self.log('val/loss', loss, on_epoch=True, sync_dist=True) + + if batch_idx == 0: + self.validation_step_outputs['target_t1'].append(batch['t1']) + self.validation_step_outputs['pred_t1'].append(pred_t1) + self.validation_step_outputs['pred_m0'].append(pred_m0) + self.validation_step_outputs['target_m0'].append(target_m0) + self.validation_step_outputs['mask'].append(batch['mask']) + + def on_validation_epoch_end(self): + outputs = {k: torch.cat(v) for k, v in self.validation_step_outputs.items()} + self.validation_step_outputs.clear() + outputs = cast(dict[str, torch.Tensor], self.all_gather(outputs)) + + if not self.trainer.is_global_zero: + return + outputs = {k: v.flatten(0, 1).cpu() for k, v in outputs.items()} + + samples = 4 + fig, axes = plt.subplots(3, samples, figsize=(4 * samples, 12)) + for i in range(samples): + self.result_plot(outputs['target_t1'][i], outputs['pred_t1'][i], outputs['mask'][i], axes[:, i]) + fig.suptitle(f'T1 Epoch {self.current_epoch}') + self.logger.run['val/images/t1'].log(fig) + plt.close(fig) + + fig, axes = plt.subplots(3, samples, figsize=(4 * samples, 12)) + for i in range(samples): + self.result_plot(outputs['target_m0'][i].abs(), outputs['pred_m0'][i].abs(), outputs['mask'][i], axes[:, i]) + fig.suptitle(f'|M0| Epoch {self.current_epoch}') + self.logger.run['val/images/m0'].log(fig) + plt.close(fig) + + def result_plot(self, target, pred, mask, axes): + target = target.squeeze().numpy() + pred = pred.squeeze().detach().numpy() + mask = mask.squeeze().detach().numpy().astype(bool) + + target[~mask] = np.nan + pred[~mask] = np.nan + difference = target - pred + vmax = np.nanmax(target) + + im1 = axes[0].imshow(target, vmin=0, vmax=vmax) + axes[0].set_title('Target') + axes[0].axis('off') + axes[0].colorbar(im1) + + im2 = axes[1].imshow(pred, vmin=0, vmax=vmax) + axes[1].set_title('Predicted') + axes[1].axis('off') + axes[1].colorbar(im2) + + diff_vmax = np.nanmax(np.abs(difference)) + im3 = axes[2].imshow(difference, cmap='coolwarm', vmin=-diff_vmax, vmax=diff_vmax) + axes[2].set_title('Difference') + axes[2].axis('off') + axes[2].colorbar(im3) + return axes + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.hparams.lr, + weight_decay=self.hparams.weight_decay, + ) + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=self.hparams.lr, + total_steps=self.trainer.max_steps, + pct_start=0.1, + div_factor=10, + final_div_factor=200, + ) + return { + 'optimizer': optimizer, + 'lr_scheduler': scheduler, + } + + +# %% +if __name__ == '__main__': + torch.set_float32_matmul_precision('high') + torch._inductor.config.worker_start_method = 'fork' + torch._inductor.config.compile_threads = 4 + torch._dynamo.config.capture_scalar_outputs = True + torch._functorch.config.activation_memory_budget = 0.9 + torch._dynamo.config.cache_size_limit = 256 + + signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 6.0)) + constraints_op = mrpro.operators.ConstraintsOp( + bounds=( + (-2, 2), # M0 in [-2, 2] + (0.01, 6.0), # T1 is constrained between 10 ms and 6 s + ) + ) + n_images = len(signalmodel.saturation_time) + parameter_is_complex = [True, False] + + dm = DataModule( + signalmodel=signalmodel, + n_images=n_images, + batch_size=16, + num_workers=8, + pin_memory=True, + size=192, + acceleration=10, + n_coils=8, + max_noise=0.1, + ) + + model = Module( + signalmodel=signalmodel, + constraints_op=constraints_op, + parameter_is_complex=parameter_is_complex, + n_images=n_images, + lr=3e-4, + weight_decay=1e-4, + n_iterations=4, + ) + + neptune_logger = NeptuneLogger( + log_model_checkpoints=False, + dependencies='infer', + ) + neptune_logger.log_hyperparams(model.hparams) + + checkpoint_callback = ModelCheckpoint( + monitor='val/loss', + mode='min', + save_top_k=2, + dirpath=Path('checkpoints') / str(neptune_logger.version), + filename='{epoch:02d}-{val/loss:.4f}', + save_last=True, + ) + + strategy = DDPStrategy(find_unused_parameters=False) + + trainer = pl.Trainer( + max_epochs=50, + accelerator='gpu', + devices=4, + strategy=strategy, + logger=neptune_logger, + callbacks=[ + LearningRateMonitor(logging_interval='step'), + checkpoint_callback, + ], + log_every_n_steps=10, + precision='16-mixed', + ) + + trainer.fit(model, datamodule=dm) From 096038e220c95ed4d57ddd49c3da08edc5a51abd Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 1 Jul 2025 14:59:03 +0200 Subject: [PATCH 076/205] update nn --- src/mrpro/nn/LayerNorm.py | 6 +- src/mrpro/nn/PermutedBlock.py | 2 + src/mrpro/nn/nets/UNet.py | 130 +--------------------------------- src/mrpro/nn/nets/__init__.py | 2 + 4 files changed, 11 insertions(+), 129 deletions(-) diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py index a90ffa690..84e1d56e4 100644 --- a/src/mrpro/nn/LayerNorm.py +++ b/src/mrpro/nn/LayerNorm.py @@ -34,10 +34,12 @@ def __init__(self, channels: int | None, features_last: bool = False, cond_dim: self.weight = Parameter(torch.ones(channels)) self.bias = Parameter(torch.zeros(channels)) self.cond_proj = None - else: + elif channels is not None: self.weight = None self.bias = None self.cond_proj = Linear(cond_dim, 2 * channels) + else: + raise ValueError('cond_dim must be zero or positive.') self.features_last = features_last @@ -48,6 +50,8 @@ def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torc ---------- x Input tensor + cond + Conditioning tensor. If `None`, no conditioning is applied. Returns ------- diff --git a/src/mrpro/nn/PermutedBlock.py b/src/mrpro/nn/PermutedBlock.py index 8b65ee62b..99a27f36a 100644 --- a/src/mrpro/nn/PermutedBlock.py +++ b/src/mrpro/nn/PermutedBlock.py @@ -1,3 +1,5 @@ +"""Block that applies a submodule along selected spatial dimensions.""" + from collections.abc import Sequence import torch diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 5258e9cd2..e7a8f07bb 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -460,7 +460,7 @@ def __init__( Sequence specifying which absolute spatial dimensions to downsample at each encoder level. If None, all dimensions in `dim_groups` are combined and downsampled at each level. - If a downsampling step contains more than 3 dimensions, downsampling is performed separatly for each + If a downsampling step contains more than 3 dimensions, downsampling is performed separately for each dimension. If the length of the sequence is less than the number of resolution levels, the sequence is repeated. E.g., ``((-1,-2), (-1,-2,-3))`` for 3D data: first level downsamples x,y; second level x,y,z; third level x,y. @@ -476,9 +476,7 @@ def __init__( attention_depths = tuple(d % depth for d in attention_depths) if downsample_dims is None: - all_spatial_dims = tuple( - sorted(list(set(d if d < 0 else d - dim - 2 for group in dim_groups for d in group))) - ) + all_spatial_dims = tuple(sorted(set(d if d < 0 else d - dim - 2 for group in dim_groups for d in group))) downsample_dims = (all_spatial_dims,) * (depth - 1) def downsampler(level_dims, c_in, c_out) -> Module: @@ -554,127 +552,3 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) super().__init__(encoder, decoder) - - -# class SpatioTemporalUNet(UNetBase): -# """UNet where blocks apply separable convolutions in different dimensions. -# U-shaped convolutional network with optional patch attention. -# Inspired by the OpenAi DDPM UNet/Latent Diffusion UNet [UNET]_, [LDM]_, -# Based on the pseudo-3D residual network of [QUI]_, [TRAN]_, [HO]_, and the residual blocks of [ZIM]_. - -# References -# ---------- -# .. [UNET] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image -# segmentation MICCAI 2015. https://arxiv.org/abs/1505.04597 -# .. [LDM] https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py -# .. [TRAN] Tran, D., Wang, H., Torresani, L., Ray, J., LeCun, Y., & Paluri, M. A closer look at spatiotemporal -# convolutions for action recognition. CVPR 2018. https://arxiv.org/abs/1711.11248 -# .. [QUI] Qiu, Z., Yao, T., & Mei, T. Learning spatio-temporal representation with pseudo-3d residual networks. -# ICCV 2017. https://arxiv.org/abs/1711.10305 -# .. [HO] Ho, J., Salimans, T., Gritsenko, A., Chan, W., Norouzi, M., & Fleet, D. J. Video diffusion models. -# NeurIPS 2022. https://arxiv.org/abs/2209.11168 -# .. [ZIM] Zimmermann, F. F., & Kofler, A. (2023, October). NoSENSE: Learned unrolled cardiac MRI reconstruction -# without explicit sensitivity maps. STACOM MICCAI 2023. https://arxiv.org/abs/2309.15608 -# """ - - -# def __init__( -# self, -# dim: int, -# in_channels: int, -# out_channels: int, -# attention_depths: Sequence[int] = (-1, -2), -# n_features: Sequence[int] = (64, 128, 192, 256), -# n_heads: int = 4, -# cond_dim: int = 0, -# encoder_blocks_per_scale: int = 2, -# temporal_downsampling: bool = False, -# ) -> None: -# """Initialize the UNet. - -# Parameters -# ---------- -# dim -# Spatial dimension of the input tensor. -# channels_in -# Number of channels in the input tensor. -# channels_out -# Number of channels in the output tensor. -# attention_depths -# The depths at which to apply attention. -# n_features -# Number of features at each resolution level. The length determines the number of resolution levels. -# n_heads -# Number of attention heads. -# cond_dim -# Number of channels in the conditioning tensor. If 0, no conditioning is applied. -# encoder_blocks_per_scale -# Number of encoder blocks per resolution level. The number of decoder blocks is one more. -# temporal_downsampling -# Whether to downsample the temporal dimension. -# """ -# depth = len(n_features) -# if not all(-depth <= d < depth for d in attention_depths): -# raise ValueError( -# f'attention_depths must be in the range [-depth, depth], got {attention_depths=} for {depth=}' -# ) -# attention_depths = tuple(d % depth for d in attention_depths) -# if len(attention_depths) != len(set(attention_depths)): -# raise ValueError(f'attention_depths must be unique, got {attention_depths=}') - -# def attention_block(channels: int) -> Module: -# SpatioTemporalBlock(SpatialTransformerBlock( -# dim, channels, n_heads, channels_per_head=channels // n_heads, cond_dim=cond_dim -# ) - -# def block(channels_in: int, channels_out: int, attention: bool) -> Module: -# if not attention: -# return ResBlock(dim, channels_in, channels_out, cond_dim) -# return Sequential(ResBlock(dim, channels_in, channels_out, cond_dim), attention_block(channels_out)) - -# first_block = ConvND(dim)(in_channels, n_features[0], 3, padding=1) -# encoder_blocks: list[Module] = [] -# down_blocks: list[Module] = [] -# skip_features = [] -# n_feat_old = n_features[0] -# for i_level, n_feat in enumerate(n_features): -# encoder_blocks.append(Identity()) -# skip_features.append(n_feat_old) -# for _ in range(encoder_blocks_per_scale): -# encoder_blocks.append(block(n_feat_old, n_feat, attention=i_level in attention_depths)) -# n_feat_old = n_feat -# down_blocks.append(Identity()) -# skip_features.append(n_feat_old) -# down_blocks.append(ConvND(dim)(n_feat, n_feat, 3, stride=2, padding=1)) -# down_blocks[-1] = Identity() # no downsampling after the last resolution level -# middle_block = Sequential( -# ResBlock(dim, n_features[-1], n_features[-1], cond_dim), -# ResBlock(dim, n_features[-1], n_features[-1], cond_dim), -# ) -# if i_level in attention_depths: -# middle_block.insert(1, attention_block(n_features[-1])) -# encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) - -# decoder_blocks: list[Module] = [] -# up_blocks: list[Module] = [Identity()] -# for i_level, n_feat in reversed(list(enumerate(n_features))): -# decoder_blocks.append( -# block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) -# ) -# n_feat_old = n_feat -# for _ in range(encoder_blocks_per_scale): -# decoder_blocks.append( -# block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) -# ) -# up_blocks.append(Identity()) -# up_blocks.append(Upsample(dim, scale_factor=2)) -# up_blocks.pop() # no upsampling after the last resolution level -# concat_blocks = [Concat()] * len(decoder_blocks) -# last_block = Sequential( -# GroupNorm(n_features[0]), -# SiLU(), -# ConvND(dim)(n_features[0], out_channels, 3, padding=1), -# ) -# decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) - -# super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 50baa2573..291ea52ca 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -7,8 +7,10 @@ __all__ = [ "AttentionGatedUNet", + "BasicUNet", "DCVAE", "Restormer", + "SeparableUNet", "SwinIR", "UNet", "Uformer", From a66f0dbae7ff76fa117b5e4c85ab6bfc93bdaba6 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 1 Jul 2025 14:59:25 +0200 Subject: [PATCH 077/205] update dataclass --- src/mrpro/data/Dataclass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/data/Dataclass.py b/src/mrpro/data/Dataclass.py index 3fb264b94..10d30a638 100644 --- a/src/mrpro/data/Dataclass.py +++ b/src/mrpro/data/Dataclass.py @@ -825,7 +825,7 @@ def concatenate(self, *others: Self, dim: int) -> Self: return new def stack(self, *others: Self) -> Self: - """Stack other along new first dimension + """Stack other along new first dimension. Parameters ---------- From 4b9508decde5892dbff999f97947b21449e372f4 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 1 Jul 2025 14:59:44 +0200 Subject: [PATCH 078/205] train pinqi --- examples/scripts/pinqi.py | 140 ------------ examples/scripts/pinqi.py.bak | 375 ++++++++++++++++++++++++++++++++ examples/scripts/train_pinqi.py | 23 +- 3 files changed, 391 insertions(+), 147 deletions(-) delete mode 100644 examples/scripts/pinqi.py create mode 100644 examples/scripts/pinqi.py.bak diff --git a/examples/scripts/pinqi.py b/examples/scripts/pinqi.py deleted file mode 100644 index 73d57e3c9..000000000 --- a/examples/scripts/pinqi.py +++ /dev/null @@ -1,140 +0,0 @@ -# %% -import einops -import einops.layers -import mrpro -import torch - - -# %% -class Dataset(torch.utils.data.Dataset): - def __init__(self, size=64, acceleration=8, n_coils=8, random=True): - self.phantom = mrpro.phantoms.brainweb.BrainwebSlices( - what=('m0', 't1', 'mask'), - seed='index' if not random else 'random', - slice_preparation=mrpro.phantoms.brainweb.augment(size=size), - ) - self.signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2, 8)) - self.encoding_matrix = mrpro.data.SpatialDimension(1, size, size) - self.fov = mrpro.data.SpatialDimension(0.01, 0.25, 0.25) - self.acceleration = acceleration - self.n_coils = n_coils - self._random = random - - @property - def n_images(self): - return 5 - - @property - def n_parameters(self): - return 2 - - def __len__(self): - return len(self.phantom) - - def __getitem__(self, index): - phantom = self.phantom[index] - (images,) = self.signalmodel(phantom['m0'], phantom['t1']) - seed = torch.randint(0, 1000000, (1,)).item() if self._random else index - traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density( - encoding_matrix=self.encoding_matrix, - seed=seed, - fwhm_ratio=2, - ) - header = mrpro.data.KHeader( - encoding_matrix=self.encoding_matrix, - recon_matrix=self.encoding_matrix, - recon_fov=self.fov, - encoding_fov=self.fov, - ) - header.ti = self.signalmodel.saturation_time.tolist() - fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj) - csm = mrpro.data.CsmData(mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix), header) - images = einops.rearrange(images, 't y x -> t 1 1 y x') - (data,) = (fourier_op @ csm.as_operator())(images) - kdata = mrpro.data.KData(header, data, traj) - return {'kdata': kdata, 'csm': csm, **phantom} - - @staticmethod - def collate_fn(batch): - return torch.utils.data._utils.collate.collate( - batch, - collate_fn_map={ - mrpro.data.Dataclass: lambda batch, *, collate_fn_map: batch[0].stack(*batch[1:]), - **torch.utils.data._utils.collate.default_collate_fn_map, - }, - ) - - -# %% -ds = Dataset() -dl = torch.utils.data.DataLoader( - ds, batch_size=4, collate_fn=ds.collate_fn, num_workers=4, worker_init_fn=lambda *_: torch.set_num_threads(1) -) - -# %% - - -class PINQI(torch.nn.Module): - def __init__(self, signalmodel, n_parameters, n_images, n_iterations=2): - super().__init__() - self.signalmodel = mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ signalmodel - self._n_parameters = n_parameters - self._n_images = n_images - self.parameter_net = torch.nn.Conv2d(n_images * 2, n_parameters, kernel_size=1) - self.image_net = torch.nn.Conv3d(2, 2, kernel_size=1) - self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3)) - self.softplus = torch.nn.Softplus() - - def objective_factory(parameter_reg, lambda_parameters, image): - dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel - reg = mrpro.operators.functionals.L2NormSquared(parameter_reg) - return dc + lambda_parameters * reg - - self.nonlinear_solver = mrpro.operators.OptimizerOp(objective_factory, lambda parameter_reg, *_: parameter_reg) - self.linear_solver = mrpro.operators.ConjugateGradientOp( - operator_factory=lambda gram, lambda_image, lambda_q, *_: gram + lambda_image + lambda_q, - rhs_factory=lambda _gram, lambda_image, lambda_q, image_reg, signal, zero_filled_image: ( - zero_filled_image + lambda_image * image_reg + lambda_q * signal, - ), - ) - - def get_parameter_reg(self, image): - image = einops.rearrange(torch.view_as_real(image), 'batch t 1 1 y x complex-> batch (t complex) y x') - parameters = self.parameter_net(image) - parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x') - return tuple(parameters) - - def get_image_reg(self, image): - image = einops.rearrange(torch.view_as_real(image), 'batch t 1 1 y x complex-> batch complex t y x') - image = image + self.image_net(image) - image = einops.rearrange(image, 'batch complex t y x-> batch t 1 1 y x complex') - return torch.view_as_complex(image.contiguous()) - - def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): - csm_op = csm.as_operator() - fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) - aquisition_op = fourier_op @ csm_op - gram = aquisition_op.gram - (zero_filled_image,) = aquisition_op.H(kdata.data) - images = list(mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=3)) - parameters = [self.get_parameter_reg(images[-1])] - for lambda_image, lambda_q, lambda_parameter in self.softplus(self.lambdas_raw): - # subproblem 1 - image_reg = self.get_image_reg(images[-1]) - (signal,) = self.signalmodel(*parameters[-1]) - images.append(self.linear_solver(gram, lambda_image, lambda_q, image_reg, signal, zero_filled_image)) - # subproblem 2 - parameters_reg = self.get_parameter_reg(images[-1]) - parameters.append(self.nonlinear_solver(parameters_reg, lambda_parameter, images[-1])) - - return images, parameters - - -# %% -from tqdm import tqdm - -pinqi = PINQI(ds.signalmodel, ds.n_parameters, ds.n_images) - -for batch in tqdm(dl): - pred = pinqi(batch['kdata'], batch['csm']) -# %% diff --git a/examples/scripts/pinqi.py.bak b/examples/scripts/pinqi.py.bak new file mode 100644 index 000000000..c80c0e035 --- /dev/null +++ b/examples/scripts/pinqi.py.bak @@ -0,0 +1,375 @@ +# %% +import einops +import matplotlib.pyplot as plt +import mrpro +import torch + +# %matplotlib inline + +# %% +# mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True) +# %% + + +class Dataset(torch.utils.data.Dataset): + def __init__(self, size=192, acceleration=10, n_coils=8, random=True, max_noise=0.1): + self.phantom = mrpro.phantoms.brainweb.BrainwebSlices( + what=('m0', 't1', 'mask'), + seed='index' if not random else 'random', + slice_preparation=mrpro.phantoms.brainweb.augment(size=size), + ) + self.signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2, 6)) + self.constraints_op = mrpro.operators.ConstraintsOp( + bounds=( + (-1, 1), # M0 in [-1, 1] + (0.001, 4.0), # T1 is constrained between 1 ms and 4 s + ) + ) + + self.encoding_matrix = mrpro.data.SpatialDimension(1, size, size) + self.fov = mrpro.data.SpatialDimension(0.01, 0.25, 0.25) + self.acceleration = acceleration + self.n_coils = n_coils + self._random = random + self.max_noise = max_noise + + @property + def n_images(self): + return 5 + + @property + def complex_parameters(self): + return [True, False] + + @property + def n_parameters(self): + return len(self.complex_parameters) + + def __len__(self): + return len(self.phantom) + + def __getitem__(self, index): + phantom = self.phantom[index] + (images,) = self.signalmodel(phantom['m0'], phantom['t1']) + seed = torch.randint(0, 1000000, (1,)).item() if self._random else index + traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density( + encoding_matrix=self.encoding_matrix, + seed=seed, + acceleration=self.acceleration, + fwhm_ratio=2, + n_center=8, + n_other=(self.n_images,), + ) + header = mrpro.data.KHeader( + encoding_matrix=self.encoding_matrix, + recon_matrix=self.encoding_matrix, + recon_fov=self.fov, + encoding_fov=self.fov, + ) + header.ti = self.signalmodel.saturation_time.tolist() + fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj) + csm = mrpro.data.CsmData(mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix), header) + images = einops.rearrange(images, 't y x -> t 1 1 y x') + (data,) = (fourier_op @ csm.as_operator())(images) + data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std() + kdata = mrpro.data.KData(header, data, traj) + return {'kdata': kdata, 'csm': csm, **phantom} + + @staticmethod + def collate_fn(batch): + return torch.utils.data._utils.collate.collate( + batch, + collate_fn_map={ + mrpro.data.Dataclass: lambda batch, *, collate_fn_map: batch[0].stack(*batch[1:]), + **torch.utils.data._utils.collate.default_collate_fn_map, + }, + ) + + +# %% +ds = Dataset() +dl = torch.utils.data.DataLoader( + ds, + batch_size=8, + collate_fn=ds.collate_fn, + num_workers=16, + worker_init_fn=lambda *_: torch.set_num_threads(1), + shuffle=True, +) + +# %% +from copy import deepcopy + + +class PINQI(torch.nn.Module): + def __init__(self, signalmodel, parameter_is_complex, n_images, n_iterations=4, constraints_op=None): + super().__init__() + self.signalmodel = mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ deepcopy(signalmodel) + if constraints_op is not None: + self.signalmodel = self.signalmodel @ constraints_op + self.constraints_op = constraints_op + self._n_images = n_images + self._parameter_is_complex = parameter_is_complex + real_parameters = sum(parameter_is_complex) + len(parameter_is_complex) + self.parameter_net = torch.compile( + mrpro.nn.nets.UNet( + dim=2, + channels_in=n_images * 2, + channels_out=real_parameters, + attention_depths=(-1,), + n_features=(64, 128, 192, 256), + ), + dynamic=False, + fullgraph=True, + ) + self.image_net = torch.compile( + mrpro.nn.nets.UNet( + 2, + channels_in=2, + channels_out=2, + attention_depths=(), + n_features=(16, 32, 48, 64), + ), + dynamic=False, + fullgraph=True, + ) + self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3)) + self.softplus = torch.nn.Softplus() + + def objective_factory(lambda_parameters, image, *parameter_reg): + dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel + reg = mrpro.operators.ProximableFunctionalSeparableSum( + *[mrpro.operators.functionals.L2NormSquared(r) for r in parameter_reg] + ) + return dc + lambda_parameters * reg + + self.nonlinear_solver = mrpro.operators.OptimizerOp( + objective_factory, lambda _l, _i, *parameter_reg: parameter_reg + ) + + def get_linear_solver(self, gram): + def operator_factory(lambda_image, lambda_q, _image_reg, _signal, _zero_filled_image): + return gram + lambda_image + lambda_q + + def rhs_factory(lambda_image, lambda_q, image_reg, signal, zero_filled_image): + return (zero_filled_image + lambda_image * image_reg + lambda_q * signal,) + + return mrpro.operators.ConjugateGradientOp( + operator_factory=operator_factory, + rhs_factory=rhs_factory, + ) + + def get_parameter_reg(self, image: torch.Tensor) -> tuple[torch.Tensor, ...]: + image = einops.rearrange(torch.view_as_real(image), 'batch t 1 1 y x complex-> batch (t complex) y x') + parameters = self.parameter_net(image.contiguous()) + parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x') + i = 0 + result = [] + for is_complex in self._parameter_is_complex: + if is_complex: + result.append(torch.complex(parameters[i], parameters[i + 1])) + i += 2 + else: + result.append(parameters[i]) + i += 1 + return tuple(result) + + def get_image_reg(self, image): + batch = image.shape[0] + image = einops.rearrange(torch.view_as_real(image), 'batch t 1 1 y x complex-> (batch t) complex y x') + image = image + self.image_net(image.contiguous()) + image = einops.rearrange(image, '(batch t) complex y x-> batch t 1 1 y x complex', batch=batch) + return torch.view_as_complex(image.contiguous()) + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): + csm_op = csm.as_operator() + fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) + acquisition_op = fourier_op @ csm_op + gram = acquisition_op.gram + (zero_filled_image,) = acquisition_op.H(kdata.data) + images = list(mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2)) + parameters = [self.get_parameter_reg(images[-1])] + linear_solver = self.get_linear_solver(gram) + + for lambda_image, lambda_q, lambda_parameter in self.softplus(self.lambdas_raw): + # subproblem 1 + image_reg = self.get_image_reg(images[-1]) + (signal,) = self.signalmodel(*parameters[-1]) + images.extend(linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image)) + # subproblem 2 + parameters_reg = self.get_parameter_reg(images[-1]) + parameters.append(self.nonlinear_solver(lambda_parameter, images[-1], *parameters_reg)) + if self.constraints_op is not None: + parameters = [self.constraints_op(*p) for p in parameters] + return images, parameters + + +# %% +from typing import TypeVar + +T = TypeVar('T') + + +def to_device(batch: T, device: torch.device | str) -> T: + """Moves tensors and Mrpro data to the specified device recursively.""" + if isinstance(batch, torch.Tensor | mrpro.data.Dataclass): + return batch.to(device) + if isinstance(batch, dict): + return {k: to_device(v, device) for k, v in batch.items()} + if isinstance(batch, list): + return [to_device(v, device) for v in batch] + if isinstance(batch, tuple): + return tuple(to_device(v, device) for v in batch) + + return batch + + +# %% +# from tqdm import tqdm + +# pinqi = PINQI(ds.signalmodel, ds.n_parameters, ds.n_images, constraints_op=ds.constraints_op).cuda() +# for epoch in tqdm(range(10)): +# pbar = tqdm(dl, leave=False) +# optim = torch.optim.Adam(pinqi.parameters(), lr=1e-4) +# for batch in pbar: +# batch = to_device(batch, 'cuda') +# images, parameters = pinqi(batch['kdata'], batch['csm']) +# prediction_m0, prediction_t1 = parameters[-1] +# loss_t1 = torch.nn.functional.mse_loss(prediction_t1.squeeze()[batch['mask']], batch['t1'][batch['mask']]) + +# loss_m0 = torch.nn.functional.mse_loss( +# torch.view_as_real((prediction_m0 + 0j).squeeze()[batch['mask']]), +# torch.view_as_real(batch['m0'][batch['mask']]), +# ) + +# loss = loss_t1 + loss_m0 +# pbar.set_postfix(loss=loss.item()) +# loss.backward() +# optim.step() +# optim.zero_grad() + + +# %% +import numpy as np +import torch +from IPython.display import clear_output, display +from tqdm.notebook import tqdm + + +def plot_results( + fig: plt.Figure, + axes: np.ndarray, + losses: list[float], + target_t1: torch.Tensor, + pred_t1: torch.Tensor, + mask: torch.Tensor, +) -> None: + """ + Updates and displays the training plot. + + Parameters + ---------- + fig + The matplotlib figure object. + axes + The array of matplotlib axes objects. + losses + losses for each step + target_t1 + The ground truth T1 map from the last batch. + pred_t1 + The predicted T1 map from the last batch. + mask + The mask from the last batch. + """ + clear_output(wait=True) + + axes[0].clear() + axes[0].semilogy(losses) + axes[0].set_title('Loss') + axes[0].set_xlabel('Step') + axes[0].set_ylabel('Loss') + axes[0].grid(True) + + target_t1_viz = target_t1[1].squeeze().cpu().numpy() + pred_t1_viz = pred_t1[1].squeeze().detach().cpu().numpy() + mask_viz = mask[1].squeeze().detach().cpu().numpy() + target_t1_viz[~mask_viz] = np.nan + pred_t1_viz[~mask_viz] = np.nan + difference = target_t1_viz - pred_t1_viz + vmax = np.nanmax(target_t1_viz) + + axes[1].clear() + axes[1].imshow(target_t1_viz, vmin=0, vmax=vmax) + axes[1].set_title('Target T1') + axes[1].axis('off') + + axes[2].clear() + axes[2].imshow(pred_t1_viz, vmin=0, vmax=vmax) + axes[2].set_title(f'Predicted T1 (Epoch {epoch + 1})') + axes[2].axis('off') + + axes[3].clear() + axes[3].imshow(difference, cmap='coolwarm') + axes[3].set_title('Difference') + axes[3].axis('off') + + fig.tight_layout() + display(fig) + + +# %% +def calculate_loss(predictions, batch, weights=(0.2, 0.1, 0.1, 0.1, 0.5)) -> torch.Tensor: + loss = torch.tensor(0.0) + target_m0 = batch['m0'] + target_t1 = batch['t1'] + mask = batch['mask'] + for prediction, weight in zip(predictions, weights, strict=False): + prediction_m0, prediction_t1 = prediction + loss_t1 = torch.nn.functional.mse_loss(prediction_t1.squeeze()[mask], target_t1[mask]) + loss_m0 = torch.nn.functional.mse_loss( + torch.view_as_real((prediction_m0).squeeze()[mask]), + torch.view_as_real(target_m0[mask]), + ) + loss = loss + weight * (loss_t1 + loss_m0) + return loss + + +# %% +torch.set_float32_matmul_precision('high') +torch._inductor.config.worker_start_method = 'fork' +torch._inductor.config.compile_threads = 4 +torch._dynamo.config.capture_scalar_outputs = True +torch._functorch.config.activation_memory_budget = 0.9 + +pinqi = PINQI(ds.signalmodel, ds.complex_parameters, ds.n_images, constraints_op=ds.constraints_op).to('cuda') +optim = torch.optim.AdamW(pinqi.parameters(), lr=3e-4, weight_decay=1e-4) +torch._dynamo.config.cache_size_limit = 256 +n_epochs = 10 +losses = [] +fig, axes = plt.subplots(1, 4, figsize=(20, 5)) + +for epoch in range(n_epochs): + epoch_losses = [] + pbar = tqdm(dl, desc=f'Epoch {epoch + 1}/{n_epochs}', leave=False) + + for batch in pbar: + batch = to_device(batch, 'cuda') + optim.zero_grad() + + images, parameters = pinqi(batch['kdata'], batch['csm']) + loss = calculate_loss(parameters, batch) + + loss.backward() + optim.step() + epoch_losses.append(loss.item()) + pbar.set_postfix(epoch_loss=f'{np.mean(epoch_losses):.3f}', loss=f'{epoch_losses[-1]:.3f}') + + losses.extend(epoch_losses) + prediction_t1 = parameters[-1][1] + plot_results(fig, axes, losses, batch['t1'], prediction_t1, batch['mask']) + +plt.close(fig) + + +# %% diff --git a/examples/scripts/train_pinqi.py b/examples/scripts/train_pinqi.py index 453de5c37..8787c1a6e 100644 --- a/examples/scripts/train_pinqi.py +++ b/examples/scripts/train_pinqi.py @@ -1,9 +1,10 @@ -# %% +# ruff: noqa: D102, ANN201 + import collections from collections.abc import Sequence from copy import deepcopy from pathlib import Path -from typing import cast +from typing import TypedDict, cast import einops import matplotlib.pyplot as plt @@ -19,6 +20,14 @@ # mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True) +class BatchType(TypedDict): + kdata: mrpro.data.KData + csm: mrpro.data.CsmData + m0: torch.Tensor + t1: torch.Tensor + mask: torch.Tensor + + class Dataset(torch.utils.data.Dataset): def __init__( self, @@ -195,9 +204,9 @@ def get_image_reg(self, image: torch.Tensor) -> torch.Tensor: def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): csm_op = csm.as_operator() fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) - aquisition_op = fourier_op @ csm_op - gram = aquisition_op.gram - (zero_filled_image,) = aquisition_op.H(kdata.data) + acquisition_op = fourier_op @ csm_op + gram = acquisition_op.gram + (zero_filled_image,) = acquisition_op.H(kdata.data) images = list(mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2)) parameters = [self.get_parameter_reg(images[-1])] linear_solver = self.get_linear_solver(gram) @@ -315,7 +324,7 @@ def training_step(self, batch, batch_idx): ) return loss - def validation_step(self, batch, batch_idx): + def validation_step(self, batch: BatchType, batch_idx: int) -> None: images, parameters = self(batch['kdata'], batch['csm']) loss = self.loss(parameters, batch) @@ -388,7 +397,7 @@ def result_plot(self, target, pred, mask, axes): axes[2].colorbar(im3) return axes - def configure_optimizers(self): + def configure_optimizers(self) -> dict[str, torch.optim.Optimizer | torch.optim.lr_scheduler.LRScheduler]: optimizer = torch.optim.AdamW( self.parameters(), lr=self.hparams.lr, From da8aef1b109e989a1f16beb0bba99d230197f3d3 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 1 Jul 2025 16:01:15 +0200 Subject: [PATCH 079/205] update pinqi --- examples/scripts/train_pinqi.py | 224 +++++++++++++++++++++----------- 1 file changed, 149 insertions(+), 75 deletions(-) diff --git a/examples/scripts/train_pinqi.py b/examples/scripts/train_pinqi.py index 8787c1a6e..b5df7258a 100644 --- a/examples/scripts/train_pinqi.py +++ b/examples/scripts/train_pinqi.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from copy import deepcopy from pathlib import Path -from typing import TypedDict, cast +from typing import Literal, TypedDict, cast import einops import matplotlib.pyplot as plt @@ -17,6 +17,7 @@ from pytorch_lightning.loggers import NeptuneLogger from pytorch_lightning.strategies import DDPStrategy + # mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True) @@ -31,6 +32,7 @@ class BatchType(TypedDict): class Dataset(torch.utils.data.Dataset): def __init__( self, + folder: Path, signalmodel: mrpro.operators.SignalModel, n_images: int, size: int = 192, @@ -38,11 +40,18 @@ def __init__( n_coils: int = 8, random: bool = True, max_noise: float = 0.1, + orientation: Sequence[Literal["axial", "coronal", "sagittal"]] = ( + "axial", + "coronal", + "sagittal", + ), ): self.phantom = mrpro.phantoms.brainweb.BrainwebSlices( - what=('m0', 't1', 'mask'), - seed='index' if not random else 'random', + folder=folder, + what=("m0", "t1", "mask"), + seed="index" if not random else "random", slice_preparation=mrpro.phantoms.brainweb.augment(size=size), + orientation=orientation, ) self.signalmodel = signalmodel self.encoding_matrix = mrpro.data.SpatialDimension(1, size, size) @@ -58,16 +67,18 @@ def __len__(self) -> int: def __getitem__(self, index: int): phantom = self.phantom[index] - (images,) = self.signalmodel(phantom['m0'], phantom['t1']) + (images,) = self.signalmodel(phantom["m0"], phantom["t1"]) seed = int(torch.randint(0, 1000000, (1,))) if self._random else index - traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density( - encoding_matrix=self.encoding_matrix, - seed=seed, - acceleration=self.acceleration, - fwhm_ratio=2, - n_center=8, - n_other=(self._n_images,), + traj = ( + mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density( + encoding_matrix=self.encoding_matrix, + seed=seed, + acceleration=self.acceleration, + fwhm_ratio=2, + n_center=8, + n_other=(self._n_images,), + ) ) header = mrpro.data.KHeader( encoding_matrix=self.encoding_matrix, @@ -81,16 +92,20 @@ def __getitem__(self, index: int): elif isinstance(self.signalmodel, mrpro.operators.models.InversionRecovery): header.ti = self.signalmodel.ti.tolist() - fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj) + fourier_op = mrpro.operators.FourierOp( + self.encoding_matrix, self.encoding_matrix, traj + ) csm = mrpro.data.CsmData( mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix), header, ) - images = einops.rearrange(images, 't y x -> t 1 1 y x') + images = einops.rearrange(images, "t y x -> t 1 1 y x") (data,) = (fourier_op @ csm.as_operator())(images) - data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std() + data = ( + data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std() + ) kdata = mrpro.data.KData(header, data, traj) - return {'kdata': kdata, 'csm': csm, **phantom} + return {"kdata": kdata, "csm": csm, **phantom} @staticmethod def collate_fn(batch): @@ -116,12 +131,16 @@ def __init__( ): super().__init__() self.signalmodel = ( - mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ deepcopy(signalmodel) @ constraints_op + mrpro.operators.RearrangeOp("t batch ... -> batch t ...") + @ deepcopy(signalmodel) + @ constraints_op ) self.constraints_op = constraints_op self._n_images = n_images self._parameter_is_complex = parameter_is_complex - real_parameters = sum(1 for c in parameter_is_complex if c) + len(parameter_is_complex) + real_parameters = sum(1 for c in parameter_is_complex if c) + len( + parameter_is_complex + ) self.parameter_net = mrpro.nn.nets.UNet( dim=2, channels_in=n_images * 2, @@ -139,7 +158,11 @@ def __init__( self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3)) self.softplus = torch.nn.Softplus() - def objective_factory(lambda_parameters: torch.Tensor, image: torch.Tensor, *parameter_reg: torch.Tensor): + def objective_factory( + lambda_parameters: torch.Tensor, + image: torch.Tensor, + *parameter_reg: torch.Tensor, + ): dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel reg = mrpro.operators.ProximableFunctionalSeparableSum( *[mrpro.operators.functionals.L2NormSquared(r) for r in parameter_reg] @@ -176,10 +199,12 @@ def rhs_factory( def get_parameter_reg(self, image: torch.Tensor) -> tuple[torch.Tensor, ...]: image = einops.rearrange( torch.view_as_real(image), - 'batch t 1 1 y x complex-> batch (t complex) y x', + "batch t 1 1 y x complex-> batch (t complex) y x", ) parameters = self.parameter_net(image.contiguous()) - parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x') + parameters = einops.rearrange( + parameters, "batch parameters y x-> parameters batch 1 1 y x" + ) i = 0 result = [] for is_complex in self._parameter_is_complex: @@ -195,10 +220,12 @@ def get_image_reg(self, image: torch.Tensor) -> torch.Tensor: batch = image.shape[0] image = einops.rearrange( torch.view_as_real(image), - 'batch t 1 1 y x complex-> (batch t) complex y x', + "batch t 1 1 y x complex-> (batch t) complex y x", ) image = image + self.image_net(image.contiguous()) - image = einops.rearrange(image, '(batch t) complex y x-> batch t 1 1 y x complex', batch=batch) + image = einops.rearrange( + image, "(batch t) complex y x-> batch t 1 1 y x complex", batch=batch + ) return torch.view_as_complex(image.contiguous()) def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): @@ -207,16 +234,24 @@ def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): acquisition_op = fourier_op @ csm_op gram = acquisition_op.gram (zero_filled_image,) = acquisition_op.H(kdata.data) - images = list(mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2)) + images = list( + mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2) + ) parameters = [self.get_parameter_reg(images[-1])] linear_solver = self.get_linear_solver(gram) for lambda_image, lambda_q, lambda_parameter in self.softplus(self.lambdas_raw): image_reg = self.get_image_reg(images[-1]) (signal,) = self.signalmodel(*parameters[-1]) - images.extend(linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image)) + images.extend( + linear_solver( + lambda_image, lambda_q, image_reg, signal, zero_filled_image + ) + ) parameters_reg = self.get_parameter_reg(images[-1]) - parameters.append(self.nonlinear_solver(lambda_parameter, images[-1], *parameters_reg)) + parameters.append( + self.nonlinear_solver(lambda_parameter, images[-1], *parameters_reg) + ) if self.constraints_op is not None: parameters = [self.constraints_op(*p) for p in parameters] return images, parameters @@ -225,17 +260,33 @@ def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): class DataModule(pl.LightningDataModule): def __init__( self, + folder: Path, batch_size: int = 8, num_workers: int = 4, - signalmodel: mrpro.operators.SignalModel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 6.0)), + signalmodel: mrpro.operators.SignalModel = mrpro.operators.models.SaturationRecovery( + (0.5, 1.0, 1.5, 2.0, 6.0) + ), n_images: int = 5, **kwargs, ): super().__init__() self.batch_size = batch_size self.num_workers = num_workers - self.train_dataset = Dataset(signalmodel=signalmodel, n_images=n_images, **kwargs, random=True) - self.val_dataset = Dataset(signalmodel=signalmodel, n_images=n_images, **kwargs, random=False) + self.train_dataset = Dataset( + folder=folder, + signalmodel=signalmodel, + n_images=n_images, + **kwargs, + random=True, + ) + self.val_dataset = Dataset( + folder=folder, + orientation=("axial",), + signalmodel=signalmodel, + n_images=n_images, + **kwargs, + random=False, + ) def train_dataloader(self): return torch.utils.data.DataLoader( @@ -278,7 +329,9 @@ def __init__( super().__init__() self.save_hyperparameters() if len(loss_weights) != n_iterations + 1: - raise ValueError(f'loss_weights must be of length {n_iterations + 1} for {n_iterations} iterations') + raise ValueError( + f"loss_weights must be of length {n_iterations + 1} for {n_iterations} iterations" + ) self.pinqi = PINQI( signalmodel=signalmodel, @@ -297,12 +350,16 @@ def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): def loss(self, predictions, batch): loss = torch.tensor(0.0, device=self.device) - target_m0 = batch['m0'] - target_t1 = batch['t1'] - mask = batch['mask'] - for prediction, weight in zip(predictions, self.hparams.loss_weights, strict=False): + target_m0 = batch["m0"] + target_t1 = batch["t1"] + mask = batch["mask"] + for prediction, weight in zip( + predictions, self.hparams.loss_weights, strict=False + ): prediction_m0, prediction_t1 = prediction - loss_t1 = torch.nn.functional.mse_loss(prediction_t1.squeeze()[mask], target_t1[mask]) + loss_t1 = torch.nn.functional.mse_loss( + prediction_t1.squeeze()[mask], target_t1[mask] + ) loss_m0 = torch.nn.functional.mse_loss( torch.view_as_real((prediction_m0).squeeze()[mask]), torch.view_as_real(target_m0[mask]), @@ -312,10 +369,10 @@ def loss(self, predictions, batch): return loss def training_step(self, batch, batch_idx): - images, parameters = self(batch['kdata'], batch['csm']) + images, parameters = self(batch["kdata"], batch["csm"]) loss = self.loss(parameters, batch) self.log( - 'train/loss', + "train/loss", loss, on_step=True, on_epoch=True, @@ -325,26 +382,26 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch: BatchType, batch_idx: int) -> None: - images, parameters = self(batch['kdata'], batch['csm']) + images, parameters = self(batch["kdata"], batch["csm"]) loss = self.loss(parameters, batch) pred_m0, pred_t1 = parameters[-1] - target_t1, target_m0 = batch['t1'], batch['m0'] - mask = batch['mask'] + target_t1, target_m0 = batch["t1"], batch["m0"] + mask = batch["mask"] (ssim_t1,) = mrpro.operators.functionals.SSIM(target_t1, mask)(pred_t1) (l1_t1,) = mrpro.operators.functionals.L1Norm(target_t1, mask)(pred_t1) (l1_m0,) = mrpro.operators.functionals.L1Norm(target_m0, mask)(pred_m0) - self.log('val/ssim_t1', ssim_t1, on_epoch=True, sync_dist=True) - self.log('val/l1_t1', l1_t1, on_epoch=True, sync_dist=True) - self.log('val/l1_m0', l1_m0, on_epoch=True, sync_dist=True) - self.log('val/loss', loss, on_epoch=True, sync_dist=True) + self.log("val/ssim_t1", ssim_t1, on_epoch=True, sync_dist=True) + self.log("val/l1_t1", l1_t1, on_epoch=True, sync_dist=True) + self.log("val/l1_m0", l1_m0, on_epoch=True, sync_dist=True) + self.log("val/loss", loss, on_epoch=True, sync_dist=True) if batch_idx == 0: - self.validation_step_outputs['target_t1'].append(batch['t1']) - self.validation_step_outputs['pred_t1'].append(pred_t1) - self.validation_step_outputs['pred_m0'].append(pred_m0) - self.validation_step_outputs['target_m0'].append(target_m0) - self.validation_step_outputs['mask'].append(batch['mask']) + self.validation_step_outputs["target_t1"].append(batch["t1"]) + self.validation_step_outputs["pred_t1"].append(pred_t1) + self.validation_step_outputs["pred_m0"].append(pred_m0) + self.validation_step_outputs["target_m0"].append(target_m0) + self.validation_step_outputs["mask"].append(batch["mask"]) def on_validation_epoch_end(self): outputs = {k: torch.cat(v) for k, v in self.validation_step_outputs.items()} @@ -358,16 +415,26 @@ def on_validation_epoch_end(self): samples = 4 fig, axes = plt.subplots(3, samples, figsize=(4 * samples, 12)) for i in range(samples): - self.result_plot(outputs['target_t1'][i], outputs['pred_t1'][i], outputs['mask'][i], axes[:, i]) - fig.suptitle(f'T1 Epoch {self.current_epoch}') - self.logger.run['val/images/t1'].log(fig) + self.result_plot( + outputs["target_t1"][i], + outputs["pred_t1"][i], + outputs["mask"][i], + axes[:, i], + ) + fig.suptitle(f"T1 Epoch {self.current_epoch}") + self.logger.run["val/images/t1"].log(fig) plt.close(fig) fig, axes = plt.subplots(3, samples, figsize=(4 * samples, 12)) for i in range(samples): - self.result_plot(outputs['target_m0'][i].abs(), outputs['pred_m0'][i].abs(), outputs['mask'][i], axes[:, i]) - fig.suptitle(f'|M0| Epoch {self.current_epoch}') - self.logger.run['val/images/m0'].log(fig) + self.result_plot( + outputs["target_m0"][i].abs(), + outputs["pred_m0"][i].abs(), + outputs["mask"][i], + axes[:, i], + ) + fig.suptitle(f"|M0| Epoch {self.current_epoch}") + self.logger.run["val/images/m0"].log(fig) plt.close(fig) def result_plot(self, target, pred, mask, axes): @@ -381,23 +448,27 @@ def result_plot(self, target, pred, mask, axes): vmax = np.nanmax(target) im1 = axes[0].imshow(target, vmin=0, vmax=vmax) - axes[0].set_title('Target') - axes[0].axis('off') + axes[0].set_title("Target") + axes[0].axis("off") axes[0].colorbar(im1) im2 = axes[1].imshow(pred, vmin=0, vmax=vmax) - axes[1].set_title('Predicted') - axes[1].axis('off') + axes[1].set_title("Predicted") + axes[1].axis("off") axes[1].colorbar(im2) diff_vmax = np.nanmax(np.abs(difference)) - im3 = axes[2].imshow(difference, cmap='coolwarm', vmin=-diff_vmax, vmax=diff_vmax) - axes[2].set_title('Difference') - axes[2].axis('off') + im3 = axes[2].imshow( + difference, cmap="coolwarm", vmin=-diff_vmax, vmax=diff_vmax + ) + axes[2].set_title("Difference") + axes[2].axis("off") axes[2].colorbar(im3) return axes - def configure_optimizers(self) -> dict[str, torch.optim.Optimizer | torch.optim.lr_scheduler.LRScheduler]: + def configure_optimizers( + self, + ) -> dict[str, torch.optim.Optimizer | torch.optim.lr_scheduler.LRScheduler]: optimizer = torch.optim.AdamW( self.parameters(), lr=self.hparams.lr, @@ -412,20 +483,22 @@ def configure_optimizers(self) -> dict[str, torch.optim.Optimizer | torch.optim. final_div_factor=200, ) return { - 'optimizer': optimizer, - 'lr_scheduler': scheduler, + "optimizer": optimizer, + "lr_scheduler": scheduler, } # %% -if __name__ == '__main__': - torch.set_float32_matmul_precision('high') - torch._inductor.config.worker_start_method = 'fork' +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + torch._inductor.config.worker_start_method = "fork" torch._inductor.config.compile_threads = 4 torch._dynamo.config.capture_scalar_outputs = True torch._functorch.config.activation_memory_budget = 0.9 torch._dynamo.config.cache_size_limit = 256 + data_folder = Path("/scratch/zimmer08/brainweb") + signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 6.0)) constraints_op = mrpro.operators.ConstraintsOp( bounds=( @@ -437,6 +510,7 @@ def configure_optimizers(self) -> dict[str, torch.optim.Optimizer | torch.optim. parameter_is_complex = [True, False] dm = DataModule( + folder=data_folder, signalmodel=signalmodel, n_images=n_images, batch_size=16, @@ -460,16 +534,16 @@ def configure_optimizers(self) -> dict[str, torch.optim.Optimizer | torch.optim. neptune_logger = NeptuneLogger( log_model_checkpoints=False, - dependencies='infer', + dependencies="infer", ) neptune_logger.log_hyperparams(model.hparams) checkpoint_callback = ModelCheckpoint( - monitor='val/loss', - mode='min', + monitor="val/loss", + mode="min", save_top_k=2, - dirpath=Path('checkpoints') / str(neptune_logger.version), - filename='{epoch:02d}-{val/loss:.4f}', + dirpath=Path("checkpoints") / str(neptune_logger.version), + filename="{epoch:02d}-{val/loss:.4f}", save_last=True, ) @@ -477,16 +551,16 @@ def configure_optimizers(self) -> dict[str, torch.optim.Optimizer | torch.optim. trainer = pl.Trainer( max_epochs=50, - accelerator='gpu', + accelerator="gpu", devices=4, strategy=strategy, logger=neptune_logger, callbacks=[ - LearningRateMonitor(logging_interval='step'), + LearningRateMonitor(logging_interval="step"), checkpoint_callback, ], log_every_n_steps=10, - precision='16-mixed', + precision="16-mixed", ) trainer.fit(model, datamodule=dm) From dfa72826b946665d6c8c51e7a9227205e628921c Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 2 Jul 2025 15:34:16 +0200 Subject: [PATCH 080/205] modl --- examples/scripts/modl.py | 102 ++++++++++++++++++++++++++++++++++ src/mrpro/nn/__init__.py | 9 ++- src/mrpro/nn/nets/BasicCNN.py | 65 ++++++++++++++++++++++ src/mrpro/nn/nets/__init__.py | 6 +- 4 files changed, 179 insertions(+), 3 deletions(-) create mode 100644 examples/scripts/modl.py create mode 100644 src/mrpro/nn/nets/BasicCNN.py diff --git a/examples/scripts/modl.py b/examples/scripts/modl.py new file mode 100644 index 000000000..464cf744e --- /dev/null +++ b/examples/scripts/modl.py @@ -0,0 +1,102 @@ +# %% +from collections.abc import Sequence +from pathlib import Path +from typing import TypedDict + +import mrpro +import torch + + +class BatchType(TypedDict): + data: mrpro.data.KData + target: mrpro.data.IData + csm: mrpro.data.CsmData + + +class AcceleratedFastMRI(torch.utils.data.Dataset): + def __init__(self, path: Path, acceleration: int = 4): + self.acceleration = acceleration + self.dataset = mrpro.phantoms.FastMRIKDataDataset(path) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index: int) -> BatchType: + data = self.dataset[index] + reconstruction = mrpro.algorithms.reconstruction.DirectReconstruction( + data, + csm=lambda data: mrpro.data.CsmData.from_idata_inati(data, downsampled_size=64), + ) + csm = reconstruction.csm + target = reconstruction(data) + data_undersampled = data[..., :: self.acceleration, :] + assert csm is not None # for mypy + if csm.data.isnan().any(): + print('csm nan') + csm = mrpro.data.CsmData.from_kdata_inati(data, downsampled_size=64) + + return {'data': data_undersampled, 'target': target, 'csm': csm} + + +class MODL(torch.nn.Module): + def __init__(self, iterations: int = 10, n_features: Sequence[int] = (64, 64, 64)): + super().__init__() + cnn = mrpro.nn.nets.BasicCNN( + dim=2, + channels_in=2, + channels_out=2, + batch_norm=True, + n_features=(64, 64, 64), + ) + self.network = mrpro.nn.Residual(mrpro.nn.ComplexAsChannel(mrpro.nn.PermutedBlock((-1, -2), cnn))) + self.iterations = iterations + self.regularization_weight = torch.nn.Parameter(torch.tensor(1.0)) + + def prepare_dataconsistency( + self, + gram: mrpro.operators.LinearOperator, + zero_filled_image: torch.Tensor, + ): + return mrpro.operators.ConjugateGradientOp( + operator_factory=lambda _: gram + self.regularization_weight, + rhs_factory=lambda regularization_image: zero_filled_image + + self.regularization_weight * regularization_image, + ) + + def __call__(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData: + return super().__call__(kdata, csm) + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData: + fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) + acquisition_op = fourier_op @ csm.as_operator() + + (image,) = acquisition_op.H(kdata.data) + data_consistency_op = self.prepare_dataconsistency(acquisition_op.gram, image) + + for _ in range(self.iterations): + regularization = self.network(image) + (image,) = data_consistency_op(regularization) + if image.isnan().any(): + raise ValueError('NaN in image') + + return mrpro.data.IData(image, header=mrpro.data.IHeader.from_kheader(kdata.header)) + + +# %% +from tqdm import tqdm + +path = Path('/echo/allgemein/resources/publicTrainingData/fastmri/brain_multicoil_train/') +dataset = AcceleratedFastMRI(path) +dataloader = torch.utils.data.DataLoader(dataset, num_workers=0, shuffle=True, collate_fn=lambda batch: batch[0]) +modl = MODL().cuda() +optimizer = torch.optim.Adam(modl.parameters(), lr=1e-4) +pbar = tqdm(dataloader) +for batch in pbar: + optimizer.zero_grad() + kdata, csm, target = batch['data'].cuda(), batch['csm'].cuda(), batch['target'].cuda() + pred = modl(kdata, csm) + (loss,) = mrpro.operators.functionals.MSE(target.data)(pred.data) + loss.backward() + optimizer.step() + pbar.set_postfix(loss=loss.item()) +# %% diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index e59e4efde..dd8afc33e 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -20,8 +20,13 @@ from mrpro.nn.SqueezeExcitation import SqueezeExcitation from mrpro.nn.TransposedAttention import TransposedAttention from mrpro.nn.DropPath import DropPath +from mrpro.nn.Residual import Residual +from mrpro.nn.ComplexAsChannel import ComplexAsChannel from mrpro.nn import nets +from mrpro.nn.PermutedBlock import PermutedBlock + __all__ = [ + "ComplexAsChannel", "AdaptiveAvgPoolND", "AttentionGate", "AvgPoolND", @@ -30,6 +35,7 @@ "ConvND", "ConvTransposeND", "DropPath", + "ComplexAsChannel", "FiLM", "GroupNorm", "InstanceNormND", @@ -40,5 +46,6 @@ "ShiftedWindowAttention", "SqueezeExcitation", "TransposedAttention", - "nets" + "nets", + "PermutedBlock", ] \ No newline at end of file diff --git a/src/mrpro/nn/nets/BasicCNN.py b/src/mrpro/nn/nets/BasicCNN.py new file mode 100644 index 000000000..b2671c121 --- /dev/null +++ b/src/mrpro/nn/nets/BasicCNN.py @@ -0,0 +1,65 @@ +from collections.abc import Sequence +from itertools import pairwise + +import torch +from torch.nn import ReLU + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.ndmodules import BatchNormND, ConvND +from mrpro.nn.Sequential import Sequential + + +class BasicCNN(Sequential): + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + batch_norm: bool = True, + n_features: Sequence[int] = (64, 64, 64), + cond_dim: int = 0, + ): + """Initialize a basic CNN. + + Parameters + ---------- + dim + The number of spatial dimensions of the input tensor. + channels_in + The number of input channels. + channels_out + The number of output channels. + batch_norm + Whether to use batch normalization. + n_features + The number of features in the hidden layers. The length of this sequence determines the number of hidden layers. + cond_dim + The dimension of the condition tensor. If 0, no FiLM conditioning is applied. + """ + super().__init__() + use_film = cond_dim > 0 + self.append(ConvND(dim)(channels_in, n_features[0], kernel_size=3, padding='same')) + for c_in, c_out in pairwise((*n_features, channels_out)): + if batch_norm: + self.append(BatchNormND(dim)(c_in, affine=not use_film)) + if use_film: + self.append(FiLM(c_in, cond_dim)) + self.append(ReLU(True)) + self.append(ConvND(dim)(c_in, c_out, kernel_size=3, padding='same')) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None) -> torch.Tensor: + """Apply the basic CNN to the input tensor. + + Parameters + ---------- + x + The input tensor. Should be of shape `(batch_size, channels_in, *spatial dimensions)` + with `spatial dimensions` being of length `dim`. + cond + The condition tensor. If None, no FiLM conditioning is applied. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 291ea52ca..78e7fa82e 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -4,6 +4,7 @@ from mrpro.nn.nets.VAE import VAE from mrpro.nn.nets.UNet import UNet, AttentionGatedUNet, BasicUNet, SeparableUNet from mrpro.nn.nets.SwinIR import SwinIR +from mrpro.nn.nets.BasicCNN import BasicCNN __all__ = [ "AttentionGatedUNet", @@ -14,5 +15,6 @@ "SwinIR", "UNet", "Uformer", - "VAE" -] \ No newline at end of file + "VAE", + "BasicCNN", +] From 78a6322ef11d7b62254a869e6916e52ab9edba13 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 3 Jul 2025 13:10:04 +0200 Subject: [PATCH 081/205] oberator subtractino --- src/mrpro/operators/LinearOperator.py | 8 ++-- src/mrpro/operators/LinearOperatorMatrix.py | 6 +-- src/mrpro/operators/Operator.py | 48 +++++++++++++++++++-- 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/src/mrpro/operators/LinearOperator.py b/src/mrpro/operators/LinearOperator.py index 8556a6907..15a94093a 100644 --- a/src/mrpro/operators/LinearOperator.py +++ b/src/mrpro/operators/LinearOperator.py @@ -254,7 +254,7 @@ def __matmul__( return OperatorComposition(self, cast(Operator[Unpack[Tin2], tuple[torch.Tensor,]], other)) return NotImplemented # type: ignore[unreachable] - def __radd__(self, other: torch.Tensor) -> LinearOperator: + def __radd__(self, other: torch.Tensor | complex) -> LinearOperator: """Operator addition. Returns ``lambda x: self(x) + other*x`` @@ -262,7 +262,7 @@ def __radd__(self, other: torch.Tensor) -> LinearOperator: return self + other @overload # type: ignore[override] - def __add__(self, other: LinearOperator | torch.Tensor) -> LinearOperator: ... + def __add__(self, other: LinearOperator | torch.Tensor | complex) -> LinearOperator: ... @overload def __add__( @@ -270,14 +270,14 @@ def __add__( ) -> Operator[torch.Tensor, tuple[torch.Tensor,]]: ... def __add__( - self, other: Operator[torch.Tensor, tuple[torch.Tensor,]] | LinearOperator | torch.Tensor + self, other: Operator[torch.Tensor, tuple[torch.Tensor,]] | LinearOperator | torch.Tensor | complex ) -> Operator[torch.Tensor, tuple[torch.Tensor,]] | LinearOperator: """Operator addition. Returns ``lambda x: self(x) + other(x)`` if other is a operator, ``lambda x: self(x) + other`` if other is a tensor """ - if isinstance(other, torch.Tensor): + if isinstance(other, torch.Tensor | complex | int | float): # tensor addition return LinearOperatorSum(self, mrpro.operators.IdentityOp() * other) elif isinstance(self, mrpro.operators.ZeroOp): diff --git a/src/mrpro/operators/LinearOperatorMatrix.py b/src/mrpro/operators/LinearOperatorMatrix.py index 8ace21e9b..d77aa0c87 100644 --- a/src/mrpro/operators/LinearOperatorMatrix.py +++ b/src/mrpro/operators/LinearOperatorMatrix.py @@ -144,7 +144,7 @@ def __repr__(self): return f'LinearOperatorMatrix(shape={self._shape}, operators={self._operators})' # Note: The type ignores are needed because we currently cannot do arithmetic operations with non-linear operators. - def __add__(self, other: Self | LinearOperator | torch.Tensor) -> Self: # type: ignore[override] + def __add__(self, other: Self | LinearOperator | torch.Tensor | complex) -> Self: # type: ignore[override] """Addition.""" operators: list[list[LinearOperator]] = [] if isinstance(other, LinearOperatorMatrix): @@ -152,7 +152,7 @@ def __add__(self, other: Self | LinearOperator | torch.Tensor) -> Self: # type: raise ValueError('OperatorMatrix shapes do not match.') for self_row, other_row in zip(self._operators, other._operators, strict=False): operators.append([s + o for s, o in zip(self_row, other_row, strict=False)]) - elif isinstance(other, LinearOperator | torch.Tensor): + elif isinstance(other, LinearOperator | torch.Tensor | complex): if not self.shape[0] == self.shape[1]: raise NotImplementedError('Cannot add a LinearOperator to a non-square OperatorMatrix.') for i, self_row in enumerate(self._operators): @@ -161,7 +161,7 @@ def __add__(self, other: Self | LinearOperator | torch.Tensor) -> Self: # type: return NotImplemented # type: ignore[unreachable] return self.__class__(operators) - def __radd__(self, other: Self | LinearOperator | torch.Tensor) -> Self: + def __radd__(self, other: Self | LinearOperator | torch.Tensor | complex) -> Self: """Right addition.""" return self.__add__(other) diff --git a/src/mrpro/operators/Operator.py b/src/mrpro/operators/Operator.py index d52b4aabc..ea82f90af 100644 --- a/src/mrpro/operators/Operator.py +++ b/src/mrpro/operators/Operator.py @@ -53,7 +53,7 @@ def __matmul__( return OperatorComposition(self, cast(Operator[Unpack[Tin2], tuple[Unpack[Tin]]], other)) def __radd__( - self: Operator[Unpack[Tin], tuple[Unpack[Tin]]], other: torch.Tensor + self: Operator[Unpack[Tin], tuple[Unpack[Tin]]], other: torch.Tensor | complex ) -> Operator[Unpack[Tin], tuple[Unpack[Tin]]]: """Operator right addition. @@ -65,18 +65,18 @@ def __radd__( def __add__(self, other: Operator[Unpack[Tin], Tout]) -> Operator[Unpack[Tin], Tout]: ... @overload def __add__( - self: Operator[Unpack[Tin], tuple[Unpack[Tin]]], other: torch.Tensor + self: Operator[Unpack[Tin], tuple[Unpack[Tin]]], other: torch.Tensor | complex ) -> Operator[Unpack[Tin], tuple[Unpack[Tin]]]: ... def __add__( - self, other: Operator[Unpack[Tin], Tout] | torch.Tensor | mrpro.operators.ZeroOp + self, other: Operator[Unpack[Tin], Tout] | torch.Tensor | complex | mrpro.operators.ZeroOp ) -> Operator[Unpack[Tin], Tout] | Operator[Unpack[Tin], tuple[Unpack[Tin]]]: """Operator addition. Returns ``lambda x: self(x) + other(x)`` if other is a operator, ``lambda x: self(x) + other*x`` if other is a tensor """ - if isinstance(other, torch.Tensor): + if isinstance(other, torch.Tensor | complex | int | float): s = cast(Operator[Unpack[Tin], tuple[Unpack[Tin]]], self) o = cast(Operator[Unpack[Tin], tuple[Unpack[Tin]]], mrpro.operators.MultiIdentityOp() * other) return OperatorSum(s, o) @@ -102,6 +102,46 @@ def __rmul__(self, other: torch.Tensor | complex) -> Operator[Unpack[Tin], Tout] """ return OperatorElementwiseProductRight(self, other) + @overload + def __sub__(self, other: Operator[Unpack[Tin], Tout]) -> Operator[Unpack[Tin], Tout]: ... + + @overload + def __sub__( + self: Operator[Unpack[Tin], tuple[Unpack[Tin]]], other: torch.Tensor | complex + ) -> Operator[Unpack[Tin], tuple[Unpack[Tin]]]: ... + + def __sub__( + self, other: Operator[Unpack[Tin], Tout] | torch.Tensor | complex | mrpro.operators.ZeroOp + ) -> Operator[Unpack[Tin], Tout] | Operator[Unpack[Tin], tuple[Unpack[Tin]]]: + """Operator subtraction. + + Returns ``lambda x: self(x) - other(x)`` if other is a operator, + ``lambda x: self(x) - other*x`` if other is a tensor + """ + if isinstance(other, mrpro.operators.ZeroOp): + return self + return self + (-1.0) * other + + @overload + def __rsub__(self, other: Operator[Unpack[Tin], Tout]) -> Operator[Unpack[Tin], Tout]: ... + + @overload + def __rsub__( + self: Operator[Unpack[Tin], tuple[Unpack[Tin]]], other: torch.Tensor | complex + ) -> Operator[Unpack[Tin], tuple[Unpack[Tin]]]: ... + + def __rsub__( + self, other: Operator[Unpack[Tin], Tout] | torch.Tensor | complex | mrpro.operators.ZeroOp + ) -> Operator[Unpack[Tin], Tout] | Operator[Unpack[Tin], tuple[Unpack[Tin]]]: + """Operator subtraction. + + Returns ``lambda x: self(x) - other(x)`` if other is a operator, + ``lambda x: self(x) - other*x`` if other is a tensor + """ + if isinstance(other, mrpro.operators.ZeroOp): + return self + return (-1.0) * self + other + class OperatorComposition(Operator[Unpack[Tin2], Tout]): """Operator composition.""" From 5b40ad1d32d67c6442750bd7607d624edbaee4c9 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 3 Jul 2025 13:10:40 +0200 Subject: [PATCH 082/205] fastmri: fix padding undo --- src/mrpro/phantoms/fastmri.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrpro/phantoms/fastmri.py b/src/mrpro/phantoms/fastmri.py index 522a5b22b..814ac8937 100644 --- a/src/mrpro/phantoms/fastmri.py +++ b/src/mrpro/phantoms/fastmri.py @@ -89,7 +89,7 @@ def __getitem__(self, idx: int) -> KData: n_k0=n_k0, k0_center=n_k0 // 2, k1_idx=info.idx.k1, - k1_center=n_k1 // 2, + k1_center=first + n_k1 // 2, k2_idx=torch.tensor(0), k2_center=0, ) @@ -212,4 +212,4 @@ def __getitem__(self, idx: int) -> torch.Tensor: img = (img * csm.conj()).sum(dim=0, keepdim=True) if self.augment is not None: img = self.augment(img, idx) - return rearrange(img, 'coils y x -> 1 coils 1 y x') # , rearrange(csm, 'coils y x -> 1 coils 1 y x') + return rearrange(img, 'coils y x -> 1 coils 1 y x') From cc34beb2860eef339c71d047436d19e15222fc72 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 3 Jul 2025 13:11:25 +0200 Subject: [PATCH 083/205] fix dataclass error --- src/mrpro/data/Dataclass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/data/Dataclass.py b/src/mrpro/data/Dataclass.py index 10d30a638..549a15c30 100644 --- a/src/mrpro/data/Dataclass.py +++ b/src/mrpro/data/Dataclass.py @@ -41,7 +41,7 @@ def concatenate(self, *others: Self, dim: int) -> Self: """Concatenate other instances to self.""" -class InconsistentDeviceError(ValueError): +class InconsistentDeviceError(RuntimeError): """Raised if the devices of different fields differ. There is no single device that all fields are on, thus From 777443f908a6eb99f17c6f58fa8687e73322fab4 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 3 Jul 2025 13:11:47 +0200 Subject: [PATCH 084/205] nn --- src/mrpro/nn/__init__.py | 8 ++++---- src/mrpro/nn/nets/__init__.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index dd8afc33e..9d2b1dff4 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -26,26 +26,26 @@ from mrpro.nn.PermutedBlock import PermutedBlock __all__ = [ - "ComplexAsChannel", "AdaptiveAvgPoolND", "AttentionGate", "AvgPoolND", "BatchNormND", + "ComplexAsChannel", "CondMixin", "ConvND", "ConvTransposeND", "DropPath", - "ComplexAsChannel", "FiLM", "GroupNorm", "InstanceNormND", "MaxPoolND", "NeighborhoodSelfAttention", + "PermutedBlock", "ResBlock", + "Residual", "Sequential", "ShiftedWindowAttention", "SqueezeExcitation", "TransposedAttention", - "nets", - "PermutedBlock", + "nets" ] \ No newline at end of file diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 78e7fa82e..228596dc8 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -8,6 +8,7 @@ __all__ = [ "AttentionGatedUNet", + "BasicCNN", "BasicUNet", "DCVAE", "Restormer", @@ -15,6 +16,5 @@ "SwinIR", "UNet", "Uformer", - "VAE", - "BasicCNN", -] + "VAE" +] \ No newline at end of file From 09b70c55a568f07c46be7ab95cfb3f0276fb59e7 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 3 Jul 2025 13:12:00 +0200 Subject: [PATCH 085/205] inati: no nans --- src/mrpro/algorithms/csm/inati.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/src/mrpro/algorithms/csm/inati.py b/src/mrpro/algorithms/csm/inati.py index 0e860ae2e..6e20a3bba 100644 --- a/src/mrpro/algorithms/csm/inati.py +++ b/src/mrpro/algorithms/csm/inati.py @@ -32,10 +32,13 @@ def inati( """ # After 10 power iterations we will have a very good estimate of the singular vector n_power_iterations = 10 + eps = 1e-8 # for numerical stability if isinstance(smoothing_width, int): smoothing_width = SpatialDimension( - z=smoothing_width if coil_img.shape[-3] > 1 else 1, y=smoothing_width, x=smoothing_width + z=smoothing_width if coil_img.shape[-3] > 1 else 1, + y=smoothing_width, + x=smoothing_width, ) if any(ks % 2 != 1 for ks in [smoothing_width.z, smoothing_width.y, smoothing_width.x]): @@ -44,22 +47,33 @@ def inati( ks_halved = [ks // 2 for ks in smoothing_width.zyx] padded_coil_img = torch.nn.functional.pad( coil_img, - (ks_halved[-1], ks_halved[-1], ks_halved[-2], ks_halved[-2], ks_halved[-3], ks_halved[-3]), + ( + ks_halved[-1], + ks_halved[-1], + ks_halved[-2], + ks_halved[-2], + ks_halved[-3], + ks_halved[-3], + ), mode='replicate', ) # Get the voxels in an ROI defined by the smoothing_width around each voxel leading to shape # (z y x coils window=prod(smoothing_width)) coil_img_roi = sliding_window(padded_coil_img, smoothing_width.zyx, dim=(-3, -2, -1)).flatten(-3) - coil_img_cov = einsum(coil_img_roi.conj(), coil_img_roi, '... coils1 window,... coils2 window->... coils1 coils2') + coil_img_cov = einsum( + coil_img_roi.conj(), + coil_img_roi, + '... coils1 window,... coils2 window->... coils1 coils2', + ) singular_vector = torch.sum(coil_img_roi, dim=-1) # z y x coils - singular_vector /= singular_vector.norm(dim=-1, keepdim=True) + singular_vector /= singular_vector.norm(dim=-1, keepdim=True) + eps for _ in range(n_power_iterations): singular_vector = einsum(coil_img_cov, singular_vector, '... coils1 coils2,... coils2->... coils1') - singular_vector /= singular_vector.norm(dim=-1, keepdim=True) + singular_vector /= singular_vector.norm(dim=-1, keepdim=True) + eps singular_value = einsum(coil_img_roi, singular_vector, '... coils window,... coils->... window') phase = singular_value.sum(-1) - phase /= phase.abs() + phase /= phase.abs() + eps csm = einsum(singular_vector.conj(), phase, '... coils,...->coils ...') # coils z y x return csm From b0a85171ffa2ee72e35ee999184ecc96787a2fb2 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 3 Jul 2025 13:12:11 +0200 Subject: [PATCH 086/205] modl --- examples/scripts/modl.py | 97 +++++++++++++++++++++++++++++++--------- 1 file changed, 76 insertions(+), 21 deletions(-) diff --git a/examples/scripts/modl.py b/examples/scripts/modl.py index 464cf744e..ea36268ba 100644 --- a/examples/scripts/modl.py +++ b/examples/scripts/modl.py @@ -1,10 +1,14 @@ # %% +# %matplotlib inline from collections.abc import Sequence from pathlib import Path from typing import TypedDict +import matplotlib.axes +import matplotlib.pyplot as plt import mrpro import torch +from tqdm import tqdm class BatchType(TypedDict): @@ -14,49 +18,58 @@ class BatchType(TypedDict): class AcceleratedFastMRI(torch.utils.data.Dataset): - def __init__(self, path: Path, acceleration: int = 4): + def __init__(self, path: Path, acceleration: float = 16, noise_level: float = 0.2): self.acceleration = acceleration self.dataset = mrpro.phantoms.FastMRIKDataDataset(path) + self.noise_level = noise_level def __len__(self): return len(self.dataset) def __getitem__(self, index: int) -> BatchType: data = self.dataset[index] + data = data.remove_readout_os() + data.data /= data.data.std() reconstruction = mrpro.algorithms.reconstruction.DirectReconstruction( data, csm=lambda data: mrpro.data.CsmData.from_idata_inati(data, downsampled_size=64), ) csm = reconstruction.csm target = reconstruction(data) - data_undersampled = data[..., :: self.acceleration, :] - assert csm is not None # for mypy - if csm.data.isnan().any(): - print('csm nan') - csm = mrpro.data.CsmData.from_kdata_inati(data, downsampled_size=64) + n = max(data.data.shape[-2:]) + distance = (torch.linspace(-1, 1, n)[:, None] ** 2 + torch.linspace(-1, 1, n) ** 2).sqrt() + random = 0.1 / (distance + 0.1) + torch.rand_like(distance) + threshold = torch.kthvalue(random.ravel(), int(n**2 * (1 - 1 / self.acceleration))).values + undersampling_mask = mrpro.utils.pad_or_crop(random > threshold, data.data.shape[-2:]) + data_undersampled = data[..., undersampling_mask].rearrange('k ... 1 -> ... k') + + noise = mrpro.utils.RandomGenerator(seed=index).randn_like(data_undersampled.data) + data_undersampled.data += self.noise_level * noise + + assert csm is not None # for mypy return {'data': data_undersampled, 'target': target, 'csm': csm} class MODL(torch.nn.Module): - def __init__(self, iterations: int = 10, n_features: Sequence[int] = (64, 64, 64)): + def __init__(self, iterations: int = 10, n_features: Sequence[int] = (64, 64, 64, 64)): super().__init__() cnn = mrpro.nn.nets.BasicCNN( dim=2, channels_in=2, channels_out=2, - batch_norm=True, - n_features=(64, 64, 64), + n_features=n_features, ) self.network = mrpro.nn.Residual(mrpro.nn.ComplexAsChannel(mrpro.nn.PermutedBlock((-1, -2), cnn))) + self.network = torch.compile(self.network, dynamic=True, fullgraph=True) self.iterations = iterations self.regularization_weight = torch.nn.Parameter(torch.tensor(1.0)) - def prepare_dataconsistency( + def _prepare_dataconsistency( self, gram: mrpro.operators.LinearOperator, zero_filled_image: torch.Tensor, - ): + ) -> mrpro.operators.ConjugateGradientOp: return mrpro.operators.ConjugateGradientOp( operator_factory=lambda _: gram + self.regularization_weight, rhs_factory=lambda regularization_image: zero_filled_image @@ -71,32 +84,74 @@ def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.dat acquisition_op = fourier_op @ csm.as_operator() (image,) = acquisition_op.H(kdata.data) - data_consistency_op = self.prepare_dataconsistency(acquisition_op.gram, image) + data_consistency_op = self._prepare_dataconsistency(acquisition_op.gram, image) for _ in range(self.iterations): regularization = self.network(image) (image,) = data_consistency_op(regularization) - if image.isnan().any(): - raise ValueError('NaN in image') return mrpro.data.IData(image, header=mrpro.data.IHeader.from_kheader(kdata.header)) +def plot(batch: BatchType, prediction: mrpro.data.IData): + target = batch['target'].rss().cpu().squeeze() + direct = mrpro.algorithms.reconstruction.DirectReconstruction(batch['data'], csm=batch['csm'])(batch['data']) + direct = direct.rss().cpu().squeeze() + direct *= target.std() / direct.std() + sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(batch['data'], csm=batch['csm'])(batch['data']) + sense = sense.rss().cpu().squeeze() + prediction_ = prediction.rss().cpu().squeeze().detach() + + ssim = mrpro.operators.functionals.SSIM(mrpro.utils.pad_or_crop(target[None], (320, 320))) + + def show(ax: matplotlib.axes.Axes, data: torch.Tensor, label: str): + data = mrpro.utils.pad_or_crop(data, (320, 320)) + ax.imshow(data, vmin=0, vmax=target.max().item(), cmap='gray') + if label != 'Ground Truth': + (ssim_value,) = ssim(data[None]) + ax.text( + 0.98, + 0.1, + f'{ssim_value.item():.2f}', + color='white', + horizontalalignment='right', + verticalalignment='top', + transform=ax.transAxes, + ) + ax.set_title(label) + ax.set_axis_off() + + fig, ax = plt.subplots(1, 4) + show(ax[0], direct, 'Direct') + show(ax[1], sense, 'CG-SENSE') + show(ax[2], prediction_, 'MODL') + show(ax[3], target, 'Ground Truth') + fig.tight_layout() + plt.show() + + # %% -from tqdm import tqdm path = Path('/echo/allgemein/resources/publicTrainingData/fastmri/brain_multicoil_train/') dataset = AcceleratedFastMRI(path) -dataloader = torch.utils.data.DataLoader(dataset, num_workers=0, shuffle=True, collate_fn=lambda batch: batch[0]) +dataloader = torch.utils.data.DataLoader(dataset, num_workers=8, shuffle=True, collate_fn=lambda batch: batch[0]) modl = MODL().cuda() optimizer = torch.optim.Adam(modl.parameters(), lr=1e-4) pbar = tqdm(dataloader) -for batch in pbar: - optimizer.zero_grad() +for i, batch in enumerate(pbar): kdata, csm, target = batch['data'].cuda(), batch['csm'].cuda(), batch['target'].cuda() - pred = modl(kdata, csm) - (loss,) = mrpro.operators.functionals.MSE(target.data)(pred.data) + prediction = modl(kdata, csm) + objective = mrpro.operators.functionals.MSE(target.data) - mrpro.operators.functionals.SSIM(target.data) + (loss,) = objective(prediction.data) loss.backward() - optimizer.step() + + if i % 4 == 0: + optimizer.step() + optimizer.zero_grad() + pbar.set_postfix(loss=loss.item()) + + if i % 100 == 0: + plot(batch, prediction) + # %% From 816f3a308fac63b4124c6c494752eaffef51fd31 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 4 Jul 2025 17:14:57 +0200 Subject: [PATCH 087/205] pinqi --- examples/scripts/train_pinqi.py | 543 +++++++++++++++++++------------- 1 file changed, 328 insertions(+), 215 deletions(-) diff --git a/examples/scripts/train_pinqi.py b/examples/scripts/train_pinqi.py index b5df7258a..bbb88cdfe 100644 --- a/examples/scripts/train_pinqi.py +++ b/examples/scripts/train_pinqi.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from copy import deepcopy from pathlib import Path -from typing import Literal, TypedDict, cast +from typing import Any, Literal, TypedDict, cast import einops import matplotlib.pyplot as plt @@ -17,11 +17,12 @@ from pytorch_lightning.loggers import NeptuneLogger from pytorch_lightning.strategies import DDPStrategy - # mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True) class BatchType(TypedDict): + """Typehint for a batch of data.""" + kdata: mrpro.data.KData csm: mrpro.data.CsmData m0: torch.Tensor @@ -30,27 +31,37 @@ class BatchType(TypedDict): class Dataset(torch.utils.data.Dataset): + """A brainweb based cartesian qMRI dataset.""" + def __init__( self, folder: Path, signalmodel: mrpro.operators.SignalModel, n_images: int, - size: int = 192, - acceleration: int = 10, - n_coils: int = 8, + size: int, + acceleration: int, + n_coils: int, + max_noise: float, + orientation: Sequence[Literal['axial', 'coronal', 'sagittal']], random: bool = True, - max_noise: float = 0.1, - orientation: Sequence[Literal["axial", "coronal", "sagittal"]] = ( - "axial", - "coronal", - "sagittal", - ), ): + """Initialize the dataset.""" + if random: + augment = mrpro.phantoms.brainweb.augment(size=size) + else: + augment = mrpro.phantoms.brainweb.augment( + size=size, + max_random_shear=0, + max_random_rotation=0, + max_random_scaling_factor=0, + p_horizontal_flip=0, + p_vertical_flip=1.0, + ) self.phantom = mrpro.phantoms.brainweb.BrainwebSlices( folder=folder, - what=("m0", "t1", "mask"), - seed="index" if not random else "random", - slice_preparation=mrpro.phantoms.brainweb.augment(size=size), + what=('m0', 't1', 'mask'), + seed='index' if not random else 'random', + slice_preparation=augment, orientation=orientation, ) self.signalmodel = signalmodel @@ -63,22 +74,22 @@ def __init__( self._n_images = n_images def __len__(self) -> int: + """Get the length of the dataset.""" return len(self.phantom) def __getitem__(self, index: int): + """Get an item from the dataset.""" phantom = self.phantom[index] - (images,) = self.signalmodel(phantom["m0"], phantom["t1"]) + (images,) = self.signalmodel(phantom['m0'], phantom['t1']) seed = int(torch.randint(0, 1000000, (1,))) if self._random else index - traj = ( - mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density( - encoding_matrix=self.encoding_matrix, - seed=seed, - acceleration=self.acceleration, - fwhm_ratio=2, - n_center=8, - n_other=(self._n_images,), - ) + traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density( + encoding_matrix=self.encoding_matrix, + seed=seed, + acceleration=self.acceleration, + fwhm_ratio=1.5, + n_center=10, + n_other=(self._n_images,), ) header = mrpro.data.KHeader( encoding_matrix=self.encoding_matrix, @@ -92,71 +103,71 @@ def __getitem__(self, index: int): elif isinstance(self.signalmodel, mrpro.operators.models.InversionRecovery): header.ti = self.signalmodel.ti.tolist() - fourier_op = mrpro.operators.FourierOp( - self.encoding_matrix, self.encoding_matrix, traj - ) + fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj) csm = mrpro.data.CsmData( mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix), header, ) - images = einops.rearrange(images, "t y x -> t 1 1 y x") + images = einops.rearrange(images, 't y x -> t 1 1 y x') (data,) = (fourier_op @ csm.as_operator())(images) - data = ( - data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std() - ) + data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std() kdata = mrpro.data.KData(header, data, traj) - return {"kdata": kdata, "csm": csm, **phantom} - - @staticmethod - def collate_fn(batch): - return torch.utils.data._utils.collate.collate( - batch, - collate_fn_map={ - mrpro.data.Dataclass: lambda batch, *, _: batch[0].stack(*batch[1:]), - **torch.utils.data._utils.collate.default_collate_fn_map, - }, - ) + return {'kdata': kdata, 'csm': csm, **phantom} + + +def collate_fn(batch: Any): # noqa: ANN401 + """Join dataclasses to a batch.""" + return torch.utils.data._utils.collate.collate( + batch, + collate_fn_map={ + mrpro.data.Dataclass: lambda batch, *, collate_fn_map: batch[0].stack(*batch[1:]), # noqa: ARG005 + **torch.utils.data._utils.collate.default_collate_fn_map, + }, + ) class PINQI(torch.nn.Module): + """PINQI model.""" + def __init__( self, signalmodel: mrpro.operators.SignalModel, constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp, parameter_is_complex: Sequence[bool], n_images: int, - n_iterations: int = 4, - n_features_parameter_net: Sequence[int] = (64, 128, 192, 256), - n_features_image_net: Sequence[int] = (16, 32, 48, 64), + n_iterations: int, + n_features_parameter_net: Sequence[int], + n_features_image_net: Sequence[int], ): + """Initialize the PINQI model.""" super().__init__() - self.signalmodel = ( - mrpro.operators.RearrangeOp("t batch ... -> batch t ...") - @ deepcopy(signalmodel) - @ constraints_op - ) + self.signalmodel = mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ signalmodel @ constraints_op self.constraints_op = constraints_op self._n_images = n_images self._parameter_is_complex = parameter_is_complex - real_parameters = sum(1 for c in parameter_is_complex if c) + len( - parameter_is_complex - ) - self.parameter_net = mrpro.nn.nets.UNet( - dim=2, - channels_in=n_images * 2, - channels_out=real_parameters, - attention_depths=(-1,), - n_features=n_features_parameter_net, + real_parameters = sum(1 for c in parameter_is_complex if c) + len(parameter_is_complex) + self.parameter_net = torch.compile( + mrpro.nn.nets.UNet( + dim=2, + channels_in=n_images * 2, + channels_out=real_parameters, + attention_depths=(-1, -2), + n_features=n_features_parameter_net, + cond_dim=128, + ), + dynamic=False, + fullgraph=True, ) - self.image_net = mrpro.nn.nets.UNet( - 2, - channels_in=2, - channels_out=2, - attention_depths=(), - n_features=n_features_image_net, + self.image_net = torch.compile( + mrpro.nn.nets.UNet( + 2, channels_in=2, channels_out=2, attention_depths=(), n_features=n_features_image_net, cond_dim=128 + ), + dynamic=False, + fullgraph=True, ) self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3)) - self.softplus = torch.nn.Softplus() + self.softplus = torch.nn.Softplus(beta=5) + self.iteration_embedding = torch.nn.Embedding(n_iterations + 1, 128) def objective_factory( lambda_parameters: torch.Tensor, @@ -196,15 +207,14 @@ def rhs_factory( rhs_factory=rhs_factory, ) - def get_parameter_reg(self, image: torch.Tensor) -> tuple[torch.Tensor, ...]: + def get_parameter_reg(self, image: torch.Tensor, iteration: int = 0) -> tuple[torch.Tensor, ...]: image = einops.rearrange( torch.view_as_real(image), - "batch t 1 1 y x complex-> batch (t complex) y x", - ) - parameters = self.parameter_net(image.contiguous()) - parameters = einops.rearrange( - parameters, "batch parameters y x-> parameters batch 1 1 y x" + 'batch t 1 1 y x complex-> batch (t complex) y x', ) + cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None] + parameters = self.parameter_net(image.contiguous(), cond=cond) + parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x') i = 0 result = [] for is_complex in self._parameter_is_complex: @@ -216,16 +226,15 @@ def get_parameter_reg(self, image: torch.Tensor) -> tuple[torch.Tensor, ...]: i += 1 return tuple(result) - def get_image_reg(self, image: torch.Tensor) -> torch.Tensor: + def get_image_reg(self, image: torch.Tensor, iteration: int = 0) -> torch.Tensor: batch = image.shape[0] image = einops.rearrange( torch.view_as_real(image), - "batch t 1 1 y x complex-> (batch t) complex y x", - ) - image = image + self.image_net(image.contiguous()) - image = einops.rearrange( - image, "(batch t) complex y x-> batch t 1 1 y x complex", batch=batch + 'batch t 1 1 y x complex-> (batch t) complex y x', ) + cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None] + image = image + self.image_net(image.contiguous(), cond=cond) + image = einops.rearrange(image, '(batch t) complex y x-> batch t 1 1 y x complex', batch=batch) return torch.view_as_complex(image.contiguous()) def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): @@ -234,58 +243,71 @@ def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): acquisition_op = fourier_op @ csm_op gram = acquisition_op.gram (zero_filled_image,) = acquisition_op.H(kdata.data) - images = list( - mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2) - ) - parameters = [self.get_parameter_reg(images[-1])] + images = list(mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2)) + parameters = [self.get_parameter_reg(images[-1], 0)] linear_solver = self.get_linear_solver(gram) - for lambda_image, lambda_q, lambda_parameter in self.softplus(self.lambdas_raw): - image_reg = self.get_image_reg(images[-1]) + for i, (lambda_image, lambda_q, lambda_parameter) in enumerate(self.softplus(self.lambdas_raw)): + image_reg = self.get_image_reg(images[-1], i + 1) (signal,) = self.signalmodel(*parameters[-1]) - images.extend( - linear_solver( - lambda_image, lambda_q, image_reg, signal, zero_filled_image - ) - ) - parameters_reg = self.get_parameter_reg(images[-1]) - parameters.append( - self.nonlinear_solver(lambda_parameter, images[-1], *parameters_reg) - ) + images.extend(linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image)) + parameters_reg = self.get_parameter_reg(images[-1], i + 1) + parameters.append(self.nonlinear_solver(lambda_parameter, images[-1], *parameters_reg)) if self.constraints_op is not None: parameters = [self.constraints_op(*p) for p in parameters] return images, parameters class DataModule(pl.LightningDataModule): + """Data module for training the PINQI model.""" + def __init__( self, folder: Path, - batch_size: int = 8, - num_workers: int = 4, - signalmodel: mrpro.operators.SignalModel = mrpro.operators.models.SaturationRecovery( - (0.5, 1.0, 1.5, 2.0, 6.0) + signalmodel: mrpro.operators.SignalModel, + n_images: int, + size: int = 192, + acceleration: int = 10, + n_coils: int = 8, + max_noise: float = 0.1, + orientation_train: Sequence[Literal['axial', 'coronal', 'sagittal']] = ( + 'axial', + 'coronal', + 'sagittal', ), - n_images: int = 5, - **kwargs, + orientation_val: Sequence[Literal['axial', 'coronal', 'sagittal']] = ('axial',), + batch_size: int = 16, + num_workers: int = 4, ): + """Initialize the data module.""" super().__init__() + self.save_hyperparameters(ignore=['signalmodel', 'folder', 'num_workers']) self.batch_size = batch_size self.num_workers = num_workers self.train_dataset = Dataset( folder=folder, signalmodel=signalmodel, n_images=n_images, - **kwargs, + size=size, + acceleration=acceleration, + n_coils=n_coils, + max_noise=max_noise, + orientation=orientation_train, random=True, ) - self.val_dataset = Dataset( - folder=folder, - orientation=("axial",), - signalmodel=signalmodel, - n_images=n_images, - **kwargs, - random=False, + self.val_dataset = torch.utils.data.Subset( + Dataset( + folder=folder, + signalmodel=signalmodel, + n_images=n_images, + size=size, + acceleration=acceleration, + n_coils=n_coils, + max_noise=max_noise, + orientation=orientation_val, + random=False, + ), + list(range(30, 500, 20)), ) def train_dataloader(self): @@ -294,25 +316,27 @@ def train_dataloader(self): batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, - pin_memory=True, + pin_memory=False, persistent_workers=self.num_workers > 0, - collate_fn=self.train_dataset.collate_fn, + collate_fn=collate_fn, worker_init_fn=lambda *_: torch.set_num_threads(1), ) def val_dataloader(self): return torch.utils.data.DataLoader( self.val_dataset, - batch_size=self.batch_size, + batch_size=1, shuffle=False, num_workers=self.num_workers, - pin_memory=True, + pin_memory=False, persistent_workers=self.num_workers > 0, - collate_fn=self.val_dataset.collate_fn, + collate_fn=collate_fn, ) -class Module(pl.LightningModule): +class PinqiModule(pl.LightningModule): + """Module for training the PINQI model.""" + def __init__( self, signalmodel: mrpro.operators.SignalModel, @@ -320,19 +344,18 @@ def __init__( parameter_is_complex: Sequence[bool], n_images: int, n_iterations: int = 4, - n_features_parameter_net: Sequence[int] = (64, 128, 192, 256), + n_features_parameter_net: Sequence[int] = (64, 128, 192, 224, 256), n_features_image_net: Sequence[int] = (16, 32, 48, 64), - lr: float = 3e-4, - weight_decay: float = 1e-4, - loss_weights: Sequence[float] = (0.1, 0.1, 0.1, 0.2, 0.5), + lr: float = 4e-4, # noqa: ARG002 + weight_decay: float = 1e-3, # noqa: ARG002 + loss_weights: Sequence[float] = (0.2, 0.1, 0.1, 0.1, 0.8), ): + """Initialize the PINQI module.""" super().__init__() - self.save_hyperparameters() + self.save_hyperparameters(ignore=['signalmodel', 'constraints_op']) if len(loss_weights) != n_iterations + 1: - raise ValueError( - f"loss_weights must be of length {n_iterations + 1} for {n_iterations} iterations" - ) - + raise ValueError(f'loss_weights must be of length {n_iterations + 1} for {n_iterations} iterations') + signalmodel, constraints_op = map(deepcopy, (signalmodel, constraints_op)) self.pinqi = PINQI( signalmodel=signalmodel, constraints_op=constraints_op, @@ -344,66 +367,77 @@ def __init__( ) self.validation_step_outputs = collections.defaultdict(list) + self.baseline = Baseline(signalmodel, constraints_op, parameter_is_complex) def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): + """Apply the PINQI model to the data.""" return self.pinqi(kdata, csm) - def loss(self, predictions, batch): + def loss(self, predictions: Sequence[torch.Tensor], batch: BatchType) -> torch.Tensor: + """Compute the loss.""" loss = torch.tensor(0.0, device=self.device) - target_m0 = batch["m0"] - target_t1 = batch["t1"] - mask = batch["mask"] - for prediction, weight in zip( - predictions, self.hparams.loss_weights, strict=False - ): - prediction_m0, prediction_t1 = prediction - loss_t1 = torch.nn.functional.mse_loss( - prediction_t1.squeeze()[mask], target_t1[mask] - ) + target_m0, target_t1, mask = map(torch.squeeze, (batch['m0'], batch['t1'], batch['mask'])) + for prediction, weight in zip(predictions, self.hparams.loss_weights, strict=False): + prediction_m0, prediction_t1 = map(torch.squeeze, prediction) + loss_t1 = torch.nn.functional.mse_loss(prediction_t1[mask], target_t1[mask]) loss_m0 = torch.nn.functional.mse_loss( - torch.view_as_real((prediction_m0).squeeze()[mask]), + torch.view_as_real(prediction_m0[mask]), torch.view_as_real(target_m0[mask]), ) loss_outside = prediction_m0[~mask].abs().mean() loss = loss + weight * (loss_t1 + 0.5 * loss_m0 + 0.1 * loss_outside) return loss - def training_step(self, batch, batch_idx): - images, parameters = self(batch["kdata"], batch["csm"]) + def training_step(self, batch: BatchType, _batch_idx: int) -> torch.Tensor: + """Training step.""" + images, parameters = self(batch['kdata'], batch['csm']) loss = self.loss(parameters, batch) self.log( - "train/loss", + 'train/loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True, + batch_size=len(batch['mask']), ) return loss def validation_step(self, batch: BatchType, batch_idx: int) -> None: - images, parameters = self(batch["kdata"], batch["csm"]) + """Validate. + + Needs to be adapted for other signal models than Saturation Recovery. + """ + images, parameters = self(batch['kdata'], batch['csm']) loss = self.loss(parameters, batch) pred_m0, pred_t1 = parameters[-1] - target_t1, target_m0 = batch["t1"], batch["m0"] - mask = batch["mask"] + target_t1, target_m0 = batch['t1'], batch['m0'] + mask = batch['mask'] + batch_size = len(batch['mask']) (ssim_t1,) = mrpro.operators.functionals.SSIM(target_t1, mask)(pred_t1) (l1_t1,) = mrpro.operators.functionals.L1Norm(target_t1, mask)(pred_t1) (l1_m0,) = mrpro.operators.functionals.L1Norm(target_m0, mask)(pred_m0) - self.log("val/ssim_t1", ssim_t1, on_epoch=True, sync_dist=True) - self.log("val/l1_t1", l1_t1, on_epoch=True, sync_dist=True) - self.log("val/l1_m0", l1_m0, on_epoch=True, sync_dist=True) - self.log("val/loss", loss, on_epoch=True, sync_dist=True) + self.log('val/ssim_t1', ssim_t1, on_epoch=True, sync_dist=True, batch_size=batch_size) + self.log('val/l1_t1', l1_t1, on_epoch=True, sync_dist=True, batch_size=batch_size) + self.log('val/l1_m0', l1_m0, on_epoch=True, sync_dist=True, batch_size=batch_size) + self.log('val/loss', loss, on_epoch=True, sync_dist=True, batch_size=batch_size) if batch_idx == 0: - self.validation_step_outputs["target_t1"].append(batch["t1"]) - self.validation_step_outputs["pred_t1"].append(pred_t1) - self.validation_step_outputs["pred_m0"].append(pred_m0) - self.validation_step_outputs["target_m0"].append(target_m0) - self.validation_step_outputs["mask"].append(batch["mask"]) + self.validation_step_outputs['target_t1'].append(batch['t1']) + self.validation_step_outputs['pred_t1'].append(pred_t1) + self.validation_step_outputs['pred_m0'].append(pred_m0) + self.validation_step_outputs['target_m0'].append(target_m0) + self.validation_step_outputs['mask'].append(batch['mask']) + baseline_m0, baseline_t1 = self.baseline(batch['kdata'], batch['csm']) + self.validation_step_outputs['baseline_t1'].append(baseline_t1) + self.validation_step_outputs['baseline_m0'].append(baseline_m0) def on_validation_epoch_end(self): + """Validate. + + Needs to be adapted for other signal models than Saturation Recovery. + """ outputs = {k: torch.cat(v) for k, v in self.validation_step_outputs.items()} self.validation_step_outputs.clear() outputs = cast(dict[str, torch.Tensor], self.all_gather(outputs)) @@ -412,94 +446,176 @@ def on_validation_epoch_end(self): return outputs = {k: v.flatten(0, 1).cpu() for k, v in outputs.items()} - samples = 4 - fig, axes = plt.subplots(3, samples, figsize=(4 * samples, 12)) + samples = len(outputs['mask']) + fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 16)) + for i in range(samples): self.result_plot( - outputs["target_t1"][i], - outputs["pred_t1"][i], - outputs["mask"][i], + outputs['target_t1'][i], + outputs['pred_t1'][i], + outputs['mask'][i], axes[:, i], + outputs['baseline_t1'][i], + '$T_1$ (s)', ) - fig.suptitle(f"T1 Epoch {self.current_epoch}") - self.logger.run["val/images/t1"].log(fig) + fig.suptitle(f'$T_1$ Epoch {self.current_epoch}') + self.logger.run['val/images/t1'].log(fig) plt.close(fig) - fig, axes = plt.subplots(3, samples, figsize=(4 * samples, 12)) + fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 12)) for i in range(samples): self.result_plot( - outputs["target_m0"][i].abs(), - outputs["pred_m0"][i].abs(), - outputs["mask"][i], + outputs['target_m0'][i].abs(), + outputs['pred_m0'][i].abs(), + outputs['mask'][i], axes[:, i], + outputs['baseline_m0'][i].abs(), + '$|M_0|$ (a.u.)', ) - fig.suptitle(f"|M0| Epoch {self.current_epoch}") - self.logger.run["val/images/m0"].log(fig) + fig.suptitle(f'$|M_0|$ Epoch {self.current_epoch}') + self.logger.run['val/images/m0'].log(fig) plt.close(fig) - def result_plot(self, target, pred, mask, axes): + def result_plot( + self, + target: torch.Tensor, + pred: torch.Tensor, + mask: torch.Tensor, + axes: Sequence[plt.Axes], + baseline: torch.Tensor, + label: str, + ) -> None: + """Plot the results.""" target = target.squeeze().numpy() pred = pred.squeeze().detach().numpy() mask = mask.squeeze().detach().numpy().astype(bool) + baseline = baseline.squeeze().detach().numpy() target[~mask] = np.nan pred[~mask] = np.nan - difference = target - pred + baseline[~mask] = np.nan + difference = (target - pred) / target * 100 vmax = np.nanmax(target) - im1 = axes[0].imshow(target, vmin=0, vmax=vmax) - axes[0].set_title("Target") - axes[0].axis("off") - axes[0].colorbar(im1) + im0 = axes[0].imshow(target, vmin=0, vmax=vmax) + axes[0].set_title('Ground Truth') + axes[0].axis('off') + plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04, label=label) - im2 = axes[1].imshow(pred, vmin=0, vmax=vmax) - axes[1].set_title("Predicted") - axes[1].axis("off") - axes[1].colorbar(im2) + im1 = axes[1].imshow(baseline, vmin=0, vmax=vmax) + axes[1].set_title('SENSE + Regression') + axes[1].axis('off') + plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04, label=label) - diff_vmax = np.nanmax(np.abs(difference)) - im3 = axes[2].imshow( - difference, cmap="coolwarm", vmin=-diff_vmax, vmax=diff_vmax - ) - axes[2].set_title("Difference") - axes[2].axis("off") - axes[2].colorbar(im3) - return axes + im2 = axes[2].imshow(pred, vmin=0, vmax=vmax) + axes[2].set_title('PINQI') + axes[2].axis('off') + plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04, label=label) + + diff_vmax = np.nanpercentile(np.abs(difference), 90) + im3 = axes[3].imshow(difference, cmap='coolwarm', vmin=-diff_vmax, vmax=diff_vmax) + axes[3].set_title('rel. Error') + axes[3].axis('off') + plt.colorbar(im3, ax=axes[3], fraction=0.046, pad=0.04, label='%') def configure_optimizers( self, - ) -> dict[str, torch.optim.Optimizer | torch.optim.lr_scheduler.LRScheduler]: + ) -> dict: + """Configure the optimizer and the learning rate scheduler.""" + scalars = ('lambdas_raw', 'rezero') + params, scalar_params = [], [] + for n, p in self.named_parameters(): + if not p.requires_grad: + continue + if any(s in n for s in scalars): + scalar_params.append(p) + else: + params.append(p) optimizer = torch.optim.AdamW( - self.parameters(), - lr=self.hparams.lr, - weight_decay=self.hparams.weight_decay, + [ + {'params': params, 'weight_decay': self.hparams.weight_decay, 'lr': self.hparams.lr}, + {'params': scalar_params, 'weight_decay': 0.0, 'lr': self.hparams.lr * 10}, + ], ) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, - max_lr=self.hparams.lr, - total_steps=self.trainer.max_steps, + max_lr=[self.hparams.lr, 10 * self.hparams.lr], + total_steps=self.trainer.estimated_stepping_batches, pct_start=0.1, - div_factor=10, - final_div_factor=200, + div_factor=30, + final_div_factor=300, ) return { - "optimizer": optimizer, - "lr_scheduler": scheduler, + 'optimizer': optimizer, + 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}, } -# %% -if __name__ == "__main__": - torch.set_float32_matmul_precision("high") - torch._inductor.config.worker_start_method = "fork" +class Baseline(torch.nn.Module): + """Baseline solution using SENSE + Regression.""" + + def __init__( + self, + signalmodel: mrpro.operators.SignalModel, + constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp, + parameter_is_complex: Sequence[bool], + ): + """Initialize the baseline.""" + super().__init__() + self.signalmodel = signalmodel + self.constraints_op = constraints_op + self.parameter_is_complex = parameter_is_complex + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> tuple[torch.Tensor, ...]: + """Compute the baseline solution.""" + sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(kdata, csm=csm) + images = sense(kdata).rearrange('batch time ...-> time batch ...') + + objective = mrpro.operators.functionals.L2NormSquared(images.data) @ self.signalmodel @ self.constraints_op + initial_values = tuple( + torch.zeros(images.shape[1:], device=images.device, dtype=torch.complex64 if is_complex else torch.float32) + for is_complex in self.parameter_is_complex + ) + solution = self.constraints_op(*mrpro.algorithms.optimizers.lbfgs(objective, initial_values)) + return solution + + +class LogLambdasCallback(pl.Callback): + """Log the lambdas.""" + + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: PinqiModule, + _outputs: dict, + _batch: BatchType, + _batch_idx: int, + ) -> None: + if trainer.global_step % 10 == 0: + lambdas = pl_module.pinqi.softplus(pl_module.pinqi.lambdas_raw).detach().cpu().numpy() + for iteration, (lambda_image, lambda_q, lambda_parameter) in enumerate(lambdas): + self.log_dict( + { + f'parameter/lambda_image_{iteration}': lambda_image, + f'parameter/lambda_q_{iteration}': lambda_q, + f'parameter/lambda_parameter_{iteration}': lambda_parameter, + }, + on_step=True, + on_epoch=False, + ) + + +if __name__ == '__main__': + torch.set_float32_matmul_precision('high') torch._inductor.config.compile_threads = 4 + torch._inductor.config.worker_start_method = 'fork' torch._dynamo.config.capture_scalar_outputs = True - torch._functorch.config.activation_memory_budget = 0.9 torch._dynamo.config.cache_size_limit = 256 + torch._functorch.config.activation_memory_budget = 0.95 - data_folder = Path("/scratch/zimmer08/brainweb") + data_folder = Path('/scratch/zimmer08/brainweb') - signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 6.0)) + signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 8.0)) constraints_op = mrpro.operators.ConstraintsOp( bounds=( (-2, 2), # M0 in [-2, 2] @@ -514,53 +630,50 @@ def configure_optimizers( signalmodel=signalmodel, n_images=n_images, batch_size=16, - num_workers=8, - pin_memory=True, + num_workers=16, size=192, - acceleration=10, + acceleration=8, n_coils=8, max_noise=0.1, ) - model = Module( + model = PinqiModule( signalmodel=signalmodel, constraints_op=constraints_op, parameter_is_complex=parameter_is_complex, n_images=n_images, - lr=3e-4, - weight_decay=1e-4, - n_iterations=4, ) neptune_logger = NeptuneLogger( log_model_checkpoints=False, - dependencies="infer", + dependencies='infer', ) - neptune_logger.log_hyperparams(model.hparams) + neptune_logger.log_model_summary(model=model, max_depth=-1) checkpoint_callback = ModelCheckpoint( - monitor="val/loss", - mode="min", + monitor='val/loss', + mode='min', save_top_k=2, - dirpath=Path("checkpoints") / str(neptune_logger.version), - filename="{epoch:02d}-{val/loss:.4f}", + dirpath=Path('checkpoints') / str(neptune_logger.version), + filename='{epoch:02d}-{val/loss:.4f}', save_last=True, ) strategy = DDPStrategy(find_unused_parameters=False) - trainer = pl.Trainer( - max_epochs=50, - accelerator="gpu", + max_epochs=100, + accelerator='gpu', devices=4, strategy=strategy, logger=neptune_logger, callbacks=[ - LearningRateMonitor(logging_interval="step"), + LearningRateMonitor(logging_interval='step'), checkpoint_callback, + LogLambdasCallback(), ], log_every_n_steps=10, - precision="16-mixed", + gradient_clip_algorithm='norm', + gradient_clip_val=5.0, ) trainer.fit(model, datamodule=dm) From d8bb3051051cf7dde2009bc683900f29efaf0fa5 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 4 Jul 2025 17:15:27 +0200 Subject: [PATCH 088/205] fix ssim --- src/mrpro/operators/functionals/SSIM.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/mrpro/operators/functionals/SSIM.py b/src/mrpro/operators/functionals/SSIM.py index 663baf089..0b62e4995 100644 --- a/src/mrpro/operators/functionals/SSIM.py +++ b/src/mrpro/operators/functionals/SSIM.py @@ -80,14 +80,14 @@ def ssim3d( return (real_ssim + imag_ssim) / 2 if target.ndim < 3: raise ValueError('Input must be at least 3D (z, y, x)') - + window = tuple(window_size if s > 1 else 1 for s in target.shape[-3:]) # To support 1D and 2D uses if weight is not None: if (weight < 0).any(): raise ValueError('Mask contains negative values') target, prediction, weight = cast( tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.broadcast_tensors(target, prediction, weight) ) - weight = sliding_window(weight, window_shape=window_size, dim=(-3, -2, -1)) + weight = sliding_window(weight, window_shape=window, dim=(-3, -2, -1)) # Set weights to 0 for windows that are not fully inside the mask weight = weight * ~torch.isclose(weight, torch.tensor(0, dtype=weight.dtype)).any((-3, -2, -1), keepdim=True) weight = weight.mean((-1, -2, -3), dtype=torch.float32).moveaxis((0, 1, 2), (-3, -2, -1)) @@ -96,7 +96,6 @@ def ssim3d( else: target, prediction = cast(tuple[torch.Tensor, torch.Tensor], torch.broadcast_tensors(target, prediction)) - window = tuple(window_size if s > 1 else 1 for s in target.shape[-3:]) # To support 1D and 2D uses target_window = sliding_window(target, window_shape=window, dim=(-3, -2, -1)).movedim((0, 1, 2), (-6, -5, -4)) if data_range is None: From 6b46b3aa8b948879a074ee5d3749dea10617be5c Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 3 Jul 2025 21:38:25 +0200 Subject: [PATCH 089/205] fix cg --- src/mrpro/operators/ConjugateGradientOp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mrpro/operators/ConjugateGradientOp.py b/src/mrpro/operators/ConjugateGradientOp.py index 9da2e025b..617f15770 100644 --- a/src/mrpro/operators/ConjugateGradientOp.py +++ b/src/mrpro/operators/ConjugateGradientOp.py @@ -83,6 +83,7 @@ def backward(ctx: ConjugateGradientCTX, *grad_output: torch.Tensor) -> tuple[tor ctx.saved_tensors[: ctx.len_solution], ctx.saved_tensors[ctx.len_solution :], ) + inputs = tuple(x.detach().clone().requires_grad_(x.requires_grad) for x in inputs) with torch.enable_grad(): rhs = ctx.rhs_factory(*inputs) operator = ctx.operator_factory(*inputs) From 7cd0d7f8473506b93b06400b51e789e48531251d Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 4 Jul 2025 17:25:08 +0200 Subject: [PATCH 090/205] modl --- examples/scripts/modl.py | 71 +++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 38 deletions(-) diff --git a/examples/scripts/modl.py b/examples/scripts/modl.py index ea36268ba..5039d233a 100644 --- a/examples/scripts/modl.py +++ b/examples/scripts/modl.py @@ -18,9 +18,10 @@ class BatchType(TypedDict): class AcceleratedFastMRI(torch.utils.data.Dataset): - def __init__(self, path: Path, acceleration: float = 16, noise_level: float = 0.2): + def __init__(self, path: Path, acceleration: float = 12, noise_level: float = 0.1): self.acceleration = acceleration - self.dataset = mrpro.phantoms.FastMRIKDataDataset(path) + files = list(path.glob('*AXT1*')) + self.dataset = mrpro.phantoms.FastMRIKDataDataset(files) self.noise_level = noise_level def __len__(self): @@ -31,8 +32,7 @@ def __getitem__(self, index: int) -> BatchType: data = data.remove_readout_os() data.data /= data.data.std() reconstruction = mrpro.algorithms.reconstruction.DirectReconstruction( - data, - csm=lambda data: mrpro.data.CsmData.from_idata_inati(data, downsampled_size=64), + data, csm=lambda data: mrpro.data.CsmData.from_idata_inati(data, downsampled_size=64) ) csm = reconstruction.csm target = reconstruction(data) @@ -52,29 +52,19 @@ def __getitem__(self, index: int) -> BatchType: class MODL(torch.nn.Module): - def __init__(self, iterations: int = 10, n_features: Sequence[int] = (64, 64, 64, 64)): + def __init__(self, iterations: int = 8, n_features: Sequence[int] = (64, 64, 64, 64)): super().__init__() cnn = mrpro.nn.nets.BasicCNN( dim=2, channels_in=2, channels_out=2, n_features=n_features, + batch_norm=True, ) self.network = mrpro.nn.Residual(mrpro.nn.ComplexAsChannel(mrpro.nn.PermutedBlock((-1, -2), cnn))) self.network = torch.compile(self.network, dynamic=True, fullgraph=True) self.iterations = iterations - self.regularization_weight = torch.nn.Parameter(torch.tensor(1.0)) - - def _prepare_dataconsistency( - self, - gram: mrpro.operators.LinearOperator, - zero_filled_image: torch.Tensor, - ) -> mrpro.operators.ConjugateGradientOp: - return mrpro.operators.ConjugateGradientOp( - operator_factory=lambda _: gram + self.regularization_weight, - rhs_factory=lambda regularization_image: zero_filled_image - + self.regularization_weight * regularization_image, - ) + self.regularization_weights = torch.nn.Parameter(0.2 * torch.ones(iterations)) def __call__(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData: return super().__call__(kdata, csm) @@ -82,18 +72,23 @@ def __call__(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.da def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData: fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) acquisition_op = fourier_op @ csm.as_operator() + (zero_filled_image,) = acquisition_op.H(kdata.data) + gram = acquisition_op.gram + data_consistency_op = mrpro.operators.ConjugateGradientOp( + operator_factory=lambda _image, weight: gram + weight, + rhs_factory=lambda image, weight: zero_filled_image + weight * image, + ) - (image,) = acquisition_op.H(kdata.data) - data_consistency_op = self._prepare_dataconsistency(acquisition_op.gram, image) - - for _ in range(self.iterations): + (image,) = mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=5) + for iteration in range(self.iterations): regularization = self.network(image) - (image,) = data_consistency_op(regularization) + (image,) = data_consistency_op(regularization, self.regularization_weights[iteration]) return mrpro.data.IData(image, header=mrpro.data.IHeader.from_kheader(kdata.header)) -def plot(batch: BatchType, prediction: mrpro.data.IData): +def plot(batch: BatchType, prediction: mrpro.data.IData, step: int): + """Plot the direct, sense, and modl reconstructions.""" target = batch['target'].rss().cpu().squeeze() direct = mrpro.algorithms.reconstruction.DirectReconstruction(batch['data'], csm=batch['csm'])(batch['data']) direct = direct.rss().cpu().squeeze() @@ -112,7 +107,7 @@ def show(ax: matplotlib.axes.Axes, data: torch.Tensor, label: str): ax.text( 0.98, 0.1, - f'{ssim_value.item():.2f}', + f'SSIM: {ssim_value.item():.2f}', color='white', horizontalalignment='right', verticalalignment='top', @@ -127,31 +122,31 @@ def show(ax: matplotlib.axes.Axes, data: torch.Tensor, label: str): show(ax[2], prediction_, 'MODL') show(ax[3], target, 'Ground Truth') fig.tight_layout() - plt.show() + fig.savefig(f'modl_{step}.pdf', bbox_inches='tight', pad_inches=0) -# %% - +# %%. path = Path('/echo/allgemein/resources/publicTrainingData/fastmri/brain_multicoil_train/') dataset = AcceleratedFastMRI(path) -dataloader = torch.utils.data.DataLoader(dataset, num_workers=8, shuffle=True, collate_fn=lambda batch: batch[0]) +dataloader = torch.utils.data.DataLoader(dataset, num_workers=16, shuffle=True, collate_fn=lambda batch: batch[0]) modl = MODL().cuda() -optimizer = torch.optim.Adam(modl.parameters(), lr=1e-4) +optimizer = torch.optim.Adam(modl.parameters(), lr=1e-3) pbar = tqdm(dataloader) for i, batch in enumerate(pbar): - kdata, csm, target = batch['data'].cuda(), batch['csm'].cuda(), batch['target'].cuda() + optimizer.zero_grad() + kdata, csm, target = (batch['data'].cuda(), batch['csm'].cuda(), batch['target'].cuda()) prediction = modl(kdata, csm) - objective = mrpro.operators.functionals.MSE(target.data) - mrpro.operators.functionals.SSIM(target.data) + objective = 0.5 * mrpro.operators.functionals.MSE(target.data) - mrpro.operators.functionals.SSIM(target.data) (loss,) = objective(prediction.data) loss.backward() - - if i % 4 == 0: - optimizer.step() - optimizer.zero_grad() + torch.nn.utils.clip_grad_norm_(modl.parameters(), 5.0) + optimizer.step() pbar.set_postfix(loss=loss.item()) - - if i % 100 == 0: - plot(batch, prediction) + if i % 200 == 0: + plot(batch, prediction, i) + print(modl.regularization_weights) + state = {'modl': modl.state_dict(), 'optimizer': optimizer.state_dict()} + torch.save(state, f'modl_{i}.pt') # %% From 3a91b5d4a10b19681908e12dde0ef2e9a3c68ffd Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 8 Jul 2025 18:14:46 +0200 Subject: [PATCH 091/205] fix test --- tests/data/test_kdata.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index 22bc6b20e..1cca97c3c 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -10,6 +10,7 @@ from einops import repeat from mrpro.data import KData, KHeader, KTrajectory, SpatialDimension from mrpro.data.acq_filters import has_n_coils, is_coil_calibration_acquisition, is_image_acquisition +from mrpro.data.Dataclass import InconsistentDeviceError from mrpro.data.traj_calculators import KTrajectoryIsmrmrd from mrpro.data.traj_calculators.KTrajectoryCalculator import DummyTrajectory from mrpro.operators import FastFourierOp @@ -248,7 +249,7 @@ def test_KData_inconsistentdevice(ismrmrd_cart) -> None: kdata_mix = KData(data=kdata_cuda.data, header=kdata_cpu.header, traj=kdata_cpu.traj) assert not kdata_mix.is_cuda assert not kdata_mix.is_cpu - with pytest.raises(ValueError): + with pytest.raises(InconsistentDeviceError): _ = kdata_mix.device From 3b17a711ca58a19cf67d2942915ed107cc48db0d Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 9 Jul 2025 00:42:57 +0200 Subject: [PATCH 092/205] apply pinqi --- examples/scripts/apply_pinqi.py | 496 ++++++++++++++++++++++++++++++++ 1 file changed, 496 insertions(+) create mode 100644 examples/scripts/apply_pinqi.py diff --git a/examples/scripts/apply_pinqi.py b/examples/scripts/apply_pinqi.py new file mode 100644 index 000000000..075a5bac8 --- /dev/null +++ b/examples/scripts/apply_pinqi.py @@ -0,0 +1,496 @@ +# %% +from collections.abc import Sequence +from copy import deepcopy +from pathlib import Path +from typing import Literal, TypedDict + +import einops +import mrpro +import torch + +# mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True) + + +class BatchType(TypedDict): + """Typehint for a batch of data.""" + + kdata: mrpro.data.KData + csm: mrpro.data.CsmData + m0: torch.Tensor + t1: torch.Tensor + mask: torch.Tensor + + +class Dataset(torch.utils.data.Dataset[BatchType]): + """A brainweb based cartesian qMRI dataset.""" + + def __init__( + self, + folder: Path, + signalmodel: mrpro.operators.SignalModel, + n_images: int, + size: int, + acceleration: int, + n_coils: int, + max_noise: float, + orientation: Sequence[Literal['axial', 'coronal', 'sagittal']], + random: bool = True, + ): + """Initialize the dataset.""" + if random: + augment = mrpro.phantoms.brainweb.augment(size=size) + else: + augment = mrpro.phantoms.brainweb.augment( + size=size, + max_random_shear=0, + max_random_rotation=0, + max_random_scaling_factor=0, + p_horizontal_flip=0, + p_vertical_flip=1.0, + ) + self.phantom = mrpro.phantoms.brainweb.BrainwebSlices( + folder=folder, + what=('m0', 't1', 'mask'), + seed='index' if not random else 'random', + slice_preparation=augment, + orientation=orientation, + ) + self.signalmodel = deepcopy(signalmodel) + self.encoding_matrix = mrpro.data.SpatialDimension(1, size, size) + self.fov = mrpro.data.SpatialDimension(0.01, 0.25, 0.25) + self.acceleration = acceleration + self.n_coils = n_coils + self._random = random + self.max_noise = max_noise + self._n_images = n_images + + def __len__(self) -> int: + """Get the length of the dataset.""" + return len(self.phantom) + + def __getitem__(self, index: int): + """Get an item from the dataset.""" + phantom = self.phantom[index] + (images,) = self.signalmodel(phantom['m0'], phantom['t1']) + seed = int(torch.randint(0, 1000000, (1,))) if self._random else index + + traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density( + encoding_matrix=self.encoding_matrix, + seed=seed, + acceleration=self.acceleration, + fwhm_ratio=1.5, + n_center=10, + n_other=(self._n_images,), + ) + header = mrpro.data.KHeader( + encoding_matrix=self.encoding_matrix, + recon_matrix=self.encoding_matrix, + recon_fov=self.fov, + encoding_fov=self.fov, + ) + + if isinstance(self.signalmodel, mrpro.operators.models.SaturationRecovery): + header.ti = self.signalmodel.saturation_time.tolist() + elif isinstance(self.signalmodel, mrpro.operators.models.InversionRecovery): + header.ti = self.signalmodel.ti.tolist() + + fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj) + csm = mrpro.data.CsmData( + mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix), + header, + ) + images = einops.rearrange(images, 't y x -> t 1 1 y x') + (data,) = (fourier_op @ csm.as_operator())(images) + data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std() + kdata = mrpro.data.KData(header, data, traj) + return {'kdata': kdata, 'csm': csm, **phantom} + + +class PINQI(torch.nn.Module): + """PINQI model.""" + + def __init__( + self, + signalmodel: mrpro.operators.SignalModel, + constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp, + parameter_is_complex: Sequence[bool], + n_images: int, + n_iterations: int, + n_features_parameter_net: Sequence[int], + n_features_image_net: Sequence[int], + ): + """Initialize the PINQI model.""" + super().__init__() + self.signalmodel = mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ signalmodel @ constraints_op + self.constraints_op = constraints_op + self._n_images = n_images + self._parameter_is_complex = parameter_is_complex + real_parameters = sum(1 for c in parameter_is_complex if c) + len(parameter_is_complex) + self.parameter_net = mrpro.nn.nets.UNet( + dim=2, + channels_in=n_images * 2, + channels_out=real_parameters, + attention_depths=(-1, -2), + n_features=n_features_parameter_net, + cond_dim=128, + ) + + self.image_net = mrpro.nn.nets.UNet( + 2, channels_in=2, channels_out=2, attention_depths=(), n_features=n_features_image_net, cond_dim=128 + ) + self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3)) + self.softplus = torch.nn.Softplus(beta=5) + self.iteration_embedding = torch.nn.Embedding(n_iterations + 1, 128) + + def objective_factory( + lambda_parameters: torch.Tensor, + image: torch.Tensor, + *parameter_reg: torch.Tensor, + ): + dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel + reg = mrpro.operators.ProximableFunctionalSeparableSum( + *[mrpro.operators.functionals.L2NormSquared(r) for r in parameter_reg] + ) + return dc + lambda_parameters * reg + + self.nonlinear_solver = mrpro.operators.OptimizerOp( + objective_factory, + lambda _l, _i, *parameter_reg: parameter_reg, + ) + + def get_linear_solver(self, gram: mrpro.operators.LinearOperator): + def operator_factory( + lambda_image: torch.Tensor, + lambda_q: torch.Tensor, + *_, + ): + return gram + lambda_image + lambda_q + + def rhs_factory( + lambda_image: torch.Tensor, + lambda_q: torch.Tensor, + image_reg: torch.Tensor, + signal: torch.Tensor, + zero_filled_image: torch.Tensor, + ): + return (zero_filled_image + lambda_image * image_reg + lambda_q * signal,) + + return mrpro.operators.ConjugateGradientOp( + operator_factory=operator_factory, + rhs_factory=rhs_factory, + ) + + def get_parameter_reg(self, image: torch.Tensor, iteration: int = 0) -> tuple[torch.Tensor, ...]: + image = einops.rearrange( + torch.view_as_real(image), + 'batch t 1 1 y x complex-> batch (t complex) y x', + ) + cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None] + parameters = self.parameter_net(image.contiguous(), cond=cond) + parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x') + i = 0 + result = [] + for is_complex in self._parameter_is_complex: + if is_complex: + result.append(torch.complex(parameters[i], parameters[i + 1])) + i += 2 + else: + result.append(parameters[i]) + i += 1 + return tuple(result) + + def get_image_reg(self, image: torch.Tensor, iteration: int = 0) -> torch.Tensor: + batch = image.shape[0] + image = einops.rearrange( + torch.view_as_real(image), + 'batch t 1 1 y x complex-> (batch t) complex y x', + ) + cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None] + image = image + self.image_net(image.contiguous(), cond=cond) + image = einops.rearrange(image, '(batch t) complex y x-> batch t 1 1 y x complex', batch=batch) + return torch.view_as_complex(image.contiguous()) + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): + csm_op = csm.as_operator() + fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) + acquisition_op = fourier_op @ csm_op + gram = acquisition_op.gram + (zero_filled_image,) = acquisition_op.H(kdata.data) + images = list(mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2)) + parameters = [self.get_parameter_reg(images[-1], 0)] + linear_solver = self.get_linear_solver(gram) + + for i, (lambda_image, lambda_q, lambda_parameter) in enumerate(self.softplus(self.lambdas_raw)): + image_reg = self.get_image_reg(images[-1], i + 1) + (signal,) = self.signalmodel(*parameters[-1]) + images.extend(linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image)) + parameters_reg = self.get_parameter_reg(images[-1], i + 1) + parameters.append(self.nonlinear_solver(lambda_parameter, images[-1], *parameters_reg)) + if self.constraints_op is not None: + parameters = [self.constraints_op(*p) for p in parameters] + return images, parameters + + +# def validation_step(self, batch: BatchType, batch_idx: int) -> None: +# """Validate. + +# Needs to be adapted for other signal models than Saturation Recovery. +# """ +# images, parameters = self(batch['kdata'], batch['csm']) +# loss = self.loss(parameters, batch) + +# pred_m0, pred_t1 = parameters[-1] +# target_t1, target_m0 = batch['t1'], batch['m0'] +# mask = batch['mask'] +# batch_size = len(batch['mask']) +# (ssim_t1,) = mrpro.operators.functionals.SSIM(target_t1, mask)(pred_t1) +# (l1_t1,) = mrpro.operators.functionals.L1Norm(target_t1, mask)(pred_t1) +# (l1_m0,) = mrpro.operators.functionals.L1Norm(target_m0, mask)(pred_m0) +# self.log('val/ssim_t1', ssim_t1, on_epoch=True, sync_dist=True, batch_size=batch_size) +# self.log('val/l1_t1', l1_t1, on_epoch=True, sync_dist=True, batch_size=batch_size) +# self.log('val/l1_m0', l1_m0, on_epoch=True, sync_dist=True, batch_size=batch_size) +# self.log('val/loss', loss, on_epoch=True, sync_dist=True, batch_size=batch_size) + +# if batch_idx == 0: +# self.validation_step_outputs['target_t1'].append(batch['t1']) +# self.validation_step_outputs['pred_t1'].append(pred_t1) +# self.validation_step_outputs['pred_m0'].append(pred_m0) +# self.validation_step_outputs['target_m0'].append(target_m0) +# self.validation_step_outputs['mask'].append(batch['mask']) +# baseline_m0, baseline_t1 = self.baseline(batch['kdata'], batch['csm']) +# self.validation_step_outputs['baseline_t1'].append(baseline_t1) +# self.validation_step_outputs['baseline_m0'].append(baseline_m0) + +# def on_validation_epoch_end(self): +# """Validate. + +# Needs to be adapted for other signal models than Saturation Recovery. +# """ +# outputs = {k: torch.cat(v) for k, v in self.validation_step_outputs.items()} +# self.validation_step_outputs.clear() +# outputs = cast(dict[str, torch.Tensor], self.all_gather(outputs)) + +# if not self.trainer.is_global_zero: +# return +# outputs = {k: v.flatten(0, 1).cpu() for k, v in outputs.items()} + +# samples = len(outputs['mask']) +# fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 16)) + +# for i in range(samples): +# self.result_plot( +# outputs['target_t1'][i], +# outputs['pred_t1'][i], +# outputs['mask'][i], +# axes[:, i], +# outputs['baseline_t1'][i], +# '$T_1$ (s)', +# ) +# fig.suptitle(f'$T_1$ Epoch {self.current_epoch}') +# self.logger.run['val/images/t1'].log(fig) +# plt.close(fig) + +# fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 12)) +# for i in range(samples): +# self.result_plot( +# outputs['target_m0'][i].abs(), +# outputs['pred_m0'][i].abs(), +# outputs['mask'][i], +# axes[:, i], +# outputs['baseline_m0'][i].abs(), +# '$|M_0|$ (a.u.)', +# ) +# fig.suptitle(f'$|M_0|$ Epoch {self.current_epoch}') +# self.logger.run['val/images/m0'].log(fig) +# plt.close(fig) + +# def result_plot( +# self, +# target: torch.Tensor, +# pred: torch.Tensor, +# mask: torch.Tensor, +# axes: Sequence[plt.Axes], +# baseline: torch.Tensor, +# label: str, +# ) -> None: +# """Plot the results.""" +# target = target.squeeze().numpy() +# pred = pred.squeeze().detach().numpy() +# mask = mask.squeeze().detach().numpy().astype(bool) +# baseline = baseline.squeeze().detach().numpy() + +# target[~mask] = np.nan +# pred[~mask] = np.nan +# baseline[~mask] = np.nan +# difference = (target - pred) / target * 100 +# vmax = np.nanmax(target) + +# im0 = axes[0].imshow(target, vmin=0, vmax=vmax) +# axes[0].set_title('Ground Truth') +# axes[0].axis('off') +# plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04, label=label) + +# im1 = axes[1].imshow(baseline, vmin=0, vmax=vmax) +# axes[1].set_title('SENSE + Regression') +# axes[1].axis('off') +# plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04, label=label) + +# im2 = axes[2].imshow(pred, vmin=0, vmax=vmax) +# axes[2].set_title('PINQI') +# axes[2].axis('off') +# plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04, label=label) + +# diff_vmax = np.nanpercentile(np.abs(difference), 90) +# im3 = axes[3].imshow(difference, cmap='coolwarm', vmin=-diff_vmax, vmax=diff_vmax) +# axes[3].set_title('rel. Error') +# axes[3].axis('off') +# plt.colorbar(im3, ax=axes[3], fraction=0.046, pad=0.04, label='%') + + +# %% +# As a baseline methods for comparision, we use a simple non-learned approach. We reconstruct the qualitative images at different saturation times using iterative SENSE. +# We then perform a constrained non-linear least squares regression usingL-BFGS to obtain the parameter maps. +# %% +def baseline_solution( + signalmodel: mrpro.operators.SignalModel, + constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp, + parameter_is_complex: Sequence[bool], + kdata: mrpro.data.KData, + csm: mrpro.data.CsmData, +) -> tuple[torch.Tensor, ...]: + """Compute a baseline solution using SENSE + Regression.""" + sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(kdata, csm=csm) + images = sense(kdata) + objective = mrpro.operators.functionals.L2NormSquared(images.data) @ signalmodel @ constraints_op + initial_values = tuple( + torch.zeros(images.shape[1:], device=images.device, dtype=torch.complex64 if is_complex else torch.float32) + for is_complex in parameter_is_complex + ) + solution = constraints_op(*mrpro.algorithms.optimizers.lbfgs(objective, initial_values)) + return solution + + +# %% +data_folder = Path('/home/zimmer08/.cache/mrpro/brainweb') + +signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 8.0)) +constraints_op = mrpro.operators.ConstraintsOp( + bounds=( + (-2, 2), # M0 in [-2, 2] + (0.01, 6.0), # T1 is constrained between 10 ms and 6 s + ) +) +n_images = len(signalmodel.saturation_time) +parameter_is_complex = [True, False] + + +dataset = torch.utils.data.Subset( + Dataset( + folder=data_folder, + signalmodel=signalmodel, + n_images=n_images, + size=192, + acceleration=8, + n_coils=8, + max_noise=0.05, + orientation=('axial',), + random=False, + ), + list(range(500)), +) +# %% +checkpoint = torch.load('last.ckpt', map_location='cpu') +hyper_parameters = checkpoint['hyper_parameters'] + + +pinqi = PINQI( + signalmodel=signalmodel, + constraints_op=constraints_op, + parameter_is_complex=parameter_is_complex, + n_images=n_images, + n_iterations=hyper_parameters['n_iterations'], + n_features_parameter_net=hyper_parameters['n_features_parameter_net'], + n_features_image_net=hyper_parameters['n_features_image_net'], +) +state_dict = { + k.replace('pinqi.', '').replace('_orig_mod.', ''): v + for k, v in checkpoint['state_dict'].items() + if 'baseline' not in k +} +pinqi.load_state_dict(state_dict) +# %% +batch = dataset[40] +csm, kdata = batch['csm'], batch['kdata'] + +if torch.cuda.is_available(): + pinqi, csm, kdata = pinqi.cuda(), csm.cuda(), kdata.cuda() +images, parameters = pinqi(kdata[None], csm[None]) +with torch.no_grad(): + predicted_m0, predicted_t1 = (p.cpu().detach().squeeze() for p in parameters[-1]) +baseline_m0, baseline_t1 = baseline_solution(signalmodel, constraints_op, parameter_is_complex, kdata, csm) +# %% +(ssim_t1,) = mrpro.operators.functionals.SSIM(batch['t1'][None], batch['mask'][None])(predicted_t1[None]) +(mse_t1,) = mrpro.operators.functionals.MSE(batch['t1'], batch['mask'])(predicted_t1) + +(mse_baseline,) = mrpro.operators.functionals.MSE(batch['t1'], batch['mask'])(baseline_t1) +nrmse_t1 = torch.sqrt(mse_t1) / batch['t1'][batch['mask']].max() +(ssim_baseline,) = mrpro.operators.functionals.SSIM(batch['t1'][None], batch['mask'][None])(baseline_t1[None]) +nrmse_baseline = torch.sqrt(mse_baseline) / batch['t1'][batch['mask']].max() + + +# %% +import matplotlib.pyplot as plt +from cmap import Colormap + +cmap = Colormap('lipari').to_matplotlib() + +print(f'SSIM: {ssim_baseline.item():.4f}, NRMSE: {nrmse_baseline.item():.4f}') +print(f'SSIM: {ssim_t1.item():.4f}, NRMSE: {nrmse_t1.item():.4f}') + + +fig, ax = plt.subplots(1, 4, gridspec_kw={'width_ratios': [1, 1, 1, 0.075]}, figsize=(6, 2)) +baseline_t1 = baseline_t1.squeeze() +baseline_t1[~batch['mask']] = torch.nan +ax[0].imshow(baseline_t1, vmin=0, vmax=2, cmap=cmap) +ax[0].axis('off') +ax[0].set_title('SENSE + Regression') +ax[0].text( + 0.5, + -0.05, + f'SSIM: {ssim_baseline.item():.2f}', + color='black', + horizontalalignment='center', + verticalalignment='top', + transform=ax[0].transAxes, +) +predicted_t1 = predicted_t1.squeeze() +predicted_t1[~batch['mask']] = torch.nan +ax[1].imshow(predicted_t1, vmin=0, vmax=2, cmap=cmap) +ax[1].axis('off') +ax[1].set_title('PINQI') +ax[1].text( + 0.5, + -0.05, + f'SSIM: {ssim_t1.item():.2f}', + color='black', + horizontalalignment='center', + verticalalignment='top', + transform=ax[1].transAxes, +) + +target_t1 = batch['t1'].squeeze() +target_t1[~batch['mask']] = torch.nan +im = ax[2].imshow(target_t1, vmin=0, vmax=2, cmap=cmap) +ax[2].axis('off') +ax[2].set_title('Ground Truth') + +plt.colorbar(im, cax=ax[3], label='$T_1$ (s)') +plt.savefig('/home/zimmer08/code/mrpro/examples/scripts/pinqi_t1.pdf', bbox_inches='tight') +plt.show() + + +# %% + + +# %% +# %% From 1be7d321a6c2f3e6767cf754068ad10ebf607393 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 10 Jul 2025 17:38:49 +0200 Subject: [PATCH 093/205] update --- examples/scripts/apply_pinqi.py | 131 ++------------------------------ 1 file changed, 8 insertions(+), 123 deletions(-) diff --git a/examples/scripts/apply_pinqi.py b/examples/scripts/apply_pinqi.py index 075a5bac8..daad3137c 100644 --- a/examples/scripts/apply_pinqi.py +++ b/examples/scripts/apply_pinqi.py @@ -231,122 +231,6 @@ def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): return images, parameters -# def validation_step(self, batch: BatchType, batch_idx: int) -> None: -# """Validate. - -# Needs to be adapted for other signal models than Saturation Recovery. -# """ -# images, parameters = self(batch['kdata'], batch['csm']) -# loss = self.loss(parameters, batch) - -# pred_m0, pred_t1 = parameters[-1] -# target_t1, target_m0 = batch['t1'], batch['m0'] -# mask = batch['mask'] -# batch_size = len(batch['mask']) -# (ssim_t1,) = mrpro.operators.functionals.SSIM(target_t1, mask)(pred_t1) -# (l1_t1,) = mrpro.operators.functionals.L1Norm(target_t1, mask)(pred_t1) -# (l1_m0,) = mrpro.operators.functionals.L1Norm(target_m0, mask)(pred_m0) -# self.log('val/ssim_t1', ssim_t1, on_epoch=True, sync_dist=True, batch_size=batch_size) -# self.log('val/l1_t1', l1_t1, on_epoch=True, sync_dist=True, batch_size=batch_size) -# self.log('val/l1_m0', l1_m0, on_epoch=True, sync_dist=True, batch_size=batch_size) -# self.log('val/loss', loss, on_epoch=True, sync_dist=True, batch_size=batch_size) - -# if batch_idx == 0: -# self.validation_step_outputs['target_t1'].append(batch['t1']) -# self.validation_step_outputs['pred_t1'].append(pred_t1) -# self.validation_step_outputs['pred_m0'].append(pred_m0) -# self.validation_step_outputs['target_m0'].append(target_m0) -# self.validation_step_outputs['mask'].append(batch['mask']) -# baseline_m0, baseline_t1 = self.baseline(batch['kdata'], batch['csm']) -# self.validation_step_outputs['baseline_t1'].append(baseline_t1) -# self.validation_step_outputs['baseline_m0'].append(baseline_m0) - -# def on_validation_epoch_end(self): -# """Validate. - -# Needs to be adapted for other signal models than Saturation Recovery. -# """ -# outputs = {k: torch.cat(v) for k, v in self.validation_step_outputs.items()} -# self.validation_step_outputs.clear() -# outputs = cast(dict[str, torch.Tensor], self.all_gather(outputs)) - -# if not self.trainer.is_global_zero: -# return -# outputs = {k: v.flatten(0, 1).cpu() for k, v in outputs.items()} - -# samples = len(outputs['mask']) -# fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 16)) - -# for i in range(samples): -# self.result_plot( -# outputs['target_t1'][i], -# outputs['pred_t1'][i], -# outputs['mask'][i], -# axes[:, i], -# outputs['baseline_t1'][i], -# '$T_1$ (s)', -# ) -# fig.suptitle(f'$T_1$ Epoch {self.current_epoch}') -# self.logger.run['val/images/t1'].log(fig) -# plt.close(fig) - -# fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 12)) -# for i in range(samples): -# self.result_plot( -# outputs['target_m0'][i].abs(), -# outputs['pred_m0'][i].abs(), -# outputs['mask'][i], -# axes[:, i], -# outputs['baseline_m0'][i].abs(), -# '$|M_0|$ (a.u.)', -# ) -# fig.suptitle(f'$|M_0|$ Epoch {self.current_epoch}') -# self.logger.run['val/images/m0'].log(fig) -# plt.close(fig) - -# def result_plot( -# self, -# target: torch.Tensor, -# pred: torch.Tensor, -# mask: torch.Tensor, -# axes: Sequence[plt.Axes], -# baseline: torch.Tensor, -# label: str, -# ) -> None: -# """Plot the results.""" -# target = target.squeeze().numpy() -# pred = pred.squeeze().detach().numpy() -# mask = mask.squeeze().detach().numpy().astype(bool) -# baseline = baseline.squeeze().detach().numpy() - -# target[~mask] = np.nan -# pred[~mask] = np.nan -# baseline[~mask] = np.nan -# difference = (target - pred) / target * 100 -# vmax = np.nanmax(target) - -# im0 = axes[0].imshow(target, vmin=0, vmax=vmax) -# axes[0].set_title('Ground Truth') -# axes[0].axis('off') -# plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04, label=label) - -# im1 = axes[1].imshow(baseline, vmin=0, vmax=vmax) -# axes[1].set_title('SENSE + Regression') -# axes[1].axis('off') -# plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04, label=label) - -# im2 = axes[2].imshow(pred, vmin=0, vmax=vmax) -# axes[2].set_title('PINQI') -# axes[2].axis('off') -# plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04, label=label) - -# diff_vmax = np.nanpercentile(np.abs(difference), 90) -# im3 = axes[3].imshow(difference, cmap='coolwarm', vmin=-diff_vmax, vmax=diff_vmax) -# axes[3].set_title('rel. Error') -# axes[3].axis('off') -# plt.colorbar(im3, ax=axes[3], fraction=0.046, pad=0.04, label='%') - - # %% # As a baseline methods for comparision, we use a simple non-learned approach. We reconstruct the qualitative images at different saturation times using iterative SENSE. # We then perform a constrained non-linear least squares regression usingL-BFGS to obtain the parameter maps. @@ -448,7 +332,7 @@ def baseline_solution( print(f'SSIM: {ssim_t1.item():.4f}, NRMSE: {nrmse_t1.item():.4f}') -fig, ax = plt.subplots(1, 4, gridspec_kw={'width_ratios': [1, 1, 1, 0.075]}, figsize=(6, 2)) +fig, ax = plt.subplots(1, 5, gridspec_kw={'width_ratios': [1, 1, 1, 0.01, 0.075], 'wspace': 0.0}, figsize=(5, 2)) baseline_t1 = baseline_t1.squeeze() baseline_t1[~batch['mask']] = torch.nan ax[0].imshow(baseline_t1, vmin=0, vmax=2, cmap=cmap) @@ -456,7 +340,7 @@ def baseline_solution( ax[0].set_title('SENSE + Regression') ax[0].text( 0.5, - -0.05, + -0.00, f'SSIM: {ssim_baseline.item():.2f}', color='black', horizontalalignment='center', @@ -470,12 +354,13 @@ def baseline_solution( ax[1].set_title('PINQI') ax[1].text( 0.5, - -0.05, + -0.0, f'SSIM: {ssim_t1.item():.2f}', color='black', horizontalalignment='center', verticalalignment='top', transform=ax[1].transAxes, + size=10, ) target_t1 = batch['t1'].squeeze() @@ -483,10 +368,10 @@ def baseline_solution( im = ax[2].imshow(target_t1, vmin=0, vmax=2, cmap=cmap) ax[2].axis('off') ax[2].set_title('Ground Truth') - -plt.colorbar(im, cax=ax[3], label='$T_1$ (s)') -plt.savefig('/home/zimmer08/code/mrpro/examples/scripts/pinqi_t1.pdf', bbox_inches='tight') -plt.show() +fig.tight_layout() +ax[-2].axis('off') +plt.colorbar(im, cax=ax[-1], label='$T_1$ (s)') +fig.savefig('/home/zimmer08/code/mrpro/examples/scripts/pinqi_t1_2.pdf', bbox_inches='tight') # %% From 3386923bc91314407bd9d3c430090ccee8da892c Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 11 Jul 2025 13:00:51 +0200 Subject: [PATCH 094/205] Squashed commit of the following: commit db742b0588aa3c6b50c3a493a754b751956b8a2f Merge: aaa68e97 73b40335 Author: Felix Zimmermann Date: Fri Jul 11 12:21:02 2025 +0200 Merge branch 'main' into nn commit aaa68e97317797944ba353b8db1a0bab6a46f649 Author: Felix Zimmermann Date: Mon Jun 23 22:37:26 2025 +0200 separable commit fb6eb419a6d42c2443ddd8b62ee778a2d76af81a Author: Felix Zimmermann Date: Sun Jun 22 12:51:23 2025 +0200 update commit 10f19949ec3ec9b5f2619e3015260dc687e28900 Author: Felix Zimmermann Date: Wed Jun 4 02:14:06 2025 +0200 Refactor import statements and enhance EMA documentation - Updated import statements in join.py and Uformer.py to ensure consistent usage of the CondMixin class. - Removed unnecessary import of CondMixin in SpatialTransformerBlock.py for cleaner code. - Added a docstring to ema.py to clarify the purpose of the Exponential Moving Average (EMA) dictionary. commit c7e588ec223cca3ebf2f7aec3ac148ac498b6b17 Author: Felix Zimmermann Date: Wed Jun 4 02:13:22 2025 +0200 Refactor AttentionGate and DropPath modules for improved functionality - Updated AttentionGate to ensure consistent interpolation of gate tensors regardless of shape. - Modified DropPath to conditionally scale the mask based on the keep probability, enhancing flexibility in dropout behavior. - Enhanced _fix_shapes function in join.py to support new interpolation modes (linear and nearest) for better tensor shape handling. - Improved documentation in Concat class to reflect new interpolation options and ensure clarity in parameter descriptions. commit e20d6f7edaa50ef6de13877fea649eac940cf777 Author: Felix Zimmermann Date: Tue Jun 3 23:31:04 2025 +0200 wip commit 4c1ec0feaa67ca540f08126c35e756b88dfb0dba Author: Felix Zimmermann Date: Tue Jun 3 17:20:25 2025 +0200 Refactor parameter documentation in encoding and normalization modules - Updated parameter documentation in FourierFeatures, AbsolutePositionEncoding, GEGLU, and LayerNorm classes to remove type hints from docstrings for consistency. - Enhanced clarity in parameter descriptions while maintaining the overall structure of the documentation. commit ea3109ef8caff466e4d4cb680cc65f992e2489f8 Author: Felix Zimmermann Date: Mon Jun 2 22:56:20 2025 +0200 update commit 6b942ca1dc336193a61f04ffe4619eae434292d9 Author: Felix Zimmermann Date: Mon Jun 2 17:22:00 2025 +0200 wip commit 59569dc72c6c567fcdef479f4ce721c1a33c39d4 Author: Felix Zimmermann Date: Mon Jun 2 16:25:15 2025 +0200 - Updated AttentionGate to include a new 'concatenate' parameter, allowing for optional concatenation of gated and gate signals in the channel dimension. - Adjusted ResBlock to modify the rezero parameter for better stability during training. - Refactored forward methods in both AttentionGate and ResBlock to ensure compatibility with the new features and maintain clarity in tensor operations. - Updated import statements and class references in UNet and related modules to reflect the new AttentionGatedUNet class. commit dca726b02843484b9f6dcc5bad22c78b4d6fcc41 Author: Felix Zimmermann Date: Mon Jun 2 02:25:23 2025 +0200 Enhance UNet architecture with improved attention handling and modularity - Updated UNet class to include configurable attention depths and encoder blocks per scale, enhancing flexibility in model design. - Introduced new attention block functionality and refined block initialization for better clarity and maintainability. - Adjusted forward method to accept conditioning tensors explicitly, improving modularity in the encoder-decoder structure. - Integrated GroupNorm and SiLU into the final layers for improved performance and consistency. commit b1ff7f84991d2dfbcd1baf0c4eb9314464dfb7fe Author: Felix Zimmermann Date: Mon Jun 2 02:25:10 2025 +0200 Refactor neural network modules to standardize feature dimension handling - Updated parameter names from 'channel_last' to 'features_last' across multiple modules for consistency. - Adjusted related logic in GEGLU, LayerNorm, LinearSelfAttention, NeighborhoodSelfAttention, RMSNorm, and BasicTransformerBlock to reflect the new parameter naming. - Enhanced clarity in the handling of feature dimensions, improving modularity and maintainability of the codebase. commit 376be37273f3e3369f89871d164dc4c8cc820b2b Author: Felix Zimmermann Date: Mon Jun 2 02:24:51 2025 +0200 Add Upsample module for tensor resizing functionality - Introduced the Upsample class to facilitate tensor upsampling with configurable scale factors and modes (nearest, linear). - Implemented the forward method to compute new tensor sizes based on the specified dimensions and scale factor, enhancing flexibility in tensor manipulation. commit 24bfbc9306db963ca80c7e6b0ea931dc4e1fed04 Author: Felix Zimmermann Date: Fri May 23 01:46:07 2025 +0200 Refactor ZeroPadOp and pad_or_crop utility for improved functionality and clarity - Updated import statement in ZeroPadOp to directly import pad_or_crop function. - Enhanced pad_or_crop function to include a new 'mode' parameter for padding options, improving flexibility in data manipulation. commit afd7a4573c26567da47d61f92fba77e32ca657c9 Author: Felix Zimmermann Date: Fri May 23 01:45:39 2025 +0200 Refactor FiLM and Uformer modules for improved clarity and functionality - Simplified the FiLM class by removing unnecessary Sequential and Identity layers. - Updated the Uformer architecture to conditionally include FiLM layers based on the provided conditioning dimension. - Enhanced the forward method of LeWinTransformerBlock to accept conditioning tensors, improving modularity and flexibility. commit eafbfc62a6432066e01c4cd8dcb03fbcecf7a7ab Author: Felix Zimmermann Date: Fri May 23 01:45:24 2025 +0200 Add SpatialTransformerBlock and integrate into UNet architecture - Introduced SpatialTransformerBlock for enhanced attention mechanisms. - Updated MultiHeadAttention to support cross-attention channels. - Modified UNet to include SpatialTransformerBlock in encoder and decoder stages based on specified attention depths. - Improved modularity and flexibility of the UNet structure. commit 4c497345b111d92e4b239391abafb712a8212f01 Author: Felix Zimmermann Date: Thu May 22 16:48:14 2025 +0200 Refactor Restormer and Uformer networks to utilize UNetEncoder and UNetDecoder classes for improved modularity and clarity. Update import statements and streamline block initialization for better readability and maintainability. commit f3aaa6ab527e765a0f89081ef27aad2d3fe890f5 Author: Felix Zimmermann Date: Thu May 22 10:00:41 2025 +0200 Refactor EfficientViTBlock and Encoder/Decoder stages to use dynamic head counts based on width; improve sequential structure in PixelUnshuffleDownsample. commit 3757b1fa3a883c5413f1f94c24c6068bbb30335b Author: Felix Zimmermann Date: Thu May 22 02:11:50 2025 +0200 Add EMADict class for Exponential Moving Average functionality and update imports - Introduced EMADict class to maintain exponential moving averages for various data types. - Updated __all__ lists in utils and nn modules to include new EMADict class. - Added tests for EMADict to ensure correct functionality and error handling. commit 7d608b5de0efaf063a36b137ac112106146f02cd Author: Felix Zimmermann Date: Thu May 22 00:04:31 2025 +0200 Refactor method signatures in neural network modules to use keyword-only arguments for conditioning tensors. Update import statements for clarity and add new SwinIR network to the module. Enhance VAE with a mode method for improved functionality. commit 4f6a603a53e05323be4219c049c84a479e139f6d Author: Felix Zimmermann Date: Wed May 21 18:19:46 2025 +0200 Enhance conversion functions between Linear and Conv layers by refining method signatures and improving test structure. Update parameter names for clarity and ensure compatibility with multiple input tensors. commit 01881fef9076635036b932e2a94c01e90145cbb9 Author: Felix Zimmermann Date: Wed May 21 16:04:02 2025 +0200 Refactor imports to use lowercase 'ndmodules' and update method signatures for better compatibility with multiple input tensors. Introduce conversion functions between Linear and Conv layers, along with corresponding tests. commit d626bbbdd420c3b0a59c1deef62444afcf784170 Author: Felix Zimmermann Date: Wed May 21 01:26:18 2025 +0200 update commit 52c8630c89ce48b82120ab4a5f22597389cb5778 Author: Felix Zimmermann Date: Tue May 20 22:43:25 2025 +0200 update commit 3d259bbc3122f3d312f9a0a628c994e452820d81 Author: Felix Zimmermann Date: Tue May 20 02:09:41 2025 +0200 update commit 7e9d12183fae6d953967b3409a64e729848274bf Author: Felix Zimmermann Date: Mon May 19 21:52:03 2025 +0200 update commit 62e04a1f4685464bda7a35f2d15ed35e7a1fc2ff Author: Felix Zimmermann Date: Mon May 19 17:20:20 2025 +0200 update commit 33d95572ecb7e46cf5049d31a26ed5e4d6d69885 Author: Felix Zimmermann Date: Mon May 19 14:41:48 2025 +0200 update commit 97115e77559e36e4adf02c09cba5b96cc53c4100 Author: Felix Zimmermann Date: Mon May 19 02:23:57 2025 +0200 update commit 7f37fa99ba4392cd89a52cca08b5e18ff0fe50a3 Author: Felix Zimmermann Date: Mon May 19 01:52:18 2025 +0200 update commit 3d0122093a1acf71bbeb167005c90398d2bd7eca Author: Felix Zimmermann Date: Sun May 18 16:30:43 2025 +0200 update commit b6a1db3c809c8f2338e7891ab897406da9a0a08f Author: Felix Zimmermann Date: Fri May 16 14:02:00 2025 +0200 doc commit 912d7c8ecfa0b1a54d9e1c53f819b0dc550067c2 Author: Felix Zimmermann Date: Fri May 16 00:21:58 2025 +0200 update commit 3ae37d1ceaf5b8f1ff22e98517b0873efb04e07f Author: Felix Zimmermann Date: Thu May 15 02:33:40 2025 +0200 update commit 54a66b61321c58dd95c6056030ac7e32986608f6 Author: Felix Zimmermann Date: Wed May 14 17:17:38 2025 +0200 fix commit cf4be7f27b769f1a6070bf0169037ad996a452e9 Author: Felix Zimmermann Date: Wed May 14 17:14:33 2025 +0200 uformer commit 9cfae55b33c34de10f606019ca3e0d03fc560ee0 Author: Felix Zimmermann Date: Wed May 14 02:22:29 2025 +0200 update commit 420cdc1b29a6604a91eec2b3578d6b5b6447ad83 Author: Felix Zimmermann Date: Wed May 14 01:18:30 2025 +0200 update commit 633682b1959c07dec272277ade9ea50c5b8898f5 Author: Felix Zimmermann Date: Wed May 14 00:53:02 2025 +0200 update commit c39a9afcb1849005e635c16c74344296acc511d5 Author: Felix Zimmermann Date: Tue May 13 22:27:48 2025 +0200 update commit 26467bf5714f6ff57883c714fed9e26583698aeb Author: Felix Zimmermann Date: Tue May 13 21:36:23 2025 +0200 update commit 7e83be7af248fcc288d69b435de48db1e8ad703c Author: Felix Zimmermann Date: Mon May 12 23:03:54 2025 +0200 update commit a458855117baf3b21bdc0738230ecf60f7256bb2 Author: Felix Zimmermann Date: Sat May 10 21:09:00 2025 +0200 start commit 904f3c941e308fad00dbb0ec857a7c4b23b9060c Author: Felix Zimmermann Date: Sat May 10 21:05:34 2025 +0200 fix doc --- pyproject.toml | 2 + src/mrpro/__init__.py | 6 +- src/mrpro/nn/AttentionGate.py | 72 +++ src/mrpro/nn/ComplexAsChannel.py | 59 ++ src/mrpro/nn/CondMixin.py | 22 + src/mrpro/nn/DropPath.py | 55 ++ src/mrpro/nn/FiLM.py | 54 ++ src/mrpro/nn/GEGLU.py | 56 ++ src/mrpro/nn/GluMBConvResBlock.py | 118 ++++ src/mrpro/nn/GroupNorm.py | 47 ++ src/mrpro/nn/LayerNorm.py | 79 +++ src/mrpro/nn/LinearSelfAttention.py | 109 ++++ src/mrpro/nn/MultiHeadAttention.py | 87 +++ src/mrpro/nn/NeighborhoodSelfAttention.py | 192 ++++++ src/mrpro/nn/PixelShuffle.py | 226 +++++++ src/mrpro/nn/RMSNorm.py | 53 ++ src/mrpro/nn/ResBlock.py | 70 +++ src/mrpro/nn/Residual.py | 45 ++ src/mrpro/nn/RoPE.py | 141 +++++ src/mrpro/nn/Sequential.py | 60 ++ src/mrpro/nn/ShiftedWindowAttention.py | 103 ++++ src/mrpro/nn/SpatialTransformerBlock.py | 148 +++++ src/mrpro/nn/SqueezeExcitation.py | 57 ++ src/mrpro/nn/TransposedAttention.py | 76 +++ src/mrpro/nn/Upsample.py | 70 +++ src/mrpro/nn/__init__.py | 44 ++ src/mrpro/nn/convert_linear_conv.py | 100 +++ src/mrpro/nn/encoding.py | 115 ++++ src/mrpro/nn/join.py | 125 ++++ src/mrpro/nn/ndmodules.py | 176 ++++++ src/mrpro/nn/nets/CNN.py | 60 ++ src/mrpro/nn/nets/DCAE.py | 280 +++++++++ src/mrpro/nn/nets/Restormer.py | 208 +++++++ src/mrpro/nn/nets/SwinIR.py | 247 ++++++++ src/mrpro/nn/nets/UNet.py | 706 ++++++++++++++++++++++ src/mrpro/nn/nets/Uformer.py | 230 +++++++ src/mrpro/nn/nets/VAE.py | 64 ++ src/mrpro/nn/nets/__init__.py | 16 + src/mrpro/operators/ZeroPadOp.py | 2 +- src/mrpro/utils/__init__.py | 5 +- src/mrpro/utils/ema.py | 94 +++ src/mrpro/utils/pad_or_crop.py | 9 +- src/mrpro/utils/to_tuple.py | 36 ++ tests/nn/test_attentiongate.py | 41 ++ tests/nn/test_complexaschannel.py | 30 + tests/nn/test_convert_linear_conv.py | 150 +++++ tests/nn/test_film.py | 38 ++ tests/nn/test_groupnorm32.py | 34 ++ tests/nn/test_resblock.py | 40 ++ tests/nn/test_sequential.py | 41 ++ tests/nn/test_shiftedwindowattention.py | 37 ++ tests/nn/test_sqeezeexcitation.py | 26 + tests/nn/test_transposedattention.py | 36 ++ tests/utils/test_ema.py | 89 +++ 54 files changed, 5077 insertions(+), 9 deletions(-) create mode 100644 src/mrpro/nn/AttentionGate.py create mode 100644 src/mrpro/nn/ComplexAsChannel.py create mode 100644 src/mrpro/nn/CondMixin.py create mode 100644 src/mrpro/nn/DropPath.py create mode 100644 src/mrpro/nn/FiLM.py create mode 100644 src/mrpro/nn/GEGLU.py create mode 100644 src/mrpro/nn/GluMBConvResBlock.py create mode 100644 src/mrpro/nn/GroupNorm.py create mode 100644 src/mrpro/nn/LayerNorm.py create mode 100644 src/mrpro/nn/LinearSelfAttention.py create mode 100644 src/mrpro/nn/MultiHeadAttention.py create mode 100644 src/mrpro/nn/NeighborhoodSelfAttention.py create mode 100644 src/mrpro/nn/PixelShuffle.py create mode 100644 src/mrpro/nn/RMSNorm.py create mode 100644 src/mrpro/nn/ResBlock.py create mode 100644 src/mrpro/nn/Residual.py create mode 100644 src/mrpro/nn/RoPE.py create mode 100644 src/mrpro/nn/Sequential.py create mode 100644 src/mrpro/nn/ShiftedWindowAttention.py create mode 100644 src/mrpro/nn/SpatialTransformerBlock.py create mode 100644 src/mrpro/nn/SqueezeExcitation.py create mode 100644 src/mrpro/nn/TransposedAttention.py create mode 100644 src/mrpro/nn/Upsample.py create mode 100644 src/mrpro/nn/__init__.py create mode 100644 src/mrpro/nn/convert_linear_conv.py create mode 100644 src/mrpro/nn/encoding.py create mode 100644 src/mrpro/nn/join.py create mode 100644 src/mrpro/nn/ndmodules.py create mode 100644 src/mrpro/nn/nets/CNN.py create mode 100644 src/mrpro/nn/nets/DCAE.py create mode 100644 src/mrpro/nn/nets/Restormer.py create mode 100644 src/mrpro/nn/nets/SwinIR.py create mode 100644 src/mrpro/nn/nets/UNet.py create mode 100644 src/mrpro/nn/nets/Uformer.py create mode 100644 src/mrpro/nn/nets/VAE.py create mode 100644 src/mrpro/nn/nets/__init__.py create mode 100644 src/mrpro/utils/ema.py create mode 100644 src/mrpro/utils/to_tuple.py create mode 100644 tests/nn/test_attentiongate.py create mode 100644 tests/nn/test_complexaschannel.py create mode 100644 tests/nn/test_convert_linear_conv.py create mode 100644 tests/nn/test_film.py create mode 100644 tests/nn/test_groupnorm32.py create mode 100644 tests/nn/test_resblock.py create mode 100644 tests/nn/test_sequential.py create mode 100644 tests/nn/test_shiftedwindowattention.py create mode 100644 tests/nn/test_sqeezeexcitation.py create mode 100644 tests/nn/test_transposedattention.py create mode 100644 tests/utils/test_ema.py diff --git a/pyproject.toml b/pyproject.toml index 85c92deba..c8f90da87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,7 @@ docs = [ "sphinx-autodoc-typehints>=3, <3.1", "sphinx-copybutton>=0.5, <0.6", "sphinx-last-updated-by-git>=0.3, <0.4", + "snowballstemmer>=2.2, <3.0", ] notebooks = [ "zenodo_get>=2.0", @@ -226,6 +227,7 @@ iy = "iy" arange = "arange" # torch.arange Ba = "Ba" wht = "wht" # Brainweb tissue class +ND = "ND" # Short for N-dimensional [tool.typos.files] extend-exclude = [ diff --git a/src/mrpro/__init__.py b/src/mrpro/__init__.py index 729ae188c..bbd401f1f 100644 --- a/src/mrpro/__init__.py +++ b/src/mrpro/__init__.py @@ -1,10 +1,12 @@ from mrpro._version import __version__ -from mrpro import algorithms, operators, data, phantoms, utils +from mrpro import algorithms, operators, data, phantoms, utils, nn + __all__ = [ "__version__", "algorithms", "data", + "nn", "operators", "phantoms", "utils" -] +] \ No newline at end of file diff --git a/src/mrpro/nn/AttentionGate.py b/src/mrpro/nn/AttentionGate.py new file mode 100644 index 000000000..1d57fe5ee --- /dev/null +++ b/src/mrpro/nn/AttentionGate.py @@ -0,0 +1,72 @@ +"""Attention gate from Attention UNet.""" + +import torch +from torch.nn import Module, ReLU, Sequential, Sigmoid + +from mrpro.nn.ndmodules import ConvND + + +class AttentionGate(Module): + """Attention gate from Attention UNet. + + The attention mechanism from the attention UNet [OKT18]_. + + References + ---------- + ..[OKT18] Oktay, Ozan, et al. "Attention U-net: Learning where to look for the pancreas." MIDL (2018). + https://arxiv.org/abs/1804.03999 + """ + + def __init__(self, dim: int, channels_gate: int, channels_in: int, channels_hidden: int, concatenate: bool = False): + """Initialize the attention gate. + + Parameters + ---------- + dim + The dimension, i.e. 1, 2 or 3. + channels_gate + The number of channels in the gate tensor. + channels_in + The number of channels in the input tensor. + channels_hidden + The number of internal, hidden channels. + concatenate + Whether to concatenate the gated signal with the gate signal in the channel dimension (1) + """ + super().__init__() + self.project_gate = ConvND(dim)(channels_gate, channels_hidden, kernel_size=1) + self.project_x = ConvND(dim)(channels_in, channels_hidden, kernel_size=1) + self.psi = Sequential( + ReLU(), + ConvND(dim)(channels_hidden, 1, kernel_size=1), + Sigmoid(), + ) + self.concatenate = concatenate + + def __call__(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + """Apply the attention gate. + + Parameters + ---------- + x + The input tensor. + gate + The gate tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, gate) + + def forward(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + """Apply the attention gate.""" + projected_gate = self.project_gate(gate) + projected_x = self.project_x(x) + projected_gate = torch.nn.functional.interpolate(projected_gate, size=x.shape[2:], mode='nearest') + alpha = self.psi(projected_gate + projected_x) + x = x * alpha + if self.concatenate: + gate = torch.nn.functional.interpolate(gate, size=x.shape[2:], mode='nearest') + x = torch.cat([x, gate], dim=1) + return x diff --git a/src/mrpro/nn/ComplexAsChannel.py b/src/mrpro/nn/ComplexAsChannel.py new file mode 100644 index 000000000..7c1bec0fd --- /dev/null +++ b/src/mrpro/nn/ComplexAsChannel.py @@ -0,0 +1,59 @@ +"""ComplexAsChannel: handling complex-valued tensors as channels.""" + +import torch +from einops import rearrange +from torch.nn import Module + +from mrpro.nn.CondMixin import CondMixin, call_with_cond + + +class ComplexAsChannel(CondMixin, Module): + """Wrap module to treat complex numbers as a channel dimension.""" + + def __init__(self, module: Module, convert_back: bool = True): + """Initialize the ComplexAsChannel module. + + Wraps a module to treat complex numbers as a channel dimension. + For each complex tensor in the input, real and imaginary parts are concatenated along the channel dimension + before being passed to the wrapped module. + + + Parameters + ---------- + module : Module + The module to wrap. Should output a single real tensor. + convert_back : bool + If True, the output is converted back to a complex tensor. + The output should have a number of channels that is a multiple of 2. + """ + super().__init__() + self.module = module + self.convert_back = convert_back + + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + cond : torch.Tensor | None + The conditioning tensor (if used by the wrapped module) + """ + return super().__call__(*x, cond=cond) + + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module.""" + x_real = [ + rearrange(torch.view_as_real(c), 'batch channel ... complex -> batch (channel complex) ...') + if c.is_complex() + else c + for c in x + ] + + y = call_with_cond(self.module, *x_real, cond=cond) + + if self.convert_back: + y = rearrange(y, 'b (channel complex) ... -> b channel ... complex', complex=2).contiguous() + y = torch.view_as_complex(y) + return y diff --git a/src/mrpro/nn/CondMixin.py b/src/mrpro/nn/CondMixin.py new file mode 100644 index 000000000..6a902c413 --- /dev/null +++ b/src/mrpro/nn/CondMixin.py @@ -0,0 +1,22 @@ +"""Base class for modules using a conditioning.""" + +import torch +from torch.nn import Module + + +def call_with_cond(module: Module, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Call a module with conditioning if it is a CondMixin.""" + if isinstance(module, CondMixin): + return module(*x, cond=cond) + return module(*x) + + +class CondMixin(Module): + """Mixin for modules using a conditioning. + + Used to determine if a module uses a conditioning within a Sequential container. + """ + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module to the input.""" + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/DropPath.py b/src/mrpro/nn/DropPath.py new file mode 100644 index 000000000..b1314904e --- /dev/null +++ b/src/mrpro/nn/DropPath.py @@ -0,0 +1,55 @@ +"""DropPath (stochastic depth).""" + +import torch +from torch.nn import Module + + +class DropPath(Module): + """Drop path or stochastic depth. + + Drops full samples from batch with probability `droprate`. + Should be used in the main path of a Resblock. + + References + ---------- + .. [HUANG16] Huang, G., Sun, Y., Liu, Z., Sedra, D., & Weinberger, K. Q. Deep networks with stochastic depth. + ECCV 2016. https://link.springer.com/chapter/10.1007/978-3-319-46493-0_39 + """ + + def __init__(self, droprate: float = 0.0, scale_by_keep: bool = False): + """Initialize the DropPath module. + + Parameters + ---------- + droprate : float, optional + Drop probability + scale_by_keep : bool, optional + If True, the kept samples are scaled by `1/(1-droprate)` + """ + super().__init__() + self.droprate = droprate + self.scale_by_keep = scale_by_keep + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply DropPath. + + Parameters + ---------- + x : torch.Tensor + Input tensor + + Returns + ------- + Tensor with + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply DropPath.""" + if self.droprate == 0 or not self.training: + return x + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + mask = ((1 - self.droprate) + torch.rand(shape, dtype=x.dtype, device=x.device)).floor_() + if self.scale_by_keep: + mask = mask.div_(1 - self.droprate) + return x * mask diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py new file mode 100644 index 000000000..92780aae3 --- /dev/null +++ b/src/mrpro/nn/FiLM.py @@ -0,0 +1,54 @@ +"""Feature-wise Linear Modulation.""" + +import torch +from torch.nn import Linear, Module + +from mrpro.nn.CondMixin import CondMixin +from mrpro.utils.reshape import unsqueeze_tensors_right + + +class FiLM(CondMixin, Module): + """Feature-wise Linear Modulation. + + Feature-wise Linear Modulation from [FiLM]_ to condition a network on a conditioning tensor. + + + References + ---------- + ..[FiLM] Perez, L., Strub, F., de Vries, H., Dumoulin, V., & Courville, A. "FiLM : Visual reasoning with a general + conditioning layer." AAAI (2018). https://arxiv.org/abs/1709.07871 + """ + + def __init__(self, channels: int, cond_dim: int) -> None: + """Initialize FiLM. + + Parameters + ---------- + channels + The number of channels in the input tensor. + cond_dim + The dimension of the conditioning tensor. + """ + super().__init__() + self.project = Linear(cond_dim, 2 * channels) if cond_dim > 0 else None + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply FiLM. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply FiLM.""" + if cond is None or self.project is None: + return x + scale, shift = self.project(cond).chunk(2, dim=1) + + scale, shift = unsqueeze_tensors_right(scale, shift, ndim=x.ndim) + return x * (1 + scale) + shift diff --git a/src/mrpro/nn/GEGLU.py b/src/mrpro/nn/GEGLU.py new file mode 100644 index 000000000..0310c6a76 --- /dev/null +++ b/src/mrpro/nn/GEGLU.py @@ -0,0 +1,56 @@ +"""Gated linear unit activation function.""" + +import torch +from torch.nn import Linear, Module + + +class GEGLU(Module): + r"""Gated linear unit activation function. + + References + ---------- + ..[GLU] Shazeer, N. (2020). GLU variants improve transformer. https://arxiv.org/abs/2002.05202 + """ + + def __init__(self, in_features: int, out_features: int | None = None, features_last: bool = False): + """Initialize the GEGLU activation function. + + Parameters + ---------- + in_features + The number of input features. + out_features + The number of output features. If None, the number of + output features is the same as the number of input features. + features_last + If True, the channel dimension is the last dimension, else in the second dimension. + """ + super().__init__() + out_features_ = in_features if out_features is None else out_features + self.proj = Linear(in_features, out_features_ * 2) # gate and output stacked + self.features_last = features_last + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the GEGLU activation.""" + if not self.features_last: + x = x.moveaxis(1, -1) + h, gate = self.proj(x).chunk(2, dim=-1) + gate = torch.nn.functional.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + out = h * gate + if not self.features_last: + out = out.moveaxis(-1, 1) + return out + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the GEGLU activation. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Activated tensor + """ + return super().__call__(x) diff --git a/src/mrpro/nn/GluMBConvResBlock.py b/src/mrpro/nn/GluMBConvResBlock.py new file mode 100644 index 000000000..3eaf3b9d4 --- /dev/null +++ b/src/mrpro/nn/GluMBConvResBlock.py @@ -0,0 +1,118 @@ +"""Gateded MBConv Residual Block.""" + +import torch +from torch.nn import Identity, Module, Sequential, SiLU + +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.FiLM import FiLM +from mrpro.nn.ndmodules import ConvND +from mrpro.nn.RMSNorm import RMSNorm + + +class GluMBConvResBlock(CondMixin, Module): + """Gated MBConv residual block. + + Gated variant [DCAE]_ of the MBConv block [EffNet]_ with a residual connection. + + References + ---------- + .. [DCAE] Chen et al. Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models. ICLR 2025 + https://arxiv.org/abs/2410.10733 + .. [EffNet] Tan et al. EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. ICML 2019 + https://arxiv.org/abs/1905.11946 + """ + + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + expand_ratio: int = 6, + stride: int = 1, + kernel_size: int = 3, + cond_dim: int = 0, + ): + """Initialize MBConv block. + + Parameters + ---------- + dim + Number of spatial dimensions. + channels_in + Number of input channels. + channels_out + Number of output channels. + expand_ratio + Expansion ratio inside the block. + stride + Stride of the depthwise convolution. + kernel_size + Kernel size of the depthwise convolution. + cond_dim + Dimension of the conditioning tensor used in a FiLM. If 0, no FiLM is used. + """ + super().__init__() + channels_mid = channels_in * expand_ratio + if stride == 1 and channels_in == channels_out: + self.skip: Module = Identity() + else: + self.skip = ConvND(dim)(channels_in, channels_out, kernel_size=1, stride=stride) + self.inverted_conv = Sequential( + ConvND(dim)( + channels_in, + channels_mid * 2, + kernel_size=1, + ), + SiLU(), + ) + self.depth_conv = Sequential( + ConvND(dim)( + channels_mid * 2, + channels_mid * 2, + kernel_size=kernel_size, + stride=stride, + padding='same', + groups=channels_mid * 2, + ), + SiLU(), + ) + self.point_conv = Sequential( + ConvND(dim)( + channels_mid, + channels_out, + kernel_size=1, + ), + RMSNorm(channels_out), + SiLU(), + ) + if cond_dim > 0: + self.film: FiLM | None = FiLM(channels_mid, cond_dim) + else: + self.film = None + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply MBConv block. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. If None, no conditioning is applied. + + Returns + ------- + Output tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply MBConv block.""" + h = self.inverted_conv(x) + h = self.depth_conv(h) + h, gate = torch.chunk(h, 2, dim=1) + h = h * torch.nn.functional.silu(gate) + if self.film is not None: + h = self.film(h, cond=cond) + h = self.point_conv(h) + return self.skip(x) + h diff --git a/src/mrpro/nn/GroupNorm.py b/src/mrpro/nn/GroupNorm.py new file mode 100644 index 000000000..09e91cf11 --- /dev/null +++ b/src/mrpro/nn/GroupNorm.py @@ -0,0 +1,47 @@ +"""GroupNorm with 32-bit precision.""" + +import torch + + +class GroupNorm(torch.nn.GroupNorm): + """A 32-bit GroupNorm. + + Casts to float32 before calling the parent class to avoid instabilities in mixed precision training. + """ + + def __init__(self, channels: int, groups: int | None = None): + """Initialize GroupNorm32. + + Parameters + ---------- + channels + The number of channels in the input tensor. + groups + The number of groups to use. If None, the number of groups is determined automatically as + a power of 2 that is less than or equal to 32 and leaves at least 4 channels per group. + """ + if groups is None: + groups_, candidate = 1, 2 + while (candidate <= min(32, channels // 4)) and (channels % candidate == 0): + groups_, candidate = candidate, groups_ * 2 + else: + groups_ = groups + super().__init__(groups_, channels) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply GroupNorm32. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x.float()).type(x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply GroupNorm32.""" + return super().forward(x.float()).type(x.dtype) diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py new file mode 100644 index 000000000..699de57f0 --- /dev/null +++ b/src/mrpro/nn/LayerNorm.py @@ -0,0 +1,79 @@ +"""Layer normalization.""" + +import torch +from torch.nn import Linear, Module, Parameter + +from mrpro.nn.CondMixin import CondMixin +from mrpro.utils.reshape import unsqueeze_at, unsqueeze_right + + +class LayerNorm(CondMixin, Module): + """Layer normalization.""" + + def __init__(self, channels: int | None, features_last: bool = False, cond_dim: int = 0) -> None: + """Initialize the layer normalization. + + Parameters + ---------- + channels + Number of channels in the input tensor. If `None`, the layer normalization does not do an elementwise + affine transformation. + features_last + If `True`, the channel dimension is the last dimension. + cond_dim + Number of channels in the conditioning tensor. If `0`, no adaptive scaling is applied. + """ + super().__init__() + if channels is None and cond_dim == 0: + self.weight: Parameter | None = None + self.bias: Parameter | None = None + self.cond_proj: Linear | None = None + elif channels is None and cond_dim > 0: + raise ValueError('channels must be provided if cond_dim > 0') + elif channels is not None and cond_dim == 0: + self.weight = Parameter(torch.ones(channels)) + self.bias = Parameter(torch.zeros(channels)) + self.cond_proj = None + else: + self.weight = None + self.bias = None + self.cond_proj = Linear(cond_dim, 2 * channels) + + self.features_last = features_last + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply layer normalization to the input tensor. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Normalized output tensor + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply layer normalization to the input tensor.""" + dims = tuple(range(1, x.ndim)) + mean = x.mean(dim=dims, keepdim=True) + std = x.std(dim=dims, keepdim=True, unbiased=False) + x = (x - mean) / (std + 1e-5) + + if self.weight is not None and self.bias is not None: + if self.features_last: + x = x * self.weight + self.bias + else: + x = x * unsqueeze_right(self.weight, x.ndim - 2) + unsqueeze_right(self.bias, x.ndim - 2) + + if self.cond_proj is not None and cond is not None: + scale, shift = self.cond_proj(cond).chunk(2, dim=-1) + scale = 1 + scale + if self.features_last: + x = x * unsqueeze_at(scale, 1, x.ndim - 2) + unsqueeze_at(shift, 1, x.ndim - 2) + else: + x = x * unsqueeze_right(scale, x.ndim - 2) + unsqueeze_right(shift, x.ndim - 2) + + return x diff --git a/src/mrpro/nn/LinearSelfAttention.py b/src/mrpro/nn/LinearSelfAttention.py new file mode 100644 index 000000000..612f0ffe2 --- /dev/null +++ b/src/mrpro/nn/LinearSelfAttention.py @@ -0,0 +1,109 @@ +"""Linear self-attention.""" + +import torch +from einops import rearrange +from torch import Tensor +from torch.nn import Linear, Module, ReLU + + +class LinearSelfAttention(Module): + """Linear multi-head self-attention via kernel trick. + + Uses a ReLU kernel to compute attention in O(N) [KAT20]_ time and space. + + + References + ---------- + .. [KAT20] Katharopoulos, Angelos, et al. Transformers are rnns: Fast autoregressive transformers with linear + attention. ICML 2020. https://arxiv.org/abs/2006.16236 + + Parameters + ---------- + channels + Input and output channel dimension. + n_heads + Number of attention heads. + bias + Whether to use bias in the QKV projection. + eps + Small epsilon for numerical stability in normalization. + """ + + def __init__( + self, + channels_in: int, + channels_out: int, + n_heads: int, + eps: float = 1e-6, + features_last: bool = False, + ): + """Initialize linear self-attention layer. + + Parameters + ---------- + channels_in + Input channel dimension. + channels_out + Output channel dimension. + n_heads + Number of attention heads. + eps + Small epsilon for numerical stability in normalization. + features_last + Whether the channel dimension is the last dimension, as common in transformer models, + or the second dimension, as common in image models. + """ + super().__init__() + self.features_last = features_last + self.eps = eps + self.n_heads = n_heads + channels_per_head = channels_in // n_heads + self.to_qkv = Linear(channels_in, 3 * channels_per_head * n_heads) + self.kernel_function = ReLU() + self.to_out = Linear(channels_per_head * n_heads, channels_out) + + def __call__(self, x: Tensor) -> Tensor: + """Apply linear self-attention. + + Parameters + ---------- + x + Tensor of shape `batch, channels, *spatial_dims` or (`batch, *spatial_dims, channels` if `features_last`) + + Returns + ------- + Tensor after attention, same shape as input. + """ + return super().__call__(x) + + def forward(self, x: Tensor) -> Tensor: + """Apply linear self-attention.""" + orig_dtype = x.dtype + if x.dtype == torch.float16: + x = x.float() + if not self.features_last: + x = x.moveaxis(1, -1) + spatial_shape = x.shape[1:-1] + + qkv = self.to_qkv(x) + query, key, value = rearrange( + qkv, 'batch ... (qkv head channels) -> qkv batch head (...) channels', qkv=3, head=self.n_heads + ) + + query = self.kernel_function(query) + key = self.kernel_function(key) + + # trick to avoid second attention calculation: add normalization slot + value = torch.nn.functional.pad(value, (0, 0, 0, 1), mode='constant', value=1.0) + + value_key = value @ key.transpose(-1, -2) + value_key_query = value_key @ query + normalization = value_key_query[..., -1:, :] + self.eps + attn = value_key_query[..., :-1, :] / normalization + attn = attn.moveaxis(1, -1).flatten(-2) # join heads and channels + out = self.to_out(attn) + out = out.to(orig_dtype) + out = out.unflatten(-2, spatial_shape) + if not self.features_last: + out = out.moveaxis(-1, 1) + return out diff --git a/src/mrpro/nn/MultiHeadAttention.py b/src/mrpro/nn/MultiHeadAttention.py new file mode 100644 index 000000000..c7b5e1cde --- /dev/null +++ b/src/mrpro/nn/MultiHeadAttention.py @@ -0,0 +1,87 @@ +"""Multi-head Attention.""" + +import torch +from torch.nn import Linear, Module + + +class MultiHeadAttention(Module): + """Multi-head Attention. + + Implements multihead scaled dot-product attention and supports "image-like" inputs, + i.e. `batch, channels, *spatial_dims` as well as "transformer-like" inputs, `batch, sequence, features`. + """ + + def __init__( + self, + channels_in: int, + channels_out: int, + n_heads: int, + features_last: bool = False, + p_dropout: float = 0.0, + channels_cross: int | None = None, + ): + """Initialize the Multi-head Attention. + + Parameters + ---------- + dim + Number of spatial dimensions. + channels_in + Number of input channels. + channels_out + Number of output channels. + n_heads + number of attention heads + features_last + Whether the features dimension is the last dimension, as common in transformer models, + or the second dimension, as common in image models. + p_dropout + Dropout probability. + channels_cross + Number of channels for cross-attention. If `None`, use `channels_in`. + """ + super().__init__() + self.mha = torch.nn.MultiheadAttention( + embed_dim=channels_in, + num_heads=n_heads, + batch_first=True, + dropout=p_dropout, + kdim=channels_cross, + vdim=channels_cross, + ) + self.features_last = features_last + self.to_out = Linear(channels_in, channels_out) + + def __call__(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: + """Apply multi-head attention. + + Parameters + ---------- + x + The input tensor. + cross_attention + The key and value tensors for cross-attention. If `None`, self-attention is applied. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cross_attention) + + def _reshape(self, x: torch.Tensor) -> torch.Tensor: + if not self.features_last: + x = x.moveaxis(1, -1) + return x.flatten(1, -2) + + def forward(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: + """Apply multi-head attention.""" + reshaped_x = self._reshape(x) + reshaped_cross_attention = self._reshape(cross_attention) if cross_attention is not None else reshaped_x + + y = self.mha(reshaped_cross_attention, reshaped_cross_attention, reshaped_x, need_weights=False)[0] + out: torch.Tensor = self.to_out(y) + + if not self.features_last: + out = out.moveaxis(-1, 1) + + return out.reshape(x.shape) diff --git a/src/mrpro/nn/NeighborhoodSelfAttention.py b/src/mrpro/nn/NeighborhoodSelfAttention.py new file mode 100644 index 000000000..91151e4bb --- /dev/null +++ b/src/mrpro/nn/NeighborhoodSelfAttention.py @@ -0,0 +1,192 @@ +"""Neighborhood Self Attention.""" + +from collections.abc import Sequence +from functools import cache, reduce +from typing import TypeVar + +import torch +from einops import rearrange +from torch.nn import Linear, Module +from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention + +from mrpro.utils.to_tuple import to_tuple + +T = TypeVar('T') + + +@cache +def neighborhood_mask( + input_size: torch.Size, + kernel_size: int | tuple[int, ...], # tuples instead of Sequence for cache + dilation: int | tuple[int, ...] = 1, + circular: bool | tuple[bool, ...] = False, +) -> BlockMask: + """Create a flex attention block mask for neighborhood attention. + + This function defines which key/value pairs a query can attend to based + on a local neighborhood. The neighborhood is defined by `kernel_size` + and `dilation` and can be circular (wrapping around edges). + + Parameters + ---------- + input_size + The dimensions of the input tensor (e.g., (H, W) for 2D). + kernel_size + The size of the attention neighborhood window. Can be a single + integer for a symmetric window or a sequence of integers for + each dimension. + dilation + The dilation factor for the neighborhood + Can be a single integer for a symmetric window or a sequence + of integers for each dimension. + circular + Whether the neighborhood wraps around the edges (circular padding). + Can be a single boolean or a sequence of booleans. + + Returns + ------- + A mask object suitable for `flex_attention` that defines the + allowed attention connections. + """ + kernel_size_tuple, dilation_tuple, circular_tuple = ( + to_tuple(len(input_size), x) for x in (kernel_size, dilation, circular) + ) + + def unravel_index(idx: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Convert a flat 1D index into multi-dimensional coordinates.""" + idx = idx.clone() + coords = [] + for dim in reversed(input_size): + coords.append(idx % dim) + idx = (idx / dim).floor().long() + coords.reverse() + return tuple(coords) + + def mask( + _batch: torch.Tensor, + _head: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + ) -> torch.Tensor: + """Determine if a query can attend to a key/value pair.""" + q_coord = unravel_index(q_idx) + kv_coord = unravel_index(kv_idx) + + masks = [] + for input_, kernel_, dilation_, circular_, q_, kv_ in zip( + input_size, + kernel_size_tuple, + dilation_tuple, + circular_tuple, + q_coord, + kv_coord, + strict=False, + ): + masks.append((q_ % dilation_) == (kv_ % dilation_)) + kernel_dilation = kernel_ * dilation_ + window_left = kernel_dilation // 2 + window_right = (kernel_dilation // 2) + ((kernel_dilation % 2) - 1) + if circular_: + left = (q_ - kv_ + input_) % input_ + right = (kv_ - q_ + input_) % input_ + masks.append((left <= window_left) | (right <= window_right)) + else: + center = q_.clamp(window_left, input_ - 1 - window_right) + left = center - kv_ + right = kv_ - center + masks.append(((left >= 0) & (left <= window_left)) | ((right >= 0) & (right <= window_right))) + return reduce(lambda x, y: x & y, masks) + + qkv_len = input_size.numel() + return create_block_mask(mask, B=None, H=None, Q_LEN=qkv_len, KV_LEN=qkv_len, _compile=True) + + +class NeighborhoodSelfAttention(Module): + """Attention where each query attends to a neighborhood of the key and value. + + Neighborhood attention is a type of attention where each query attends to a neighborhood of the key and value. + It is a more efficient alternative to regular attention, especially for large input sizes [NAT]_. + + This implementation uses `~torch.nn.attention.flex_attention`. For a more efficient implementation, + see also [NATTEN]_. + + + References + ---------- + .. [NAT] Hassani, A. et al. "Neighborhood Attention Transformer" CVPR, 2023, https://arxiv.org/abs/2204.07143 + .. [NATTEN] https://github.com/SHI-Labs/NATTEN/ + """ + + def __init__( + self, + channels_in: int, + channels_out: int, + n_heads: int, + kernel_size: int | Sequence[int], + dilation: int | Sequence[int] = 1, + circular: bool | Sequence[bool] = False, + features_last: bool = False, + ) -> None: + """Initialize a neighborhood attention module. + + The parameters `kernel_size`, `dilation`, and `circular` can either be sequences, interpreted as per-dimension + values, or scalars, interpreted as the same value for all dimensions. + + Parameters + ---------- + channels_in + The number of channels in the input tensor. + channels_out + The number of channels in the output tensor. + n_heads + The number of attention heads. + kernel_size + The size of the attention neighborhood window. + dilation + The dilation factor for the neighborhood. + circular + Whether the neighborhood wraps around the edges (circular padding) + features_last + Whether the channels are in the last dimension of the tensor, as common in visíon transformers. + Otherwise, assume the channels are in the second dimension, as common in CNN models. + """ + super().__init__() + self.n_head = n_heads + self.kernel_size = kernel_size if isinstance(kernel_size, int) else tuple(kernel_size) + self.dilation = dilation if isinstance(dilation, int) else tuple(dilation) + self.circular = circular if isinstance(circular, bool) else tuple(circular) + self.features_last = features_last + channels_per_head = channels_in // n_heads + self.to_qkv = Linear(channels_in, 3 * channels_per_head * n_heads) + self.to_out = Linear(channels_per_head * n_heads, channels_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply neighborhood attention to the input tensor. + + Parameters + ---------- + x + The input tensor, with shape `batch, channels, *spatial_dims`. + + Returns + ------- + The output tensor after attention, with the same shape as the input tensor. + """ + if not self.features_last: + x = x.moveaxis(1, -1) + spatial_shape = x.shape[1:-1] + qkv = self.to_qkv(x) + query, key, value = rearrange( + qkv, 'batch ... (qkv head channels) -> qkv batch head (...) channel', qkv=3, head=self.n_head + ) + # the mask depends on the input size. To be more flexible if used within CNNs, we compute it here. + # The computation is cached.. + mask = neighborhood_mask( + input_size=spatial_shape, kernel_size=self.kernel_size, dilation=self.dilation, circular=self.circular + ) + out: torch.Tensor = flex_attention(query.contiguous(), key.contiguous(), value.contiguous(), block_mask=mask) # type: ignore[assignment] # wrong type hints + out = self.to_out(out) + out = out.unflatten(-2, spatial_shape) + if not self.features_last: + out = out.moveaxis(-1, 1) + return out diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py new file mode 100644 index 000000000..70b0270df --- /dev/null +++ b/src/mrpro/nn/PixelShuffle.py @@ -0,0 +1,226 @@ +"""ND-version of PixelShuffle and PixelUnshuffle.""" + +import torch +from torch.nn import Module + +from mrpro.nn.ndmodules import ConvND + + +class PixelUnshuffle(Module): + """ND-version of PixelUnshuffle downscaling.""" + + def __init__(self, downscale_factor: int): + """Initialize PixelUnshuffle. + + Reduces spatial dimensions and increases the channel number by reshaping. + The first dimension is considered a batch dimension, the second dimension + the channel dimension, and the remaining dimensions the spatial dimensions that are downscaled. + + See `mrpro.nn.PixelShuffle` for the inverse operation. + + Parameters + ---------- + downscale_factor : int + The factor by which to downscale the input tensor. + """ + super().__init__() + self.downscale_factor = downscale_factor + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Downscale the input. + + Parameters + ---------- + x + Tensor of shape `batch, channels, *spatial_dims` + + Returns + ------- + Tensor of shape `batch, channels * downscale_factor**dim, *spatial_dims/downscale_factor` + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Downscale the input.""" + dim = x.ndim - 2 + if dim == 2: # fast path for 2D + return torch.nn.functional.pixel_unshuffle(x, self.downscale_factor) + + new_shape = list(x.shape[:2]) + source_positions = [] + for i, old in enumerate(x.shape[2:]): + if old % self.downscale_factor: + raise ValueError('Spatial size must be divisible by downscale_factor.') + new_shape.append(old // self.downscale_factor) + new_shape.append(self.downscale_factor) + source_positions.append(2 + 2 * i) + + x = x.view(new_shape) + x = x.moveaxis(source_positions, tuple(range(-dim, 0))) + x = x.flatten(1, -dim - 1) + return x + + +class PixelUnshuffleDownsample(Module): + """PixelUnshuffle Downsampling. + + PixelUnshuffle followed by a convolution. Optionally uses a residual connection [DCAE]_ + + References + ---------- + .. [DCAE] Chen et al. Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models. ICLR 2025 + https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, dim: int, channels_in: int, channels_out: int, downscale_factor: int = 2, residual: bool = False + ): + """Initialize a PixelUnshuffleDownsample layer. + + Parameters + ---------- + dim : int + Dimension of the input tensor. + channels_in : int + Number of channels in the input tensor. + channels_out : int + Number of channels in the output tensor. + downscale_factor : int, optional + Factor by which to downscale the input tensor. + residual : bool, optional + Whether to use a residual connection as proposed in [DCAE]_. + """ + super().__init__() + self.pixel_unshuffle = PixelUnshuffle(downscale_factor) + out_ratio = downscale_factor**dim + if channels_out % out_ratio != 0: + raise ValueError(f'channels_out must be divisible by downscale_factor**{dim}.') + self.conv = ConvND(dim)(channels_in, channels_out // out_ratio, kernel_size=3, padding='same') + self.residual = residual + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply downsampling. + + Parameters + ---------- + x + Tensor of shape `batch, channels_in, *spatial_dims` + + Returns + ------- + Tensor of shape `batch, channels_out, *spatial_dims/downscale_factor` + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply downsampling.""" + h = self.conv(x) + h = self.pixel_unshuffle(h) + + if self.residual: + x = self.pixel_unshuffle(x) + h = h + x.unflatten(1, (h.shape[1], -1)).mean(2) + return h + + +class PixelShuffleUpsample(Module): + """PixelShuffle Upsampling. + + Convolution followed by PixelShuffle. Optionally uses a residual connection [DCAE]_ + + References + ---------- + .. [DCAE] Chen et al. Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models. ICLR 2025 + https://arxiv.org/abs/2410.10733 + """ + + def __init__(self, dim: int, channels_in: int, channels_out: int, upscale_factor: int = 2, residual: bool = False): + """Initialize a PixelShuffleUpsample layer. + + Parameters + ---------- + dim : int + Dimension of the input tensor. + channels_in : int + Number of channels in the input tensor. + channels_out : int + Number of channels in the output tensor. + upscale_factor : int, optional + Factor by which to upscale the input tensor. + residual : bool, optional + Whether to use a residual connection as proposed in [DCAE]_. + """ + super().__init__() + self.conv = ConvND(dim)(channels_in, channels_out * upscale_factor**dim, kernel_size=3, padding='same') + self.pixel_shuffle = PixelShuffle(upscale_factor) + self.residual = residual + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply upsampling. + + Parameters + ---------- + x + Tensor of shape `batch, channels_in, *spatial_dims` + + Returns + ------- + Tensor of shape `batch, channels_out, *spatial_dims * upscale_factor` + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply upsampling.""" + h = self.conv(x) + if self.residual: + h = h + x.repeat_interleave(h.shape[1] // x.shape[1], dim=1) + out = self.pixel_shuffle(h) + return out + + +class PixelShuffle(Module): + """ND-version of PixelShuffle upscaling.""" + + def __init__(self, upscale_factor: int): + """Initialize PixelShuffle. + + Upscales spatial dimensions and decreases the channel number by reshaping. + The first dimension is considered a batch dimension, the second dimension + the channel dimension, and the remaining dimensions the spatial dimensions that are upscaled. + + See `mrpro.nn.PixelUnshuffle` for the inverse operation. + + Parameters + ---------- + upscale_factor : int + The factor by which to upscale the spatial dimensions. + """ + super().__init__() + self.upscale_factor = upscale_factor + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Upscale the input. + + Parameters + ---------- + x + Tensor of shape `batch, channels, *spatial_dims` + + Returns + ------- + Tensor of shape `batch, channels / upscale_factor**dim, *spatial_dims * upscale_factor` + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Upscale the input.""" + dim = x.ndim - 2 + if dim == 2: # fast path for 2D + return torch.nn.functional.pixel_shuffle(x, self.upscale_factor) + + new_shape = (x.shape[0], -1, *(old * self.upscale_factor for old in x.shape[-dim:])) + + x = x.unflatten(1, (-1, *(self.upscale_factor,) * dim)) + x = x.moveaxis(tuple(range(2, 2 + dim)), tuple(range(-2 * dim + 1, 0, 2))) + x = x.reshape(new_shape) + return x diff --git a/src/mrpro/nn/RMSNorm.py b/src/mrpro/nn/RMSNorm.py new file mode 100644 index 000000000..28cecbf9f --- /dev/null +++ b/src/mrpro/nn/RMSNorm.py @@ -0,0 +1,53 @@ +"""RMSNorm module for root mean square normalization.""" + +import torch +from torch.nn import Module, Parameter + + +class RMSNorm(Module): + """RMSNorm over the channel dimension.""" + + def __init__(self, channels: int, eps: float = 1e-8, features_last: bool = False): + """Initialize RMSNorm. + + Includes a learnable weight and bias. + + Parameters + ---------- + channels + Number of channels. + eps + Epsilon value to avoid division by zero. + features_last + If True, the channel dimension is the last dimension. + """ + super().__init__() + self.weight = Parameter(torch.zeros(channels)) + self.bias = Parameter(torch.zeros(channels)) + self.eps = eps + self.channel_dim = -1 if features_last else 1 + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply RMSNorm over the channel dimension. + + Parameters + ---------- + x + Input tensor. + + Returns + ------- + Normalized tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply RMSNorm over the channel dimension.""" + mean_square = x.square().mean(dim=self.channel_dim, keepdim=True) + scale = (mean_square + self.eps).rsqrt() + x = x * scale + shape = [1] * x.ndim + shape[self.channel_dim] = -1 + weight = (self.weight + 1).view(shape) + bias = self.bias.view(shape) + return x * weight + bias diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py new file mode 100644 index 000000000..8f61e6022 --- /dev/null +++ b/src/mrpro/nn/ResBlock.py @@ -0,0 +1,70 @@ +"""Residual convolution block with two convolutions.""" + +import torch +from torch.nn import Identity, Module, SiLU + +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import ConvND +from mrpro.nn.Sequential import Sequential + + +class ResBlock(CondMixin, Module): + """Residual convolution block with two convolutions.""" + + def __init__(self, dim: int, channels_in: int, channels_out: int, cond_dim: int) -> None: + """Initialize the ResBlock. + + Parameters + ---------- + dim + The dimension, i.e. 1, 2 or 3. + channels_in + The number of channels in the input tensor. + channels_out + The number of channels in the output tensor. + cond_dim + The number of features in the conditioning tensor used in a FiLM. + If set to 0 no FiLM is used. + + """ + super().__init__() + self.rezero = torch.nn.Parameter(torch.tensor(0.1)) + self.block = Sequential( + GroupNorm(channels_in), + SiLU(), + ConvND(dim)(channels_in, channels_out, kernel_size=3, padding=1), + GroupNorm(channels_out), + SiLU(), + ConvND(dim)(channels_out, channels_out, kernel_size=3, padding=1), + ) + if cond_dim > 0: + self.block.insert(-3, FiLM(channels_out, cond_dim)) + + if channels_out == channels_in: + self.skip_connection: Module = Identity() + else: + self.skip_connection = ConvND(dim)(channels_in, channels_out, kernel_size=1) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the ResBlock. + + Parameters + ---------- + x + The input tensor. + cond + A conditioning tensor to be used for FiLM. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the ResBlock.""" + h = self.block(x, cond=cond) + x = self.skip_connection(x) + self.rezero * h + return x diff --git a/src/mrpro/nn/Residual.py b/src/mrpro/nn/Residual.py new file mode 100644 index 000000000..e524fe169 --- /dev/null +++ b/src/mrpro/nn/Residual.py @@ -0,0 +1,45 @@ +"""Residual connection.""" + +import torch +from torch.nn import Identity, Module + +from mrpro.nn.CondMixin import CondMixin, call_with_cond + + +class Residual(CondMixin, Module): + """Residual connection.""" + + def __init__(self, module: Module, skip: Module | None = None): + """Initialize the residual connection. + + Parameters + ---------- + module + The main path of the residual connection. + skip + The skip path of the residual connection. If None, the identity function is used. + """ + super().__init__() + self.module = module + self.skip = Identity() if skip is None else skip + + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module. + + Parameters + ---------- + x + The input tensor. + cond + The optional conditioning tensor. If the modules are an instance of `CondMixin`, + the conditioning is passed to the modules. + + Returns + ------- + The output tensor. + """ + return super().__call__(*x, cond=cond) + + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module.""" + return call_with_cond(self.module, *x, cond=cond) + call_with_cond(self.skip, *x, cond=cond) diff --git a/src/mrpro/nn/RoPE.py b/src/mrpro/nn/RoPE.py new file mode 100644 index 000000000..90ecb8739 --- /dev/null +++ b/src/mrpro/nn/RoPE.py @@ -0,0 +1,141 @@ +"""Rotary Position Embedding (RoPE).""" + +import torch +from torch.nn import Module + + +@torch.compile +def apply_rotary_emb_(x: torch.Tensor, theta: torch.Tensor, conjugated: bool) -> None: + """Add rotary embedding to the input tensor (inplace). + + This is a helper function for the `AxialRoPE` class. + + Parameters + ---------- + x : torch.Tensor + Input tensor to modify + theta : torch.Tensor + Rotation angles + conjugated : bool + Whether to use conjugated rotation + """ + n_emb = theta.shape[-1] * 2 + if n_emb > x.shape[-1]: + raise ValueError(f'Embedding dimension {n_emb} is larger than input dimension {x.shape[-1]}') + x1, x2 = x[..., :n_emb].chunk(2, dim=-1) + if conjugated: + x1, x2 = x2, x1 + x[..., :n_emb] = torch.cat([x1 * theta.cos() - x2 * theta.sin(), x2 * theta.cos() + x1 * theta.sin()], dim=-1) + + +class RotaryEmbedding(torch.autograd.Function): + """Custom autograd function for rotary embeddings.""" + + @staticmethod + def forward( + x: torch.Tensor, + theta: torch.Tensor, + conjugated: bool, + ) -> torch.Tensor: + """Apply rotary embedding in forward pass.""" + apply_rotary_emb_(x, theta, conjugated) + return x + + @staticmethod + def setup_context( + ctx: torch.autograd.function.FunctionCtx, inputs: tuple[torch.Tensor, torch.Tensor, bool], _output: torch.Tensor + ) -> None: + """Save tensors for backward pass.""" + _, theta, conjugated = inputs + ctx.save_for_backward(theta) + ctx.conjugated = conjugated # type: ignore[attr-defined] + + @staticmethod + def backward( # type: ignore[override] + ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor + ) -> tuple[torch.Tensor, None, None]: + """Apply backward pass.""" + (theta,) = ctx.saved_tensors # type: ignore[attr-defined] + apply_rotary_emb_(grad_output, theta, ctx.conjugated) # type: ignore[attr-defined] + return grad_output, None, None + + +class AxialRoPE(Module): + """Axial Rotary Position Embedding. + + Applies rotary position embeddings along each axis independently. + """ + + freqs: torch.Tensor + + def __init__(self, dim: int, d_head: int, n_heads: int, headpos: int = -2, non_embed_fraction: float = 0.5): + """Initialize AxialRoPE. + + Parameters + ---------- + dim : int + Dimension of the input space + d_head : int + Dimension of each attention head + n_heads : int + Number of attention heads + headpos : int, optional + Position of the head dimension + non_embed_fraction : float, optional + Fraction of dimensions to not embed + """ + super().__init__() + log_min = torch.log(torch.tensor(torch.pi)) + log_max = torch.log(torch.tensor(10000.0)) + freqs = torch.exp(torch.linspace(log_min, log_max, d_head // 2)) + self.register_buffer('freqs', freqs) + self.headpos = headpos + + def get_theta(self, pos: torch.Tensor) -> torch.Tensor: + """Get rotation angles for given positions. + + Parameters + ---------- + pos : torch.Tensor + Position tensor + + Returns + ------- + torch.Tensor + Rotation angles + """ + return (self.freqs * pos[..., None, :, None]).flatten(start_dim=-2).movedim(-2, self.headpos) + + def forward(self, pos: torch.Tensor, *tensors: torch.Tensor) -> None: + """Apply rotary embeddings to input tensors. + + Parameters + ---------- + pos : torch.Tensor + Position tensor + *tensors : torch.Tensor + Tensors to apply rotary embeddings to + """ + theta = self.get_theta(pos) + tuple(RotaryEmbedding.apply(x, theta, False) for x in tensors) + + @staticmethod + def make_axial_positions(*shape: int) -> torch.Tensor: + """Create axial position tensors. + + Parameters + ---------- + *shape : int + Shape of the position tensor + + Returns + ------- + torch.Tensor + Position tensor + """ + m = torch.as_tensor(shape).max() + pos = torch.stack( + [torch.arange(s, device=m.device) - s // 2 for s in shape], + dim=-1, + ) + return pos diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py new file mode 100644 index 000000000..15b5d0152 --- /dev/null +++ b/src/mrpro/nn/Sequential.py @@ -0,0 +1,60 @@ +"""Sequential container with support for conditioning and Operators.""" + +from collections import OrderedDict +from typing import cast + +import torch + +from mrpro.nn.CondMixin import CondMixin +from mrpro.operators import Operator + + +class Sequential(CondMixin,torch.nn.Sequential): + """Sequential container with support for conditioning and Operators. + + Allows multiple input tensors and a single output tensor of the sequential block. + + """ + + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply all modules in series to the input. + + Parameters + ---------- + x + The input tensor. + cond + The (optional) conditioning tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(*x, cond=cond) + + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply all modules in series to the input.""" + for module in self: + if isinstance(module, Operator): + x = cast(tuple[torch.Tensor, ...], module(*x)) # always tuple + else: + ret: torch.Tensor | tuple[torch.Tensor, ...] + if isinstance(module, CondMixin): + ret = module(*x, cond=cond) + else: + ret = module(*x) + if isinstance(ret, tuple): + x = ret + else: + x = (ret,) + return x[0] + + def __getitem__(self, idx: slice | int) -> 'Sequential': + """Get a slice or item from the Sequential container. + + Subclasses will decompose to `Sequential` on indexing. + """ + if isinstance(idx, slice): + return Sequential(OrderedDict(list(self._modules.items())[idx])) + else: + return self._get_item_by_idx(self._modules.values(), idx) diff --git a/src/mrpro/nn/ShiftedWindowAttention.py b/src/mrpro/nn/ShiftedWindowAttention.py new file mode 100644 index 000000000..f66e2277e --- /dev/null +++ b/src/mrpro/nn/ShiftedWindowAttention.py @@ -0,0 +1,103 @@ +"""Shifted Window Attention.""" + +import torch +from einops import rearrange +from torch.nn import Module + +from mrpro.nn.ndmodules import ConvND +from mrpro.utils.reshape import ravel_multi_index +from mrpro.utils.sliding_window import sliding_window + + +class ShiftedWindowAttention(Module): + """Shifted Window Attention. + + (Shifted) Window Attention calculates attention over windows of the input. + It was introduced in Swin Transformer [SWIN]_ and is used in Uformer. + + References + ---------- + .. [SWIN] Liu, Ze, et al. "Swin transformer: Hierarchical vision transformer using shifted windows." ICCV 2021. + """ + + rel_position_index: torch.Tensor + + def __init__( + self, dim: int, channels_in: int, channels_out: int, n_heads: int, window_size: int = 7, shifted: bool = True + ): + """Initialize the ShiftedWindowAttention module. + + Parameters + ---------- + dim : int + The dimension of the input. + channels_in : int + The number of channels in the input tensor. + channels_out : int + The number of channels in the output tensor. + n_heads : int + The number of attention heads. The number if channels per head is ``channels // n_heads``. + window_size : int + The size of the window. + shifted : bool + Whether to shift the window. + """ + super().__init__() + self.n_heads = n_heads + self.window_size = window_size + self.shifted = shifted + channels_per_head = channels_in // n_heads + self.to_qkv = ConvND(dim)(channels_per_head * n_heads, 3 * channels_per_head * n_heads, 1) + self.to_out = ConvND(dim)(channels_per_head * n_heads, channels_out, 1) + self.dim = dim + coords_1d = torch.arange(window_size) + coords_nd = torch.stack(torch.meshgrid(*([coords_1d] * dim), indexing='ij'), 0).flatten(1) + rel_coords = coords_nd[:, :, None] - coords_nd[:, None, :] # (dim, window_size**dim, window_size**dim) + rel_coords += window_size - 1 # shift to >=0 + rel_position_index = ravel_multi_index(tuple(rel_coords), (2 * window_size - 1,) * dim) + self.register_buffer('rel_position_index', rel_position_index) + + self.relative_position_bias_table = torch.nn.Parameter(torch.empty((2 * window_size - 1) ** dim, n_heads)) + torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02, a=-0.04, b=0.04) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the ShiftedWindowAttention. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the ShiftedWindowAttention.""" + if self.shifted: + x = torch.roll(x, (-(self.window_size // 2),) * self.dim, dims=tuple(range(-self.dim, 0))) + qkv = self.to_qkv(x) + windowed = sliding_window(qkv, window_shape=self.window_size, stride=self.window_size, dim=range(-self.dim, 0)) + flat = windowed.flatten(0, self.dim - 1).flatten(-self.dim) + q, k, v = rearrange( + flat, + 'spatial batch (qkv heads channels) window->qkv spatial batch heads window channels', + heads=self.n_heads, + qkv=3, + ) + bias = rearrange(self.relative_position_bias_table[self.rel_position_index], 'wd1 wd2 heads -> 1 heads wd1 wd2') + attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=bias) + attention = rearrange(attention, 'spatial batch head window channels->batch (head channels) spatial window') + attention = attention.unflatten(-2, windowed.shape[: self.dim]).unflatten(-1, (self.window_size,) * self.dim) + # permute (in 3d) batch channels z y x wz wy wx -> batch channels wz z wy y wx x + attention = attention.moveaxis(list(range(-self.dim, 0)), list(range(3, 3 + 2 * self.dim, 2))) + attention = attention.reshape(x.shape) + if self.shifted: + attention = torch.roll(attention, (self.window_size // 2,) * self.dim, dims=tuple(range(-self.dim, 0))) + out = self.to_out(attention) + return out + + +'' diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py new file mode 100644 index 000000000..2b4c3e6e2 --- /dev/null +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -0,0 +1,148 @@ +"""Spatial transformer block.""" + +from collections.abc import Sequence + +import torch +from torch.nn import Dropout, Linear, Module + +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.GEGLU import GEGLU +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.nn.MultiHeadAttention import MultiHeadAttention +from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.Sequential import Sequential + + +def zero_init(m: Module) -> Module: + """Initialize module weights and bias to zero.""" + if hasattr(m, 'weight') and isinstance(m.weight, torch.Tensor): + torch.nn.init.zeros_(m.weight) + if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor): + torch.nn.init.zeros_(m.bias) + return m + + +class BasicTransformerBlock(CondMixin, Module): + """Basic vision transformer block.""" + + def __init__( + self, + channels: int, + n_heads: int, + p_dropout: float = 0.0, + cond_dim: int = 0, + mlp_ratio: float = 4, + features_last: bool = False, + ): + """Initialize the basic transformer block. + + Parameters + ---------- + channels + Number of channels in the input and output. + n_heads + Number of attention heads. + p_dropout + Dropout probability. + cond_dim + Number of channels in the conditioning tensor. + mlp_ratio + Ratio of the hidden dimension to the input dimension. + features_last + Whether the features are last in the input tensor. + """ + super().__init__() + self.features_last = features_last + self.selfattention = Sequential( + LayerNorm(channels, features_last=True), + MultiHeadAttention( + channels_in=channels, + channels_out=channels, + n_heads=n_heads, + p_dropout=p_dropout, + features_last=True, + ), + ) + hidden_dim = int(channels * mlp_ratio) + self.ff = Sequential( + LayerNorm(channels, features_last=True, cond_dim=cond_dim), + GEGLU(channels, hidden_dim, features_last=True), + Dropout(p_dropout), + Linear(hidden_dim, channels), + ) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the basic transformer block. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. If None, no conditioning is applied. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the basic transformer block.""" + if not self.features_last: + x = x.moveaxis(1, -1).contiguous() + x = self.selfattention(x) + x + x = self.ff(x, cond=cond) + x + if not self.features_last: + x = x.moveaxis(-1, 1).contiguous() + return x + + +class SpatialTransformerBlock(CondMixin, Module): + """Spatial transformer block.""" + + def __init__( + self, + dim_groups: Sequence[tuple[int, ...]], + channels: int, + n_heads: int, + depth: int = 1, + dropout: float = 0.0, + cond_dim: int = 0, + ): + """ + Parameters + ---------- + dim_groups + Groups of spatial dimensions for separate attention mechanisms. + channels + Number of channels in the input and output. + n_heads + Number of attention heads for each group. + depth + Number of transformer blocks for each group. + dropout + Dropout probability. + cond_dim + Dimension of the conditioning tensor. + """ + super().__init__() + hidden_dim = n_heads * (channels // n_heads) + self.norm = GroupNorm(channels) + self.proj_in = Linear(channels, hidden_dim) + self.transformer_blocks = Sequential() + for group in (g for _ in range(depth) for g in dim_groups): + block = BasicTransformerBlock(hidden_dim, n_heads, p_dropout=dropout, cond_dim=cond_dim, features_last=True) + self.transformer_blocks.append(PermutedBlock(group, block, features_last=True)) + self.proj_out = Linear(hidden_dim, channels) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the spatial transformer block.""" + skip = x + h = self.norm(x) + h = h.movedim(1, -1) + h = self.proj_in(h) + h = self.transformer_blocks(h, cond=cond) + h = self.proj_out(h) + h = h.movedim(-1, 1) + return skip + h + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/SqueezeExcitation.py b/src/mrpro/nn/SqueezeExcitation.py new file mode 100644 index 000000000..787817173 --- /dev/null +++ b/src/mrpro/nn/SqueezeExcitation.py @@ -0,0 +1,57 @@ +"""Squeeze-and-Excitation block.""" + +import torch +from torch.nn import Module, ReLU, Sigmoid + +from mrpro.nn.ndmodules import AdaptiveAvgPoolND, ConvND +from mrpro.nn.Sequential import Sequential + + +class SqueezeExcitation(Module): + """Squeeze-and-Excitation block. + + Sequeeze-and-Excitation block from [SE]_. + + References + ---------- + ..[SE] Hu, Jie, Li Shen, and Gang Sun. "Squeeze-and-excitation networks." CVPR 2018, https://arxiv.org/abs/1709.01507 + """ + + def __init__(self, dim: int, input_channels: int, squeeze_channels: int) -> None: + """Initialize SqueezeExcitation. + + Parameters + ---------- + dim + The dimension of the input tensor. + input_channels + The number of channels in the input tensor. + squeeze_channels + The number of channels in the squeeze tensor. + """ + super().__init__() + self.scale = Sequential( + AdaptiveAvgPoolND(dim)(1), + ConvND(dim)(input_channels, squeeze_channels, kernel_size=1), + ReLU(), + ConvND(dim)(squeeze_channels, input_channels, kernel_size=1), + Sigmoid(), + ) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply SqueezeExcitation. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply SqueezeExcitation.""" + return x * self.scale(x) diff --git a/src/mrpro/nn/TransposedAttention.py b/src/mrpro/nn/TransposedAttention.py new file mode 100644 index 000000000..043afa750 --- /dev/null +++ b/src/mrpro/nn/TransposedAttention.py @@ -0,0 +1,76 @@ +"""Transposed Attention from Restormer.""" + +import torch +from einops import rearrange +from torch.nn import Module, Parameter + +from mrpro.nn.ndmodules import ConvND + + +class TransposedAttention(Module): + """Transposed Self Attention from Restormer. + + Implements the transposed self-attention, i.e. channel-wise multihead self-attention, + layer from Restormer [ZAM22]_. + + References + ---------- + .. [ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for high-resolution image restoration." + CVPR 2022, https://arxiv.org/pdf/2111.09881.pdf + """ + + def __init__(self, dim: int, channels_in: int, channels_out: int, n_heads: int): + """Initialize a TransposedAttention layer. + + Parameters + ---------- + dim + input dimension + channels_in + Number of channels in the input tensor. + channels_out + Number of channels in the output tensor. + n_heads + Number of attention heads. + """ + super().__init__() + self.n_heads = n_heads + self.temperature = Parameter(torch.ones(n_heads, 1, 1)) + channels_per_head = channels_in // n_heads + self.to_qkv = ConvND(dim)(channels_in, channels_per_head * n_heads * 3, kernel_size=1) + self.qkv_dwconv = ConvND(dim)( + channels_per_head * n_heads * 3, + channels_per_head * n_heads * 3, + kernel_size=3, + groups=channels_in * 3, + padding=1, + bias=False, + ) + self.to_out = ConvND(dim)(channels_per_head * n_heads, channels_out, kernel_size=1) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply transposed attention. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply transposed attention.""" + qkv = self.qkv_dwconv(self.to_qkv(x)) + q, k, v = rearrange(qkv, 'b (qkv heads channels) ... -> qkv b heads (...) channels', heads=self.n_heads, qkv=3) + q = torch.nn.functional.normalize(q, dim=-1) * self.temperature + k = torch.nn.functional.normalize(k, dim=-1) + attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=1.0) + out = rearrange(attention, '... heads points channels -> ... (heads channels) points').unflatten( + -1, x.shape[2:] + ) + out = self.to_out(out) + return out diff --git a/src/mrpro/nn/Upsample.py b/src/mrpro/nn/Upsample.py new file mode 100644 index 000000000..ec9b0e032 --- /dev/null +++ b/src/mrpro/nn/Upsample.py @@ -0,0 +1,70 @@ +"""Upsampling by interpolation.""" + +from collections.abc import Sequence +from typing import Literal + +import torch +from torch.nn import Module, Sequential + +from mrpro.nn.PermutedBlock import PermutedBlock + + +class Upsample(Module): + """Upsampling by interpolation.""" + + def __init__( + self, dim: Sequence[int], scale_factor: int = 2, mode: Literal['nearest', 'linear', 'cubic'] = 'linear' + ): + """Initialize the upsampling layer. + + Parameters + ---------- + dim + Dimensions which to upsample + scale_factor + Factor by which to upsample + mode + Interpolation mode. See `torch.nn.functional.interpolate` for details. + """ + super().__init__() + self.scale_factor = scale_factor + if mode == 'nearest': + dims = [tuple(d) for d in torch.tensor(dim).split(3)] + modes = ['nearest'] * len(self.dim) + elif mode == 'linear': + dims = [tuple(d) for d in torch.tensor(dim).split(3)] + modes = [{1: 'linear', 2: 'bilinear', 3: 'trilinear'}[len(d)] for d in dims] + elif mode == 'cubic': + if not len(dim) == 2: + raise ValueError('Cubic interpolation is only supported for 2D images.') + dims = [tuple(dim)] + modes = ['bicubic'] + + self.blocks = Sequential( + *[ + PermutedBlock(d, Upsample(d, scale_factor=scale_factor, mode=m)) + for d, m in zip(dims, modes, strict=False) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Upsample the input tensor.""" + return torch.nn.functional.interpolate( + x, + mode=self.mode, + scale_factor=self.scale_factor, + ) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Upsample the input tensor. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Upsampled tensor + """ + return super().__call__(x) diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py new file mode 100644 index 000000000..e59e4efde --- /dev/null +++ b/src/mrpro/nn/__init__.py @@ -0,0 +1,44 @@ +"""Neural network modules and utilities.""" + +from mrpro.nn.AttentionGate import AttentionGate +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import ( + AdaptiveAvgPoolND, + AvgPoolND, + BatchNormND, + ConvND, + ConvTransposeND, + InstanceNormND, + MaxPoolND, +) +from mrpro.nn.NeighborhoodSelfAttention import NeighborhoodSelfAttention +from mrpro.nn.ResBlock import ResBlock +from mrpro.nn.Sequential import Sequential +from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention +from mrpro.nn.SqueezeExcitation import SqueezeExcitation +from mrpro.nn.TransposedAttention import TransposedAttention +from mrpro.nn.DropPath import DropPath +from mrpro.nn import nets +__all__ = [ + "AdaptiveAvgPoolND", + "AttentionGate", + "AvgPoolND", + "BatchNormND", + "CondMixin", + "ConvND", + "ConvTransposeND", + "DropPath", + "FiLM", + "GroupNorm", + "InstanceNormND", + "MaxPoolND", + "NeighborhoodSelfAttention", + "ResBlock", + "Sequential", + "ShiftedWindowAttention", + "SqueezeExcitation", + "TransposedAttention", + "nets" +] \ No newline at end of file diff --git a/src/mrpro/nn/convert_linear_conv.py b/src/mrpro/nn/convert_linear_conv.py new file mode 100644 index 000000000..a6dac5f33 --- /dev/null +++ b/src/mrpro/nn/convert_linear_conv.py @@ -0,0 +1,100 @@ +"""Convert Linear layers to kernel size 1 ConvNd layers and vice versa.""" + +from typing import Literal, overload + +import torch +from torch.nn import Conv1d, Conv2d, Conv3d, Linear + +from mrpro.nn.ndmodules import ConvND + + +@overload +def linear_to_conv(linear_layer: Linear, dim: Literal[1]) -> Conv1d: ... + + +@overload +def linear_to_conv(linear_layer: Linear, dim: Literal[2]) -> Conv2d: ... + + +@overload +def linear_to_conv(linear_layer: Linear, dim: Literal[3]) -> Conv3d: ... + + +@overload +def linear_to_conv(linear_layer: Linear, dim: int) -> Conv1d | Conv2d | Conv3d: ... + + +def linear_to_conv(linear_layer: Linear, dim: int) -> Conv1d | Conv2d | Conv3d: + """Convert a Linear layer to a ConvNd layer with kernel size 1. + + Rearranging the spatial dimensions to the batch dimension, + applying the linear layer and rearranging the spatial dimensions back + is equivalent to applying a kernel size 1 ConvNd layer. + + This function will create the Conv1d, Conv2d, or Conv3d with the correct weights and bias. + + See :func:`conv_to_linear` for the reverse operation. + + + + Parameters + ---------- + linear_layer : nn.Linear + The linear layer to convert. + dim : int + The convolution dimension (1, 2, or 3). + + Returns + ------- + A Conv layer with equivalent weights and bias. + """ + conv = ConvND(dim)( + in_channels=linear_layer.in_features, + out_channels=linear_layer.out_features, + kernel_size=1, + bias=linear_layer.bias is not None, + device=linear_layer.weight.device, + dtype=linear_layer.weight.dtype, + ) + + with torch.no_grad(): + conv.weight.copy_(linear_layer.weight.view_as(conv.weight)) + if conv.bias is not None and linear_layer.bias is not None: + conv.bias.copy_(linear_layer.bias) + + return conv + + +def conv_to_linear(conv_layer: Conv1d | Conv2d | Conv3d) -> Linear: + """ + Convert a Conv1d, Conv2d, or Conv3d layer with kernel size 1 to a Linear layer. + + Applying a kernel size 1 ConvNd layer is equivalent to applying a Linear layer to each voxel. + This function will create the Linear layer with the correct weights and bias. + + See :func:`linear_to_conv` for the reverse operation. + + Parameters + ---------- + conv_layer : nn.Module + The convolutional layer to convert. Must have kernel size 1. + + Returns + ------- + A linear layer with equivalent weights and bias. + """ + if not all(k == 1 for k in conv_layer.kernel_size): + raise ValueError('Kernel size must be 1 for conversion.') + linear = Linear( + conv_layer.in_channels, + conv_layer.out_channels, + bias=conv_layer.bias is not None, + device=conv_layer.weight.device, + dtype=conv_layer.weight.dtype, + ) + with torch.no_grad(): + linear.weight.copy_(conv_layer.weight.view_as(linear.weight)) + if linear.bias is not None and conv_layer.bias is not None: + linear.bias.copy_(conv_layer.bias) + + return linear diff --git a/src/mrpro/nn/encoding.py b/src/mrpro/nn/encoding.py new file mode 100644 index 000000000..39f48c51e --- /dev/null +++ b/src/mrpro/nn/encoding.py @@ -0,0 +1,115 @@ +"""Encoding modules for neural networks.""" + +from itertools import combinations +from math import ceil + +import torch +from torch.nn import Module + +from mrpro.utils.reshape import unsqueeze_right + + +class FourierFeatures(Module): + """Fourier feature encoding layer. + + Projects input features into a higher dimensional space using random Fourier features. + This is useful for encoding positional information in neural networks. + """ + + weight: torch.Tensor + + def __init__(self, in_features: int, out_features: int, std: float = 1.0): + """Initialize Fourier feature encoding layer. + + Parameters + ---------- + in_features + Number of input features + out_features + Number of output features (must be even) + std + Standard deviation for random initialization + """ + if out_features % 2 != 0: + raise ValueError('out_features must be even.') + super().__init__() + self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply Fourier feature encoding. + + Parameters + ---------- + x + Input tensor of shape (..., in_features) + + Returns + ------- + Encoded features of shape (..., out_features) + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply Fourier feature encoding.""" + f = 2 * torch.pi * x @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) + + +class AbsolutePositionEncoding(Module): + """Absolute position encoding layer. + + Encodes absolute positions in a grid using learned embeddings. + """ + + encoding: torch.Tensor + + def __init__(self, dim: int, features: int, include_radii: bool = True, base_resolution: int = 128): + """Initialize absolute position encoding layer. + + Parameters + ---------- + dim + Dimension of the input space (1, 2, or 3) + features + Number of output features + include_radii + Whether to include radius features + base_resolution + Base resolution for position encoding + """ + super().__init__() + + coords = [unsqueeze_right(torch.linspace(-1, 1, base_resolution), i) for i in range(dim)] + if include_radii: + for n in range(2, dim + 1): + for combination in combinations(coords, n): + coords.append((2 * sum([c**2 for c in combination])) ** 0.5 - 1) + n_freqs = ceil(features / len(coords) / 2) + freqs = unsqueeze_right((base_resolution) ** torch.linspace(0, 1, n_freqs), dim) + encoding = [] + for coord in coords: + encoding.append(torch.sin(coord * freqs).broadcast_to(1, -1, *((base_resolution,) * dim))) + encoding.append(torch.cos(coord * freqs).broadcast_to(1, -1, *((base_resolution,) * dim))) + self.register_buffer('encoding', torch.cat(encoding, dim=1)[:, :features]) + self.interpolation_mode = ['linear', 'bilinear', 'trilinear'][dim - 1] + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for encoding. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Encoded tensor with absolute position information + """ + features = self.encoding.shape[1] + if features > x.shape[1]: + raise ValueError(f'x has {x.shape[1]} features, but {features} are required') + + x_enc, x_unenc = x.split([features, x.shape[1] - features], dim=1) + encoding = torch.nn.functional.interpolate(self.encoding, size=x_unenc.shape[2:], mode=self.interpolation_mode) + return torch.cat((x_enc + encoding, x_unenc), dim=1) diff --git a/src/mrpro/nn/join.py b/src/mrpro/nn/join.py new file mode 100644 index 000000000..0aed41b8d --- /dev/null +++ b/src/mrpro/nn/join.py @@ -0,0 +1,125 @@ +"""Modules for concatenating or adding tensors.""" + +from collections.abc import Sequence +from typing import Literal + +import torch +from torch.nn import Module + +from mrpro.utils.interpolate import interpolate +from mrpro.utils.pad_or_crop import pad_or_crop + + +def _fix_shapes( + xs: Sequence[torch.Tensor], + mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'], + dim: Sequence[int], +) -> tuple[torch.Tensor, ...]: + """Fix shapes of input tensors by padding or cropping.""" + if mode == 'fail': + return tuple(xs) + + shapes = [[x.shape[d] for d in dim] for x in xs] + if mode == 'crop': # smallest as target + target = tuple(min(s) for s in zip(*shapes, strict=True)) + else: # largest as target + target = tuple(max(s) for s in zip(*shapes, strict=True)) + if mode == 'linear' or mode == 'nearest': + return tuple(interpolate(x, target, dim=dim, mode=mode) for x in xs) + if mode == 'zero' or mode == 'crop': + return tuple(pad_or_crop(x, target, dim=dim, mode='constant', value=0.0) for x in xs) + else: + return tuple(pad_or_crop(x, target, dim=dim, mode=mode) for x in xs) + + +class Concat(Module): + """Concatenate tensors along the channel dimension.""" + + def __init__( + self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'] = 'fail', dim: int = 1 + ) -> None: + """Initialize Concat. + + Parameters + ---------- + mode + How to handle mismatched spatial dimensions: + - 'fail': do not align, raise error if shapes mismatch + - 'crop': center-crop to smallest spatial size + - 'zero': zero-pad to largest spatial size + - 'replicate': pad by edge value replication + - 'circular': circular padding + - 'linear': linear interpolation to largest spatial size + - 'nearest': nearest neighbor interpolation to largest spatial size + dim + Dimension along which to concatenate. + """ + super().__init__() + modes = {'fail', 'crop', 'zero', 'replicate', 'circular', 'interpolate'} + if mode not in modes: + raise ValueError(f'mode must be one of {modes}') + self.mode = mode + self.dim = dim + + def forward(self, *xs: torch.Tensor) -> torch.Tensor: + """Concatenate input tensors.""" + xs = _fix_shapes(xs, self.mode, dim=[i for i in range(max(x.ndim for x in xs)) if i != self.dim]) + return torch.cat(xs, dim=self.dim) + + def __call__(self, *xs: torch.Tensor) -> torch.Tensor: + """ + Concatenate input tensors. + + Parameters + ---------- + xs + Input tensors + + Returns + ------- + Concatenated tensor + """ + return super().__call__(*xs) + + +class Add(Module): + """Add tensors.""" + + def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'] = 'fail') -> None: + """Initialize Add. + + Parameters + ---------- + mode : {'fail', 'crop', 'zero', 'replicate', 'circular'}, default='zero' + How to handle mismatched spatial dimensions: + - 'fail': do not align, raise error if shapes mismatch + - 'crop': center-crop to smallest spatial size + - 'zero': zero-pad to largest spatial size + - 'replicate': pad by edge value replication + - 'circular': circular padding + """ + super().__init__() + modes = {'fail', 'crop', 'zero', 'replicate', 'circular'} + if mode not in modes: + raise ValueError(f'mode must be one of {modes}') + self.mode = mode + + def forward(self, *xs: torch.Tensor) -> torch.Tensor: + """Add input tensors.""" + xs = _fix_shapes(xs, self.mode, dim=range(max(x.ndim for x in xs))) + return sum(xs, start=torch.tensor(0.0)) + + def __call__(self, *xs: torch.Tensor) -> torch.Tensor: + """ + Add input tensors. + + Parameters + ---------- + xs + Input tensors + + Returns + ------- + Summed tensor + """ + return super().__call__(*xs) diff --git a/src/mrpro/nn/ndmodules.py b/src/mrpro/nn/ndmodules.py new file mode 100644 index 000000000..b4bf089ea --- /dev/null +++ b/src/mrpro/nn/ndmodules.py @@ -0,0 +1,176 @@ +"""Helper functions to get the correct N-dimensional module.""" + +import torch + + +def ConvND(dim: int) -> type[torch.nn.Conv1d] | type[torch.nn.Conv2d] | type[torch.nn.Conv3d]: # noqa: N802 + """Get the `dim`-dimensional convolution class. + + Parameters + ---------- + dim + The dimension of the convolution. + + Returns + ------- + The convolution class. + """ + match dim: + case 1: + return torch.nn.Conv1d + case 2: + return torch.nn.Conv2d + case 3: + return torch.nn.Conv3d + case _: + raise NotImplementedError(f'ConvND for dim {dim} not implemented. Raise an issue if you need this.') + + +def ConvTransposeND( # noqa: N802 + dim: int, +) -> type[torch.nn.ConvTranspose1d] | type[torch.nn.ConvTranspose2d] | type[torch.nn.ConvTranspose3d]: + """Get the `dim`-dimensional transposed convolution class. + + Parameters + ---------- + dim + The dimension of the transposed convolution. + + Returns + ------- + The transposed convolution class. + """ + match dim: + case 1: + return torch.nn.ConvTranspose1d + case 2: + return torch.nn.ConvTranspose2d + case 3: + return torch.nn.ConvTranspose3d + case _: + raise NotImplementedError( + f'ConvTransposeND for dim {dim} not implemented. Raise an issue if you need this.' + ) + + +def MaxPoolND(dim: int) -> type[torch.nn.MaxPool1d] | type[torch.nn.MaxPool2d] | type[torch.nn.MaxPool3d]: # noqa: N802 + """Get the `dim`-dimensional max pooling class. + + Parameters + ---------- + dim + The dimension of the max pooling. + + Returns + ------- + The max pooling class. + """ + match dim: + case 1: + return torch.nn.MaxPool1d + case 2: + return torch.nn.MaxPool2d + case 3: + return torch.nn.MaxPool3d + case _: + raise NotImplementedError(f'MaxPoolNd for dim {dim} not implemented. Raise an issue if you need this.') + + +def AvgPoolND(dim: int) -> type[torch.nn.AvgPool1d] | type[torch.nn.AvgPool2d] | type[torch.nn.AvgPool3d]: # noqa: N802 + """Get the `dim`-dimensional average pooling class. + + Parameters + ---------- + dim + The dimension of the average pooling. + + Returns + ------- + The average pooling class. + """ + match dim: + case 1: + return torch.nn.AvgPool1d + case 2: + return torch.nn.AvgPool2d + case 3: + return torch.nn.AvgPool3d + case _: + raise NotImplementedError(f'AvgPoolNd for dim {dim} not implemented. Raise an issue if you need this.') + + +def AdaptiveAvgPoolND( # noqa: N802 + dim: int, +) -> type[torch.nn.AdaptiveAvgPool1d] | type[torch.nn.AdaptiveAvgPool2d] | type[torch.nn.AdaptiveAvgPool3d]: + """Get the `dim`-dimensional adaptive average pooling class. + + Parameters + ---------- + dim + The dimension of the adaptive average pooling. + + Returns + ------- + The adaptive average pooling class. + """ + match dim: + case 1: + return torch.nn.AdaptiveAvgPool1d + case 2: + return torch.nn.AdaptiveAvgPool2d + case 3: + return torch.nn.AdaptiveAvgPool3d + case _: + raise NotImplementedError( + f'AdaptiveAvgPoolNd for dim {dim} not implemented. Raise an issue if you need this.' + ) + + +def InstanceNormND( # noqa: N802 + dim: int, +) -> type[torch.nn.InstanceNorm1d] | type[torch.nn.InstanceNorm2d] | type[torch.nn.InstanceNorm3d]: + """Get the `dim`-dimensional instance normalization class. + + Parameters + ---------- + dim + The dimension of the instance normalization. + + Returns + ------- + The instance normalization class. + """ + match dim: + case 1: + return torch.nn.InstanceNorm1d + case 2: + return torch.nn.InstanceNorm2d + case 3: + return torch.nn.InstanceNorm3d + case _: + raise NotImplementedError(f'InstanceNormNd for dim {dim} not implemented. Raise an issue if you need this.') + + +def BatchNormND( # noqa: N802 + dim: int, +) -> type[torch.nn.BatchNorm1d] | type[torch.nn.BatchNorm2d] | type[torch.nn.BatchNorm3d]: + """Get the `dim`-dimensional batch normalization class. + + Parameters + ---------- + dim + The dimension of the batch normalization. + + Returns + ------- + The batch normalization class. + """ + match dim: + case 1: + return torch.nn.BatchNorm1d + case 2: + return torch.nn.BatchNorm2d + case 3: + return torch.nn.BatchNorm3d + case _: + raise NotImplementedError(f'BatchNormNd for dim {dim} not implemented. Raise an issue if you need this.') diff --git a/src/mrpro/nn/nets/CNN.py b/src/mrpro/nn/nets/CNN.py new file mode 100644 index 000000000..3fbabcc4f --- /dev/null +++ b/src/mrpro/nn/nets/CNN.py @@ -0,0 +1,60 @@ +"""Simple Convolutional Neural Network.""" + +from collections.abc import Sequence +from itertools import pairwise + +from torch.nn import ReLU + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import ConvND +from mrpro.nn.Residual import Residual +from mrpro.nn.Sequential import Sequential + + +class CNN(Sequential): + """A simple CNN network.""" + + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + features: Sequence[int], + norm: bool = True, + residual: bool = True, + cond_dim: int = 0, + ): + """Initialize the CNN. + + Parameters + ---------- + dim + The number of spatial dimensions. + channels_in + The number of input channels. + channels_out + The number of output channels. + features + The number of features in each layer. The length of the list is the number of hidden layers. + norm + Whether to use layer normalization. + residual + Whether to use residual connections. + cond_dim + The dimension of the conditioning tensor. If 0, no FiLM is used. + """ + super().__init__() + channels = [channels_in, *features] + for i, (channels_current, channels_next) in enumerate(pairwise(channels)): + block = Sequential(ConvND(dim)(channels_current, channels_next, 3, padding=1), ReLU(True)) + if norm: + block.append(GroupNorm(1)) + if cond_dim > 0 and i % 2 == 0: + block.append(FiLM(channels_next, cond_dim)) + if residual: + self.append(Residual(block)) + else: + self.append(block) + + self.append(ConvND(dim)(channels_next, channels_out, 3, padding=1)) diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py new file mode 100644 index 000000000..1f49a0297 --- /dev/null +++ b/src/mrpro/nn/nets/DCAE.py @@ -0,0 +1,280 @@ +"""Deep Compression Autoencoder.""" + +from collections.abc import Sequence +from typing import Literal + +import torch +from torch.nn import Module, ReLU, SiLU + +from mrpro.nn.GluMBConvResBlock import GluMBConvResBlock +from mrpro.nn.LinearSelfAttention import LinearSelfAttention +from mrpro.nn.MultiHeadAttention import MultiHeadAttention +from mrpro.nn.ndmodules import ConvND +from mrpro.nn.nets.VAE import VAE +from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample +from mrpro.nn.Residual import Residual +from mrpro.nn.RMSNorm import RMSNorm +from mrpro.nn.Sequential import Sequential + + +class CNNBlock(Residual): + """Block with two convolutions and normalization. + + As used in the DCAE [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + dim: int, + channels: int, + ): + """Initialize the CNNBlock. + + Parameters + ---------- + dim : int + The spatial dimension of the input tensor. + channels : int + The number of channels in the input tensor. + """ + super().__init__( + Sequential( + ConvND(dim)(channels, channels, kernel_size=3, padding=1), + SiLU(True), + ConvND(dim)(channels, channels, kernel_size=3, padding=1, bias=False), + RMSNorm(channels), + ) + ) + + +class EfficientViTBlock(Module): + """Efficient Vision Transformer block with optional linear attention. + + As used in the DCAE [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + dim: int, + channels: int, + n_heads: int, + expand_ratio: int = 4, + linear_attn: bool = False, + ): + """Initialize the EfficientViTBlock. + + Parameters + ---------- + dim : int + The spatial dimension of the input tensor. + channels : int + The number of channels in the input tensor. + n_heads : int + The number of attention heads. + expand_ratio : int + The expansion ratio of the GluMBConvResBlock. + linear_attn : bool + Whether to use linear attention instead of softmax attention with quadratic complexity. + """ + super().__init__() + if linear_attn: + attention: Module = LinearSelfAttention(channels, channels, n_heads) + else: + attention = MultiHeadAttention(channels, channels, n_heads, features_last=False) + self.context_module = Residual(Sequential(attention, RMSNorm(channels))) + self.local_module = GluMBConvResBlock( + dim=dim, + channels_in=channels, + channels_out=channels, + expand_ratio=expand_ratio, + ) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the EfficientViTBlock. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Output tensor + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for EfficientViTBlock.""" + x = self.context_module(x) + x = self.local_module(x) + return x + + +class Encoder(Sequential): + """Encoder for DCAE. + + As used in the DC-Autoencoder [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + dim: int = 2, + channels_in: int = 3, + channels_out: int = 32, + block_types: Sequence[Literal['CNN', 'LinearViT', 'ViT']] = ('CNN', 'CNN', 'LinearViT', 'LinearViT', 'ViT'), + widths: Sequence[int] = (256, 512, 512, 1024, 1024), + depths: Sequence[int] = (4, 6, 2, 2, 2), + ): + """Initialize the Encoder. + + The length of the `block_types`, `widths`, and `depths` must be the same and determine + the number of stages in the encoder. Between the stages, downsampling is performed. + + Parameters + ---------- + dim : int + The spatial dimension of the input tensor. + channels_in : int + The number of channels in the input tensor, i.e. the latent space + channels_out : int + The number of channels in the output tensor, i.e. the original space + block_types : Sequence[str] + The types of blocks to use in the decoder. + widths : Sequence[int] + The widths of the blocks in the decoder, i.e. the number of channels in the blocks + depths : Sequence[int] + The depths of the blocks in the decoder, i.e. the number blocks in the stage + """ + super().__init__() + self.append(PixelUnshuffleDownsample(dim, channels_in, widths[0], downscale_factor=2, residual=False)) + if len(block_types) != len(widths) or len(block_types) != len(depths): + raise ValueError('block_types, widths, and depths must have the same length') + for block_type, width, next_width, depth in zip(block_types, widths, (*widths[1:], None), depths, strict=False): + match block_type: + case 'CNN': + stage: list[Module] = [CNNBlock(dim, width) for _ in range(depth)] + case 'LinearViT': + stage = [EfficientViTBlock(dim, width, max(1, width // 32), linear_attn=True) for _ in range(depth)] + case 'ViT': + stage = [EfficientViTBlock(dim, width, max(1, width // 32)) for _ in range(depth)] + case _: + raise ValueError(f'Block type {block_type} not supported') + self.append(Sequential(*stage)) + if next_width: + self.append(PixelUnshuffleDownsample(dim, width, next_width, downscale_factor=2, residual=True)) + self.append( + Sequential( + RMSNorm(widths[-1]), + ReLU(), + PixelUnshuffleDownsample(dim, widths[-1], channels_out, downscale_factor=1, residual=True), + ) + ) + + +class Decoder(Sequential): + """Decoder for DCAE. + + As used in the DC-Autoencoder [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + dim: int = 2, + channels_in: int = 32, + channels_out: int = 3, + block_types: Sequence[Literal['ViT', 'LinearViT', 'CNN']] = ('ViT', 'LinearViT', 'LinearViT', 'CNN', 'CNN'), + widths: Sequence[int] = (1024, 1024, 512, 512, 256), + depths: Sequence[int] = (2, 2, 2, 6, 4), + ): + """Initialize the Decoder. + + The length of the `block_types`, `widths`, and `depths` must be the same and determine + the number of stages in the decoder. Between the stages, upsampling is performed. + + Parameters + ---------- + dim : int + The spatial dimension of the input tensor. + channels_in : int + The number of channels in the input tensor, i.e. the latent space + channels_out : int + The number of channels in the output tensor, i.e. the original space + block_types : Sequence[str] + The types of blocks to use in the decoder. + widths : Sequence[int] + The widths of the blocks in the decoder, i.e. the number of channels in the blocks + depths : Sequence[int] + The depths of the blocks in the decoder, i.e. the number blocks in the stage + """ + super().__init__() + if not (len(block_types) == len(widths) == len(depths)): + raise ValueError('block_types, widths, and depths must have the same length') + self.append(PixelShuffleUpsample(dim, channels_in, widths[0], upscale_factor=1, residual=True)) + + for block_type, width, next_width, depth in zip(block_types, widths, (*widths[1:], None), depths, strict=False): + match block_type: + case 'CNN': + stage: list[Module] = [CNNBlock(dim, width) for _ in range(depth)] + case 'LinearViT': + stage = [EfficientViTBlock(dim, width, n_heads=width // 32, linear_attn=True) for _ in range(depth)] + case 'ViT': + stage = [ + EfficientViTBlock(dim, width, n_heads=width // 32, linear_attn=False) for _ in range(depth) + ] + case _: + raise ValueError(f'Block type {block_type} not supported') + self.append(Sequential(*stage)) + if next_width: + self.append(PixelShuffleUpsample(dim, width, next_width, upscale_factor=2, residual=True)) + + self.append( + Sequential( + RMSNorm(widths[-1]), + ReLU(), + PixelShuffleUpsample(dim, widths[-1], channels_out, upscale_factor=2), + ) + ) + + +class DCVAE(VAE): + """Variational Autoencoder based on DCAE. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + dim: int, + channels: int, + latent_dim: int = 32, + block_types: Sequence[Literal['CNN', 'LinearViT', 'ViT']] = ('CNN', 'CNN', 'LinearViT', 'LinearViT', 'ViT'), + widths: Sequence[int] = (256, 512, 512, 1024, 1024), + depths: Sequence[int] = (4, 6, 2, 2, 2), + ): + """Initialize the DCVAE.""" + encoder = Encoder(dim, channels, latent_dim * 2, block_types, widths, depths) + decoder = Decoder(dim, latent_dim, channels, block_types[::-1], widths[::-1], depths[::-1]) + super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py new file mode 100644 index 000000000..f95fd2f98 --- /dev/null +++ b/src/mrpro/nn/nets/Restormer.py @@ -0,0 +1,208 @@ +"""Restormer implementation.""" + +from collections.abc import Sequence +from itertools import pairwise + +import torch +from torch.nn import Module + +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.FiLM import FiLM +from mrpro.nn.join import Concat +from mrpro.nn.ndmodules import ConvND, InstanceNormND +from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder +from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample +from mrpro.nn.Sequential import Sequential +from mrpro.nn.TransposedAttention import TransposedAttention + + +class GDFN(Module): + """Gated depthwise feed forward network. + + As used in the Restormer architecture. + """ + + def __init__(self, dim: int, channels: int, mlp_ratio: float): + """Initialize GDFN. + + Parameters + ---------- + dim : int + Dimension of the input space + channels : int + Number of input/output channels + mlp_ratio : float + Ratio for hidden dimension expansion + """ + super().__init__() + + hidden_features = int(channels * mlp_ratio) + self.project_in = ConvND(dim)(channels, hidden_features * 2, kernel_size=1) + self.depthwise_conv = ConvND(dim)( + hidden_features * 2, + hidden_features * 2, + kernel_size=3, + stride=1, + padding=1, + groups=hidden_features * 2, + ) + self.project_out = ConvND(dim)(hidden_features, channels, kernel_size=1) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the gated depthwise feed forward network. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Output tensor + """ + x = self.project_in(x) + x1, x2 = self.depthwise_conv(x).chunk(2, dim=1) + x = x1 * torch.sigmoid(x2) + x = self.project_out(x) + return x + + +class RestormerBlock(CondMixin, Module): + """Transformer block with transposed attention and gated depthwise feed forward network.""" + + def __init__(self, dim: int, channels: int, n_heads: int, mlp_ratio: float, cond_dim: int = 0): + """Initialize RestormerBlock. + + Parameters + ---------- + dim + Dimension of the input space + channels : int + Number of input/output channels + n_heads + Number of attention heads + mlp_ratio + Ratio for hidden dimension expansion + cond_dim + Dimension of conditioning input. If 0, no conditioning is applied. + """ + super().__init__() + self.norm1 = Sequential(InstanceNormND(dim)(channels)) + self.attn = TransposedAttention(dim, channels, channels, n_heads) + self.norm2 = Sequential(InstanceNormND(dim)(channels)) + self.ffn = GDFN(dim, channels, mlp_ratio) + if cond_dim > 0: + self.norm2.append(FiLM(channels=channels, cond_dim=cond_dim)) + + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply Restormer block. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor. If None, no conditioning is applied. + + Returns + ------- + Output tensor + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Forward pass for RestormerBlock.""" + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x, cond=cond)) + return x + + +class Restormer(UNetBase): + """Restormer architecture. + + Implements the Restormer [ZAM22]_ network, which is a U-shaped transformer + with channel wise attention and depthwise convolutions in the feed forward network. + + References + ---------- + .. [ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for high-resolution image restoration." + CVPR 2022, https://arxiv.org/pdf/2111.09881.pdf + """ + + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + n_blocks: Sequence[int] = (4, 6, 6, 8), + n_refinement_blocks: int = 4, + n_heads: Sequence[int] = (1, 2, 4, 8), + n_channels_per_head: int = 48, + mlp_ratio: float = 2.66, + cond_dim: int = 0, + ): + """Initialize Restormer. + + Parameters + ---------- + dim + Dimension of the input space + channels_in + Number of input channels + channels_out + Number of output channels + n_blocks + Number of blocks in each stage + n_refinement_blocks + Number of refinement blocks + n_heads + Number of attention heads in each stage + n_channels_per_head + Number of channels per attention head + mlp_ratio + Ratio for hidden dimension expansion + cond_dim + Dimension of conditioning input. If 0, no conditioning is applied. + """ + + def blocks(n_heads: int, n_blocks: int): + layers = Sequential( + *(RestormerBlock(dim, n_channels_per_head * n_heads, n_heads, mlp_ratio) for _ in range(n_blocks)) + ) + + if cond_dim > 0 and n_blocks > 1: + layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, cond_dim=cond_dim)) + return layers + + first_block = ConvND(dim)(channels_in, n_channels_per_head, kernel_size=3, stride=1, padding=1, bias=False) + encoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)] + down_blocks = [ + PixelUnshuffleDownsample(dim, n_channels_per_head * head_current, n_channels_per_head * head_next) + for head_current, head_next in pairwise(n_heads) + ] + middle_block = blocks(n_heads[-1], n_blocks[-1]) + encoder = UNetEncoder( + first_block=first_block, + blocks=encoder_blocks, + down_blocks=down_blocks, + middle_block=middle_block, + ) + + up_blocks = [ + PixelShuffleUpsample(dim, n_channels_per_head * head_next, n_channels_per_head * head_current) + for head_current, head_next in pairwise(n_heads) + ][::-1] + concat_blocks = [Concat() for _ in range(len(encoder_blocks))] + decoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)][::-1] + last_block = Sequential( + *(RestormerBlock(dim, n_channels_per_head, n_heads[0], mlp_ratio) for _ in range(n_refinement_blocks)), + ConvND(dim)(n_channels_per_head, channels_out, kernel_size=3, stride=1, padding=1), + ) + decoder = UNetDecoder( + blocks=decoder_blocks, + up_blocks=up_blocks, + concat_blocks=concat_blocks, + last_block=last_block, + ) + + super().__init__(encoder=encoder, decoder=decoder) diff --git a/src/mrpro/nn/nets/SwinIR.py b/src/mrpro/nn/nets/SwinIR.py new file mode 100644 index 000000000..e3e8a440a --- /dev/null +++ b/src/mrpro/nn/nets/SwinIR.py @@ -0,0 +1,247 @@ +"""SwinIR implementation.""" + +import torch +from torch.nn import GELU, Module + +from mrpro.nn.DropPath import DropPath +from mrpro.nn.FiLM import FiLM +from mrpro.nn.ndmodules import ConvND, InstanceNormND +from mrpro.nn.Sequential import Sequential +from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention + + +class SwinTransformerLayer(Module): + """Swin Transformer layer. + + Implements a single layer of the Swin Transformer architecture. + """ + + def __init__( + self, + dim: int, + channels: int, + n_heads: int, + window_size: int, + mlp_ratio: int = 4, + emb_dim: int = 0, + p_droppath: float = 0.0, + ): + """Initialize SwinTransformerLayer. + + Parameters + ---------- + dim : int + Dimension of the input space + channels : int + Number of input/output channels + n_heads : int + Number of attention heads + window_size : int + Size of the attention window + mlp_ratio : int + Ratio for hidden dimension expansion in MLP + emb_dim : int + Dimension of conditioning input. If 0, no FiLM conditioning is used. + p_droppath : float + Droppath probability for MLP + """ + super().__init__() + self.norm1 = InstanceNormND(dim)(channels) + self.attn = ShiftedWindowAttention(dim, channels, channels, n_heads, window_size) + self.norm2 = Sequential(InstanceNormND(dim)(channels)) + if emb_dim > 0: + self.norm2.append(FiLM(channels=channels, cond_dim=emb_dim)) + self.mlp = Sequential( + ConvND(dim)(channels, channels * mlp_ratio, 1), + GELU('tanh'), + ConvND(dim)(channels * mlp_ratio, channels, 1), + DropPath(p_droppath), + ) + + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the Swin Transformer layer. + + Parameters + ---------- + x : torch.Tensor + Input tensor + cond : torch.Tensor | None, optional + Conditioning input + + Returns + ------- + torch.Tensor + Output tensor + """ + return super().__call__(x, cond) + + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the Swin Transformer layer.""" + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x, cond=cond)) + return x + + +class ResidualSwinTransformerBlock(Module): + """Residual Swin Transformer block (RSTB). + + Combines a Swin Transformer layer with a residual connection, + as used in the SwinIR architecture. + """ + + def __init__( + self, + dim: int, + channels: int, + n_heads: int, + window_size: int, + depth: int, + emb_dim: int = 0, + p_droppath: float = 0.0, + mlp_ratio: int = 4, + ): + """Initialize ResidualSwinTransformerBlock. + + Parameters + ---------- + dim : int + Dimension of the input space + channels : int + Number of input/output channels + n_heads : int + Number of attention heads + window_size : int + Size of the attention window + depth : int + Number of Swin Transformer layers + emb_dim : int, optional + Dimension of conditioning input. If 0, no FiLM conditioning is used. + p_droppath : float, optional + Droppath probability for MLP. + mlp_ratio : int, optional + Ratio for hidden dimension expansion in MLP + """ + super().__init__() + self.layers = Sequential( + *( + SwinTransformerLayer( + dim, channels, n_heads, window_size, emb_dim=emb_dim, p_droppath=p_droppath, mlp_ratio=mlp_ratio + ) + for _ in range(depth) + ) + ) + self.conv = ConvND(dim)(channels, channels, 3, padding=1) + + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the residual Swin Transformer block. + + Parameters + ---------- + x : torch.Tensor + Input tensor + cond : torch.Tensor | None, optional + Conditioning input. If None, no FiLM conditioning is used. + + Returns + ------- + torch.Tensor + Output tensor + """ + return super().__call__(x, cond) + + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the residual Swin Transformer block.""" + return x + self.conv(self.layers(x, cond=cond)) + + +class SwinIR(Module): + """SwinIR architecture. + + Implements the SwinIR [LZL21]_ network, which is a Swin Transformer based + image restoration network. + + References + ---------- + .. [LZL21] Liang, Jie, et al. "SwinIR: Image restoration using swin transformer." + ICCVW 2021, https://arxiv.org/pdf/2108.10257.pdf + """ + + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + channels_per_head: int = 16, + n_heads: int = 6, + window_size: int = 64, + n_blocks: int = 6, + n_attn_per_block: int = 6, + emb_dim: int = 0, + p_droppath: float = 0.0, + mlp_ratio: int = 4, + ): + """Initialize SwinIR. + + Parameters + ---------- + dim : int + Dimension of the input space + channels_in : int + Number of input channels + channels_out : int + Number of output channels + channels_per_head : int, optional + Number of channels per attention head + n_heads : int, optional + Number of attention heads + window_size : int + Size of the attention window. Inputs sizes must be divisible by this value. + n_blocks : int + Number of residual blocks + n_attn_per_block : int + Number of attention layers per block + emb_dim : int, optional + Dimension of conditioning input. If 0, no FiLM conditioning is used. + p_droppath : float, optional + Droppath probability for MLP. + mlp_ratio : int, optional + Ratio for hidden dimension expansion in MLP. + """ + super().__init__() + self.first = ConvND(dim)(channels_in, channels_per_head * n_heads, kernel_size=3, padding=1) + self.blocks = Sequential( + *( + ResidualSwinTransformerBlock( + dim, + channels_per_head * n_heads, + n_heads, + window_size, + n_attn_per_block, + emb_dim, + p_droppath, + mlp_ratio, + ) + for _ in range(n_blocks) + ) + ) + self.last = ConvND(dim)(channels_per_head * n_heads, channels_out, kernel_size=3, padding=1) + + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply SwinIR. + + Parameters + ---------- + x : torch.Tensor + Input tensor + cond : torch.Tensor | None, optional + Conditioning input. If None, no FiLM conditioning is used. + + Returns + ------- + torch.Tensor + Output tensor + """ + x = self.first(x) + x = self.blocks(x, cond=cond) + x = self.last(x) + return x diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py new file mode 100644 index 000000000..00d1ea510 --- /dev/null +++ b/src/mrpro/nn/nets/UNet.py @@ -0,0 +1,706 @@ +"""UNet variants.""" + +from collections.abc import Sequence +from functools import partial +from itertools import pairwise + +import torch +from torch.nn import Identity, Module, ModuleList, ReLU, SiLU + +from mrpro.nn.AttentionGate import AttentionGate +from mrpro.nn.CondMixin import call_with_cond +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.join import Concat +from mrpro.nn.ndmodules import ConvND, MaxPoolND +from mrpro.nn.ResBlock import ResBlock +from mrpro.nn.Sequential import Sequential +from mrpro.nn.SpatialTransformerBlock import SpatialTransformerBlock +from mrpro.nn.Upsample import Upsample + + +class UNetEncoder(Module): + """Encoder.""" + + def __init__( + self, + first_block: Module, + blocks: Sequence[Module], + down_blocks: Sequence[Module], + middle_block: Module, + ) -> None: + """Initialize the UNetEncoder.""" + super().__init__() + self.first = first_block + """The first block. Should expand from the number of input channels.""" + + self.blocks = ModuleList(blocks) + """The encoder blocks. Order is highest resolution to lowest resolution.""" + + self.down_blocks = ModuleList(down_blocks) + """The downsampling blocks""" + + self.middle_block = middle_block + """Also called bottleneck block""" + + def __len__(self): + """Get the number of resolutions levels.""" + return len(self.down_blocks) + 1 + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple[torch.Tensor, ...]: + """Apply to Network.""" + call = partial(call_with_cond, cond=cond) + + x = call(self.first, x) + + xs = [] + for block, down in zip(self.blocks, self.down_blocks, strict=True): + x = call(block, x) + xs.append(x) + x = call(down, x) + + x = call(self.middle_block, x) + + return (*xs, x) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple[torch.Tensor, ...]: + """Apply to Network. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor. + + Returns + ------- + The tensors at the different resolutions, highest resolution first. + """ + return super().__call__(x, cond=cond) + + +class UNetDecoder(Module): + """Decoder.""" + + def __init__( + self, + blocks: Sequence[Module], + up_blocks: Sequence[Module], + concat_blocks: Sequence[Module], + last_block: Module, + ) -> None: + """Initialize the UNetDecoder.""" + super().__init__() + self.blocks = ModuleList(blocks) + """The decoder blocks. Order is lowest resolution to highest resolution.""" + + self.up_blocks = ModuleList(up_blocks) + """The upsampling blocks""" + + self.concat_blocks = ModuleList(concat_blocks) + """Joins the skip connections with the upsampled features from a lower resolution level""" + + self.last_block = last_block + """The last block. Should reduce to the number of output channels.""" + + def __len__(self): + """Get the number of resolutions levels.""" + return len(self.up_blocks) + 1 + + def forward(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network.""" + call = partial(call_with_cond, cond=cond) + + x = hs[-1] # lowest resolution, from middle block + for block, up, concat, h in zip(self.blocks, self.up_blocks, self.concat_blocks, hs[-2::-1], strict=True): + x = call(up, x) + x = concat(h, x) + x = call(block, x) + x = call(self.last_block, x) + return x + + def __call__(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network. + + Parameters + ---------- + hs + The tensors at the different resolutions, highest resolution first. + cond + The conditioning tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(hs, cond=cond) + + +class UNetBase(Module): + """Base class for U-shaped networks.""" + + def __init__(self, encoder: UNetEncoder, decoder: UNetDecoder, skip_blocks: Sequence[Module] | None = None) -> None: + """Initialize the UNetBase.""" + super().__init__() + self.encoder = encoder + """The encoder.""" + + self.decoder = decoder + """The decoder.""" + + self.skip_blocks = ModuleList() + """Modifications of the skip connections.""" + + if len(decoder) != len(encoder): + raise ValueError( + 'The number of resolutions in the encoder and decoder must be the same, ' + f'got {len(decoder)} and {len(encoder)}' + ) + + if skip_blocks is None: + self.skip_blocks.extend(Identity() for _ in range(len(decoder))) + elif len(skip_blocks) != len(decoder): + raise ValueError( + f'The number of skip blocks must be the same as the number of resolutions, ' + f'got {len(skip_blocks)} and {len(encoder)}' + ) + else: + self.skip_blocks.extend(skip_blocks) + + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network.""" + xs = self.encoder(x, cond=cond) + xs = tuple( + call_with_cond(self.skip_blocks[i], x, cond=cond) if i < len(self.skip_blocks) else x + for i, x in enumerate(xs) + ) + x = self.decoder(xs, cond=cond) + return x + + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) + + +class BasicUNet(UNetBase): + """Basic UNet. + + A Basic UNet with residual blocks, convolutional downsampling, and nearest neighbor upsampling. + + References + ---------- + .. [UNET] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image + segmentation MICCAI 2015. https://arxiv.org/abs/1505.04597 + """ + + def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Sequence[int], cond_dim: int): + """Initialize the BasicUNet.""" + encoder_blocks: list[Module] = [] + decoder_blocks: list[Module] = [] + down_blocks: list[Module] = [] + up_blocks: list[Module] = [] + concat_blocks: list[Module] = [] + for n_feat, n_feat_next in pairwise(n_features): + encoder_blocks.append(ResBlock(dim, n_feat, n_feat, cond_dim)) + decoder_blocks.append(ResBlock(dim, 2 * n_feat, n_feat, cond_dim)) + down_blocks.append(ConvND(dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) + up_blocks.append(Sequential(Upsample(dim, scale_factor=2), ConvND(dim)(n_feat_next, n_feat, 3, padding=1))) + concat_blocks.append(Concat()) + up_blocks = up_blocks[::-1] + decoder_blocks = decoder_blocks[::-1] + first_block = ConvND(dim)(channels_in, n_features[0], 3, padding=1) + last_block = Sequential( + GroupNorm(n_features[0]), SiLU(), ConvND(dim)(n_features[0], channels_out, 3, padding=1) + ) + middle_block = ResBlock(dim, n_features[-1], n_features[-1], cond_dim) + encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) + decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + super().__init__(encoder, decoder) + + +class UNet(UNetBase): + """UNet. + + U-shaped convolutional network with optional patch attention. + Inspired by the OpenAi DDPM UNet/Latent Diffusion UNet [LDM]_, + significant differences to the vanilla UNet [UNET]_ include: + - Spatial transformer blocks + - Multiple skip connections per resolution + - Convolutional downsampling, nearest neighbor upsampling + - Residual convolution blocks with group normalization and SiLU activation + + References + ---------- + .. [UNET] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image + segmentation MICCAI 2015. https://arxiv.org/abs/1505.04597 + .. [LDM] https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py + """ + + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + attention_depths: Sequence[int] = (-1, -2), + n_features: Sequence[int] = (64, 128, 192, 256), + n_heads: int = 4, + cond_dim: int = 0, + encoder_blocks_per_scale: int = 2, + ) -> None: + """Initialize the UNet. + + Parameters + ---------- + dim + Spatial dimension of the input tensor. + channels_in + Number of channels in the input tensor. + channels_out + Number of channels in the output tensor. + attention_depths + The depths at which to apply attention. + n_features + Number of features at each resolution level. The length determines the number of resolution levels. + n_heads + Number of attention heads. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + encoder_blocks_per_scale + Number of encoder blocks per resolution level. The number of decoder blocks is one more. + """ + depth = len(n_features) + if not all(-depth <= d < depth for d in attention_depths): + raise ValueError( + f'attention_depths must be in the range [-depth, depth], got {attention_depths=} for {depth=}' + ) + attention_depths = tuple(d % depth for d in attention_depths) + if len(attention_depths) != len(set(attention_depths)): + raise ValueError(f'attention_depths must be unique, got {attention_depths=}') + + def attention_block(channels: int) -> Module: + dim_groups = (tuple(range(-dim, 0)),) + return SpatialTransformerBlock( + dim_groups, channels, n_heads, channels_per_head=channels // n_heads, cond_dim=cond_dim + ) + + def block(channels_in: int, channels_out: int, attention: bool) -> Module: + if not attention: + return ResBlock(dim, channels_in, channels_out, cond_dim) + return Sequential(ResBlock(dim, channels_in, channels_out, cond_dim), attention_block(channels_out)) + + first_block = ConvND(dim)(channels_in, n_features[0], 3, padding=1) + encoder_blocks: list[Module] = [] + down_blocks: list[Module] = [] + skip_features = [] + n_feat_old = n_features[0] + for i_level, n_feat in enumerate(n_features): + encoder_blocks.append(Identity()) + skip_features.append(n_feat_old) + for _ in range(encoder_blocks_per_scale): + encoder_blocks.append(block(n_feat_old, n_feat, attention=i_level in attention_depths)) + n_feat_old = n_feat + down_blocks.append(Identity()) + skip_features.append(n_feat_old) + down_blocks.append(ConvND(dim)(n_feat, n_feat, 3, stride=2, padding=1)) + down_blocks[-1] = Identity() # no downsampling after the last resolution level + middle_block = Sequential( + ResBlock(dim, n_features[-1], n_features[-1], cond_dim), + ResBlock(dim, n_features[-1], n_features[-1], cond_dim), + ) + if i_level in attention_depths: + middle_block.insert(1, attention_block(n_features[-1])) + encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) + + decoder_blocks: list[Module] = [] + up_blocks: list[Module] = [Identity()] + for i_level, n_feat in reversed(list(enumerate(n_features))): + decoder_blocks.append( + block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) + ) + n_feat_old = n_feat + for _ in range(encoder_blocks_per_scale): + decoder_blocks.append( + block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) + ) + up_blocks.append(Identity()) + up_blocks.append(Upsample(dim, scale_factor=2)) + up_blocks.pop() # no upsampling after the last resolution level + concat_blocks = [Concat() for _ in range(len(decoder_blocks))] + last_block = Sequential( + GroupNorm(n_features[0]), + SiLU(), + ConvND(dim)(n_features[0], channels_out, 3, padding=1), + ) + decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + + super().__init__(encoder, decoder) + + +class AttentionGatedUNet(UNetBase): + """UNet with attention gates. + + Basic UNet with attention gating of the skip signals by the lower resolution features [OKT18]_. + + References + ---------- + .. [OKT18] Oktay, Ozan, et al. "Attention U-net: Learning where to look for the pancreas." MIDL (2018). + https://arxiv.org/abs/1804.03999 + """ + + def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Sequence[int], cond_dim: int = 0): + """Initialize the AttentionGatedUNet. + + Parameters + ---------- + dim + Spatial dimension of the input tensor. + channels_in + Number of channels in the input tensor. + channels_out + Number of channels in the output tensor. + n_features + Number of features at each resolution level. The length determines the number of resolution levels. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + """ + + def block(channels_in: int, channels_out: int) -> Module: + block = Sequential( + ConvND(dim)(channels_in, channels_out, 3, padding=1), + ReLU(True), + ConvND(dim)(channels_out, channels_out, 3, padding=1), + ReLU(True), + ) + if cond_dim > 0: + block.insert(2, FiLM(channels_out, cond_dim)) + return block + + encoder_blocks: list[Module] = [] + down_blocks: list[Module] = [] + n_feat_old = channels_in + for n_feat in n_features[:-1]: + encoder_blocks.append(block(n_feat_old, n_feat)) + down_blocks.append(MaxPoolND(dim)(2)) + n_feat_old = n_feat + middle_block = block(n_features[-2], n_features[-1]) + encoder = UNetEncoder(Identity(), encoder_blocks, down_blocks, middle_block) + + concat_blocks = [] + decoder_blocks: list[Module] = [] + up_blocks: list[Module] = [] + for n_feat, n_feat_skip in pairwise(n_features[::-1]): + concat_blocks.append(AttentionGate(dim, n_feat, n_feat_skip, n_feat_skip, concatenate=True)) + decoder_blocks.append(block(n_feat + n_feat_skip, n_feat_skip)) + up_blocks.append(Upsample(dim, scale_factor=2)) + last_block = ConvND(dim)(n_features[0], channels_out, 1) + decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + + super().__init__(encoder, decoder) + + +from collections.abc import Sequence + +from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.SeparableResBlock import SeparableResBlock # Assuming SeparableResBlock is here +from mrpro.nn.UNet import UNetBase, UNetDecoder, UNetEncoder + + +class SeparableUNet(UNetBase): + """ + UNet with separable convolutions and controlled downsampling. + """ + + def __init__( + self, + dim: int, # Total number of spatial dimensions (e.g., 2 for 2D, 3 for 3D) + dim_groups: Sequence[tuple[int, ...]], + channels_in: int, + channels_out: int, + n_features: Sequence[int], + cond_dim: int, + downsample_dims: Sequence[Sequence[int]] | None = None, + encoder_blocks_per_scale: int = 2, + ) -> None: + """ + Initialize the SeparableUNet. + + Parameters + ---------- + + """ + class SeparableUNet(UNetBase): + """ + UNet with separable convolutions and attention, and grouped downsampling. + """ + + def __init__( + self, + dim:int, + dim_groups: Sequence[tuple[int, ...]], + channels_in: int, + channels_out: int, + n_features: Sequence[int] = (64, 128, 256, 512), + cond_dim: int = 0, + encoder_blocks_per_scale: int = 2, + attention_depths: Sequence[int] = (-1,), + n_heads: int = 8, + downsample_dims: Sequence[Sequence[int]] | None = None, + ) -> None: + """ + Initialize the SeparableUNet. + + Parameters + ---------- + dim + Total number of non batch, non channel dimensions. + E.g., 2 for 2D images, 3 for 3D volumes or 2D+time for 2D+time images. + dim_groups + A list of tuples, where each tuple contains the spatial dimension + indices for one separable convolution. Each group must contain fewer than 3 dimensions. + channels_in + Number of channels in the input tensor. + channels_out + Number of channels in the output tensor. + n_features + Number of features at each resolution level. + cond_dim + Number of channels in the conditioning tensor. + encoder_blocks_per_scale + Number of encoder blocks per resolution level. + attention_depths + The depths at which to apply attention. + n_heads + Number of attention heads. + downsample_dims + Sequence specifying which absolute spatial dimensions to downsample + at each encoder level. If None, all dimensions in `dim_groups` are combined + and downsampled at each level. + If a downsampling step contains more than 3 dimensions, downsampling is performed separatly for each + dimension. If the length of the sequence is less than the number of resolution levels, the sequence is + repeated. E.g., ``((-1,-2), (-1,-2,-3))`` for 3D data: first level downsamples x,y; second level x,y,z; + third level x,y. + + + """ + depth = len(n_features) + for group in dim_groups: + if len(group)>3: + raise ValueError(f"dim_group {group} can at most contain 3 dimensions. Split it into multiple groups.") + if any(d>dim+2 or d<-dim for d in group): + raise ValueError(f"dim_group {group} contains dimensions that are out of range for dim={dim}") + + attention_depths = tuple(d % depth for d in attention_depths) + if downsample_dims is None: + all_spatial_dims = tuple( + sorted(list(set(d if d<0 else d-dim-2 for group in dim_groups for d in group))) + ) + downsample_dims = (all_spatial_dims,) * (depth - 1) + + + def downsampler(level_dims, c_in, c_out) -> Module: + if len(level_dims)>3: + sequence=Sequence(downsampler(d[0], c_in, c_out) for d in level_dims) + for d in level_dims[1:]: + sequence.append(downsampler(d, c_out, c_out)) + return sequence + return PermutedBlock( + level_dims, ConvND(len(level_dims))(c_in, c_out, 3, stride=2, padding=1)) + + def upsampler(level_dims, c_in, c_out) -> Module: + if len(level_dims)>3: + sequence=Sequence(upsampler(d[0], c_in, c_out) for d in level_dims) + for d in level_dims[1:]: + sequence.append(upsampler(d, c_out, c_out)) + return sequence + return PermutedBlock(level_dims, Upsample(len(level_dims), scale_factor=2, mode="nearest")) + + def block(c_in: int, c_out: int, apply_attention: bool) -> Module: + res_block = SeparableResBlock(dim_groups, c_in, c_out, cond_dim) + if not apply_attention: + return res_block + attn_block = SpatialTransformerBlock(dim_groups, c_out, n_heads, cond_dim=cond_dim) + return Sequential(res_block, attn_block) + + # --- Module Construction --- + first_block = PermutedBlock( + all_spatial_dims, ConvND(len(all_spatial_dims))(channels_in, n_features[0], 3, padding=1) + ) + + # -- Encoder -- + encoder_blocks, down_blocks, skip_features = [], [], [] + c_feat = n_features[0] + for i_level, n_feat_level in enumerate(n_features): + for _ in range(encoder_blocks_per_scale): + encoder_blocks.append(block(c_feat, n_feat_level, i_level in attention_depths)) + c_feat = n_feat_level + skip_features.append(c_feat) + if i_level < depth - 1: + down_blocks.append(_create_downsampler(downsample_dims_per_level[i_level], c_feat, n_features[i_level + 1])) + c_feat = n_features[i_level + 1] + + # -- Middle & Encoder Finalization -- + middle_block = Sequential( + block(c_feat, c_feat, depth - 1 in attention_depths), + block(c_feat, c_feat, depth - 1 in attention_depths), + ) + encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) + + # -- Decoder -- + decoder_blocks, up_blocks = [], [] + for i_level in reversed(range(depth)): + n_feat_level = n_features[i_level] + if i_level > 0: + up_blocks.append(_create_upsampler(downsample_dims_per_level[i_level - 1], c_feat, n_feat_level)) + for _ in range(encoder_blocks_per_scale + 1): + skip_c = skip_features.pop() + decoder_blocks.append(block(c_feat + skip_c, n_feat_level, i_level in attention_depths)) + c_feat = n_feat_level + + decoder_blocks.reverse() + up_blocks.reverse() + + # -- Decoder Finalization -- + concat_blocks = [Concat()] * len(decoder_blocks) + last_block = Sequential( + GroupNorm(n_features[0]), SiLU(), + PermutedBlock(all_spatial_dims, ConvND(len(all_spatial_dims))(n_features[0], channels_out, 3, padding=1)) + ) + decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + + super().__init__(encoder, decoder) + +# class SpatioTemporalUNet(UNetBase): +# """UNet where blocks apply separable convolutions in different dimensions. +# U-shaped convolutional network with optional patch attention. +# Inspired by the OpenAi DDPM UNet/Latent Diffusion UNet [UNET]_, [LDM]_, +# Based on the pseudo-3D residual network of [QUI]_, [TRAN]_, [HO]_, and the residual blocks of [ZIM]_. + +# References +# ---------- +# .. [UNET] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image +# segmentation MICCAI 2015. https://arxiv.org/abs/1505.04597 +# .. [LDM] https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py +# .. [TRAN] Tran, D., Wang, H., Torresani, L., Ray, J., LeCun, Y., & Paluri, M. A closer look at spatiotemporal +# convolutions for action recognition. CVPR 2018. https://arxiv.org/abs/1711.11248 +# .. [QUI] Qiu, Z., Yao, T., & Mei, T. Learning spatio-temporal representation with pseudo-3d residual networks. +# ICCV 2017. https://arxiv.org/abs/1711.10305 +# .. [HO] Ho, J., Salimans, T., Gritsenko, A., Chan, W., Norouzi, M., & Fleet, D. J. Video diffusion models. +# NeurIPS 2022. https://arxiv.org/abs/2209.11168 +# .. [ZIM] Zimmermann, F. F., & Kofler, A. (2023, October). NoSENSE: Learned unrolled cardiac MRI reconstruction +# without explicit sensitivity maps. STACOM MICCAI 2023. https://arxiv.org/abs/2309.15608 +# """ + + +# def __init__( +# self, +# dim: int, +# in_channels: int, +# out_channels: int, +# attention_depths: Sequence[int] = (-1, -2), +# n_features: Sequence[int] = (64, 128, 192, 256), +# n_heads: int = 4, +# cond_dim: int = 0, +# encoder_blocks_per_scale: int = 2, +# temporal_downsampling: bool = False, +# ) -> None: +# """Initialize the UNet. + +# Parameters +# ---------- +# dim +# Spatial dimension of the input tensor. +# channels_in +# Number of channels in the input tensor. +# channels_out +# Number of channels in the output tensor. +# attention_depths +# The depths at which to apply attention. +# n_features +# Number of features at each resolution level. The length determines the number of resolution levels. +# n_heads +# Number of attention heads. +# cond_dim +# Number of channels in the conditioning tensor. If 0, no conditioning is applied. +# encoder_blocks_per_scale +# Number of encoder blocks per resolution level. The number of decoder blocks is one more. +# temporal_downsampling +# Whether to downsample the temporal dimension. +# """ +# depth = len(n_features) +# if not all(-depth <= d < depth for d in attention_depths): +# raise ValueError( +# f'attention_depths must be in the range [-depth, depth], got {attention_depths=} for {depth=}' +# ) +# attention_depths = tuple(d % depth for d in attention_depths) +# if len(attention_depths) != len(set(attention_depths)): +# raise ValueError(f'attention_depths must be unique, got {attention_depths=}') + +# def attention_block(channels: int) -> Module: +# SpatioTemporalBlock(SpatialTransformerBlock( +# dim, channels, n_heads, channels_per_head=channels // n_heads, cond_dim=cond_dim +# ) + +# def block(channels_in: int, channels_out: int, attention: bool) -> Module: +# if not attention: +# return ResBlock(dim, channels_in, channels_out, cond_dim) +# return Sequential(ResBlock(dim, channels_in, channels_out, cond_dim), attention_block(channels_out)) + +# first_block = ConvND(dim)(in_channels, n_features[0], 3, padding=1) +# encoder_blocks: list[Module] = [] +# down_blocks: list[Module] = [] +# skip_features = [] +# n_feat_old = n_features[0] +# for i_level, n_feat in enumerate(n_features): +# encoder_blocks.append(Identity()) +# skip_features.append(n_feat_old) +# for _ in range(encoder_blocks_per_scale): +# encoder_blocks.append(block(n_feat_old, n_feat, attention=i_level in attention_depths)) +# n_feat_old = n_feat +# down_blocks.append(Identity()) +# skip_features.append(n_feat_old) +# down_blocks.append(ConvND(dim)(n_feat, n_feat, 3, stride=2, padding=1)) +# down_blocks[-1] = Identity() # no downsampling after the last resolution level +# middle_block = Sequential( +# ResBlock(dim, n_features[-1], n_features[-1], cond_dim), +# ResBlock(dim, n_features[-1], n_features[-1], cond_dim), +# ) +# if i_level in attention_depths: +# middle_block.insert(1, attention_block(n_features[-1])) +# encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) + +# decoder_blocks: list[Module] = [] +# up_blocks: list[Module] = [Identity()] +# for i_level, n_feat in reversed(list(enumerate(n_features))): +# decoder_blocks.append( +# block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) +# ) +# n_feat_old = n_feat +# for _ in range(encoder_blocks_per_scale): +# decoder_blocks.append( +# block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) +# ) +# up_blocks.append(Identity()) +# up_blocks.append(Upsample(dim, scale_factor=2)) +# up_blocks.pop() # no upsampling after the last resolution level +# concat_blocks = [Concat()] * len(decoder_blocks) +# last_block = Sequential( +# GroupNorm(n_features[0]), +# SiLU(), +# ConvND(dim)(n_features[0], out_channels, 3, padding=1), +# ) +# decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + +# super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py new file mode 100644 index 000000000..dd97efe59 --- /dev/null +++ b/src/mrpro/nn/nets/Uformer.py @@ -0,0 +1,230 @@ +"""Uformer: U-Net with window attention.""" + +from collections.abc import Sequence +from itertools import pairwise + +import torch +from torch.nn import GELU, LeakyReLU, Module + +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.DropPath import DropPath +from mrpro.nn.FiLM import FiLM +from mrpro.nn.join import Concat +from mrpro.nn.ndmodules import ConvND, ConvTransposeND, InstanceNormND +from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder +from mrpro.nn.Sequential import Sequential +from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention + + +class LeWinTransformerBlock(CondMixin, Module): + """Locally-enhanced windowed attention transformer block. + + Part of the Uformer architecture. + """ + + def __init__( + self, + dim: int, + n_channels_per_head: int, + n_heads: int, + window_size: int = 8, + shifted: bool = False, + mlp_ratio: float = 4.0, + p_droppath: float = 0.0, + cond_dim: int = 0, + ) -> None: + """Initialize the LeWinTransformerBlock module. + + Parameters + ---------- + dim + Dimension of the input, e.g. 2 or 3 + n_channels_per_head + Number of features per head + n_heads + Number of attention heads + window_size + Size of the attention window + shifted + Whether to use shifted variant of the attention + mlp_ratio + Ratio of the hidden dimension to the input dimension + p_droppath + Dropout probability for the drop path. + cond_dim + Dimension of a conditioning tensor. If `0`, no FiLM layers are added. + """ + super().__init__() + channels = n_channels_per_head * n_heads + hidden_dim = int(channels * mlp_ratio) + self.norm1 = InstanceNormND(dim)(channels) + self.attn = ShiftedWindowAttention( + dim=dim, + channels_in=channels, + channels_out=channels, + n_heads=n_heads, + window_size=window_size, + shifted=shifted, + ) + self.norm2 = InstanceNormND(dim)(channels) + self.ff = Sequential( + ConvND(dim)(channels, hidden_dim, 1), + GELU(), + ConvND(dim)(hidden_dim, hidden_dim, kernel_size=3, groups=hidden_dim, stride=1, padding=1), + GELU(), + ConvND(dim)(hidden_dim, channels, 1), + ) + if cond_dim > 0: + self.ff.append(FiLM(channels, cond_dim)) + self.modulator = torch.nn.Parameter(torch.empty(channels, *((window_size,) * dim))) + torch.nn.init.trunc_normal_(self.modulator) + self.drop_path = DropPath(droprate=p_droppath) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the transformer block. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor + + Returns + ------- + Output tensor + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the transformer block.""" + modulator = self.modulator.tile([t // s for t, s in zip(x.shape[1:], self.modulator.shape, strict=False)]) + x_mod = self.norm1(x) + modulator + x_attn = self.attn(x_mod) + x_ff = self.ff(self.norm2(x_attn), cond=cond) + return x + self.drop_path(x_ff) + + +class Uformer(UNetBase): + """Uformer: U-Net with window attention. + + Implements the Uformer network proposed in [WANG21]_ + It is SWin-Transformer/U-Net hybrid consisting of (shifted) windows attention transformer layers at different + resolution levels, extended by FiLM layers for conditioning. + + References + ---------- + .. [WANG21] Wang, Z., Cun, X., Bao, J., Zhou, W., Liu, J., & Li, H. Uformer: A general u-shaped transformer for + image restoration. CVPR 2022. https://doi.org/10.48550/arXiv.2106.03106 + """ + + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + n_channels_per_head: int = 32, + n_heads: Sequence[int] = (1, 2, 4, 8), + n_blocks: int = 2, + cond_dim: int = 0, + window_size: int = 8, + mlp_ratio: float = 4.0, + max_droppath_rate: float = 0.1, + ): + """Initialize the Uformer module. + + Parameters + ---------- + dim : int + Dimension of the input, e.g. 2 or 3 + channels_in : int + Number of input channels + channels_out : int + Number of output channels + n_channels_per_head : int, optional + Number of features per head. The number of features at a resolution level is given by + `n_channels_per_head * n_heads`. + n_heads : Sequence[int], optional + Number of attention heads at each resolution level. + n_blocks : int, optional + Number of transformer blocks at each resolution level in the input and output path + cond_dim : int, optional + Dimension of a conditioning tensor. If `0`, no FiLM layers are added. + window_size : int, optional + Size of the attention windows in the (shifted) window attention layers. + mlp_ratio : float, optional + Ratio of the hidden dimension to the input dimension in the feed-forward blocks + max_droppath_rate : float, optional + Maximum drop path rate. As in the original implementation, the drop path rate in the input path + is linearly increased from `0` to `max_droppath_rate` with decreasing resolution. The rate in output + blocks is fixed to `max_droppath_rate`. + """ + + def blocks(n_heads: int, p_droppath: float = 0.0): + return Sequential( + *( + LeWinTransformerBlock( + dim=dim, + n_heads=n_heads, + n_channels_per_head=n_channels_per_head, + window_size=window_size, + mlp_ratio=mlp_ratio, + shifted=bool(i % 2), + p_droppath=p_droppath, + cond_dim=cond_dim, + ) + for i in range(n_blocks) + ) + ) + + first_block = torch.nn.Sequential( + ConvND(dim)(channels_in, n_channels_per_head * n_heads[0], kernel_size=3, stride=1, padding='same'), + LeakyReLU(), + ) + drop_path_rates = torch.linspace(0, max_droppath_rate, len(n_heads)).tolist() + encoder_blocks = [ + blocks(n_heads=n_head, p_droppath=p_droppath_input) + for n_head, p_droppath_input in zip(n_heads[:-1], drop_path_rates[:-1], strict=True) + ] + down_blocks = [ + ConvND(dim)( + n_channels_per_head * n_head_current, + n_channels_per_head * n_head_next, + kernel_size=4, + stride=2, + padding=1, + ) + for n_head_current, n_head_next in pairwise(n_heads) + ] + middle_block = blocks(n_heads=n_heads[-1], p_droppath=max_droppath_rate) + encoder = UNetEncoder( + first_block=first_block, + blocks=encoder_blocks, + down_blocks=down_blocks, + middle_block=middle_block, + ) + + decoder_blocks = [blocks(n_heads=2 * n_head, p_droppath=max_droppath_rate) for n_head in reversed(n_heads[:-1])] + concat_blocks = [Concat() for _ in range(len(decoder_blocks))] + up_blocks = [ + ConvTransposeND(dim)( + n_channels_per_head * n_heads[-1], n_channels_per_head * n_heads[-2], kernel_size=2, stride=2 + ) + ] + for n_head_current, n_head_next in pairwise(reversed(n_heads[:-1])): + up_blocks.append( + ConvTransposeND(dim)( + 2 * n_channels_per_head * n_head_current, n_channels_per_head * n_head_next, kernel_size=2, stride=2 + ) + ) + last_block = ConvND(dim)( + 2 * n_channels_per_head * n_heads[0], channels_out, kernel_size=3, stride=1, padding='same' + ) + decoder = UNetDecoder( + blocks=decoder_blocks, + concat_blocks=concat_blocks, + up_blocks=up_blocks, + last_block=last_block, + ) + + super().__init__(encoder=encoder, decoder=decoder) diff --git a/src/mrpro/nn/nets/VAE.py b/src/mrpro/nn/nets/VAE.py new file mode 100644 index 000000000..cd4a1260a --- /dev/null +++ b/src/mrpro/nn/nets/VAE.py @@ -0,0 +1,64 @@ +"""Variational Autoencoder with a Gaussian latent space.""" + +import torch +from torch.nn import Module + + +class VAE(Module): + """Basic Variational Autoencoder. + + Consists of an encoder to transform the input into a latent space and a decoder to transform the latent space back + into the original space. The encoder should return twice the number of channels as the decoder needs to reconstruct + the input: half of the channels are the mean and the other half the log variance of the latent space. + The reparameterization trick is used to sample from the latent space. + The forward pass returns the reconstructed image and the KL divergence between the latent space and the standard + normal distribution. + """ + + def __init__(self, encoder: Module, decoder: Module): + """Initialize the VAE. + + Parameters + ---------- + encoder : Module + Encoder module. Should return double the number of channels of the latent space. + decoder : Module + Decoder module + """ + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass of the VAE. + + Calculates the reconstruction as well as the KL divergence between the latent space and the + standard normal distribution. + + Parameters + ---------- + x : torch.Tensor + Input tensor + + Returns + ------- + tuple of the reconstructed image and + the KL divergence between the latent space and the standard normal distribution. + """ + return self.forward(x) + + def mode(self, x: torch.Tensor) -> torch.Tensor: + """Mode of the VAE.""" + z = self.encoder(x) + mean, _ = z.chunk(2, dim=1) + return self.decoder(mean) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass of the VAE.""" + z = self.encoder(x) + mean, logvar = z.chunk(2, dim=1) + std = torch.exp(0.5 * logvar) + sample = mean + torch.randn_like(std) * std + reconstruction = self.decoder(sample) + kl = -0.5 * torch.sum(1 + logvar - mean.square() - std.square()) + return reconstruction, kl diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py new file mode 100644 index 000000000..6f540e118 --- /dev/null +++ b/src/mrpro/nn/nets/__init__.py @@ -0,0 +1,16 @@ +from mrpro.nn.nets.Restormer import Restormer +from mrpro.nn.nets.Uformer import Uformer +from mrpro.nn.nets.DCAE import DCVAE +from mrpro.nn.nets.VAE import VAE +from mrpro.nn.nets.UNet import UNet, AttentionGatedUNet +from mrpro.nn.nets.SwinIR import SwinIR + +__all__ = [ + "AttentionGatedUNet", + "DCVAE", + "Restormer", + "SwinIR", + "UNet", + "Uformer", + "VAE" +] \ No newline at end of file diff --git a/src/mrpro/operators/ZeroPadOp.py b/src/mrpro/operators/ZeroPadOp.py index c4adfc831..19f19b23e 100644 --- a/src/mrpro/operators/ZeroPadOp.py +++ b/src/mrpro/operators/ZeroPadOp.py @@ -5,7 +5,7 @@ import torch from mrpro.operators.LinearOperator import LinearOperator -from mrpro.utils import pad_or_crop +from mrpro.utils.pad_or_crop import pad_or_crop class ZeroPadOp(LinearOperator): diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index ad52f0571..04eb9cd92 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -15,8 +15,10 @@ from mrpro.utils.TensorAttributeMixin import TensorAttributeMixin from mrpro.utils.interpolate import interpolate, apply_lowres from mrpro.utils.RandomGenerator import RandomGenerator - +from mrpro.utils.to_tuple import to_tuple +from mrpro.utils.ema import EMADict __all__ = [ + "EMADict", "Indexer", "RandomGenerator", "TensorAttributeMixin", @@ -37,6 +39,7 @@ "split_idx", "summarize_object", "summarize_values", + "to_tuple", "typing", "unit_conversion", "unsqueeze_at", diff --git a/src/mrpro/utils/ema.py b/src/mrpro/utils/ema.py new file mode 100644 index 000000000..23b7d6cf1 --- /dev/null +++ b/src/mrpro/utils/ema.py @@ -0,0 +1,94 @@ +"""Exponential Moving Average (EMA) dictionary.""" + +from collections.abc import ItemsView, KeysView, Mapping, ValuesView +from typing import Any + +import torch + + +class EMADict: + """ + Exponential Moving Average (EMA) dictionary. + + Maintains an EMA of values for each key. On update, existing keys are + updated with EMA, and new keys are added directly. + + Detaches the values from the autograd graph. + + + """ + + def __init__( + self, + decay: float, + ): + """Initialize the EMA dictionary. + + Parameters + ---------- + decay : float + Decay rate for EMA (between 0 and 1). + """ + self.decay: float = decay + if not 0 <= decay <= 1: + raise ValueError(f'Decay must be between 0 and 1, got {decay}') + self._data: dict[str, Any] = {} + + def __getitem__(self, key: str) -> Any: # noqa: ANN401 + """Get the value of the EMA dict for a given key.""" + return self._data[key] + + def __setitem__(self, key: str, value: Any) -> None: # noqa: ANN401 + """Set the value of the EMA dict for a given key.""" + if key in self._data: + old_v = self._data[key] + if isinstance(value, torch.Tensor): + if isinstance(old_v, torch.Tensor) and isinstance(value, torch.Tensor): + if torch.is_floating_point(old_v) or torch.is_complex(old_v): + old_v.mul_(self.decay).add_(value.detach().to(old_v.device), alpha=1.0 - self.decay) + else: + old_v.copy_(value) + return + elif isinstance(old_v, float) and isinstance(value, float): # noqa: SIM114 + self._data[key] = self.decay * old_v + (1.0 - self.decay) * value + return + elif isinstance(old_v, complex) and isinstance(value, complex): + self._data[key] = self.decay * old_v + (1.0 - self.decay) * value + return + if isinstance(value, torch.Tensor): + self._data[key] = value.detach().clone() + else: + self._data[key] = value + + def __delitem__(self, key: str) -> None: + """Delete a key from the EMA dict.""" + del self._data[key] + + def __contains__(self, key: str) -> bool: + """Check if a key is in the EMA dict.""" + return key in self._data + + def values(self) -> ValuesView[Any]: + """Get the values of the EMA dict.""" + return self._data.values() + + def keys(self) -> KeysView[str]: + """Get the keys of the EMA dict.""" + return self._data.keys() + + def items(self) -> ItemsView[str, Any]: + """Get the items of the EMA dict.""" + return self._data.items() + + def update(self, other: Mapping[Any, Any]) -> None: + """Update the EMA dict with another dictionary. + + For existing keys, performs EMA update. For new keys, adds them directly. + + Parameters + ---------- + other : dict + Dictionary to update from. + """ + for k, v in other.items(): + self.__setitem__(k, v) diff --git a/src/mrpro/utils/pad_or_crop.py b/src/mrpro/utils/pad_or_crop.py index af61e7354..5bb82f21d 100644 --- a/src/mrpro/utils/pad_or_crop.py +++ b/src/mrpro/utils/pad_or_crop.py @@ -5,7 +5,6 @@ from typing import Literal import torch -import torch.nn.functional as F # noqa: N812 def normalize_index(ndim: int, index: int) -> int: @@ -35,7 +34,7 @@ def pad_or_crop( data: torch.Tensor, new_shape: Sequence[int] | torch.Size, dim: None | Sequence[int] = None, - mode: Literal['constant', 'replicate', 'circular'] = 'constant', + mode: Literal['constant', 'reflect', 'replicate', 'circular'] = 'constant', value: float = 0.0, ) -> torch.Tensor: """Change shape of data by center cropping or symmetric padding. @@ -50,9 +49,9 @@ def pad_or_crop( Dimensions the `new_shape` corresponds to. `None` is interpreted as last ``len(new_shape)`` dimensions. mode - Mode for padding. + Mode to use for padding. value - Value to use for padding. + Value to use for constant padding. Returns ------- @@ -82,5 +81,5 @@ def pad_or_crop( if any(npad): # F.pad expects paddings in reversed order - data = F.pad(data, npad[::-1], mode=mode, value=value) + data = torch.nn.functional.pad(data, npad[::-1], value=value, mode=mode) return data diff --git a/src/mrpro/utils/to_tuple.py b/src/mrpro/utils/to_tuple.py new file mode 100644 index 000000000..657d7bf56 --- /dev/null +++ b/src/mrpro/utils/to_tuple.py @@ -0,0 +1,36 @@ +"""Standardize an argument to a fixed-length tuple.""" + +from collections.abc import Sequence +from typing import TypeVar + +T = TypeVar('T') + + +def to_tuple(length: int, arg: Sequence[T] | T) -> tuple[T, ...]: + """Standardize an argument to a fixed-length tuple. + + If the argument is a sequence, it checks if its length matches the + specified dimension. If it's a single value, it replicates it `dim` times. + + Parameters + ---------- + length + The expected length of the sequence. + arg + The argument to check. Can be a single value of type T or a + sequence of T. + + Returns + ------- + A tuple of length `dim` containing elements of type T. + + Raises + ------ + ValueError + If `arg` is a sequence and its length does not match `length`. + """ + if isinstance(arg, Sequence): + if not len(arg) == length: + raise ValueError(f'The arguments must be either single values or have length {length}. Got {arg}.') + return tuple(arg) + return (arg,) * length diff --git a/tests/nn/test_attentiongate.py b/tests/nn/test_attentiongate.py new file mode 100644 index 000000000..4b470be1c --- /dev/null +++ b/tests/nn/test_attentiongate.py @@ -0,0 +1,41 @@ +"""Tests for AttentionGate module.""" + +import pytest +from mrpro.nn.AttentionGate import AttentionGate +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'channels_gate', 'channels_in', 'channels_hidden', 'input_shape', 'gate_shape'), + [ + (2, 32, 32, 16, (1, 32, 32, 32), (1, 32, 16, 16)), + (3, 32, 4, 8, (2, 4, 16, 16, 16), (2, 32, 16, 16, 16)), + ], +) +def test_attention_gate(dim, channels_gate, channels_in, channels_hidden, input_shape, gate_shape, device): + """Test AttentionGate output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + gate = rng.float32_tensor(gate_shape).to(device).requires_grad_(True) + attn = AttentionGate( + dim=dim, channels_gate=channels_gate, channels_in=channels_in, channels_hidden=channels_hidden + ).to(device) + output = attn(x, gate) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert gate.grad is not None, 'No gradient computed for gate' + assert not x.isnan().any(), 'NaN values in input' + assert not gate.isnan().any(), 'NaN values in gate' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert not gate.grad.isnan().any(), 'NaN values in gate gradients' + assert attn.project_gate.weight.grad is not None, 'No gradient computed for project_gate' + assert attn.project_x.weight.grad is not None, 'No gradient computed for project_x' + assert attn.psi[1].weight.grad is not None, 'No gradient computed for psi' diff --git a/tests/nn/test_complexaschannel.py b/tests/nn/test_complexaschannel.py new file mode 100644 index 000000000..5eb1e87f8 --- /dev/null +++ b/tests/nn/test_complexaschannel.py @@ -0,0 +1,30 @@ +"""Tests for ComplexAsChannel module.""" + +import pytest +from mrpro.nn.ComplexAsChannel import ComplexAsChannel +from mrpro.utils import RandomGenerator +from torch.nn import Linear + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_complexaschannel(device): + """Test ComplexAsChannel output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + input_shape = (1, 32) + x = rng.complex64_tensor(input_shape).to(device).requires_grad_(True) + module = ComplexAsChannel(Linear(input_shape[1] * 2, input_shape[1] * 2)).to(device) + output = module(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + assert output.is_complex(), 'Output is not complex' + output.sum().abs().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.isnan().any(), 'NaN values in input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert module.module.weight.grad is not None, 'No gradient computed for weight' + assert module.module.bias.grad is not None, 'No gradient computed for bias' diff --git a/tests/nn/test_convert_linear_conv.py b/tests/nn/test_convert_linear_conv.py new file mode 100644 index 000000000..c977f0936 --- /dev/null +++ b/tests/nn/test_convert_linear_conv.py @@ -0,0 +1,150 @@ +"""Tests for converting between Linear and Conv layers.""" + +from typing import Literal + +import pytest +import torch +from mrpro.nn.convert_linear_conv import conv_to_linear, linear_to_conv +from mrpro.utils import RandomGenerator +from torch.nn import Conv1d, Conv2d, Conv3d, Linear + +DEVICES = pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +SHAPES = pytest.mark.parametrize( + ('dim', 'channels_in', 'channels_out', 'bias'), + [ + (1, 32, 64, True), + (2, 16, 32, True), + (3, 8, 16, True), + (3, 1, 1, False), + ], + ids=['1d', '2d', '3d', '3d_no_bias'], +) + + +@SHAPES +@DEVICES +def test_linear_to_conv(device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test converting Linear to Conv layer.""" + rng = RandomGenerator(seed=42) + linear = Linear(channels_in, channels_out, bias=bias).to(device) + linear.weight.data = rng.rand_like(linear.weight) + if bias: + linear.bias.data = rng.rand_like(linear.bias) + + conv = linear_to_conv(linear, dim) + assert isinstance(conv, (Conv1d, Conv2d, Conv3d)[dim - 1]) + + assert conv.in_channels == channels_in + assert conv.out_channels == channels_out + assert conv.kernel_size == (1,) * dim + assert conv.bias is not None if bias else conv.bias is None + + assert conv.weight.device.type == device + if conv.bias is not None: + assert conv.bias.device.type == device + + +@SHAPES +def test_linear_to_conv_functional(dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test functional equivalence of Linear to Conv conversion.""" + rng = RandomGenerator(seed=42) + linear = Linear(channels_in, channels_out, bias=bias) + linear.weight.data = rng.rand_like(linear.weight) + if bias: + linear.bias.data = rng.rand_like(linear.bias) + + conv = linear_to_conv(linear, dim) + spatial_shape = (4,) * dim + x = rng.randn_tensor((2, channels_in, *spatial_shape), torch.float32) + + y_conv = conv(x) + y_conv = y_conv.moveaxis(1, -1).flatten(0, -2) + + x_reshaped = x.moveaxis(1, -1).flatten(0, -2) + y_linear = linear(x_reshaped) + + torch.testing.assert_close(y_conv, y_linear) + + +@SHAPES +@DEVICES +def test_conv_to_linear(device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test converting Conv layer to Linear.""" + rng = RandomGenerator(seed=42) + conv_class = (Conv1d, Conv2d, Conv3d)[dim - 1] + conv = conv_class(channels_in, channels_out, kernel_size=1, bias=bias).to(device) + conv.weight.data = rng.rand_like(conv.weight) + if conv.bias is not None: + conv.bias.data = rng.rand_like(conv.bias) + + linear = conv_to_linear(conv) + + assert isinstance(linear, Linear) + assert linear.in_features == channels_in + assert linear.out_features == channels_out + assert linear.bias is not None if bias else linear.bias is None + + assert linear.weight.device.type == device + if bias: + assert linear.bias.device.type == device + + +@SHAPES +def test_conv_to_linear_functional(dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test functional equivalence of Conv to Linear conversion.""" + rng = RandomGenerator(seed=42) + conv_class = (Conv1d, Conv2d, Conv3d)[dim - 1] + conv = conv_class(channels_in, channels_out, kernel_size=1, bias=bias) + conv.weight.data = rng.rand_like(conv.weight) + if conv.bias is not None: + conv.bias.data = rng.rand_like(conv.bias) + + linear = conv_to_linear(conv) + spatial_shape = (4,) * dim + + x = rng.randn_tensor((2, channels_in, *spatial_shape), torch.float32) + y_conv = conv(x) + y_conv = y_conv.moveaxis(1, -1).flatten(0, -2) + + x_reshaped = x.moveaxis(1, -1).flatten(0, -2) + y_linear = linear(x_reshaped) + + torch.testing.assert_close(y_conv, y_linear) + + +def test_conv_to_linear_invalid_kernel(): + """Test conv_to_linear with invalid kernel size.""" + conv = Conv2d(32, 64, kernel_size=3, bias=True) + with pytest.raises(ValueError, match='Kernel size must be 1'): + conv_to_linear(conv) + + +@SHAPES +@DEVICES +def test_round_trip_conversion( + device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool +) -> None: + """Test round-trip conversion between Linear and Conv layers.""" + rng = RandomGenerator(seed=42) + + linear1 = Linear(channels_in, channels_out, bias=bias).to(device) + linear1.weight.data = rng.rand_like(linear1.weight) + if bias: + linear1.bias.data = rng.rand_like(linear1.bias) + + conv = linear_to_conv(linear1, dim) + linear2 = conv_to_linear(conv) + + assert linear2.in_features == channels_in + assert linear2.out_features == channels_out + assert linear2.bias is not None if bias else linear2.bias is None + + torch.testing.assert_close(linear2.weight, linear1.weight) + if bias: + torch.testing.assert_close(linear2.bias, linear1.bias) diff --git a/tests/nn/test_film.py b/tests/nn/test_film.py new file mode 100644 index 000000000..d49cd476b --- /dev/null +++ b/tests/nn/test_film.py @@ -0,0 +1,38 @@ +"""Tests for FiLM module.""" + +import pytest +from mrpro.nn.FiLM import FiLM +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('channels', 'channels_cond', 'input_shape', 'cond_shape'), + [ + (64, 32, (1, 64, 32, 32), (1, 32)), + (32, 16, (2, 32, 16, 16), (2, 16)), + ], +) +def test_film(channels, channels_cond, input_shape, cond_shape, device): + """Test FiLM output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) + film = FiLM(channels=channels, cond_dim=channels_cond).to(device) + output = film(x, cond=cond) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert cond.grad is not None, 'No gradient computed for condedding' + assert not x.isnan().any(), 'NaN values in input' + assert not cond.isnan().any(), 'NaN values in condedding' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert not cond.grad.isnan().any(), 'NaN values in condedding gradients' + assert film.project is not None, 'Linear layer is not initialized' + assert next(film.project.parameters()).grad is not None, 'No gradient computed for Linear layer' diff --git a/tests/nn/test_groupnorm32.py b/tests/nn/test_groupnorm32.py new file mode 100644 index 000000000..0c936dca7 --- /dev/null +++ b/tests/nn/test_groupnorm32.py @@ -0,0 +1,34 @@ +"""Tests for GroupNorm32 module.""" + +import pytest +from mrpro.nn import GroupNorm +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('channels', 'groups', 'input_shape'), + [ + (32, None, (1, 32, 32, 32)), + (64, 8, (2, 64, 16, 16, 16)), + ], +) +def test_groupnorm32(channels, groups, input_shape, device): + """Test GroupNorm32 output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + norm = GroupNorm(channels=channels, groups=groups).to(device) + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.isnan().any(), 'NaN values in input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert norm.weight.grad is not None, 'No gradient computed for weight' + assert norm.bias.grad is not None, 'No gradient computed for bias' diff --git a/tests/nn/test_resblock.py b/tests/nn/test_resblock.py new file mode 100644 index 000000000..6df1fce7f --- /dev/null +++ b/tests/nn/test_resblock.py @@ -0,0 +1,40 @@ +"""Tests for ResBlock module.""" + +import pytest +from mrpro.nn import ResBlock +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'channels_in', 'channels_out', 'cond_dim', 'input_shape', 'cond_shape'), + [ + (2, 32, 32, 16, (1, 32, 32, 32), (1, 16)), + (3, 64, 32, 0, (2, 64, 16, 16, 16), None), + ], +) +def test_resblock(dim, channels_in, channels_out, cond_dim, input_shape, cond_shape, device): + """Test ResBlock output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) if cond_shape else None + res = ResBlock(dim=dim, channels_in=channels_in, channels_out=channels_out, cond_dim=cond_dim).to(device) + output = res(x, cond=cond) + assert output.shape == (input_shape[0], channels_out, *input_shape[2:]), ( + f'Output shape {output.shape} != expected {(input_shape[0], channels_out, *input_shape[2:])}' + ) + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.isnan().any(), 'NaN values in input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert res.block[2].weight.grad is not None, 'No gradient computed for first Conv' + if cond is not None: + assert cond.grad is not None, 'No gradient computed for conditioning' + assert not cond.isnan().any(), 'NaN values in conditioning' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' diff --git a/tests/nn/test_sequential.py b/tests/nn/test_sequential.py new file mode 100644 index 000000000..83e585498 --- /dev/null +++ b/tests/nn/test_sequential.py @@ -0,0 +1,41 @@ +"""Tests for Sequential module.""" + +import pytest +from mrpro.nn import FiLM, Sequential +from mrpro.operators import FastFourierOp +from mrpro.utils import RandomGenerator +from torch.nn import Linear + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('input_shape', 'cond_dim'), + [ + ((1, 32), (1, 16)), + ((2, 64), None), + ], +) +def test_sequential(input_shape, cond_dim, device): + """Test Sequential output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_dim).to(device).requires_grad_(True) if cond_dim else None + seq = Sequential( + Linear(input_shape[1], 64), + FastFourierOp(), + FiLM(channels=64, cond_dim=16), + ).to(device) + output = seq(x, cond=cond) + assert output.shape == (input_shape[0], 32), f'Output shape {output.shape} != expected {(input_shape[0], 32)}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.isnan().any(), 'NaN values in input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert seq[0].weight.grad is not None, 'No gradient computed for first Linear' + assert seq[2].weight.grad is not None, 'No gradient computed for second Linear' diff --git a/tests/nn/test_shiftedwindowattention.py b/tests/nn/test_shiftedwindowattention.py new file mode 100644 index 000000000..7ea8a4175 --- /dev/null +++ b/tests/nn/test_shiftedwindowattention.py @@ -0,0 +1,37 @@ +import pytest +from mrpro.nn import ShiftedWindowAttention +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'window_size', 'shifted'), + [ + (2, 8, False), + (4, 4, True), + ], +) +def test_shifted_window_attentio(dim: int, window_size: int, shifted: bool, device: str) -> None: + batch = 2 + channels = 8 + n_heads = 2 + spatial_shape = (window_size * 4,) * dim + rng = RandomGenerator(13) + x = rng.float32_tensor((batch, channels, *spatial_shape)).to(device).requires_grad_(True) + swin = ShiftedWindowAttention( + dim=dim, channels_in=channels, channels_out=channels, n_heads=n_heads, window_size=window_size, shifted=shifted + ).to(device) + out = swin(x) + assert out.shape == x.shape, f'Output shape {out.shape} != input shape {x.shape}' + assert not out.isnan().any(), 'NaN values in output' + out.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert swin.to_qkv.weight.grad is not None, 'No gradient computed for to_qkv.weight' + assert swin.relative_position_bias_table.grad is not None, 'No gradient computed for relative_position_bias_table' diff --git a/tests/nn/test_sqeezeexcitation.py b/tests/nn/test_sqeezeexcitation.py new file mode 100644 index 000000000..8929b9868 --- /dev/null +++ b/tests/nn/test_sqeezeexcitation.py @@ -0,0 +1,26 @@ +"""Tests for SqueezeExcitation module.""" + +import pytest +from mrpro.nn import SqueezeExcitation +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + ('dim', 'input_shape', 'squeeze_channels'), + [ + (2, (1, 64, 32, 32), 16), + (3, (1, 64, 16, 16, 16), 16), + ], +) +def test_squeeze_excitation(dim, input_shape, squeeze_channels): + """Test SqueezeExcitation output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).requires_grad_(True) + se = SqueezeExcitation(dim=dim, input_channels=input_shape[1], squeeze_channels=squeeze_channels) + output = se(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.isnan().any(), 'NaN values in input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert se.scale[1].weight.grad is not None, 'No gradient computed for Conv' diff --git a/tests/nn/test_transposedattention.py b/tests/nn/test_transposedattention.py new file mode 100644 index 000000000..ea39781b3 --- /dev/null +++ b/tests/nn/test_transposedattention.py @@ -0,0 +1,36 @@ +"""Tests for TransposedAttention module.""" + +import pytest +from mrpro.nn import TransposedAttention +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'channels', 'num_heads', 'input_shape'), + [ + (2, 32, 4, (1, 32, 32, 32)), + (3, 64, 8, (2, 64, 16, 16, 16)), + ], +) +def test_transposed_attention(dim, channels, num_heads, input_shape, device): + """Test TransposedAttention output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + attn = TransposedAttention(dim=dim, channels_in=channels, channels_out=channels, n_heads=num_heads).to(device) + output = attn(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.isnan().any(), 'NaN values in input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert attn.to_qkv.weight.grad is not None, 'No gradient computed for qkv' + assert attn.qkv_dwconv.weight.grad is not None, 'No gradient computed for qkv_dwconv' + assert attn.to_out.weight.grad is not None, 'No gradient computed for project_out' + assert attn.temperature.grad is not None, 'No gradient computed for temperature' diff --git a/tests/utils/test_ema.py b/tests/utils/test_ema.py new file mode 100644 index 000000000..3f9ea5ca2 --- /dev/null +++ b/tests/utils/test_ema.py @@ -0,0 +1,89 @@ +"""Tests for EMADict.""" + +from typing import Any + +import pytest +import torch +from mrpro.utils import RandomGenerator +from mrpro.utils.ema import EMADict + + +@pytest.mark.parametrize( + ('key', 'value'), + [ + ('float', 1.0), + ('complex', 1.0 + 1.0j), + ('tensor', torch.ones(2, 3)), + ], +) +def test_ema_dict_numerical( + key: str, + value: Any, +) -> None: + """Test that EMA calculation is numerically correct.""" + decay = 0.8 + ema = EMADict(decay=decay) + + ema[key] = value + new_value = RandomGenerator(seed=42).float32() * value + ema.update({key: new_value}) + + expected = decay * value + (1 - decay) * new_value + if isinstance(value, torch.Tensor): + torch.testing.assert_close(ema[key], expected) + else: + assert ema[key] == pytest.approx(expected) + + +def test_ema_dict_invalid_decay() -> None: + """Test EMADict with invalid decay values.""" + with pytest.raises(ValueError, match='Decay must be between 0 and 1'): + EMADict(decay=-0.1) + with pytest.raises(ValueError, match='Decay must be between 0 and 1'): + EMADict(decay=1.1) + + +def test_ema_dict_update() -> None: + """Test EMADict update method.""" + rng = RandomGenerator(seed=42) + ema = EMADict(decay=0.9) + + new_dict: dict[str, Any] = { + 'float': rng.float32(), + 'complex': rng.complex64(), + 'tensor': rng.float32_tensor((2, 3)), + 'string': 'test', + } + ema.update(new_dict) + + for key, value in new_dict.items(): + assert key in ema + if isinstance(value, torch.Tensor): + torch.testing.assert_close(ema[key], value) + else: + assert ema[key] == value + + +def test_ema_dict_deletion() -> None: + """Test EMADict deletion.""" + rng = RandomGenerator(seed=42) + ema = EMADict(decay=0.9) + + ema['test'] = rng.float32() + assert 'test' in ema + + del ema['test'] + assert 'test' not in ema + + with pytest.raises(KeyError): + del ema['nonexistent'] + + +def test_ema_dict_tensor_detach() -> None: + """Test that tensors are detached from autograd graph.""" + rng = RandomGenerator(seed=42) + ema = EMADict(decay=0.9) + + tensor = rng.float32_tensor((2, 3)).requires_grad_(True) + ema['test'] = tensor + assert not ema['test'].requires_grad From d6cb116a0baafe8e45a2ad4693de3c2d132b5c91 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 11 Jul 2025 13:05:31 +0200 Subject: [PATCH 095/205] Squashed commit of the following: commit c15cb10ca7d0ba07eb72d8a49ea0ef202f125850 Merge: 1be7d321 73b40335 Author: Felix Zimmermann Date: Fri Jul 11 12:09:04 2025 +0200 Merge branch 'main' into pinqi commit 1be7d321a6c2f3e6767cf754068ad10ebf607393 Author: Felix Zimmermann Date: Thu Jul 10 17:38:49 2025 +0200 update commit 08cbecab54ae87ab90b32e3635ac96936688dd4e Merge: 3b17a711 b5a7e1c3 Author: Felix Zimmermann Date: Wed Jul 9 00:45:15 2025 +0200 Merge branch 'main' into pinqi commit 3b17a711ca58a19cf67d2942915ed107cc48db0d Author: Felix Zimmermann Date: Wed Jul 9 00:42:57 2025 +0200 apply pinqi commit 3a91b5d4a10b19681908e12dde0ef2e9a3c68ffd Author: Felix Zimmermann Date: Tue Jul 8 18:14:46 2025 +0200 fix test commit 7cd0d7f8473506b93b06400b51e789e48531251d Author: Felix Zimmermann Date: Fri Jul 4 17:25:08 2025 +0200 modl commit 6b46b3aa8b948879a074ee5d3749dea10617be5c Author: Felix Zimmermann Date: Thu Jul 3 21:38:25 2025 +0200 fix cg commit d8bb3051051cf7dde2009bc683900f29efaf0fa5 Author: Felix Zimmermann Date: Fri Jul 4 17:15:27 2025 +0200 fix ssim commit 816f3a308fac63b4124c6c494752eaffef51fd31 Author: Felix Zimmermann Date: Fri Jul 4 17:14:57 2025 +0200 pinqi commit b0a85171ffa2ee72e35ee999184ecc96787a2fb2 Author: Felix Zimmermann Date: Thu Jul 3 13:12:11 2025 +0200 modl commit 09b70c55a568f07c46be7ab95cfb3f0276fb59e7 Author: Felix Zimmermann Date: Thu Jul 3 13:12:00 2025 +0200 inati: no nans commit 777443f908a6eb99f17c6f58fa8687e73322fab4 Author: Felix Zimmermann Date: Thu Jul 3 13:11:47 2025 +0200 nn commit cc34beb2860eef339c71d047436d19e15222fc72 Author: Felix Zimmermann Date: Thu Jul 3 13:11:25 2025 +0200 fix dataclass error commit 5b40ad1d32d67c6442750bd7607d624edbaee4c9 Author: Felix Zimmermann Date: Thu Jul 3 13:10:40 2025 +0200 fastmri: fix padding undo commit 78a6322ef11d7b62254a869e6916e52ab9edba13 Author: Felix Zimmermann Date: Thu Jul 3 13:10:04 2025 +0200 oberator subtractino commit dfa72826b946665d6c8c51e7a9227205e628921c Author: Felix Zimmermann Date: Wed Jul 2 15:34:16 2025 +0200 modl commit da8aef1b109e989a1f16beb0bba99d230197f3d3 Author: Felix Zimmermann Date: Tue Jul 1 16:01:15 2025 +0200 update pinqi commit 4b9508decde5892dbff999f97947b21449e372f4 Author: Felix Zimmermann Date: Tue Jul 1 14:59:44 2025 +0200 train pinqi commit a66f0dbae7ff76fa117b5e4c85ab6bfc93bdaba6 Author: Felix Zimmermann Date: Tue Jul 1 14:59:25 2025 +0200 update dataclass commit 096038e220c95ed4d57ddd49c3da08edc5a51abd Author: Felix Zimmermann Date: Tue Jul 1 14:59:03 2025 +0200 update nn commit c8340650547f45536968d82ba9047537b0f2302c Author: Felix Zimmermann Date: Tue Jul 1 14:46:12 2025 +0200 train_pinqi commit d0d333e5621a015cfd8f4afec456b972eb9530a6 Author: Felix Zimmermann Date: Tue Jul 1 14:45:53 2025 +0200 fix nn commit 1322b515d9f13025ee5be3852743f057edb6ff77 Author: Felix Zimmermann Date: Tue Jul 1 14:45:16 2025 +0200 fix csmdata init typing commit 7d9dd14f13585d9949b602e49c109b0a98217d84 Author: Felix Zimmermann Date: Tue Jul 1 14:44:54 2025 +0200 fix sat rec commit eadd3a6f86c1eccb74636094550390e717fb1743 Author: Felix Zimmermann Date: Tue Jul 1 14:44:39 2025 +0200 fix brainweb commit 6861ca905258f5bb42549b39feb8bff065ba2cc4 Author: Felix Zimmermann Date: Tue Jul 1 14:43:12 2025 +0200 change tol commit a242e55fe3a5a28cf2ff19ae3e0e7ab66f96cdf5 Author: Felix Zimmermann Date: Sat Jun 28 17:33:48 2025 +0200 add commit 41e421668f93de1dcab60e105429f3afea96b541 Author: Felix Zimmermann Date: Sat Jun 28 17:30:48 2025 +0200 wip commit 6d859543b077689fa66f7c528c1459bbda3d65b2 Merge: 71850ea4 8c9943f1 Author: Felix Zimmermann Date: Fri Jun 27 22:33:38 2025 +0200 Merge branch 'diff' into pinqi commit 71850ea46a00943a9da698f1adbcebb2597eed0d Merge: 8209d117 725c01e4 Author: Felix Zimmermann Date: Fri Jun 27 19:04:36 2025 +0200 Merge branch 'main' into pinqi commit 8209d117b5d21ae921fe4e276264a76132208009 Author: Felix Zimmermann Date: Fri Jun 27 19:02:44 2025 +0200 fix commit aaa68e97317797944ba353b8db1a0bab6a46f649 Author: Felix Zimmermann Date: Mon Jun 23 22:37:26 2025 +0200 separable commit fb6eb419a6d42c2443ddd8b62ee778a2d76af81a Author: Felix Zimmermann Date: Sun Jun 22 12:51:23 2025 +0200 update commit 8c9943f1707c4453bc00bb82aedb21f41a548770 Author: Felix Zimmermann Date: Fri Jun 13 13:26:46 2025 +0200 docstring commit 8d8667a836184238d67ded764c591d5caa281b82 Merge: 630f146a 57497631 Author: Felix F Zimmermann Date: Tue Jun 10 19:15:39 2025 +0200 Merge branch 'main' into diff commit 630f146af2995ed05587ecd6e784df7dc54c9e96 Merge: ca88e3a3 0cb824d9 Author: Felix F Zimmermann Date: Tue Jun 10 18:25:59 2025 +0200 Merge branch 'main' into diff commit ca88e3a32865aaa8f30955dc64035d114c41c334 Author: Felix F Zimmermann Date: Tue Jun 10 18:24:15 2025 +0200 review commit edb3b0fd30c644276acd0609ae7e68cb31ef194f Author: Felix Zimmermann Date: Tue Jun 10 16:09:13 2025 +0200 review commit 44c8ac05433721127e631fc49019ac7a48779d38 Author: Felix F Zimmermann Date: Tue Jun 10 15:32:59 2025 +0200 Apply suggestions from code review Co-authored-by: Andreas Kofler commit 07f9f7afb7dfba8d7f4cb2168e8edf1b7c6aad73 Author: Felix F Zimmermann Date: Tue Jun 10 15:11:51 2025 +0200 Update src/mrpro/operators/ConjugateGradientOp.py Co-authored-by: Andreas Kofler commit 10f19949ec3ec9b5f2619e3015260dc687e28900 Author: Felix Zimmermann Date: Wed Jun 4 02:14:06 2025 +0200 Refactor import statements and enhance EMA documentation - Updated import statements in join.py and Uformer.py to ensure consistent usage of the CondMixin class. - Removed unnecessary import of CondMixin in SpatialTransformerBlock.py for cleaner code. - Added a docstring to ema.py to clarify the purpose of the Exponential Moving Average (EMA) dictionary. commit c7e588ec223cca3ebf2f7aec3ac148ac498b6b17 Author: Felix Zimmermann Date: Wed Jun 4 02:13:22 2025 +0200 Refactor AttentionGate and DropPath modules for improved functionality - Updated AttentionGate to ensure consistent interpolation of gate tensors regardless of shape. - Modified DropPath to conditionally scale the mask based on the keep probability, enhancing flexibility in dropout behavior. - Enhanced _fix_shapes function in join.py to support new interpolation modes (linear and nearest) for better tensor shape handling. - Improved documentation in Concat class to reflect new interpolation options and ensure clarity in parameter descriptions. commit e20d6f7edaa50ef6de13877fea649eac940cf777 Author: Felix Zimmermann Date: Tue Jun 3 23:31:04 2025 +0200 wip commit 4c1ec0feaa67ca540f08126c35e756b88dfb0dba Author: Felix Zimmermann Date: Tue Jun 3 17:20:25 2025 +0200 Refactor parameter documentation in encoding and normalization modules - Updated parameter documentation in FourierFeatures, AbsolutePositionEncoding, GEGLU, and LayerNorm classes to remove type hints from docstrings for consistency. - Enhanced clarity in parameter descriptions while maintaining the overall structure of the documentation. commit ea3109ef8caff466e4d4cb680cc65f992e2489f8 Author: Felix Zimmermann Date: Mon Jun 2 22:56:20 2025 +0200 update commit 6b942ca1dc336193a61f04ffe4619eae434292d9 Author: Felix Zimmermann Date: Mon Jun 2 17:22:00 2025 +0200 wip commit 59569dc72c6c567fcdef479f4ce721c1a33c39d4 Author: Felix Zimmermann Date: Mon Jun 2 16:25:15 2025 +0200 - Updated AttentionGate to include a new 'concatenate' parameter, allowing for optional concatenation of gated and gate signals in the channel dimension. - Adjusted ResBlock to modify the rezero parameter for better stability during training. - Refactored forward methods in both AttentionGate and ResBlock to ensure compatibility with the new features and maintain clarity in tensor operations. - Updated import statements and class references in UNet and related modules to reflect the new AttentionGatedUNet class. commit dca726b02843484b9f6dcc5bad22c78b4d6fcc41 Author: Felix Zimmermann Date: Mon Jun 2 02:25:23 2025 +0200 Enhance UNet architecture with improved attention handling and modularity - Updated UNet class to include configurable attention depths and encoder blocks per scale, enhancing flexibility in model design. - Introduced new attention block functionality and refined block initialization for better clarity and maintainability. - Adjusted forward method to accept conditioning tensors explicitly, improving modularity in the encoder-decoder structure. - Integrated GroupNorm and SiLU into the final layers for improved performance and consistency. commit b1ff7f84991d2dfbcd1baf0c4eb9314464dfb7fe Author: Felix Zimmermann Date: Mon Jun 2 02:25:10 2025 +0200 Refactor neural network modules to standardize feature dimension handling - Updated parameter names from 'channel_last' to 'features_last' across multiple modules for consistency. - Adjusted related logic in GEGLU, LayerNorm, LinearSelfAttention, NeighborhoodSelfAttention, RMSNorm, and BasicTransformerBlock to reflect the new parameter naming. - Enhanced clarity in the handling of feature dimensions, improving modularity and maintainability of the codebase. commit 376be37273f3e3369f89871d164dc4c8cc820b2b Author: Felix Zimmermann Date: Mon Jun 2 02:24:51 2025 +0200 Add Upsample module for tensor resizing functionality - Introduced the Upsample class to facilitate tensor upsampling with configurable scale factors and modes (nearest, linear). - Implemented the forward method to compute new tensor sizes based on the specified dimensions and scale factor, enhancing flexibility in tensor manipulation. commit 24bfbc9306db963ca80c7e6b0ea931dc4e1fed04 Author: Felix Zimmermann Date: Fri May 23 01:46:07 2025 +0200 Refactor ZeroPadOp and pad_or_crop utility for improved functionality and clarity - Updated import statement in ZeroPadOp to directly import pad_or_crop function. - Enhanced pad_or_crop function to include a new 'mode' parameter for padding options, improving flexibility in data manipulation. commit afd7a4573c26567da47d61f92fba77e32ca657c9 Author: Felix Zimmermann Date: Fri May 23 01:45:39 2025 +0200 Refactor FiLM and Uformer modules for improved clarity and functionality - Simplified the FiLM class by removing unnecessary Sequential and Identity layers. - Updated the Uformer architecture to conditionally include FiLM layers based on the provided conditioning dimension. - Enhanced the forward method of LeWinTransformerBlock to accept conditioning tensors, improving modularity and flexibility. commit eafbfc62a6432066e01c4cd8dcb03fbcecf7a7ab Author: Felix Zimmermann Date: Fri May 23 01:45:24 2025 +0200 Add SpatialTransformerBlock and integrate into UNet architecture - Introduced SpatialTransformerBlock for enhanced attention mechanisms. - Updated MultiHeadAttention to support cross-attention channels. - Modified UNet to include SpatialTransformerBlock in encoder and decoder stages based on specified attention depths. - Improved modularity and flexibility of the UNet structure. commit 4c497345b111d92e4b239391abafb712a8212f01 Author: Felix Zimmermann Date: Thu May 22 16:48:14 2025 +0200 Refactor Restormer and Uformer networks to utilize UNetEncoder and UNetDecoder classes for improved modularity and clarity. Update import statements and streamline block initialization for better readability and maintainability. commit f3aaa6ab527e765a0f89081ef27aad2d3fe890f5 Author: Felix Zimmermann Date: Thu May 22 10:00:41 2025 +0200 Refactor EfficientViTBlock and Encoder/Decoder stages to use dynamic head counts based on width; improve sequential structure in PixelUnshuffleDownsample. commit 3757b1fa3a883c5413f1f94c24c6068bbb30335b Author: Felix Zimmermann Date: Thu May 22 02:11:50 2025 +0200 Add EMADict class for Exponential Moving Average functionality and update imports - Introduced EMADict class to maintain exponential moving averages for various data types. - Updated __all__ lists in utils and nn modules to include new EMADict class. - Added tests for EMADict to ensure correct functionality and error handling. commit 7d608b5de0efaf063a36b137ac112106146f02cd Author: Felix Zimmermann Date: Thu May 22 00:04:31 2025 +0200 Refactor method signatures in neural network modules to use keyword-only arguments for conditioning tensors. Update import statements for clarity and add new SwinIR network to the module. Enhance VAE with a mode method for improved functionality. commit 4f6a603a53e05323be4219c049c84a479e139f6d Author: Felix Zimmermann Date: Wed May 21 18:19:46 2025 +0200 Enhance conversion functions between Linear and Conv layers by refining method signatures and improving test structure. Update parameter names for clarity and ensure compatibility with multiple input tensors. commit 01881fef9076635036b932e2a94c01e90145cbb9 Author: Felix Zimmermann Date: Wed May 21 16:04:02 2025 +0200 Refactor imports to use lowercase 'ndmodules' and update method signatures for better compatibility with multiple input tensors. Introduce conversion functions between Linear and Conv layers, along with corresponding tests. commit d626bbbdd420c3b0a59c1deef62444afcf784170 Author: Felix Zimmermann Date: Wed May 21 01:26:18 2025 +0200 update commit 52c8630c89ce48b82120ab4a5f22597389cb5778 Author: Felix Zimmermann Date: Tue May 20 22:43:25 2025 +0200 update commit 3d259bbc3122f3d312f9a0a628c994e452820d81 Author: Felix Zimmermann Date: Tue May 20 02:09:41 2025 +0200 update commit 7e9d12183fae6d953967b3409a64e729848274bf Author: Felix Zimmermann Date: Mon May 19 21:52:03 2025 +0200 update commit 62e04a1f4685464bda7a35f2d15ed35e7a1fc2ff Author: Felix Zimmermann Date: Mon May 19 17:20:20 2025 +0200 update commit 33d95572ecb7e46cf5049d31a26ed5e4d6d69885 Author: Felix Zimmermann Date: Mon May 19 14:41:48 2025 +0200 update commit 97115e77559e36e4adf02c09cba5b96cc53c4100 Author: Felix Zimmermann Date: Mon May 19 02:23:57 2025 +0200 update commit 7f37fa99ba4392cd89a52cca08b5e18ff0fe50a3 Author: Felix Zimmermann Date: Mon May 19 01:52:18 2025 +0200 update commit 3d0122093a1acf71bbeb167005c90398d2bd7eca Author: Felix Zimmermann Date: Sun May 18 16:30:43 2025 +0200 update commit b6a1db3c809c8f2338e7891ab897406da9a0a08f Author: Felix Zimmermann Date: Fri May 16 14:02:00 2025 +0200 doc commit 912d7c8ecfa0b1a54d9e1c53f819b0dc550067c2 Author: Felix Zimmermann Date: Fri May 16 00:21:58 2025 +0200 update commit 3ae37d1ceaf5b8f1ff22e98517b0873efb04e07f Author: Felix Zimmermann Date: Thu May 15 02:33:40 2025 +0200 update commit 54a66b61321c58dd95c6056030ac7e32986608f6 Author: Felix Zimmermann Date: Wed May 14 17:17:38 2025 +0200 fix commit cf4be7f27b769f1a6070bf0169037ad996a452e9 Author: Felix Zimmermann Date: Wed May 14 17:14:33 2025 +0200 uformer commit 9cfae55b33c34de10f606019ca3e0d03fc560ee0 Author: Felix Zimmermann Date: Wed May 14 02:22:29 2025 +0200 update commit 420cdc1b29a6604a91eec2b3578d6b5b6447ad83 Author: Felix Zimmermann Date: Wed May 14 01:18:30 2025 +0200 update commit 633682b1959c07dec272277ade9ea50c5b8898f5 Author: Felix Zimmermann Date: Wed May 14 00:53:02 2025 +0200 update commit c39a9afcb1849005e635c16c74344296acc511d5 Author: Felix Zimmermann Date: Tue May 13 22:27:48 2025 +0200 update commit 26467bf5714f6ff57883c714fed9e26583698aeb Author: Felix Zimmermann Date: Tue May 13 21:36:23 2025 +0200 update commit 7e83be7af248fcc288d69b435de48db1e8ad703c Author: Felix Zimmermann Date: Mon May 12 23:03:54 2025 +0200 update commit a458855117baf3b21bdc0738230ecf60f7256bb2 Author: Felix Zimmermann Date: Sat May 10 21:09:00 2025 +0200 start commit 904f3c941e308fad00dbb0ec857a7c4b23b9060c Author: Felix Zimmermann Date: Sat May 10 21:05:34 2025 +0200 fix doc commit 637cdf0846c91833b1dff2c3daa74f70704afc61 Merge: 78b570d7 e7ee8959 Author: Felix F Zimmermann Date: Wed Apr 30 01:33:21 2025 +0200 Merge branch 'main' into diff commit 78b570d7f72f28bae89eef3ecfb61212d6ef24c1 Author: Felix Zimmermann Date: Wed Apr 30 01:32:40 2025 +0200 fix rhs norm zero commit 227646af9f0e7209d4a92226805a612569b68eba Author: Felix Zimmermann Date: Tue Apr 29 23:58:40 2025 +0200 norm commit c8f51eccd7324cebb2c6c49a61fc11f585e93923 Author: Felix Zimmermann Date: Tue Apr 29 23:43:45 2025 +0200 review commit c9e71602b00c6ada47add4b523b9a27f000cabd1 Merge: 6454662d 29c74bd9 Author: Felix F Zimmermann Date: Tue Apr 29 21:44:56 2025 +0200 Merge branch 'main' into diff commit 6454662dabde169b863e69be4a5a9a9b2ec0c419 Author: Felix Zimmermann Date: Tue Apr 29 21:16:21 2025 +0200 pyr310 commit 3118792823a045408fdf9e426be9965b2c20b951 Author: Felix Zimmermann Date: Tue Apr 29 11:09:41 2025 +0200 py310 commit 41b5aec6c32cc756ca51b969e232384b09687972 Author: Felix Zimmermann Date: Tue Apr 29 09:46:45 2025 +0200 py310 commit 19087c94fc99c9c82ce5f2446e841d2e48e1d5e6 Author: Felix Zimmermann Date: Tue Apr 29 01:21:58 2025 +0200 py310 commit 97f102b2364e20f3c633b1265a160500e18ee4ad Author: Felix Zimmermann Date: Mon Apr 28 23:01:56 2025 +0200 py310 commit d9c024a8b2ebaa4524cdbddee107145e08791a06 Author: Felix Zimmermann Date: Mon Apr 28 22:38:22 2025 +0200 fix commit c3ffa9c211c87bf01d8103dd8259d0486133a234 Author: Felix Zimmermann Date: Mon Apr 28 00:40:50 2025 +0200 fix miniml commit 391979ee6b286bad0c6ca65e54d87b89ab4b48de Author: Felix Zimmermann Date: Mon Apr 28 00:38:31 2025 +0200 cleanup commit c89a25a9f9b7f0bb0648944eb43cd72c50c5be7f Author: Felix Zimmermann Date: Wed Apr 23 14:03:01 2025 +0200 fix merge commit dc6429d7d3704a3382288593dfb0493af4e4c32d Author: Felix Zimmermann Date: Wed Apr 23 13:56:54 2025 +0200 update commit 5d6fe988d33e5c7195f016197ed6d6b712f26757 Author: Felix Zimmermann Date: Wed Apr 23 13:56:22 2025 +0200 update commit a336ded279b85e958f46d8eeb348abaa107777c3 Author: Felix Zimmermann Date: Wed Apr 16 17:04:35 2025 +0200 test commit aaff916b9fb640c23c14462ae40458250a2ee8ce Author: Felix Zimmermann Date: Mon Apr 7 12:51:09 2025 +0200 types commit 1c31feda29736b3e5b4f05cb30a1441e2590efe0 Author: Felix Zimmermann Date: Sun Apr 6 16:37:00 2025 +0200 first draft commit 2be3c4f188ce03f57e075d42d7641768aeedc9de Author: Felix Zimmermann Date: Mon Apr 7 12:51:42 2025 +0200 cg matrix op --- examples/scripts/apply_pinqi.py | 381 +++++++++++ examples/scripts/modl.py | 152 +++++ examples/scripts/pinqi.py.bak | 375 +++++++++++ examples/scripts/train_pinqi.py | 679 ++++++++++++++++++++ src/mrpro/algorithms/csm/inati.py | 13 +- src/mrpro/nn/LayerNorm.py | 10 +- src/mrpro/nn/PermutedBlock.py | 58 ++ src/mrpro/nn/SeparableResBlock.py | 170 +++++ src/mrpro/nn/Sequential.py | 2 +- src/mrpro/nn/SpatialTransformerBlock.py | 1 + src/mrpro/nn/Upsample.py | 14 +- src/mrpro/nn/__init__.py | 7 + src/mrpro/nn/nets/BasicCNN.py | 65 ++ src/mrpro/nn/nets/UNet.py | 220 +------ src/mrpro/nn/nets/__init__.py | 6 +- src/mrpro/operators/LinearOperator.py | 8 +- src/mrpro/operators/LinearOperatorMatrix.py | 6 +- src/mrpro/operators/Operator.py | 48 +- src/mrpro/phantoms/__init__.py | 1 + src/mrpro/utils/pad_or_crop.py | 2 +- 20 files changed, 2004 insertions(+), 214 deletions(-) create mode 100644 examples/scripts/apply_pinqi.py create mode 100644 examples/scripts/modl.py create mode 100644 examples/scripts/pinqi.py.bak create mode 100644 examples/scripts/train_pinqi.py create mode 100644 src/mrpro/nn/PermutedBlock.py create mode 100644 src/mrpro/nn/SeparableResBlock.py create mode 100644 src/mrpro/nn/nets/BasicCNN.py diff --git a/examples/scripts/apply_pinqi.py b/examples/scripts/apply_pinqi.py new file mode 100644 index 000000000..daad3137c --- /dev/null +++ b/examples/scripts/apply_pinqi.py @@ -0,0 +1,381 @@ +# %% +from collections.abc import Sequence +from copy import deepcopy +from pathlib import Path +from typing import Literal, TypedDict + +import einops +import mrpro +import torch + +# mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True) + + +class BatchType(TypedDict): + """Typehint for a batch of data.""" + + kdata: mrpro.data.KData + csm: mrpro.data.CsmData + m0: torch.Tensor + t1: torch.Tensor + mask: torch.Tensor + + +class Dataset(torch.utils.data.Dataset[BatchType]): + """A brainweb based cartesian qMRI dataset.""" + + def __init__( + self, + folder: Path, + signalmodel: mrpro.operators.SignalModel, + n_images: int, + size: int, + acceleration: int, + n_coils: int, + max_noise: float, + orientation: Sequence[Literal['axial', 'coronal', 'sagittal']], + random: bool = True, + ): + """Initialize the dataset.""" + if random: + augment = mrpro.phantoms.brainweb.augment(size=size) + else: + augment = mrpro.phantoms.brainweb.augment( + size=size, + max_random_shear=0, + max_random_rotation=0, + max_random_scaling_factor=0, + p_horizontal_flip=0, + p_vertical_flip=1.0, + ) + self.phantom = mrpro.phantoms.brainweb.BrainwebSlices( + folder=folder, + what=('m0', 't1', 'mask'), + seed='index' if not random else 'random', + slice_preparation=augment, + orientation=orientation, + ) + self.signalmodel = deepcopy(signalmodel) + self.encoding_matrix = mrpro.data.SpatialDimension(1, size, size) + self.fov = mrpro.data.SpatialDimension(0.01, 0.25, 0.25) + self.acceleration = acceleration + self.n_coils = n_coils + self._random = random + self.max_noise = max_noise + self._n_images = n_images + + def __len__(self) -> int: + """Get the length of the dataset.""" + return len(self.phantom) + + def __getitem__(self, index: int): + """Get an item from the dataset.""" + phantom = self.phantom[index] + (images,) = self.signalmodel(phantom['m0'], phantom['t1']) + seed = int(torch.randint(0, 1000000, (1,))) if self._random else index + + traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density( + encoding_matrix=self.encoding_matrix, + seed=seed, + acceleration=self.acceleration, + fwhm_ratio=1.5, + n_center=10, + n_other=(self._n_images,), + ) + header = mrpro.data.KHeader( + encoding_matrix=self.encoding_matrix, + recon_matrix=self.encoding_matrix, + recon_fov=self.fov, + encoding_fov=self.fov, + ) + + if isinstance(self.signalmodel, mrpro.operators.models.SaturationRecovery): + header.ti = self.signalmodel.saturation_time.tolist() + elif isinstance(self.signalmodel, mrpro.operators.models.InversionRecovery): + header.ti = self.signalmodel.ti.tolist() + + fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj) + csm = mrpro.data.CsmData( + mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix), + header, + ) + images = einops.rearrange(images, 't y x -> t 1 1 y x') + (data,) = (fourier_op @ csm.as_operator())(images) + data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std() + kdata = mrpro.data.KData(header, data, traj) + return {'kdata': kdata, 'csm': csm, **phantom} + + +class PINQI(torch.nn.Module): + """PINQI model.""" + + def __init__( + self, + signalmodel: mrpro.operators.SignalModel, + constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp, + parameter_is_complex: Sequence[bool], + n_images: int, + n_iterations: int, + n_features_parameter_net: Sequence[int], + n_features_image_net: Sequence[int], + ): + """Initialize the PINQI model.""" + super().__init__() + self.signalmodel = mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ signalmodel @ constraints_op + self.constraints_op = constraints_op + self._n_images = n_images + self._parameter_is_complex = parameter_is_complex + real_parameters = sum(1 for c in parameter_is_complex if c) + len(parameter_is_complex) + self.parameter_net = mrpro.nn.nets.UNet( + dim=2, + channels_in=n_images * 2, + channels_out=real_parameters, + attention_depths=(-1, -2), + n_features=n_features_parameter_net, + cond_dim=128, + ) + + self.image_net = mrpro.nn.nets.UNet( + 2, channels_in=2, channels_out=2, attention_depths=(), n_features=n_features_image_net, cond_dim=128 + ) + self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3)) + self.softplus = torch.nn.Softplus(beta=5) + self.iteration_embedding = torch.nn.Embedding(n_iterations + 1, 128) + + def objective_factory( + lambda_parameters: torch.Tensor, + image: torch.Tensor, + *parameter_reg: torch.Tensor, + ): + dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel + reg = mrpro.operators.ProximableFunctionalSeparableSum( + *[mrpro.operators.functionals.L2NormSquared(r) for r in parameter_reg] + ) + return dc + lambda_parameters * reg + + self.nonlinear_solver = mrpro.operators.OptimizerOp( + objective_factory, + lambda _l, _i, *parameter_reg: parameter_reg, + ) + + def get_linear_solver(self, gram: mrpro.operators.LinearOperator): + def operator_factory( + lambda_image: torch.Tensor, + lambda_q: torch.Tensor, + *_, + ): + return gram + lambda_image + lambda_q + + def rhs_factory( + lambda_image: torch.Tensor, + lambda_q: torch.Tensor, + image_reg: torch.Tensor, + signal: torch.Tensor, + zero_filled_image: torch.Tensor, + ): + return (zero_filled_image + lambda_image * image_reg + lambda_q * signal,) + + return mrpro.operators.ConjugateGradientOp( + operator_factory=operator_factory, + rhs_factory=rhs_factory, + ) + + def get_parameter_reg(self, image: torch.Tensor, iteration: int = 0) -> tuple[torch.Tensor, ...]: + image = einops.rearrange( + torch.view_as_real(image), + 'batch t 1 1 y x complex-> batch (t complex) y x', + ) + cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None] + parameters = self.parameter_net(image.contiguous(), cond=cond) + parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x') + i = 0 + result = [] + for is_complex in self._parameter_is_complex: + if is_complex: + result.append(torch.complex(parameters[i], parameters[i + 1])) + i += 2 + else: + result.append(parameters[i]) + i += 1 + return tuple(result) + + def get_image_reg(self, image: torch.Tensor, iteration: int = 0) -> torch.Tensor: + batch = image.shape[0] + image = einops.rearrange( + torch.view_as_real(image), + 'batch t 1 1 y x complex-> (batch t) complex y x', + ) + cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None] + image = image + self.image_net(image.contiguous(), cond=cond) + image = einops.rearrange(image, '(batch t) complex y x-> batch t 1 1 y x complex', batch=batch) + return torch.view_as_complex(image.contiguous()) + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): + csm_op = csm.as_operator() + fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) + acquisition_op = fourier_op @ csm_op + gram = acquisition_op.gram + (zero_filled_image,) = acquisition_op.H(kdata.data) + images = list(mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2)) + parameters = [self.get_parameter_reg(images[-1], 0)] + linear_solver = self.get_linear_solver(gram) + + for i, (lambda_image, lambda_q, lambda_parameter) in enumerate(self.softplus(self.lambdas_raw)): + image_reg = self.get_image_reg(images[-1], i + 1) + (signal,) = self.signalmodel(*parameters[-1]) + images.extend(linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image)) + parameters_reg = self.get_parameter_reg(images[-1], i + 1) + parameters.append(self.nonlinear_solver(lambda_parameter, images[-1], *parameters_reg)) + if self.constraints_op is not None: + parameters = [self.constraints_op(*p) for p in parameters] + return images, parameters + + +# %% +# As a baseline methods for comparision, we use a simple non-learned approach. We reconstruct the qualitative images at different saturation times using iterative SENSE. +# We then perform a constrained non-linear least squares regression usingL-BFGS to obtain the parameter maps. +# %% +def baseline_solution( + signalmodel: mrpro.operators.SignalModel, + constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp, + parameter_is_complex: Sequence[bool], + kdata: mrpro.data.KData, + csm: mrpro.data.CsmData, +) -> tuple[torch.Tensor, ...]: + """Compute a baseline solution using SENSE + Regression.""" + sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(kdata, csm=csm) + images = sense(kdata) + objective = mrpro.operators.functionals.L2NormSquared(images.data) @ signalmodel @ constraints_op + initial_values = tuple( + torch.zeros(images.shape[1:], device=images.device, dtype=torch.complex64 if is_complex else torch.float32) + for is_complex in parameter_is_complex + ) + solution = constraints_op(*mrpro.algorithms.optimizers.lbfgs(objective, initial_values)) + return solution + + +# %% +data_folder = Path('/home/zimmer08/.cache/mrpro/brainweb') + +signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 8.0)) +constraints_op = mrpro.operators.ConstraintsOp( + bounds=( + (-2, 2), # M0 in [-2, 2] + (0.01, 6.0), # T1 is constrained between 10 ms and 6 s + ) +) +n_images = len(signalmodel.saturation_time) +parameter_is_complex = [True, False] + + +dataset = torch.utils.data.Subset( + Dataset( + folder=data_folder, + signalmodel=signalmodel, + n_images=n_images, + size=192, + acceleration=8, + n_coils=8, + max_noise=0.05, + orientation=('axial',), + random=False, + ), + list(range(500)), +) +# %% +checkpoint = torch.load('last.ckpt', map_location='cpu') +hyper_parameters = checkpoint['hyper_parameters'] + + +pinqi = PINQI( + signalmodel=signalmodel, + constraints_op=constraints_op, + parameter_is_complex=parameter_is_complex, + n_images=n_images, + n_iterations=hyper_parameters['n_iterations'], + n_features_parameter_net=hyper_parameters['n_features_parameter_net'], + n_features_image_net=hyper_parameters['n_features_image_net'], +) +state_dict = { + k.replace('pinqi.', '').replace('_orig_mod.', ''): v + for k, v in checkpoint['state_dict'].items() + if 'baseline' not in k +} +pinqi.load_state_dict(state_dict) +# %% +batch = dataset[40] +csm, kdata = batch['csm'], batch['kdata'] + +if torch.cuda.is_available(): + pinqi, csm, kdata = pinqi.cuda(), csm.cuda(), kdata.cuda() +images, parameters = pinqi(kdata[None], csm[None]) +with torch.no_grad(): + predicted_m0, predicted_t1 = (p.cpu().detach().squeeze() for p in parameters[-1]) +baseline_m0, baseline_t1 = baseline_solution(signalmodel, constraints_op, parameter_is_complex, kdata, csm) +# %% +(ssim_t1,) = mrpro.operators.functionals.SSIM(batch['t1'][None], batch['mask'][None])(predicted_t1[None]) +(mse_t1,) = mrpro.operators.functionals.MSE(batch['t1'], batch['mask'])(predicted_t1) + +(mse_baseline,) = mrpro.operators.functionals.MSE(batch['t1'], batch['mask'])(baseline_t1) +nrmse_t1 = torch.sqrt(mse_t1) / batch['t1'][batch['mask']].max() +(ssim_baseline,) = mrpro.operators.functionals.SSIM(batch['t1'][None], batch['mask'][None])(baseline_t1[None]) +nrmse_baseline = torch.sqrt(mse_baseline) / batch['t1'][batch['mask']].max() + + +# %% +import matplotlib.pyplot as plt +from cmap import Colormap + +cmap = Colormap('lipari').to_matplotlib() + +print(f'SSIM: {ssim_baseline.item():.4f}, NRMSE: {nrmse_baseline.item():.4f}') +print(f'SSIM: {ssim_t1.item():.4f}, NRMSE: {nrmse_t1.item():.4f}') + + +fig, ax = plt.subplots(1, 5, gridspec_kw={'width_ratios': [1, 1, 1, 0.01, 0.075], 'wspace': 0.0}, figsize=(5, 2)) +baseline_t1 = baseline_t1.squeeze() +baseline_t1[~batch['mask']] = torch.nan +ax[0].imshow(baseline_t1, vmin=0, vmax=2, cmap=cmap) +ax[0].axis('off') +ax[0].set_title('SENSE + Regression') +ax[0].text( + 0.5, + -0.00, + f'SSIM: {ssim_baseline.item():.2f}', + color='black', + horizontalalignment='center', + verticalalignment='top', + transform=ax[0].transAxes, +) +predicted_t1 = predicted_t1.squeeze() +predicted_t1[~batch['mask']] = torch.nan +ax[1].imshow(predicted_t1, vmin=0, vmax=2, cmap=cmap) +ax[1].axis('off') +ax[1].set_title('PINQI') +ax[1].text( + 0.5, + -0.0, + f'SSIM: {ssim_t1.item():.2f}', + color='black', + horizontalalignment='center', + verticalalignment='top', + transform=ax[1].transAxes, + size=10, +) + +target_t1 = batch['t1'].squeeze() +target_t1[~batch['mask']] = torch.nan +im = ax[2].imshow(target_t1, vmin=0, vmax=2, cmap=cmap) +ax[2].axis('off') +ax[2].set_title('Ground Truth') +fig.tight_layout() +ax[-2].axis('off') +plt.colorbar(im, cax=ax[-1], label='$T_1$ (s)') +fig.savefig('/home/zimmer08/code/mrpro/examples/scripts/pinqi_t1_2.pdf', bbox_inches='tight') + + +# %% + + +# %% +# %% diff --git a/examples/scripts/modl.py b/examples/scripts/modl.py new file mode 100644 index 000000000..5039d233a --- /dev/null +++ b/examples/scripts/modl.py @@ -0,0 +1,152 @@ +# %% +# %matplotlib inline +from collections.abc import Sequence +from pathlib import Path +from typing import TypedDict + +import matplotlib.axes +import matplotlib.pyplot as plt +import mrpro +import torch +from tqdm import tqdm + + +class BatchType(TypedDict): + data: mrpro.data.KData + target: mrpro.data.IData + csm: mrpro.data.CsmData + + +class AcceleratedFastMRI(torch.utils.data.Dataset): + def __init__(self, path: Path, acceleration: float = 12, noise_level: float = 0.1): + self.acceleration = acceleration + files = list(path.glob('*AXT1*')) + self.dataset = mrpro.phantoms.FastMRIKDataDataset(files) + self.noise_level = noise_level + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index: int) -> BatchType: + data = self.dataset[index] + data = data.remove_readout_os() + data.data /= data.data.std() + reconstruction = mrpro.algorithms.reconstruction.DirectReconstruction( + data, csm=lambda data: mrpro.data.CsmData.from_idata_inati(data, downsampled_size=64) + ) + csm = reconstruction.csm + target = reconstruction(data) + + n = max(data.data.shape[-2:]) + distance = (torch.linspace(-1, 1, n)[:, None] ** 2 + torch.linspace(-1, 1, n) ** 2).sqrt() + random = 0.1 / (distance + 0.1) + torch.rand_like(distance) + threshold = torch.kthvalue(random.ravel(), int(n**2 * (1 - 1 / self.acceleration))).values + undersampling_mask = mrpro.utils.pad_or_crop(random > threshold, data.data.shape[-2:]) + data_undersampled = data[..., undersampling_mask].rearrange('k ... 1 -> ... k') + + noise = mrpro.utils.RandomGenerator(seed=index).randn_like(data_undersampled.data) + data_undersampled.data += self.noise_level * noise + + assert csm is not None # for mypy + return {'data': data_undersampled, 'target': target, 'csm': csm} + + +class MODL(torch.nn.Module): + def __init__(self, iterations: int = 8, n_features: Sequence[int] = (64, 64, 64, 64)): + super().__init__() + cnn = mrpro.nn.nets.BasicCNN( + dim=2, + channels_in=2, + channels_out=2, + n_features=n_features, + batch_norm=True, + ) + self.network = mrpro.nn.Residual(mrpro.nn.ComplexAsChannel(mrpro.nn.PermutedBlock((-1, -2), cnn))) + self.network = torch.compile(self.network, dynamic=True, fullgraph=True) + self.iterations = iterations + self.regularization_weights = torch.nn.Parameter(0.2 * torch.ones(iterations)) + + def __call__(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData: + return super().__call__(kdata, csm) + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData: + fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) + acquisition_op = fourier_op @ csm.as_operator() + (zero_filled_image,) = acquisition_op.H(kdata.data) + gram = acquisition_op.gram + data_consistency_op = mrpro.operators.ConjugateGradientOp( + operator_factory=lambda _image, weight: gram + weight, + rhs_factory=lambda image, weight: zero_filled_image + weight * image, + ) + + (image,) = mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=5) + for iteration in range(self.iterations): + regularization = self.network(image) + (image,) = data_consistency_op(regularization, self.regularization_weights[iteration]) + + return mrpro.data.IData(image, header=mrpro.data.IHeader.from_kheader(kdata.header)) + + +def plot(batch: BatchType, prediction: mrpro.data.IData, step: int): + """Plot the direct, sense, and modl reconstructions.""" + target = batch['target'].rss().cpu().squeeze() + direct = mrpro.algorithms.reconstruction.DirectReconstruction(batch['data'], csm=batch['csm'])(batch['data']) + direct = direct.rss().cpu().squeeze() + direct *= target.std() / direct.std() + sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(batch['data'], csm=batch['csm'])(batch['data']) + sense = sense.rss().cpu().squeeze() + prediction_ = prediction.rss().cpu().squeeze().detach() + + ssim = mrpro.operators.functionals.SSIM(mrpro.utils.pad_or_crop(target[None], (320, 320))) + + def show(ax: matplotlib.axes.Axes, data: torch.Tensor, label: str): + data = mrpro.utils.pad_or_crop(data, (320, 320)) + ax.imshow(data, vmin=0, vmax=target.max().item(), cmap='gray') + if label != 'Ground Truth': + (ssim_value,) = ssim(data[None]) + ax.text( + 0.98, + 0.1, + f'SSIM: {ssim_value.item():.2f}', + color='white', + horizontalalignment='right', + verticalalignment='top', + transform=ax.transAxes, + ) + ax.set_title(label) + ax.set_axis_off() + + fig, ax = plt.subplots(1, 4) + show(ax[0], direct, 'Direct') + show(ax[1], sense, 'CG-SENSE') + show(ax[2], prediction_, 'MODL') + show(ax[3], target, 'Ground Truth') + fig.tight_layout() + fig.savefig(f'modl_{step}.pdf', bbox_inches='tight', pad_inches=0) + + +# %%. +path = Path('/echo/allgemein/resources/publicTrainingData/fastmri/brain_multicoil_train/') +dataset = AcceleratedFastMRI(path) +dataloader = torch.utils.data.DataLoader(dataset, num_workers=16, shuffle=True, collate_fn=lambda batch: batch[0]) +modl = MODL().cuda() +optimizer = torch.optim.Adam(modl.parameters(), lr=1e-3) +pbar = tqdm(dataloader) +for i, batch in enumerate(pbar): + optimizer.zero_grad() + kdata, csm, target = (batch['data'].cuda(), batch['csm'].cuda(), batch['target'].cuda()) + prediction = modl(kdata, csm) + objective = 0.5 * mrpro.operators.functionals.MSE(target.data) - mrpro.operators.functionals.SSIM(target.data) + (loss,) = objective(prediction.data) + loss.backward() + torch.nn.utils.clip_grad_norm_(modl.parameters(), 5.0) + optimizer.step() + + pbar.set_postfix(loss=loss.item()) + if i % 200 == 0: + plot(batch, prediction, i) + print(modl.regularization_weights) + state = {'modl': modl.state_dict(), 'optimizer': optimizer.state_dict()} + torch.save(state, f'modl_{i}.pt') + +# %% diff --git a/examples/scripts/pinqi.py.bak b/examples/scripts/pinqi.py.bak new file mode 100644 index 000000000..c80c0e035 --- /dev/null +++ b/examples/scripts/pinqi.py.bak @@ -0,0 +1,375 @@ +# %% +import einops +import matplotlib.pyplot as plt +import mrpro +import torch + +# %matplotlib inline + +# %% +# mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True) +# %% + + +class Dataset(torch.utils.data.Dataset): + def __init__(self, size=192, acceleration=10, n_coils=8, random=True, max_noise=0.1): + self.phantom = mrpro.phantoms.brainweb.BrainwebSlices( + what=('m0', 't1', 'mask'), + seed='index' if not random else 'random', + slice_preparation=mrpro.phantoms.brainweb.augment(size=size), + ) + self.signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2, 6)) + self.constraints_op = mrpro.operators.ConstraintsOp( + bounds=( + (-1, 1), # M0 in [-1, 1] + (0.001, 4.0), # T1 is constrained between 1 ms and 4 s + ) + ) + + self.encoding_matrix = mrpro.data.SpatialDimension(1, size, size) + self.fov = mrpro.data.SpatialDimension(0.01, 0.25, 0.25) + self.acceleration = acceleration + self.n_coils = n_coils + self._random = random + self.max_noise = max_noise + + @property + def n_images(self): + return 5 + + @property + def complex_parameters(self): + return [True, False] + + @property + def n_parameters(self): + return len(self.complex_parameters) + + def __len__(self): + return len(self.phantom) + + def __getitem__(self, index): + phantom = self.phantom[index] + (images,) = self.signalmodel(phantom['m0'], phantom['t1']) + seed = torch.randint(0, 1000000, (1,)).item() if self._random else index + traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density( + encoding_matrix=self.encoding_matrix, + seed=seed, + acceleration=self.acceleration, + fwhm_ratio=2, + n_center=8, + n_other=(self.n_images,), + ) + header = mrpro.data.KHeader( + encoding_matrix=self.encoding_matrix, + recon_matrix=self.encoding_matrix, + recon_fov=self.fov, + encoding_fov=self.fov, + ) + header.ti = self.signalmodel.saturation_time.tolist() + fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj) + csm = mrpro.data.CsmData(mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix), header) + images = einops.rearrange(images, 't y x -> t 1 1 y x') + (data,) = (fourier_op @ csm.as_operator())(images) + data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std() + kdata = mrpro.data.KData(header, data, traj) + return {'kdata': kdata, 'csm': csm, **phantom} + + @staticmethod + def collate_fn(batch): + return torch.utils.data._utils.collate.collate( + batch, + collate_fn_map={ + mrpro.data.Dataclass: lambda batch, *, collate_fn_map: batch[0].stack(*batch[1:]), + **torch.utils.data._utils.collate.default_collate_fn_map, + }, + ) + + +# %% +ds = Dataset() +dl = torch.utils.data.DataLoader( + ds, + batch_size=8, + collate_fn=ds.collate_fn, + num_workers=16, + worker_init_fn=lambda *_: torch.set_num_threads(1), + shuffle=True, +) + +# %% +from copy import deepcopy + + +class PINQI(torch.nn.Module): + def __init__(self, signalmodel, parameter_is_complex, n_images, n_iterations=4, constraints_op=None): + super().__init__() + self.signalmodel = mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ deepcopy(signalmodel) + if constraints_op is not None: + self.signalmodel = self.signalmodel @ constraints_op + self.constraints_op = constraints_op + self._n_images = n_images + self._parameter_is_complex = parameter_is_complex + real_parameters = sum(parameter_is_complex) + len(parameter_is_complex) + self.parameter_net = torch.compile( + mrpro.nn.nets.UNet( + dim=2, + channels_in=n_images * 2, + channels_out=real_parameters, + attention_depths=(-1,), + n_features=(64, 128, 192, 256), + ), + dynamic=False, + fullgraph=True, + ) + self.image_net = torch.compile( + mrpro.nn.nets.UNet( + 2, + channels_in=2, + channels_out=2, + attention_depths=(), + n_features=(16, 32, 48, 64), + ), + dynamic=False, + fullgraph=True, + ) + self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3)) + self.softplus = torch.nn.Softplus() + + def objective_factory(lambda_parameters, image, *parameter_reg): + dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel + reg = mrpro.operators.ProximableFunctionalSeparableSum( + *[mrpro.operators.functionals.L2NormSquared(r) for r in parameter_reg] + ) + return dc + lambda_parameters * reg + + self.nonlinear_solver = mrpro.operators.OptimizerOp( + objective_factory, lambda _l, _i, *parameter_reg: parameter_reg + ) + + def get_linear_solver(self, gram): + def operator_factory(lambda_image, lambda_q, _image_reg, _signal, _zero_filled_image): + return gram + lambda_image + lambda_q + + def rhs_factory(lambda_image, lambda_q, image_reg, signal, zero_filled_image): + return (zero_filled_image + lambda_image * image_reg + lambda_q * signal,) + + return mrpro.operators.ConjugateGradientOp( + operator_factory=operator_factory, + rhs_factory=rhs_factory, + ) + + def get_parameter_reg(self, image: torch.Tensor) -> tuple[torch.Tensor, ...]: + image = einops.rearrange(torch.view_as_real(image), 'batch t 1 1 y x complex-> batch (t complex) y x') + parameters = self.parameter_net(image.contiguous()) + parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x') + i = 0 + result = [] + for is_complex in self._parameter_is_complex: + if is_complex: + result.append(torch.complex(parameters[i], parameters[i + 1])) + i += 2 + else: + result.append(parameters[i]) + i += 1 + return tuple(result) + + def get_image_reg(self, image): + batch = image.shape[0] + image = einops.rearrange(torch.view_as_real(image), 'batch t 1 1 y x complex-> (batch t) complex y x') + image = image + self.image_net(image.contiguous()) + image = einops.rearrange(image, '(batch t) complex y x-> batch t 1 1 y x complex', batch=batch) + return torch.view_as_complex(image.contiguous()) + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): + csm_op = csm.as_operator() + fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) + acquisition_op = fourier_op @ csm_op + gram = acquisition_op.gram + (zero_filled_image,) = acquisition_op.H(kdata.data) + images = list(mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2)) + parameters = [self.get_parameter_reg(images[-1])] + linear_solver = self.get_linear_solver(gram) + + for lambda_image, lambda_q, lambda_parameter in self.softplus(self.lambdas_raw): + # subproblem 1 + image_reg = self.get_image_reg(images[-1]) + (signal,) = self.signalmodel(*parameters[-1]) + images.extend(linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image)) + # subproblem 2 + parameters_reg = self.get_parameter_reg(images[-1]) + parameters.append(self.nonlinear_solver(lambda_parameter, images[-1], *parameters_reg)) + if self.constraints_op is not None: + parameters = [self.constraints_op(*p) for p in parameters] + return images, parameters + + +# %% +from typing import TypeVar + +T = TypeVar('T') + + +def to_device(batch: T, device: torch.device | str) -> T: + """Moves tensors and Mrpro data to the specified device recursively.""" + if isinstance(batch, torch.Tensor | mrpro.data.Dataclass): + return batch.to(device) + if isinstance(batch, dict): + return {k: to_device(v, device) for k, v in batch.items()} + if isinstance(batch, list): + return [to_device(v, device) for v in batch] + if isinstance(batch, tuple): + return tuple(to_device(v, device) for v in batch) + + return batch + + +# %% +# from tqdm import tqdm + +# pinqi = PINQI(ds.signalmodel, ds.n_parameters, ds.n_images, constraints_op=ds.constraints_op).cuda() +# for epoch in tqdm(range(10)): +# pbar = tqdm(dl, leave=False) +# optim = torch.optim.Adam(pinqi.parameters(), lr=1e-4) +# for batch in pbar: +# batch = to_device(batch, 'cuda') +# images, parameters = pinqi(batch['kdata'], batch['csm']) +# prediction_m0, prediction_t1 = parameters[-1] +# loss_t1 = torch.nn.functional.mse_loss(prediction_t1.squeeze()[batch['mask']], batch['t1'][batch['mask']]) + +# loss_m0 = torch.nn.functional.mse_loss( +# torch.view_as_real((prediction_m0 + 0j).squeeze()[batch['mask']]), +# torch.view_as_real(batch['m0'][batch['mask']]), +# ) + +# loss = loss_t1 + loss_m0 +# pbar.set_postfix(loss=loss.item()) +# loss.backward() +# optim.step() +# optim.zero_grad() + + +# %% +import numpy as np +import torch +from IPython.display import clear_output, display +from tqdm.notebook import tqdm + + +def plot_results( + fig: plt.Figure, + axes: np.ndarray, + losses: list[float], + target_t1: torch.Tensor, + pred_t1: torch.Tensor, + mask: torch.Tensor, +) -> None: + """ + Updates and displays the training plot. + + Parameters + ---------- + fig + The matplotlib figure object. + axes + The array of matplotlib axes objects. + losses + losses for each step + target_t1 + The ground truth T1 map from the last batch. + pred_t1 + The predicted T1 map from the last batch. + mask + The mask from the last batch. + """ + clear_output(wait=True) + + axes[0].clear() + axes[0].semilogy(losses) + axes[0].set_title('Loss') + axes[0].set_xlabel('Step') + axes[0].set_ylabel('Loss') + axes[0].grid(True) + + target_t1_viz = target_t1[1].squeeze().cpu().numpy() + pred_t1_viz = pred_t1[1].squeeze().detach().cpu().numpy() + mask_viz = mask[1].squeeze().detach().cpu().numpy() + target_t1_viz[~mask_viz] = np.nan + pred_t1_viz[~mask_viz] = np.nan + difference = target_t1_viz - pred_t1_viz + vmax = np.nanmax(target_t1_viz) + + axes[1].clear() + axes[1].imshow(target_t1_viz, vmin=0, vmax=vmax) + axes[1].set_title('Target T1') + axes[1].axis('off') + + axes[2].clear() + axes[2].imshow(pred_t1_viz, vmin=0, vmax=vmax) + axes[2].set_title(f'Predicted T1 (Epoch {epoch + 1})') + axes[2].axis('off') + + axes[3].clear() + axes[3].imshow(difference, cmap='coolwarm') + axes[3].set_title('Difference') + axes[3].axis('off') + + fig.tight_layout() + display(fig) + + +# %% +def calculate_loss(predictions, batch, weights=(0.2, 0.1, 0.1, 0.1, 0.5)) -> torch.Tensor: + loss = torch.tensor(0.0) + target_m0 = batch['m0'] + target_t1 = batch['t1'] + mask = batch['mask'] + for prediction, weight in zip(predictions, weights, strict=False): + prediction_m0, prediction_t1 = prediction + loss_t1 = torch.nn.functional.mse_loss(prediction_t1.squeeze()[mask], target_t1[mask]) + loss_m0 = torch.nn.functional.mse_loss( + torch.view_as_real((prediction_m0).squeeze()[mask]), + torch.view_as_real(target_m0[mask]), + ) + loss = loss + weight * (loss_t1 + loss_m0) + return loss + + +# %% +torch.set_float32_matmul_precision('high') +torch._inductor.config.worker_start_method = 'fork' +torch._inductor.config.compile_threads = 4 +torch._dynamo.config.capture_scalar_outputs = True +torch._functorch.config.activation_memory_budget = 0.9 + +pinqi = PINQI(ds.signalmodel, ds.complex_parameters, ds.n_images, constraints_op=ds.constraints_op).to('cuda') +optim = torch.optim.AdamW(pinqi.parameters(), lr=3e-4, weight_decay=1e-4) +torch._dynamo.config.cache_size_limit = 256 +n_epochs = 10 +losses = [] +fig, axes = plt.subplots(1, 4, figsize=(20, 5)) + +for epoch in range(n_epochs): + epoch_losses = [] + pbar = tqdm(dl, desc=f'Epoch {epoch + 1}/{n_epochs}', leave=False) + + for batch in pbar: + batch = to_device(batch, 'cuda') + optim.zero_grad() + + images, parameters = pinqi(batch['kdata'], batch['csm']) + loss = calculate_loss(parameters, batch) + + loss.backward() + optim.step() + epoch_losses.append(loss.item()) + pbar.set_postfix(epoch_loss=f'{np.mean(epoch_losses):.3f}', loss=f'{epoch_losses[-1]:.3f}') + + losses.extend(epoch_losses) + prediction_t1 = parameters[-1][1] + plot_results(fig, axes, losses, batch['t1'], prediction_t1, batch['mask']) + +plt.close(fig) + + +# %% diff --git a/examples/scripts/train_pinqi.py b/examples/scripts/train_pinqi.py new file mode 100644 index 000000000..bbb88cdfe --- /dev/null +++ b/examples/scripts/train_pinqi.py @@ -0,0 +1,679 @@ +# ruff: noqa: D102, ANN201 + +import collections +from collections.abc import Sequence +from copy import deepcopy +from pathlib import Path +from typing import Any, Literal, TypedDict, cast + +import einops +import matplotlib.pyplot as plt +import mrpro +import numpy as np +import pytorch_lightning as pl +import torch +import torch.utils.data._utils +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.loggers import NeptuneLogger +from pytorch_lightning.strategies import DDPStrategy + +# mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True) + + +class BatchType(TypedDict): + """Typehint for a batch of data.""" + + kdata: mrpro.data.KData + csm: mrpro.data.CsmData + m0: torch.Tensor + t1: torch.Tensor + mask: torch.Tensor + + +class Dataset(torch.utils.data.Dataset): + """A brainweb based cartesian qMRI dataset.""" + + def __init__( + self, + folder: Path, + signalmodel: mrpro.operators.SignalModel, + n_images: int, + size: int, + acceleration: int, + n_coils: int, + max_noise: float, + orientation: Sequence[Literal['axial', 'coronal', 'sagittal']], + random: bool = True, + ): + """Initialize the dataset.""" + if random: + augment = mrpro.phantoms.brainweb.augment(size=size) + else: + augment = mrpro.phantoms.brainweb.augment( + size=size, + max_random_shear=0, + max_random_rotation=0, + max_random_scaling_factor=0, + p_horizontal_flip=0, + p_vertical_flip=1.0, + ) + self.phantom = mrpro.phantoms.brainweb.BrainwebSlices( + folder=folder, + what=('m0', 't1', 'mask'), + seed='index' if not random else 'random', + slice_preparation=augment, + orientation=orientation, + ) + self.signalmodel = signalmodel + self.encoding_matrix = mrpro.data.SpatialDimension(1, size, size) + self.fov = mrpro.data.SpatialDimension(0.01, 0.25, 0.25) + self.acceleration = acceleration + self.n_coils = n_coils + self._random = random + self.max_noise = max_noise + self._n_images = n_images + + def __len__(self) -> int: + """Get the length of the dataset.""" + return len(self.phantom) + + def __getitem__(self, index: int): + """Get an item from the dataset.""" + phantom = self.phantom[index] + (images,) = self.signalmodel(phantom['m0'], phantom['t1']) + seed = int(torch.randint(0, 1000000, (1,))) if self._random else index + + traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density( + encoding_matrix=self.encoding_matrix, + seed=seed, + acceleration=self.acceleration, + fwhm_ratio=1.5, + n_center=10, + n_other=(self._n_images,), + ) + header = mrpro.data.KHeader( + encoding_matrix=self.encoding_matrix, + recon_matrix=self.encoding_matrix, + recon_fov=self.fov, + encoding_fov=self.fov, + ) + + if isinstance(self.signalmodel, mrpro.operators.models.SaturationRecovery): + header.ti = self.signalmodel.saturation_time.tolist() + elif isinstance(self.signalmodel, mrpro.operators.models.InversionRecovery): + header.ti = self.signalmodel.ti.tolist() + + fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj) + csm = mrpro.data.CsmData( + mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix), + header, + ) + images = einops.rearrange(images, 't y x -> t 1 1 y x') + (data,) = (fourier_op @ csm.as_operator())(images) + data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std() + kdata = mrpro.data.KData(header, data, traj) + return {'kdata': kdata, 'csm': csm, **phantom} + + +def collate_fn(batch: Any): # noqa: ANN401 + """Join dataclasses to a batch.""" + return torch.utils.data._utils.collate.collate( + batch, + collate_fn_map={ + mrpro.data.Dataclass: lambda batch, *, collate_fn_map: batch[0].stack(*batch[1:]), # noqa: ARG005 + **torch.utils.data._utils.collate.default_collate_fn_map, + }, + ) + + +class PINQI(torch.nn.Module): + """PINQI model.""" + + def __init__( + self, + signalmodel: mrpro.operators.SignalModel, + constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp, + parameter_is_complex: Sequence[bool], + n_images: int, + n_iterations: int, + n_features_parameter_net: Sequence[int], + n_features_image_net: Sequence[int], + ): + """Initialize the PINQI model.""" + super().__init__() + self.signalmodel = mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ signalmodel @ constraints_op + self.constraints_op = constraints_op + self._n_images = n_images + self._parameter_is_complex = parameter_is_complex + real_parameters = sum(1 for c in parameter_is_complex if c) + len(parameter_is_complex) + self.parameter_net = torch.compile( + mrpro.nn.nets.UNet( + dim=2, + channels_in=n_images * 2, + channels_out=real_parameters, + attention_depths=(-1, -2), + n_features=n_features_parameter_net, + cond_dim=128, + ), + dynamic=False, + fullgraph=True, + ) + self.image_net = torch.compile( + mrpro.nn.nets.UNet( + 2, channels_in=2, channels_out=2, attention_depths=(), n_features=n_features_image_net, cond_dim=128 + ), + dynamic=False, + fullgraph=True, + ) + self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3)) + self.softplus = torch.nn.Softplus(beta=5) + self.iteration_embedding = torch.nn.Embedding(n_iterations + 1, 128) + + def objective_factory( + lambda_parameters: torch.Tensor, + image: torch.Tensor, + *parameter_reg: torch.Tensor, + ): + dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel + reg = mrpro.operators.ProximableFunctionalSeparableSum( + *[mrpro.operators.functionals.L2NormSquared(r) for r in parameter_reg] + ) + return dc + lambda_parameters * reg + + self.nonlinear_solver = mrpro.operators.OptimizerOp( + objective_factory, + lambda _l, _i, *parameter_reg: parameter_reg, + ) + + def get_linear_solver(self, gram: mrpro.operators.LinearOperator): + def operator_factory( + lambda_image: torch.Tensor, + lambda_q: torch.Tensor, + *_, + ): + return gram + lambda_image + lambda_q + + def rhs_factory( + lambda_image: torch.Tensor, + lambda_q: torch.Tensor, + image_reg: torch.Tensor, + signal: torch.Tensor, + zero_filled_image: torch.Tensor, + ): + return (zero_filled_image + lambda_image * image_reg + lambda_q * signal,) + + return mrpro.operators.ConjugateGradientOp( + operator_factory=operator_factory, + rhs_factory=rhs_factory, + ) + + def get_parameter_reg(self, image: torch.Tensor, iteration: int = 0) -> tuple[torch.Tensor, ...]: + image = einops.rearrange( + torch.view_as_real(image), + 'batch t 1 1 y x complex-> batch (t complex) y x', + ) + cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None] + parameters = self.parameter_net(image.contiguous(), cond=cond) + parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x') + i = 0 + result = [] + for is_complex in self._parameter_is_complex: + if is_complex: + result.append(torch.complex(parameters[i], parameters[i + 1])) + i += 2 + else: + result.append(parameters[i]) + i += 1 + return tuple(result) + + def get_image_reg(self, image: torch.Tensor, iteration: int = 0) -> torch.Tensor: + batch = image.shape[0] + image = einops.rearrange( + torch.view_as_real(image), + 'batch t 1 1 y x complex-> (batch t) complex y x', + ) + cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None] + image = image + self.image_net(image.contiguous(), cond=cond) + image = einops.rearrange(image, '(batch t) complex y x-> batch t 1 1 y x complex', batch=batch) + return torch.view_as_complex(image.contiguous()) + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): + csm_op = csm.as_operator() + fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) + acquisition_op = fourier_op @ csm_op + gram = acquisition_op.gram + (zero_filled_image,) = acquisition_op.H(kdata.data) + images = list(mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2)) + parameters = [self.get_parameter_reg(images[-1], 0)] + linear_solver = self.get_linear_solver(gram) + + for i, (lambda_image, lambda_q, lambda_parameter) in enumerate(self.softplus(self.lambdas_raw)): + image_reg = self.get_image_reg(images[-1], i + 1) + (signal,) = self.signalmodel(*parameters[-1]) + images.extend(linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image)) + parameters_reg = self.get_parameter_reg(images[-1], i + 1) + parameters.append(self.nonlinear_solver(lambda_parameter, images[-1], *parameters_reg)) + if self.constraints_op is not None: + parameters = [self.constraints_op(*p) for p in parameters] + return images, parameters + + +class DataModule(pl.LightningDataModule): + """Data module for training the PINQI model.""" + + def __init__( + self, + folder: Path, + signalmodel: mrpro.operators.SignalModel, + n_images: int, + size: int = 192, + acceleration: int = 10, + n_coils: int = 8, + max_noise: float = 0.1, + orientation_train: Sequence[Literal['axial', 'coronal', 'sagittal']] = ( + 'axial', + 'coronal', + 'sagittal', + ), + orientation_val: Sequence[Literal['axial', 'coronal', 'sagittal']] = ('axial',), + batch_size: int = 16, + num_workers: int = 4, + ): + """Initialize the data module.""" + super().__init__() + self.save_hyperparameters(ignore=['signalmodel', 'folder', 'num_workers']) + self.batch_size = batch_size + self.num_workers = num_workers + self.train_dataset = Dataset( + folder=folder, + signalmodel=signalmodel, + n_images=n_images, + size=size, + acceleration=acceleration, + n_coils=n_coils, + max_noise=max_noise, + orientation=orientation_train, + random=True, + ) + self.val_dataset = torch.utils.data.Subset( + Dataset( + folder=folder, + signalmodel=signalmodel, + n_images=n_images, + size=size, + acceleration=acceleration, + n_coils=n_coils, + max_noise=max_noise, + orientation=orientation_val, + random=False, + ), + list(range(30, 500, 20)), + ) + + def train_dataloader(self): + return torch.utils.data.DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=False, + persistent_workers=self.num_workers > 0, + collate_fn=collate_fn, + worker_init_fn=lambda *_: torch.set_num_threads(1), + ) + + def val_dataloader(self): + return torch.utils.data.DataLoader( + self.val_dataset, + batch_size=1, + shuffle=False, + num_workers=self.num_workers, + pin_memory=False, + persistent_workers=self.num_workers > 0, + collate_fn=collate_fn, + ) + + +class PinqiModule(pl.LightningModule): + """Module for training the PINQI model.""" + + def __init__( + self, + signalmodel: mrpro.operators.SignalModel, + constraints_op: mrpro.operators.ConstraintsOp, + parameter_is_complex: Sequence[bool], + n_images: int, + n_iterations: int = 4, + n_features_parameter_net: Sequence[int] = (64, 128, 192, 224, 256), + n_features_image_net: Sequence[int] = (16, 32, 48, 64), + lr: float = 4e-4, # noqa: ARG002 + weight_decay: float = 1e-3, # noqa: ARG002 + loss_weights: Sequence[float] = (0.2, 0.1, 0.1, 0.1, 0.8), + ): + """Initialize the PINQI module.""" + super().__init__() + self.save_hyperparameters(ignore=['signalmodel', 'constraints_op']) + if len(loss_weights) != n_iterations + 1: + raise ValueError(f'loss_weights must be of length {n_iterations + 1} for {n_iterations} iterations') + signalmodel, constraints_op = map(deepcopy, (signalmodel, constraints_op)) + self.pinqi = PINQI( + signalmodel=signalmodel, + constraints_op=constraints_op, + parameter_is_complex=parameter_is_complex, + n_images=n_images, + n_iterations=n_iterations, + n_features_parameter_net=n_features_parameter_net, + n_features_image_net=n_features_image_net, + ) + + self.validation_step_outputs = collections.defaultdict(list) + self.baseline = Baseline(signalmodel, constraints_op, parameter_is_complex) + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): + """Apply the PINQI model to the data.""" + return self.pinqi(kdata, csm) + + def loss(self, predictions: Sequence[torch.Tensor], batch: BatchType) -> torch.Tensor: + """Compute the loss.""" + loss = torch.tensor(0.0, device=self.device) + target_m0, target_t1, mask = map(torch.squeeze, (batch['m0'], batch['t1'], batch['mask'])) + for prediction, weight in zip(predictions, self.hparams.loss_weights, strict=False): + prediction_m0, prediction_t1 = map(torch.squeeze, prediction) + loss_t1 = torch.nn.functional.mse_loss(prediction_t1[mask], target_t1[mask]) + loss_m0 = torch.nn.functional.mse_loss( + torch.view_as_real(prediction_m0[mask]), + torch.view_as_real(target_m0[mask]), + ) + loss_outside = prediction_m0[~mask].abs().mean() + loss = loss + weight * (loss_t1 + 0.5 * loss_m0 + 0.1 * loss_outside) + return loss + + def training_step(self, batch: BatchType, _batch_idx: int) -> torch.Tensor: + """Training step.""" + images, parameters = self(batch['kdata'], batch['csm']) + loss = self.loss(parameters, batch) + self.log( + 'train/loss', + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + batch_size=len(batch['mask']), + ) + return loss + + def validation_step(self, batch: BatchType, batch_idx: int) -> None: + """Validate. + + Needs to be adapted for other signal models than Saturation Recovery. + """ + images, parameters = self(batch['kdata'], batch['csm']) + loss = self.loss(parameters, batch) + + pred_m0, pred_t1 = parameters[-1] + target_t1, target_m0 = batch['t1'], batch['m0'] + mask = batch['mask'] + batch_size = len(batch['mask']) + (ssim_t1,) = mrpro.operators.functionals.SSIM(target_t1, mask)(pred_t1) + (l1_t1,) = mrpro.operators.functionals.L1Norm(target_t1, mask)(pred_t1) + (l1_m0,) = mrpro.operators.functionals.L1Norm(target_m0, mask)(pred_m0) + self.log('val/ssim_t1', ssim_t1, on_epoch=True, sync_dist=True, batch_size=batch_size) + self.log('val/l1_t1', l1_t1, on_epoch=True, sync_dist=True, batch_size=batch_size) + self.log('val/l1_m0', l1_m0, on_epoch=True, sync_dist=True, batch_size=batch_size) + self.log('val/loss', loss, on_epoch=True, sync_dist=True, batch_size=batch_size) + + if batch_idx == 0: + self.validation_step_outputs['target_t1'].append(batch['t1']) + self.validation_step_outputs['pred_t1'].append(pred_t1) + self.validation_step_outputs['pred_m0'].append(pred_m0) + self.validation_step_outputs['target_m0'].append(target_m0) + self.validation_step_outputs['mask'].append(batch['mask']) + baseline_m0, baseline_t1 = self.baseline(batch['kdata'], batch['csm']) + self.validation_step_outputs['baseline_t1'].append(baseline_t1) + self.validation_step_outputs['baseline_m0'].append(baseline_m0) + + def on_validation_epoch_end(self): + """Validate. + + Needs to be adapted for other signal models than Saturation Recovery. + """ + outputs = {k: torch.cat(v) for k, v in self.validation_step_outputs.items()} + self.validation_step_outputs.clear() + outputs = cast(dict[str, torch.Tensor], self.all_gather(outputs)) + + if not self.trainer.is_global_zero: + return + outputs = {k: v.flatten(0, 1).cpu() for k, v in outputs.items()} + + samples = len(outputs['mask']) + fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 16)) + + for i in range(samples): + self.result_plot( + outputs['target_t1'][i], + outputs['pred_t1'][i], + outputs['mask'][i], + axes[:, i], + outputs['baseline_t1'][i], + '$T_1$ (s)', + ) + fig.suptitle(f'$T_1$ Epoch {self.current_epoch}') + self.logger.run['val/images/t1'].log(fig) + plt.close(fig) + + fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 12)) + for i in range(samples): + self.result_plot( + outputs['target_m0'][i].abs(), + outputs['pred_m0'][i].abs(), + outputs['mask'][i], + axes[:, i], + outputs['baseline_m0'][i].abs(), + '$|M_0|$ (a.u.)', + ) + fig.suptitle(f'$|M_0|$ Epoch {self.current_epoch}') + self.logger.run['val/images/m0'].log(fig) + plt.close(fig) + + def result_plot( + self, + target: torch.Tensor, + pred: torch.Tensor, + mask: torch.Tensor, + axes: Sequence[plt.Axes], + baseline: torch.Tensor, + label: str, + ) -> None: + """Plot the results.""" + target = target.squeeze().numpy() + pred = pred.squeeze().detach().numpy() + mask = mask.squeeze().detach().numpy().astype(bool) + baseline = baseline.squeeze().detach().numpy() + + target[~mask] = np.nan + pred[~mask] = np.nan + baseline[~mask] = np.nan + difference = (target - pred) / target * 100 + vmax = np.nanmax(target) + + im0 = axes[0].imshow(target, vmin=0, vmax=vmax) + axes[0].set_title('Ground Truth') + axes[0].axis('off') + plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04, label=label) + + im1 = axes[1].imshow(baseline, vmin=0, vmax=vmax) + axes[1].set_title('SENSE + Regression') + axes[1].axis('off') + plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04, label=label) + + im2 = axes[2].imshow(pred, vmin=0, vmax=vmax) + axes[2].set_title('PINQI') + axes[2].axis('off') + plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04, label=label) + + diff_vmax = np.nanpercentile(np.abs(difference), 90) + im3 = axes[3].imshow(difference, cmap='coolwarm', vmin=-diff_vmax, vmax=diff_vmax) + axes[3].set_title('rel. Error') + axes[3].axis('off') + plt.colorbar(im3, ax=axes[3], fraction=0.046, pad=0.04, label='%') + + def configure_optimizers( + self, + ) -> dict: + """Configure the optimizer and the learning rate scheduler.""" + scalars = ('lambdas_raw', 'rezero') + params, scalar_params = [], [] + for n, p in self.named_parameters(): + if not p.requires_grad: + continue + if any(s in n for s in scalars): + scalar_params.append(p) + else: + params.append(p) + optimizer = torch.optim.AdamW( + [ + {'params': params, 'weight_decay': self.hparams.weight_decay, 'lr': self.hparams.lr}, + {'params': scalar_params, 'weight_decay': 0.0, 'lr': self.hparams.lr * 10}, + ], + ) + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=[self.hparams.lr, 10 * self.hparams.lr], + total_steps=self.trainer.estimated_stepping_batches, + pct_start=0.1, + div_factor=30, + final_div_factor=300, + ) + return { + 'optimizer': optimizer, + 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}, + } + + +class Baseline(torch.nn.Module): + """Baseline solution using SENSE + Regression.""" + + def __init__( + self, + signalmodel: mrpro.operators.SignalModel, + constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp, + parameter_is_complex: Sequence[bool], + ): + """Initialize the baseline.""" + super().__init__() + self.signalmodel = signalmodel + self.constraints_op = constraints_op + self.parameter_is_complex = parameter_is_complex + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> tuple[torch.Tensor, ...]: + """Compute the baseline solution.""" + sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(kdata, csm=csm) + images = sense(kdata).rearrange('batch time ...-> time batch ...') + + objective = mrpro.operators.functionals.L2NormSquared(images.data) @ self.signalmodel @ self.constraints_op + initial_values = tuple( + torch.zeros(images.shape[1:], device=images.device, dtype=torch.complex64 if is_complex else torch.float32) + for is_complex in self.parameter_is_complex + ) + solution = self.constraints_op(*mrpro.algorithms.optimizers.lbfgs(objective, initial_values)) + return solution + + +class LogLambdasCallback(pl.Callback): + """Log the lambdas.""" + + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: PinqiModule, + _outputs: dict, + _batch: BatchType, + _batch_idx: int, + ) -> None: + if trainer.global_step % 10 == 0: + lambdas = pl_module.pinqi.softplus(pl_module.pinqi.lambdas_raw).detach().cpu().numpy() + for iteration, (lambda_image, lambda_q, lambda_parameter) in enumerate(lambdas): + self.log_dict( + { + f'parameter/lambda_image_{iteration}': lambda_image, + f'parameter/lambda_q_{iteration}': lambda_q, + f'parameter/lambda_parameter_{iteration}': lambda_parameter, + }, + on_step=True, + on_epoch=False, + ) + + +if __name__ == '__main__': + torch.set_float32_matmul_precision('high') + torch._inductor.config.compile_threads = 4 + torch._inductor.config.worker_start_method = 'fork' + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.cache_size_limit = 256 + torch._functorch.config.activation_memory_budget = 0.95 + + data_folder = Path('/scratch/zimmer08/brainweb') + + signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 8.0)) + constraints_op = mrpro.operators.ConstraintsOp( + bounds=( + (-2, 2), # M0 in [-2, 2] + (0.01, 6.0), # T1 is constrained between 10 ms and 6 s + ) + ) + n_images = len(signalmodel.saturation_time) + parameter_is_complex = [True, False] + + dm = DataModule( + folder=data_folder, + signalmodel=signalmodel, + n_images=n_images, + batch_size=16, + num_workers=16, + size=192, + acceleration=8, + n_coils=8, + max_noise=0.1, + ) + + model = PinqiModule( + signalmodel=signalmodel, + constraints_op=constraints_op, + parameter_is_complex=parameter_is_complex, + n_images=n_images, + ) + + neptune_logger = NeptuneLogger( + log_model_checkpoints=False, + dependencies='infer', + ) + neptune_logger.log_model_summary(model=model, max_depth=-1) + + checkpoint_callback = ModelCheckpoint( + monitor='val/loss', + mode='min', + save_top_k=2, + dirpath=Path('checkpoints') / str(neptune_logger.version), + filename='{epoch:02d}-{val/loss:.4f}', + save_last=True, + ) + + strategy = DDPStrategy(find_unused_parameters=False) + trainer = pl.Trainer( + max_epochs=100, + accelerator='gpu', + devices=4, + strategy=strategy, + logger=neptune_logger, + callbacks=[ + LearningRateMonitor(logging_interval='step'), + checkpoint_callback, + LogLambdasCallback(), + ], + log_every_n_steps=10, + gradient_clip_algorithm='norm', + gradient_clip_val=5.0, + ) + + trainer.fit(model, datamodule=dm) diff --git a/src/mrpro/algorithms/csm/inati.py b/src/mrpro/algorithms/csm/inati.py index f202fb629..6e20a3bba 100644 --- a/src/mrpro/algorithms/csm/inati.py +++ b/src/mrpro/algorithms/csm/inati.py @@ -36,7 +36,9 @@ def inati( if isinstance(smoothing_width, int): smoothing_width = SpatialDimension( - z=smoothing_width if coil_img.shape[-3] > 1 else 1, y=smoothing_width, x=smoothing_width + z=smoothing_width if coil_img.shape[-3] > 1 else 1, + y=smoothing_width, + x=smoothing_width, ) if any(ks % 2 != 1 for ks in [smoothing_width.z, smoothing_width.y, smoothing_width.x]): @@ -45,7 +47,14 @@ def inati( ks_halved = [ks // 2 for ks in smoothing_width.zyx] padded_coil_img = torch.nn.functional.pad( coil_img, - (ks_halved[-1], ks_halved[-1], ks_halved[-2], ks_halved[-2], ks_halved[-3], ks_halved[-3]), + ( + ks_halved[-1], + ks_halved[-1], + ks_halved[-2], + ks_halved[-2], + ks_halved[-3], + ks_halved[-3], + ), mode='replicate', ) # Get the voxels in an ROI defined by the smoothing_width around each voxel leading to shape diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py index 699de57f0..84e1d56e4 100644 --- a/src/mrpro/nn/LayerNorm.py +++ b/src/mrpro/nn/LayerNorm.py @@ -34,26 +34,30 @@ def __init__(self, channels: int | None, features_last: bool = False, cond_dim: self.weight = Parameter(torch.ones(channels)) self.bias = Parameter(torch.zeros(channels)) self.cond_proj = None - else: + elif channels is not None: self.weight = None self.bias = None self.cond_proj = Linear(cond_dim, 2 * channels) + else: + raise ValueError('cond_dim must be zero or positive.') self.features_last = features_last - def __call__(self, x: torch.Tensor) -> torch.Tensor: + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply layer normalization to the input tensor. Parameters ---------- x Input tensor + cond + Conditioning tensor. If `None`, no conditioning is applied. Returns ------- Normalized output tensor """ - return super().__call__(x) + return super().__call__(x, cond=cond) def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply layer normalization to the input tensor.""" diff --git a/src/mrpro/nn/PermutedBlock.py b/src/mrpro/nn/PermutedBlock.py new file mode 100644 index 000000000..99a27f36a --- /dev/null +++ b/src/mrpro/nn/PermutedBlock.py @@ -0,0 +1,58 @@ +"""Block that applies a submodule along selected spatial dimensions.""" + +from collections.abc import Sequence + +import torch +from torch import nn + +from mrpro.nn.CondMixin import CondMixin, call_with_cond + + +class PermutedBlock(CondMixin, nn.Module): + """Apply a submodule along selected spatial dimensions.""" + + apply_along_dim: tuple[int, ...] + module: nn.Module + + def __init__(self, apply_along_dim: Sequence[int], module: nn.Module, features_last: bool = False): + """Initialize the PermutedBlock. + + Parameters + ---------- + apply_along_dim + Spatial dimension indices to use when applying the module. + These will be moved to the last dimensions. + module + Module to apply on the selected dims. + features_last + If True, the features dimension is assumed to be the last dimension, as common in transformer models. + """ + super().__init__() + self.apply_along_dim = tuple(sorted(apply_along_dim)) + self.module = module + self.features_last = features_last + + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module along the selected dimensions.""" + keep = tuple(d % x.ndim for d in self.apply_along_dim) + if 0 in keep: + raise ValueError('Batch dimension should not be in apply_along_dim.') + if self.features_last: + if x.ndim - 1 in keep: + raise ValueError('Features dimension should not be in apply_along_dim.') + batch_dim = tuple(d for d in range(1, x.ndim - 1) if d not in keep) + permute = (0, *batch_dim, *keep, x.ndim - 1) + else: + if 1 in keep: + raise ValueError('Features dimension should not be in apply_along_dim.') + batch_dim = tuple(d for d in range(2, x.ndim) if d not in keep) + permute = (0, *batch_dim, 1, *keep) + h = x.permute(permute) + batch_shape = h.shape[: 1 + len(batch_dim)] + h = h.flatten(0, len(batch_dim)) + h = call_with_cond(self.module, h, cond=cond) + h = h.unflatten(0, batch_shape) + permute_back = [0] * x.ndim + for i, p in enumerate(permute): + permute_back[p] = i + return h.permute(tuple(permute_back)) diff --git a/src/mrpro/nn/SeparableResBlock.py b/src/mrpro/nn/SeparableResBlock.py new file mode 100644 index 000000000..c26a012cf --- /dev/null +++ b/src/mrpro/nn/SeparableResBlock.py @@ -0,0 +1,170 @@ +from collections.abc import Sequence + +import torch +from torch.nn import Module, SiLU + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import ConvND +from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.Sequential import Sequential + + +class SeparableResBlock(Module): + """Residual block with separable convolutions and ReZero.""" + + def __init__( + self, + dim_groups: Sequence[Sequence[int]], + channels_in: int, + channels_out: int, + cond_dim: int, + ) -> None: + """Initialize the SeparableResBlock. + + Applies convolutions as separable convolutions with SilU activation and group normalization. + For example, if ``dim_groups = ((-1,-2), (-3))`` then one 2D convolution is applied to the last two dimensions, + and one 1D convolution is applied to the last dimension. + The order within the block is Norm->Activation->Conv. + The whole sequence for all dimension groups is performed twice, with optional FiLM conditioning in between. + So for two `dim_groups`, a total of 4 convolutions are applied. + + Parameters + ---------- + dim_groups + Sequence of dimension groups to use in the convolutions. + channels_in + Number of input channels. + channels_out + Number of output channels. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + """ + super().__init__() + self.rezero = torch.nn.Parameter(torch.tensor(0.1)) + + def block(dims: Sequence[int], channels_in: int) -> Module: + return Sequential( + GroupNorm(channels_in), + SiLU(), + PermutedBlock(dims, ConvND(len(dims))(channels_in, channels_out, 3, padding=1)), + ) + + blocks = Sequential(*(block(d, channels_in if i == 0 else channels_out) for i, d in enumerate(dim_groups))) + if cond_dim > 0: + blocks.append(FiLM(channels_out, cond_dim)) + blocks.extend(block(d, channels_out) for d in dim_groups) + self.block = blocks + self.skip_connection = None + if channels_in != channels_out: + self.skip_connection = torch.nn.Linear(channels_in, channels_out) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. + + Returns + ------- + Output tensor with the same number and order of dimensions as the input. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock.""" + h = self.block(x, cond=cond) + if self.skip_connection is None: + skip = x + else: + skip = torch.moveaxis(x, 1, -1) + skip = self.skip_connection(skip) + skip = torch.moveaxis(skip, -1, 1) + return skip + self.rezero * h + + +from collections.abc import Sequence + +import torch +from torch.nn import Module + + +class SeparableResBlock(Module): + """Residual block with separable convolutions and ReZero.""" + + def __init__( + self, + dim_groups: Sequence[Sequence[int]], + channels_in: int, + channels_out: int, + cond_dim: int, + ) -> None: + """Initialize the SeparableResBlock. + + Applies convolutions as separable convolutions with SilU activation and group normalization. + For example, if ``dim_groups = ((-1,-2), (-3))`` then one 2D convolution is applied to the last two dimensions, + and one 1D convolution is applied to the last dimension. + The order within the block is Norm->Activation->Conv. + The whole sequence for all dimension groups is performed twice, with optional FiLM conditioning in between. + So for two `dim_groups`, a total of 4 convolutions are applied. + + Parameters + ---------- + dim_groups + Sequence of dimension groups to use in the convolutions. + channels_in + Number of input channels. + channels_out + Number of output channels. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + """ + super().__init__() + self.rezero = torch.nn.Parameter(torch.tensor(0.1)) + + def block(dims: Sequence[int], channels_in: int) -> Module: + return Sequential( + GroupNorm(channels_in), + SiLU(), + PermutedBlock(dims, ConvND(len(dims))(channels_in, channels_out, 3, padding=1)), + ) + + blocks = Sequential(*(block(d, channels_in if i == 0 else channels_out) for i, d in enumerate(dim_groups))) + if cond_dim > 0: + blocks.append(FiLM(channels_out, cond_dim)) + blocks.extend(block(d, channels_out) for d in dim_groups) + self.block = blocks + self.skip_connection = None + if channels_in != channels_out: + self.skip_connection = torch.nn.Linear(channels_in, channels_out) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. + + Returns + ------- + Output tensor with the same number and order of dimensions as the input. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock.""" + h = self.block(x, cond=cond) + if self.skip_connection is None: + skip = x + else: + skip = torch.moveaxis(x, 1, -1) + skip = self.skip_connection(skip) + skip = torch.moveaxis(skip, -1, 1) + return skip + self.rezero * h diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py index 15b5d0152..fb56bd43f 100644 --- a/src/mrpro/nn/Sequential.py +++ b/src/mrpro/nn/Sequential.py @@ -9,7 +9,7 @@ from mrpro.operators import Operator -class Sequential(CondMixin,torch.nn.Sequential): +class Sequential(CondMixin, torch.nn.Sequential): """Sequential container with support for conditioning and Operators. Allows multiple input tensors and a single output tensor of the sequential block. diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py index 2b4c3e6e2..906560c24 100644 --- a/src/mrpro/nn/SpatialTransformerBlock.py +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -129,6 +129,7 @@ def __init__( self.proj_in = Linear(channels, hidden_dim) self.transformer_blocks = Sequential() for group in (g for _ in range(depth) for g in dim_groups): + group = tuple(g - 1 if g < 0 else g for g in group) block = BasicTransformerBlock(hidden_dim, n_heads, p_dropout=dropout, cond_dim=cond_dim, features_last=True) self.transformer_blocks.append(PermutedBlock(group, block, features_last=True)) self.proj_out = Linear(hidden_dim, channels) diff --git a/src/mrpro/nn/Upsample.py b/src/mrpro/nn/Upsample.py index ec9b0e032..acced8d48 100644 --- a/src/mrpro/nn/Upsample.py +++ b/src/mrpro/nn/Upsample.py @@ -29,10 +29,10 @@ def __init__( super().__init__() self.scale_factor = scale_factor if mode == 'nearest': - dims = [tuple(d) for d in torch.tensor(dim).split(3)] - modes = ['nearest'] * len(self.dim) + dims = [d.tolist() for d in torch.tensor(dim).split(3)] + modes = ['nearest'] * len(dim) elif mode == 'linear': - dims = [tuple(d) for d in torch.tensor(dim).split(3)] + dims = [d.tolist() for d in torch.tensor(dim).split(3)] modes = [{1: 'linear', 2: 'bilinear', 3: 'trilinear'}[len(d)] for d in dims] elif mode == 'cubic': if not len(dim) == 2: @@ -42,18 +42,14 @@ def __init__( self.blocks = Sequential( *[ - PermutedBlock(d, Upsample(d, scale_factor=scale_factor, mode=m)) + PermutedBlock(d, torch.nn.Upsample(scale_factor=len(d) * (scale_factor,), mode=m)) for d, m in zip(dims, modes, strict=False) ] ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Upsample the input tensor.""" - return torch.nn.functional.interpolate( - x, - mode=self.mode, - scale_factor=self.scale_factor, - ) + return self.blocks(x) def __call__(self, x: torch.Tensor) -> torch.Tensor: """Upsample the input tensor. diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index e59e4efde..9d2b1dff4 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -20,12 +20,17 @@ from mrpro.nn.SqueezeExcitation import SqueezeExcitation from mrpro.nn.TransposedAttention import TransposedAttention from mrpro.nn.DropPath import DropPath +from mrpro.nn.Residual import Residual +from mrpro.nn.ComplexAsChannel import ComplexAsChannel from mrpro.nn import nets +from mrpro.nn.PermutedBlock import PermutedBlock + __all__ = [ "AdaptiveAvgPoolND", "AttentionGate", "AvgPoolND", "BatchNormND", + "ComplexAsChannel", "CondMixin", "ConvND", "ConvTransposeND", @@ -35,7 +40,9 @@ "InstanceNormND", "MaxPoolND", "NeighborhoodSelfAttention", + "PermutedBlock", "ResBlock", + "Residual", "Sequential", "ShiftedWindowAttention", "SqueezeExcitation", diff --git a/src/mrpro/nn/nets/BasicCNN.py b/src/mrpro/nn/nets/BasicCNN.py new file mode 100644 index 000000000..b2671c121 --- /dev/null +++ b/src/mrpro/nn/nets/BasicCNN.py @@ -0,0 +1,65 @@ +from collections.abc import Sequence +from itertools import pairwise + +import torch +from torch.nn import ReLU + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.ndmodules import BatchNormND, ConvND +from mrpro.nn.Sequential import Sequential + + +class BasicCNN(Sequential): + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + batch_norm: bool = True, + n_features: Sequence[int] = (64, 64, 64), + cond_dim: int = 0, + ): + """Initialize a basic CNN. + + Parameters + ---------- + dim + The number of spatial dimensions of the input tensor. + channels_in + The number of input channels. + channels_out + The number of output channels. + batch_norm + Whether to use batch normalization. + n_features + The number of features in the hidden layers. The length of this sequence determines the number of hidden layers. + cond_dim + The dimension of the condition tensor. If 0, no FiLM conditioning is applied. + """ + super().__init__() + use_film = cond_dim > 0 + self.append(ConvND(dim)(channels_in, n_features[0], kernel_size=3, padding='same')) + for c_in, c_out in pairwise((*n_features, channels_out)): + if batch_norm: + self.append(BatchNormND(dim)(c_in, affine=not use_film)) + if use_film: + self.append(FiLM(c_in, cond_dim)) + self.append(ReLU(True)) + self.append(ConvND(dim)(c_in, c_out, kernel_size=3, padding='same')) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None) -> torch.Tensor: + """Apply the basic CNN to the input tensor. + + Parameters + ---------- + x + The input tensor. Should be of shape `(batch_size, channels_in, *spatial dimensions)` + with `spatial dimensions` being of length `dim`. + cond + The condition tensor. If None, no FiLM conditioning is applied. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 00d1ea510..e7a8f07bb 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -13,7 +13,9 @@ from mrpro.nn.GroupNorm import GroupNorm from mrpro.nn.join import Concat from mrpro.nn.ndmodules import ConvND, MaxPoolND +from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.ResBlock import ResBlock +from mrpro.nn.SeparableResBlock import SeparableResBlock # Assuming SeparableResBlock is here from mrpro.nn.Sequential import Sequential from mrpro.nn.SpatialTransformerBlock import SpatialTransformerBlock from mrpro.nn.Upsample import Upsample @@ -217,7 +219,11 @@ def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Se encoder_blocks.append(ResBlock(dim, n_feat, n_feat, cond_dim)) decoder_blocks.append(ResBlock(dim, 2 * n_feat, n_feat, cond_dim)) down_blocks.append(ConvND(dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) - up_blocks.append(Sequential(Upsample(dim, scale_factor=2), ConvND(dim)(n_feat_next, n_feat, 3, padding=1))) + up_blocks.append( + Sequential( + Upsample(tuple(range(-dim, 0)), scale_factor=2), ConvND(dim)(n_feat_next, n_feat, 3, padding=1) + ) + ) concat_blocks.append(Concat()) up_blocks = up_blocks[::-1] decoder_blocks = decoder_blocks[::-1] @@ -254,9 +260,9 @@ def __init__( dim: int, channels_in: int, channels_out: int, - attention_depths: Sequence[int] = (-1, -2), + attention_depths: Sequence[int] = (-1,), n_features: Sequence[int] = (64, 128, 192, 256), - n_heads: int = 4, + n_heads: int = 8, cond_dim: int = 0, encoder_blocks_per_scale: int = 2, ) -> None: @@ -292,9 +298,7 @@ def __init__( def attention_block(channels: int) -> Module: dim_groups = (tuple(range(-dim, 0)),) - return SpatialTransformerBlock( - dim_groups, channels, n_heads, channels_per_head=channels // n_heads, cond_dim=cond_dim - ) + return SpatialTransformerBlock(dim_groups, channels, n_heads, cond_dim=cond_dim) def block(channels_in: int, channels_out: int, attention: bool) -> Module: if not attention: @@ -336,7 +340,7 @@ def block(channels_in: int, channels_out: int, attention: bool) -> Module: block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) ) up_blocks.append(Identity()) - up_blocks.append(Upsample(dim, scale_factor=2)) + up_blocks.append(Upsample(tuple(range(-dim, 0)), scale_factor=2)) up_blocks.pop() # no upsampling after the last resolution level concat_blocks = [Concat() for _ in range(len(decoder_blocks))] last_block = Sequential( @@ -411,44 +415,12 @@ def block(channels_in: int, channels_out: int) -> Module: super().__init__(encoder, decoder) -from collections.abc import Sequence - -from mrpro.nn.PermutedBlock import PermutedBlock -from mrpro.nn.SeparableResBlock import SeparableResBlock # Assuming SeparableResBlock is here -from mrpro.nn.UNet import UNetBase, UNetDecoder, UNetEncoder - - class SeparableUNet(UNetBase): - """ - UNet with separable convolutions and controlled downsampling. - """ + """UNet with separable convolutions and attention, and grouped downsampling.""" def __init__( self, - dim: int, # Total number of spatial dimensions (e.g., 2 for 2D, 3 for 3D) - dim_groups: Sequence[tuple[int, ...]], - channels_in: int, - channels_out: int, - n_features: Sequence[int], - cond_dim: int, - downsample_dims: Sequence[Sequence[int]] | None = None, - encoder_blocks_per_scale: int = 2, - ) -> None: - """ - Initialize the SeparableUNet. - - Parameters - ---------- - - """ - class SeparableUNet(UNetBase): - """ - UNet with separable convolutions and attention, and grouped downsampling. - """ - - def __init__( - self, - dim:int, + dim: int, dim_groups: Sequence[tuple[int, ...]], channels_in: int, channels_out: int, @@ -488,7 +460,7 @@ def __init__( Sequence specifying which absolute spatial dimensions to downsample at each encoder level. If None, all dimensions in `dim_groups` are combined and downsampled at each level. - If a downsampling step contains more than 3 dimensions, downsampling is performed separatly for each + If a downsampling step contains more than 3 dimensions, downsampling is performed separately for each dimension. If the length of the sequence is less than the number of resolution levels, the sequence is repeated. E.g., ``((-1,-2), (-1,-2,-3))`` for 3D data: first level downsamples x,y; second level x,y,z; third level x,y. @@ -497,35 +469,31 @@ def __init__( """ depth = len(n_features) for group in dim_groups: - if len(group)>3: - raise ValueError(f"dim_group {group} can at most contain 3 dimensions. Split it into multiple groups.") - if any(d>dim+2 or d<-dim for d in group): - raise ValueError(f"dim_group {group} contains dimensions that are out of range for dim={dim}") + if len(group) > 3: + raise ValueError(f'dim_group {group} can at most contain 3 dimensions. Split it into multiple groups.') + if any(d > dim + 2 or d < -dim for d in group): + raise ValueError(f'dim_group {group} contains dimensions that are out of range for dim={dim}') attention_depths = tuple(d % depth for d in attention_depths) if downsample_dims is None: - all_spatial_dims = tuple( - sorted(list(set(d if d<0 else d-dim-2 for group in dim_groups for d in group))) - ) + all_spatial_dims = tuple(sorted(set(d if d < 0 else d - dim - 2 for group in dim_groups for d in group))) downsample_dims = (all_spatial_dims,) * (depth - 1) - def downsampler(level_dims, c_in, c_out) -> Module: - if len(level_dims)>3: - sequence=Sequence(downsampler(d[0], c_in, c_out) for d in level_dims) - for d in level_dims[1:]: - sequence.append(downsampler(d, c_out, c_out)) - return sequence - return PermutedBlock( - level_dims, ConvND(len(level_dims))(c_in, c_out, 3, stride=2, padding=1)) + if len(level_dims) > 3: + sequence = Sequence(downsampler(d[0], c_in, c_out) for d in level_dims) + for d in level_dims[1:]: + sequence.append(downsampler(d, c_out, c_out)) + return sequence + return PermutedBlock(level_dims, ConvND(len(level_dims))(c_in, c_out, 3, stride=2, padding=1)) def upsampler(level_dims, c_in, c_out) -> Module: - if len(level_dims)>3: - sequence=Sequence(upsampler(d[0], c_in, c_out) for d in level_dims) + if len(level_dims) > 3: + sequence = Sequence(upsampler(d[0], c_in, c_out) for d in level_dims) for d in level_dims[1:]: sequence.append(upsampler(d, c_out, c_out)) return sequence - return PermutedBlock(level_dims, Upsample(len(level_dims), scale_factor=2, mode="nearest")) + return PermutedBlock(level_dims, Upsample(len(level_dims), scale_factor=2, mode='nearest')) def block(c_in: int, c_out: int, apply_attention: bool) -> Module: res_block = SeparableResBlock(dim_groups, c_in, c_out, cond_dim) @@ -548,7 +516,9 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: c_feat = n_feat_level skip_features.append(c_feat) if i_level < depth - 1: - down_blocks.append(_create_downsampler(downsample_dims_per_level[i_level], c_feat, n_features[i_level + 1])) + down_blocks.append( + _create_downsampler(downsample_dims_per_level[i_level], c_feat, n_features[i_level + 1]) + ) c_feat = n_features[i_level + 1] # -- Middle & Encoder Finalization -- @@ -575,132 +545,10 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: # -- Decoder Finalization -- concat_blocks = [Concat()] * len(decoder_blocks) last_block = Sequential( - GroupNorm(n_features[0]), SiLU(), - PermutedBlock(all_spatial_dims, ConvND(len(all_spatial_dims))(n_features[0], channels_out, 3, padding=1)) + GroupNorm(n_features[0]), + SiLU(), + PermutedBlock(all_spatial_dims, ConvND(len(all_spatial_dims))(n_features[0], channels_out, 3, padding=1)), ) decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) super().__init__(encoder, decoder) - -# class SpatioTemporalUNet(UNetBase): -# """UNet where blocks apply separable convolutions in different dimensions. -# U-shaped convolutional network with optional patch attention. -# Inspired by the OpenAi DDPM UNet/Latent Diffusion UNet [UNET]_, [LDM]_, -# Based on the pseudo-3D residual network of [QUI]_, [TRAN]_, [HO]_, and the residual blocks of [ZIM]_. - -# References -# ---------- -# .. [UNET] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image -# segmentation MICCAI 2015. https://arxiv.org/abs/1505.04597 -# .. [LDM] https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py -# .. [TRAN] Tran, D., Wang, H., Torresani, L., Ray, J., LeCun, Y., & Paluri, M. A closer look at spatiotemporal -# convolutions for action recognition. CVPR 2018. https://arxiv.org/abs/1711.11248 -# .. [QUI] Qiu, Z., Yao, T., & Mei, T. Learning spatio-temporal representation with pseudo-3d residual networks. -# ICCV 2017. https://arxiv.org/abs/1711.10305 -# .. [HO] Ho, J., Salimans, T., Gritsenko, A., Chan, W., Norouzi, M., & Fleet, D. J. Video diffusion models. -# NeurIPS 2022. https://arxiv.org/abs/2209.11168 -# .. [ZIM] Zimmermann, F. F., & Kofler, A. (2023, October). NoSENSE: Learned unrolled cardiac MRI reconstruction -# without explicit sensitivity maps. STACOM MICCAI 2023. https://arxiv.org/abs/2309.15608 -# """ - - -# def __init__( -# self, -# dim: int, -# in_channels: int, -# out_channels: int, -# attention_depths: Sequence[int] = (-1, -2), -# n_features: Sequence[int] = (64, 128, 192, 256), -# n_heads: int = 4, -# cond_dim: int = 0, -# encoder_blocks_per_scale: int = 2, -# temporal_downsampling: bool = False, -# ) -> None: -# """Initialize the UNet. - -# Parameters -# ---------- -# dim -# Spatial dimension of the input tensor. -# channels_in -# Number of channels in the input tensor. -# channels_out -# Number of channels in the output tensor. -# attention_depths -# The depths at which to apply attention. -# n_features -# Number of features at each resolution level. The length determines the number of resolution levels. -# n_heads -# Number of attention heads. -# cond_dim -# Number of channels in the conditioning tensor. If 0, no conditioning is applied. -# encoder_blocks_per_scale -# Number of encoder blocks per resolution level. The number of decoder blocks is one more. -# temporal_downsampling -# Whether to downsample the temporal dimension. -# """ -# depth = len(n_features) -# if not all(-depth <= d < depth for d in attention_depths): -# raise ValueError( -# f'attention_depths must be in the range [-depth, depth], got {attention_depths=} for {depth=}' -# ) -# attention_depths = tuple(d % depth for d in attention_depths) -# if len(attention_depths) != len(set(attention_depths)): -# raise ValueError(f'attention_depths must be unique, got {attention_depths=}') - -# def attention_block(channels: int) -> Module: -# SpatioTemporalBlock(SpatialTransformerBlock( -# dim, channels, n_heads, channels_per_head=channels // n_heads, cond_dim=cond_dim -# ) - -# def block(channels_in: int, channels_out: int, attention: bool) -> Module: -# if not attention: -# return ResBlock(dim, channels_in, channels_out, cond_dim) -# return Sequential(ResBlock(dim, channels_in, channels_out, cond_dim), attention_block(channels_out)) - -# first_block = ConvND(dim)(in_channels, n_features[0], 3, padding=1) -# encoder_blocks: list[Module] = [] -# down_blocks: list[Module] = [] -# skip_features = [] -# n_feat_old = n_features[0] -# for i_level, n_feat in enumerate(n_features): -# encoder_blocks.append(Identity()) -# skip_features.append(n_feat_old) -# for _ in range(encoder_blocks_per_scale): -# encoder_blocks.append(block(n_feat_old, n_feat, attention=i_level in attention_depths)) -# n_feat_old = n_feat -# down_blocks.append(Identity()) -# skip_features.append(n_feat_old) -# down_blocks.append(ConvND(dim)(n_feat, n_feat, 3, stride=2, padding=1)) -# down_blocks[-1] = Identity() # no downsampling after the last resolution level -# middle_block = Sequential( -# ResBlock(dim, n_features[-1], n_features[-1], cond_dim), -# ResBlock(dim, n_features[-1], n_features[-1], cond_dim), -# ) -# if i_level in attention_depths: -# middle_block.insert(1, attention_block(n_features[-1])) -# encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) - -# decoder_blocks: list[Module] = [] -# up_blocks: list[Module] = [Identity()] -# for i_level, n_feat in reversed(list(enumerate(n_features))): -# decoder_blocks.append( -# block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) -# ) -# n_feat_old = n_feat -# for _ in range(encoder_blocks_per_scale): -# decoder_blocks.append( -# block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) -# ) -# up_blocks.append(Identity()) -# up_blocks.append(Upsample(dim, scale_factor=2)) -# up_blocks.pop() # no upsampling after the last resolution level -# concat_blocks = [Concat()] * len(decoder_blocks) -# last_block = Sequential( -# GroupNorm(n_features[0]), -# SiLU(), -# ConvND(dim)(n_features[0], out_channels, 3, padding=1), -# ) -# decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) - -# super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 6f540e118..228596dc8 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -2,13 +2,17 @@ from mrpro.nn.nets.Uformer import Uformer from mrpro.nn.nets.DCAE import DCVAE from mrpro.nn.nets.VAE import VAE -from mrpro.nn.nets.UNet import UNet, AttentionGatedUNet +from mrpro.nn.nets.UNet import UNet, AttentionGatedUNet, BasicUNet, SeparableUNet from mrpro.nn.nets.SwinIR import SwinIR +from mrpro.nn.nets.BasicCNN import BasicCNN __all__ = [ "AttentionGatedUNet", + "BasicCNN", + "BasicUNet", "DCVAE", "Restormer", + "SeparableUNet", "SwinIR", "UNet", "Uformer", diff --git a/src/mrpro/operators/LinearOperator.py b/src/mrpro/operators/LinearOperator.py index 8556a6907..15a94093a 100644 --- a/src/mrpro/operators/LinearOperator.py +++ b/src/mrpro/operators/LinearOperator.py @@ -254,7 +254,7 @@ def __matmul__( return OperatorComposition(self, cast(Operator[Unpack[Tin2], tuple[torch.Tensor,]], other)) return NotImplemented # type: ignore[unreachable] - def __radd__(self, other: torch.Tensor) -> LinearOperator: + def __radd__(self, other: torch.Tensor | complex) -> LinearOperator: """Operator addition. Returns ``lambda x: self(x) + other*x`` @@ -262,7 +262,7 @@ def __radd__(self, other: torch.Tensor) -> LinearOperator: return self + other @overload # type: ignore[override] - def __add__(self, other: LinearOperator | torch.Tensor) -> LinearOperator: ... + def __add__(self, other: LinearOperator | torch.Tensor | complex) -> LinearOperator: ... @overload def __add__( @@ -270,14 +270,14 @@ def __add__( ) -> Operator[torch.Tensor, tuple[torch.Tensor,]]: ... def __add__( - self, other: Operator[torch.Tensor, tuple[torch.Tensor,]] | LinearOperator | torch.Tensor + self, other: Operator[torch.Tensor, tuple[torch.Tensor,]] | LinearOperator | torch.Tensor | complex ) -> Operator[torch.Tensor, tuple[torch.Tensor,]] | LinearOperator: """Operator addition. Returns ``lambda x: self(x) + other(x)`` if other is a operator, ``lambda x: self(x) + other`` if other is a tensor """ - if isinstance(other, torch.Tensor): + if isinstance(other, torch.Tensor | complex | int | float): # tensor addition return LinearOperatorSum(self, mrpro.operators.IdentityOp() * other) elif isinstance(self, mrpro.operators.ZeroOp): diff --git a/src/mrpro/operators/LinearOperatorMatrix.py b/src/mrpro/operators/LinearOperatorMatrix.py index 8ace21e9b..d77aa0c87 100644 --- a/src/mrpro/operators/LinearOperatorMatrix.py +++ b/src/mrpro/operators/LinearOperatorMatrix.py @@ -144,7 +144,7 @@ def __repr__(self): return f'LinearOperatorMatrix(shape={self._shape}, operators={self._operators})' # Note: The type ignores are needed because we currently cannot do arithmetic operations with non-linear operators. - def __add__(self, other: Self | LinearOperator | torch.Tensor) -> Self: # type: ignore[override] + def __add__(self, other: Self | LinearOperator | torch.Tensor | complex) -> Self: # type: ignore[override] """Addition.""" operators: list[list[LinearOperator]] = [] if isinstance(other, LinearOperatorMatrix): @@ -152,7 +152,7 @@ def __add__(self, other: Self | LinearOperator | torch.Tensor) -> Self: # type: raise ValueError('OperatorMatrix shapes do not match.') for self_row, other_row in zip(self._operators, other._operators, strict=False): operators.append([s + o for s, o in zip(self_row, other_row, strict=False)]) - elif isinstance(other, LinearOperator | torch.Tensor): + elif isinstance(other, LinearOperator | torch.Tensor | complex): if not self.shape[0] == self.shape[1]: raise NotImplementedError('Cannot add a LinearOperator to a non-square OperatorMatrix.') for i, self_row in enumerate(self._operators): @@ -161,7 +161,7 @@ def __add__(self, other: Self | LinearOperator | torch.Tensor) -> Self: # type: return NotImplemented # type: ignore[unreachable] return self.__class__(operators) - def __radd__(self, other: Self | LinearOperator | torch.Tensor) -> Self: + def __radd__(self, other: Self | LinearOperator | torch.Tensor | complex) -> Self: """Right addition.""" return self.__add__(other) diff --git a/src/mrpro/operators/Operator.py b/src/mrpro/operators/Operator.py index d52b4aabc..ea82f90af 100644 --- a/src/mrpro/operators/Operator.py +++ b/src/mrpro/operators/Operator.py @@ -53,7 +53,7 @@ def __matmul__( return OperatorComposition(self, cast(Operator[Unpack[Tin2], tuple[Unpack[Tin]]], other)) def __radd__( - self: Operator[Unpack[Tin], tuple[Unpack[Tin]]], other: torch.Tensor + self: Operator[Unpack[Tin], tuple[Unpack[Tin]]], other: torch.Tensor | complex ) -> Operator[Unpack[Tin], tuple[Unpack[Tin]]]: """Operator right addition. @@ -65,18 +65,18 @@ def __radd__( def __add__(self, other: Operator[Unpack[Tin], Tout]) -> Operator[Unpack[Tin], Tout]: ... @overload def __add__( - self: Operator[Unpack[Tin], tuple[Unpack[Tin]]], other: torch.Tensor + self: Operator[Unpack[Tin], tuple[Unpack[Tin]]], other: torch.Tensor | complex ) -> Operator[Unpack[Tin], tuple[Unpack[Tin]]]: ... def __add__( - self, other: Operator[Unpack[Tin], Tout] | torch.Tensor | mrpro.operators.ZeroOp + self, other: Operator[Unpack[Tin], Tout] | torch.Tensor | complex | mrpro.operators.ZeroOp ) -> Operator[Unpack[Tin], Tout] | Operator[Unpack[Tin], tuple[Unpack[Tin]]]: """Operator addition. Returns ``lambda x: self(x) + other(x)`` if other is a operator, ``lambda x: self(x) + other*x`` if other is a tensor """ - if isinstance(other, torch.Tensor): + if isinstance(other, torch.Tensor | complex | int | float): s = cast(Operator[Unpack[Tin], tuple[Unpack[Tin]]], self) o = cast(Operator[Unpack[Tin], tuple[Unpack[Tin]]], mrpro.operators.MultiIdentityOp() * other) return OperatorSum(s, o) @@ -102,6 +102,46 @@ def __rmul__(self, other: torch.Tensor | complex) -> Operator[Unpack[Tin], Tout] """ return OperatorElementwiseProductRight(self, other) + @overload + def __sub__(self, other: Operator[Unpack[Tin], Tout]) -> Operator[Unpack[Tin], Tout]: ... + + @overload + def __sub__( + self: Operator[Unpack[Tin], tuple[Unpack[Tin]]], other: torch.Tensor | complex + ) -> Operator[Unpack[Tin], tuple[Unpack[Tin]]]: ... + + def __sub__( + self, other: Operator[Unpack[Tin], Tout] | torch.Tensor | complex | mrpro.operators.ZeroOp + ) -> Operator[Unpack[Tin], Tout] | Operator[Unpack[Tin], tuple[Unpack[Tin]]]: + """Operator subtraction. + + Returns ``lambda x: self(x) - other(x)`` if other is a operator, + ``lambda x: self(x) - other*x`` if other is a tensor + """ + if isinstance(other, mrpro.operators.ZeroOp): + return self + return self + (-1.0) * other + + @overload + def __rsub__(self, other: Operator[Unpack[Tin], Tout]) -> Operator[Unpack[Tin], Tout]: ... + + @overload + def __rsub__( + self: Operator[Unpack[Tin], tuple[Unpack[Tin]]], other: torch.Tensor | complex + ) -> Operator[Unpack[Tin], tuple[Unpack[Tin]]]: ... + + def __rsub__( + self, other: Operator[Unpack[Tin], Tout] | torch.Tensor | complex | mrpro.operators.ZeroOp + ) -> Operator[Unpack[Tin], Tout] | Operator[Unpack[Tin], tuple[Unpack[Tin]]]: + """Operator subtraction. + + Returns ``lambda x: self(x) - other(x)`` if other is a operator, + ``lambda x: self(x) - other*x`` if other is a tensor + """ + if isinstance(other, mrpro.operators.ZeroOp): + return self + return (-1.0) * self + other + class OperatorComposition(Operator[Unpack[Tin2], Tout]): """Operator composition.""" diff --git a/src/mrpro/phantoms/__init__.py b/src/mrpro/phantoms/__init__.py index d8265a504..c1cc5bf24 100644 --- a/src/mrpro/phantoms/__init__.py +++ b/src/mrpro/phantoms/__init__.py @@ -3,6 +3,7 @@ from mrpro.phantoms.EllipsePhantom import EllipsePhantom from mrpro.phantoms.phantom_elements import EllipseParameters from mrpro.phantoms import brainweb +from mrpro.phantoms import coils from mrpro.phantoms.m4raw import M4RawDataset from mrpro.phantoms import mdcnn from mrpro.phantoms.fastmri import FastMRIKDataDataset, FastMRIImageDataset diff --git a/src/mrpro/utils/pad_or_crop.py b/src/mrpro/utils/pad_or_crop.py index 5bb82f21d..088f27aeb 100644 --- a/src/mrpro/utils/pad_or_crop.py +++ b/src/mrpro/utils/pad_or_crop.py @@ -51,7 +51,7 @@ def pad_or_crop( mode Mode to use for padding. value - Value to use for constant padding. + value to use for constant padding. Returns ------- From 2270236df973b8636f609d4bf8cb29ea0b359ae1 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 11 Jul 2025 13:08:38 +0200 Subject: [PATCH 096/205] pull changes form pinqi --- src/mrpro/nn/LayerNorm.py | 10 +- src/mrpro/nn/PermutedBlock.py | 58 +++++++ src/mrpro/nn/SeparableResBlock.py | 170 ++++++++++++++++++ src/mrpro/nn/Sequential.py | 2 +- src/mrpro/nn/SpatialTransformerBlock.py | 1 + src/mrpro/nn/Upsample.py | 14 +- src/mrpro/nn/__init__.py | 7 + src/mrpro/nn/nets/BasicCNN.py | 65 +++++++ src/mrpro/nn/nets/UNet.py | 220 ++++-------------------- src/mrpro/nn/nets/__init__.py | 6 +- 10 files changed, 353 insertions(+), 200 deletions(-) create mode 100644 src/mrpro/nn/PermutedBlock.py create mode 100644 src/mrpro/nn/SeparableResBlock.py create mode 100644 src/mrpro/nn/nets/BasicCNN.py diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py index 699de57f0..84e1d56e4 100644 --- a/src/mrpro/nn/LayerNorm.py +++ b/src/mrpro/nn/LayerNorm.py @@ -34,26 +34,30 @@ def __init__(self, channels: int | None, features_last: bool = False, cond_dim: self.weight = Parameter(torch.ones(channels)) self.bias = Parameter(torch.zeros(channels)) self.cond_proj = None - else: + elif channels is not None: self.weight = None self.bias = None self.cond_proj = Linear(cond_dim, 2 * channels) + else: + raise ValueError('cond_dim must be zero or positive.') self.features_last = features_last - def __call__(self, x: torch.Tensor) -> torch.Tensor: + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply layer normalization to the input tensor. Parameters ---------- x Input tensor + cond + Conditioning tensor. If `None`, no conditioning is applied. Returns ------- Normalized output tensor """ - return super().__call__(x) + return super().__call__(x, cond=cond) def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply layer normalization to the input tensor.""" diff --git a/src/mrpro/nn/PermutedBlock.py b/src/mrpro/nn/PermutedBlock.py new file mode 100644 index 000000000..99a27f36a --- /dev/null +++ b/src/mrpro/nn/PermutedBlock.py @@ -0,0 +1,58 @@ +"""Block that applies a submodule along selected spatial dimensions.""" + +from collections.abc import Sequence + +import torch +from torch import nn + +from mrpro.nn.CondMixin import CondMixin, call_with_cond + + +class PermutedBlock(CondMixin, nn.Module): + """Apply a submodule along selected spatial dimensions.""" + + apply_along_dim: tuple[int, ...] + module: nn.Module + + def __init__(self, apply_along_dim: Sequence[int], module: nn.Module, features_last: bool = False): + """Initialize the PermutedBlock. + + Parameters + ---------- + apply_along_dim + Spatial dimension indices to use when applying the module. + These will be moved to the last dimensions. + module + Module to apply on the selected dims. + features_last + If True, the features dimension is assumed to be the last dimension, as common in transformer models. + """ + super().__init__() + self.apply_along_dim = tuple(sorted(apply_along_dim)) + self.module = module + self.features_last = features_last + + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module along the selected dimensions.""" + keep = tuple(d % x.ndim for d in self.apply_along_dim) + if 0 in keep: + raise ValueError('Batch dimension should not be in apply_along_dim.') + if self.features_last: + if x.ndim - 1 in keep: + raise ValueError('Features dimension should not be in apply_along_dim.') + batch_dim = tuple(d for d in range(1, x.ndim - 1) if d not in keep) + permute = (0, *batch_dim, *keep, x.ndim - 1) + else: + if 1 in keep: + raise ValueError('Features dimension should not be in apply_along_dim.') + batch_dim = tuple(d for d in range(2, x.ndim) if d not in keep) + permute = (0, *batch_dim, 1, *keep) + h = x.permute(permute) + batch_shape = h.shape[: 1 + len(batch_dim)] + h = h.flatten(0, len(batch_dim)) + h = call_with_cond(self.module, h, cond=cond) + h = h.unflatten(0, batch_shape) + permute_back = [0] * x.ndim + for i, p in enumerate(permute): + permute_back[p] = i + return h.permute(tuple(permute_back)) diff --git a/src/mrpro/nn/SeparableResBlock.py b/src/mrpro/nn/SeparableResBlock.py new file mode 100644 index 000000000..c26a012cf --- /dev/null +++ b/src/mrpro/nn/SeparableResBlock.py @@ -0,0 +1,170 @@ +from collections.abc import Sequence + +import torch +from torch.nn import Module, SiLU + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import ConvND +from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.Sequential import Sequential + + +class SeparableResBlock(Module): + """Residual block with separable convolutions and ReZero.""" + + def __init__( + self, + dim_groups: Sequence[Sequence[int]], + channels_in: int, + channels_out: int, + cond_dim: int, + ) -> None: + """Initialize the SeparableResBlock. + + Applies convolutions as separable convolutions with SilU activation and group normalization. + For example, if ``dim_groups = ((-1,-2), (-3))`` then one 2D convolution is applied to the last two dimensions, + and one 1D convolution is applied to the last dimension. + The order within the block is Norm->Activation->Conv. + The whole sequence for all dimension groups is performed twice, with optional FiLM conditioning in between. + So for two `dim_groups`, a total of 4 convolutions are applied. + + Parameters + ---------- + dim_groups + Sequence of dimension groups to use in the convolutions. + channels_in + Number of input channels. + channels_out + Number of output channels. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + """ + super().__init__() + self.rezero = torch.nn.Parameter(torch.tensor(0.1)) + + def block(dims: Sequence[int], channels_in: int) -> Module: + return Sequential( + GroupNorm(channels_in), + SiLU(), + PermutedBlock(dims, ConvND(len(dims))(channels_in, channels_out, 3, padding=1)), + ) + + blocks = Sequential(*(block(d, channels_in if i == 0 else channels_out) for i, d in enumerate(dim_groups))) + if cond_dim > 0: + blocks.append(FiLM(channels_out, cond_dim)) + blocks.extend(block(d, channels_out) for d in dim_groups) + self.block = blocks + self.skip_connection = None + if channels_in != channels_out: + self.skip_connection = torch.nn.Linear(channels_in, channels_out) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. + + Returns + ------- + Output tensor with the same number and order of dimensions as the input. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock.""" + h = self.block(x, cond=cond) + if self.skip_connection is None: + skip = x + else: + skip = torch.moveaxis(x, 1, -1) + skip = self.skip_connection(skip) + skip = torch.moveaxis(skip, -1, 1) + return skip + self.rezero * h + + +from collections.abc import Sequence + +import torch +from torch.nn import Module + + +class SeparableResBlock(Module): + """Residual block with separable convolutions and ReZero.""" + + def __init__( + self, + dim_groups: Sequence[Sequence[int]], + channels_in: int, + channels_out: int, + cond_dim: int, + ) -> None: + """Initialize the SeparableResBlock. + + Applies convolutions as separable convolutions with SilU activation and group normalization. + For example, if ``dim_groups = ((-1,-2), (-3))`` then one 2D convolution is applied to the last two dimensions, + and one 1D convolution is applied to the last dimension. + The order within the block is Norm->Activation->Conv. + The whole sequence for all dimension groups is performed twice, with optional FiLM conditioning in between. + So for two `dim_groups`, a total of 4 convolutions are applied. + + Parameters + ---------- + dim_groups + Sequence of dimension groups to use in the convolutions. + channels_in + Number of input channels. + channels_out + Number of output channels. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + """ + super().__init__() + self.rezero = torch.nn.Parameter(torch.tensor(0.1)) + + def block(dims: Sequence[int], channels_in: int) -> Module: + return Sequential( + GroupNorm(channels_in), + SiLU(), + PermutedBlock(dims, ConvND(len(dims))(channels_in, channels_out, 3, padding=1)), + ) + + blocks = Sequential(*(block(d, channels_in if i == 0 else channels_out) for i, d in enumerate(dim_groups))) + if cond_dim > 0: + blocks.append(FiLM(channels_out, cond_dim)) + blocks.extend(block(d, channels_out) for d in dim_groups) + self.block = blocks + self.skip_connection = None + if channels_in != channels_out: + self.skip_connection = torch.nn.Linear(channels_in, channels_out) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. + + Returns + ------- + Output tensor with the same number and order of dimensions as the input. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock.""" + h = self.block(x, cond=cond) + if self.skip_connection is None: + skip = x + else: + skip = torch.moveaxis(x, 1, -1) + skip = self.skip_connection(skip) + skip = torch.moveaxis(skip, -1, 1) + return skip + self.rezero * h diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py index 15b5d0152..fb56bd43f 100644 --- a/src/mrpro/nn/Sequential.py +++ b/src/mrpro/nn/Sequential.py @@ -9,7 +9,7 @@ from mrpro.operators import Operator -class Sequential(CondMixin,torch.nn.Sequential): +class Sequential(CondMixin, torch.nn.Sequential): """Sequential container with support for conditioning and Operators. Allows multiple input tensors and a single output tensor of the sequential block. diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py index 2b4c3e6e2..906560c24 100644 --- a/src/mrpro/nn/SpatialTransformerBlock.py +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -129,6 +129,7 @@ def __init__( self.proj_in = Linear(channels, hidden_dim) self.transformer_blocks = Sequential() for group in (g for _ in range(depth) for g in dim_groups): + group = tuple(g - 1 if g < 0 else g for g in group) block = BasicTransformerBlock(hidden_dim, n_heads, p_dropout=dropout, cond_dim=cond_dim, features_last=True) self.transformer_blocks.append(PermutedBlock(group, block, features_last=True)) self.proj_out = Linear(hidden_dim, channels) diff --git a/src/mrpro/nn/Upsample.py b/src/mrpro/nn/Upsample.py index ec9b0e032..acced8d48 100644 --- a/src/mrpro/nn/Upsample.py +++ b/src/mrpro/nn/Upsample.py @@ -29,10 +29,10 @@ def __init__( super().__init__() self.scale_factor = scale_factor if mode == 'nearest': - dims = [tuple(d) for d in torch.tensor(dim).split(3)] - modes = ['nearest'] * len(self.dim) + dims = [d.tolist() for d in torch.tensor(dim).split(3)] + modes = ['nearest'] * len(dim) elif mode == 'linear': - dims = [tuple(d) for d in torch.tensor(dim).split(3)] + dims = [d.tolist() for d in torch.tensor(dim).split(3)] modes = [{1: 'linear', 2: 'bilinear', 3: 'trilinear'}[len(d)] for d in dims] elif mode == 'cubic': if not len(dim) == 2: @@ -42,18 +42,14 @@ def __init__( self.blocks = Sequential( *[ - PermutedBlock(d, Upsample(d, scale_factor=scale_factor, mode=m)) + PermutedBlock(d, torch.nn.Upsample(scale_factor=len(d) * (scale_factor,), mode=m)) for d, m in zip(dims, modes, strict=False) ] ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Upsample the input tensor.""" - return torch.nn.functional.interpolate( - x, - mode=self.mode, - scale_factor=self.scale_factor, - ) + return self.blocks(x) def __call__(self, x: torch.Tensor) -> torch.Tensor: """Upsample the input tensor. diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index e59e4efde..9d2b1dff4 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -20,12 +20,17 @@ from mrpro.nn.SqueezeExcitation import SqueezeExcitation from mrpro.nn.TransposedAttention import TransposedAttention from mrpro.nn.DropPath import DropPath +from mrpro.nn.Residual import Residual +from mrpro.nn.ComplexAsChannel import ComplexAsChannel from mrpro.nn import nets +from mrpro.nn.PermutedBlock import PermutedBlock + __all__ = [ "AdaptiveAvgPoolND", "AttentionGate", "AvgPoolND", "BatchNormND", + "ComplexAsChannel", "CondMixin", "ConvND", "ConvTransposeND", @@ -35,7 +40,9 @@ "InstanceNormND", "MaxPoolND", "NeighborhoodSelfAttention", + "PermutedBlock", "ResBlock", + "Residual", "Sequential", "ShiftedWindowAttention", "SqueezeExcitation", diff --git a/src/mrpro/nn/nets/BasicCNN.py b/src/mrpro/nn/nets/BasicCNN.py new file mode 100644 index 000000000..b2671c121 --- /dev/null +++ b/src/mrpro/nn/nets/BasicCNN.py @@ -0,0 +1,65 @@ +from collections.abc import Sequence +from itertools import pairwise + +import torch +from torch.nn import ReLU + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.ndmodules import BatchNormND, ConvND +from mrpro.nn.Sequential import Sequential + + +class BasicCNN(Sequential): + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + batch_norm: bool = True, + n_features: Sequence[int] = (64, 64, 64), + cond_dim: int = 0, + ): + """Initialize a basic CNN. + + Parameters + ---------- + dim + The number of spatial dimensions of the input tensor. + channels_in + The number of input channels. + channels_out + The number of output channels. + batch_norm + Whether to use batch normalization. + n_features + The number of features in the hidden layers. The length of this sequence determines the number of hidden layers. + cond_dim + The dimension of the condition tensor. If 0, no FiLM conditioning is applied. + """ + super().__init__() + use_film = cond_dim > 0 + self.append(ConvND(dim)(channels_in, n_features[0], kernel_size=3, padding='same')) + for c_in, c_out in pairwise((*n_features, channels_out)): + if batch_norm: + self.append(BatchNormND(dim)(c_in, affine=not use_film)) + if use_film: + self.append(FiLM(c_in, cond_dim)) + self.append(ReLU(True)) + self.append(ConvND(dim)(c_in, c_out, kernel_size=3, padding='same')) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None) -> torch.Tensor: + """Apply the basic CNN to the input tensor. + + Parameters + ---------- + x + The input tensor. Should be of shape `(batch_size, channels_in, *spatial dimensions)` + with `spatial dimensions` being of length `dim`. + cond + The condition tensor. If None, no FiLM conditioning is applied. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 00d1ea510..e7a8f07bb 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -13,7 +13,9 @@ from mrpro.nn.GroupNorm import GroupNorm from mrpro.nn.join import Concat from mrpro.nn.ndmodules import ConvND, MaxPoolND +from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.ResBlock import ResBlock +from mrpro.nn.SeparableResBlock import SeparableResBlock # Assuming SeparableResBlock is here from mrpro.nn.Sequential import Sequential from mrpro.nn.SpatialTransformerBlock import SpatialTransformerBlock from mrpro.nn.Upsample import Upsample @@ -217,7 +219,11 @@ def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Se encoder_blocks.append(ResBlock(dim, n_feat, n_feat, cond_dim)) decoder_blocks.append(ResBlock(dim, 2 * n_feat, n_feat, cond_dim)) down_blocks.append(ConvND(dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) - up_blocks.append(Sequential(Upsample(dim, scale_factor=2), ConvND(dim)(n_feat_next, n_feat, 3, padding=1))) + up_blocks.append( + Sequential( + Upsample(tuple(range(-dim, 0)), scale_factor=2), ConvND(dim)(n_feat_next, n_feat, 3, padding=1) + ) + ) concat_blocks.append(Concat()) up_blocks = up_blocks[::-1] decoder_blocks = decoder_blocks[::-1] @@ -254,9 +260,9 @@ def __init__( dim: int, channels_in: int, channels_out: int, - attention_depths: Sequence[int] = (-1, -2), + attention_depths: Sequence[int] = (-1,), n_features: Sequence[int] = (64, 128, 192, 256), - n_heads: int = 4, + n_heads: int = 8, cond_dim: int = 0, encoder_blocks_per_scale: int = 2, ) -> None: @@ -292,9 +298,7 @@ def __init__( def attention_block(channels: int) -> Module: dim_groups = (tuple(range(-dim, 0)),) - return SpatialTransformerBlock( - dim_groups, channels, n_heads, channels_per_head=channels // n_heads, cond_dim=cond_dim - ) + return SpatialTransformerBlock(dim_groups, channels, n_heads, cond_dim=cond_dim) def block(channels_in: int, channels_out: int, attention: bool) -> Module: if not attention: @@ -336,7 +340,7 @@ def block(channels_in: int, channels_out: int, attention: bool) -> Module: block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) ) up_blocks.append(Identity()) - up_blocks.append(Upsample(dim, scale_factor=2)) + up_blocks.append(Upsample(tuple(range(-dim, 0)), scale_factor=2)) up_blocks.pop() # no upsampling after the last resolution level concat_blocks = [Concat() for _ in range(len(decoder_blocks))] last_block = Sequential( @@ -411,44 +415,12 @@ def block(channels_in: int, channels_out: int) -> Module: super().__init__(encoder, decoder) -from collections.abc import Sequence - -from mrpro.nn.PermutedBlock import PermutedBlock -from mrpro.nn.SeparableResBlock import SeparableResBlock # Assuming SeparableResBlock is here -from mrpro.nn.UNet import UNetBase, UNetDecoder, UNetEncoder - - class SeparableUNet(UNetBase): - """ - UNet with separable convolutions and controlled downsampling. - """ + """UNet with separable convolutions and attention, and grouped downsampling.""" def __init__( self, - dim: int, # Total number of spatial dimensions (e.g., 2 for 2D, 3 for 3D) - dim_groups: Sequence[tuple[int, ...]], - channels_in: int, - channels_out: int, - n_features: Sequence[int], - cond_dim: int, - downsample_dims: Sequence[Sequence[int]] | None = None, - encoder_blocks_per_scale: int = 2, - ) -> None: - """ - Initialize the SeparableUNet. - - Parameters - ---------- - - """ - class SeparableUNet(UNetBase): - """ - UNet with separable convolutions and attention, and grouped downsampling. - """ - - def __init__( - self, - dim:int, + dim: int, dim_groups: Sequence[tuple[int, ...]], channels_in: int, channels_out: int, @@ -488,7 +460,7 @@ def __init__( Sequence specifying which absolute spatial dimensions to downsample at each encoder level. If None, all dimensions in `dim_groups` are combined and downsampled at each level. - If a downsampling step contains more than 3 dimensions, downsampling is performed separatly for each + If a downsampling step contains more than 3 dimensions, downsampling is performed separately for each dimension. If the length of the sequence is less than the number of resolution levels, the sequence is repeated. E.g., ``((-1,-2), (-1,-2,-3))`` for 3D data: first level downsamples x,y; second level x,y,z; third level x,y. @@ -497,35 +469,31 @@ def __init__( """ depth = len(n_features) for group in dim_groups: - if len(group)>3: - raise ValueError(f"dim_group {group} can at most contain 3 dimensions. Split it into multiple groups.") - if any(d>dim+2 or d<-dim for d in group): - raise ValueError(f"dim_group {group} contains dimensions that are out of range for dim={dim}") + if len(group) > 3: + raise ValueError(f'dim_group {group} can at most contain 3 dimensions. Split it into multiple groups.') + if any(d > dim + 2 or d < -dim for d in group): + raise ValueError(f'dim_group {group} contains dimensions that are out of range for dim={dim}') attention_depths = tuple(d % depth for d in attention_depths) if downsample_dims is None: - all_spatial_dims = tuple( - sorted(list(set(d if d<0 else d-dim-2 for group in dim_groups for d in group))) - ) + all_spatial_dims = tuple(sorted(set(d if d < 0 else d - dim - 2 for group in dim_groups for d in group))) downsample_dims = (all_spatial_dims,) * (depth - 1) - def downsampler(level_dims, c_in, c_out) -> Module: - if len(level_dims)>3: - sequence=Sequence(downsampler(d[0], c_in, c_out) for d in level_dims) - for d in level_dims[1:]: - sequence.append(downsampler(d, c_out, c_out)) - return sequence - return PermutedBlock( - level_dims, ConvND(len(level_dims))(c_in, c_out, 3, stride=2, padding=1)) + if len(level_dims) > 3: + sequence = Sequence(downsampler(d[0], c_in, c_out) for d in level_dims) + for d in level_dims[1:]: + sequence.append(downsampler(d, c_out, c_out)) + return sequence + return PermutedBlock(level_dims, ConvND(len(level_dims))(c_in, c_out, 3, stride=2, padding=1)) def upsampler(level_dims, c_in, c_out) -> Module: - if len(level_dims)>3: - sequence=Sequence(upsampler(d[0], c_in, c_out) for d in level_dims) + if len(level_dims) > 3: + sequence = Sequence(upsampler(d[0], c_in, c_out) for d in level_dims) for d in level_dims[1:]: sequence.append(upsampler(d, c_out, c_out)) return sequence - return PermutedBlock(level_dims, Upsample(len(level_dims), scale_factor=2, mode="nearest")) + return PermutedBlock(level_dims, Upsample(len(level_dims), scale_factor=2, mode='nearest')) def block(c_in: int, c_out: int, apply_attention: bool) -> Module: res_block = SeparableResBlock(dim_groups, c_in, c_out, cond_dim) @@ -548,7 +516,9 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: c_feat = n_feat_level skip_features.append(c_feat) if i_level < depth - 1: - down_blocks.append(_create_downsampler(downsample_dims_per_level[i_level], c_feat, n_features[i_level + 1])) + down_blocks.append( + _create_downsampler(downsample_dims_per_level[i_level], c_feat, n_features[i_level + 1]) + ) c_feat = n_features[i_level + 1] # -- Middle & Encoder Finalization -- @@ -575,132 +545,10 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: # -- Decoder Finalization -- concat_blocks = [Concat()] * len(decoder_blocks) last_block = Sequential( - GroupNorm(n_features[0]), SiLU(), - PermutedBlock(all_spatial_dims, ConvND(len(all_spatial_dims))(n_features[0], channels_out, 3, padding=1)) + GroupNorm(n_features[0]), + SiLU(), + PermutedBlock(all_spatial_dims, ConvND(len(all_spatial_dims))(n_features[0], channels_out, 3, padding=1)), ) decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) super().__init__(encoder, decoder) - -# class SpatioTemporalUNet(UNetBase): -# """UNet where blocks apply separable convolutions in different dimensions. -# U-shaped convolutional network with optional patch attention. -# Inspired by the OpenAi DDPM UNet/Latent Diffusion UNet [UNET]_, [LDM]_, -# Based on the pseudo-3D residual network of [QUI]_, [TRAN]_, [HO]_, and the residual blocks of [ZIM]_. - -# References -# ---------- -# .. [UNET] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image -# segmentation MICCAI 2015. https://arxiv.org/abs/1505.04597 -# .. [LDM] https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py -# .. [TRAN] Tran, D., Wang, H., Torresani, L., Ray, J., LeCun, Y., & Paluri, M. A closer look at spatiotemporal -# convolutions for action recognition. CVPR 2018. https://arxiv.org/abs/1711.11248 -# .. [QUI] Qiu, Z., Yao, T., & Mei, T. Learning spatio-temporal representation with pseudo-3d residual networks. -# ICCV 2017. https://arxiv.org/abs/1711.10305 -# .. [HO] Ho, J., Salimans, T., Gritsenko, A., Chan, W., Norouzi, M., & Fleet, D. J. Video diffusion models. -# NeurIPS 2022. https://arxiv.org/abs/2209.11168 -# .. [ZIM] Zimmermann, F. F., & Kofler, A. (2023, October). NoSENSE: Learned unrolled cardiac MRI reconstruction -# without explicit sensitivity maps. STACOM MICCAI 2023. https://arxiv.org/abs/2309.15608 -# """ - - -# def __init__( -# self, -# dim: int, -# in_channels: int, -# out_channels: int, -# attention_depths: Sequence[int] = (-1, -2), -# n_features: Sequence[int] = (64, 128, 192, 256), -# n_heads: int = 4, -# cond_dim: int = 0, -# encoder_blocks_per_scale: int = 2, -# temporal_downsampling: bool = False, -# ) -> None: -# """Initialize the UNet. - -# Parameters -# ---------- -# dim -# Spatial dimension of the input tensor. -# channels_in -# Number of channels in the input tensor. -# channels_out -# Number of channels in the output tensor. -# attention_depths -# The depths at which to apply attention. -# n_features -# Number of features at each resolution level. The length determines the number of resolution levels. -# n_heads -# Number of attention heads. -# cond_dim -# Number of channels in the conditioning tensor. If 0, no conditioning is applied. -# encoder_blocks_per_scale -# Number of encoder blocks per resolution level. The number of decoder blocks is one more. -# temporal_downsampling -# Whether to downsample the temporal dimension. -# """ -# depth = len(n_features) -# if not all(-depth <= d < depth for d in attention_depths): -# raise ValueError( -# f'attention_depths must be in the range [-depth, depth], got {attention_depths=} for {depth=}' -# ) -# attention_depths = tuple(d % depth for d in attention_depths) -# if len(attention_depths) != len(set(attention_depths)): -# raise ValueError(f'attention_depths must be unique, got {attention_depths=}') - -# def attention_block(channels: int) -> Module: -# SpatioTemporalBlock(SpatialTransformerBlock( -# dim, channels, n_heads, channels_per_head=channels // n_heads, cond_dim=cond_dim -# ) - -# def block(channels_in: int, channels_out: int, attention: bool) -> Module: -# if not attention: -# return ResBlock(dim, channels_in, channels_out, cond_dim) -# return Sequential(ResBlock(dim, channels_in, channels_out, cond_dim), attention_block(channels_out)) - -# first_block = ConvND(dim)(in_channels, n_features[0], 3, padding=1) -# encoder_blocks: list[Module] = [] -# down_blocks: list[Module] = [] -# skip_features = [] -# n_feat_old = n_features[0] -# for i_level, n_feat in enumerate(n_features): -# encoder_blocks.append(Identity()) -# skip_features.append(n_feat_old) -# for _ in range(encoder_blocks_per_scale): -# encoder_blocks.append(block(n_feat_old, n_feat, attention=i_level in attention_depths)) -# n_feat_old = n_feat -# down_blocks.append(Identity()) -# skip_features.append(n_feat_old) -# down_blocks.append(ConvND(dim)(n_feat, n_feat, 3, stride=2, padding=1)) -# down_blocks[-1] = Identity() # no downsampling after the last resolution level -# middle_block = Sequential( -# ResBlock(dim, n_features[-1], n_features[-1], cond_dim), -# ResBlock(dim, n_features[-1], n_features[-1], cond_dim), -# ) -# if i_level in attention_depths: -# middle_block.insert(1, attention_block(n_features[-1])) -# encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) - -# decoder_blocks: list[Module] = [] -# up_blocks: list[Module] = [Identity()] -# for i_level, n_feat in reversed(list(enumerate(n_features))): -# decoder_blocks.append( -# block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) -# ) -# n_feat_old = n_feat -# for _ in range(encoder_blocks_per_scale): -# decoder_blocks.append( -# block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) -# ) -# up_blocks.append(Identity()) -# up_blocks.append(Upsample(dim, scale_factor=2)) -# up_blocks.pop() # no upsampling after the last resolution level -# concat_blocks = [Concat()] * len(decoder_blocks) -# last_block = Sequential( -# GroupNorm(n_features[0]), -# SiLU(), -# ConvND(dim)(n_features[0], out_channels, 3, padding=1), -# ) -# decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) - -# super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 6f540e118..228596dc8 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -2,13 +2,17 @@ from mrpro.nn.nets.Uformer import Uformer from mrpro.nn.nets.DCAE import DCVAE from mrpro.nn.nets.VAE import VAE -from mrpro.nn.nets.UNet import UNet, AttentionGatedUNet +from mrpro.nn.nets.UNet import UNet, AttentionGatedUNet, BasicUNet, SeparableUNet from mrpro.nn.nets.SwinIR import SwinIR +from mrpro.nn.nets.BasicCNN import BasicCNN __all__ = [ "AttentionGatedUNet", + "BasicCNN", + "BasicUNet", "DCVAE", "Restormer", + "SeparableUNet", "SwinIR", "UNet", "Uformer", From 9310fb5f5c685acf74a8db11c59bca417ffd9482 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sat, 12 Jul 2025 12:53:01 +0200 Subject: [PATCH 097/205] simpliy unet --- src/mrpro/nn/MultiHeadAttention.py | 27 +++++--- src/mrpro/nn/NeighborhoodSelfAttention.py | 3 +- src/mrpro/nn/nets/UNet.py | 81 +++++++++-------------- 3 files changed, 51 insertions(+), 60 deletions(-) diff --git a/src/mrpro/nn/MultiHeadAttention.py b/src/mrpro/nn/MultiHeadAttention.py index c7b5e1cde..b79353192 100644 --- a/src/mrpro/nn/MultiHeadAttention.py +++ b/src/mrpro/nn/MultiHeadAttention.py @@ -1,6 +1,7 @@ """Multi-head Attention.""" import torch +from einops import rearrange from torch.nn import Linear, Module @@ -41,16 +42,14 @@ def __init__( Number of channels for cross-attention. If `None`, use `channels_in`. """ super().__init__() - self.mha = torch.nn.MultiheadAttention( - embed_dim=channels_in, - num_heads=n_heads, - batch_first=True, - dropout=p_dropout, - kdim=channels_cross, - vdim=channels_cross, - ) + channels_per_head_q = channels_in // n_heads + channels_per_head_kv = channels_cross // n_heads if channels_cross is not None else channels_in // n_heads + self.to_q = Linear(channels_in, channels_per_head_q * n_heads) + self.to_kv = Linear(channels_in, channels_per_head_kv * n_heads * 2) + self.p_dropout = p_dropout self.features_last = features_last self.to_out = Linear(channels_in, channels_out) + self.n_heads = n_heads def __call__(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: """Apply multi-head attention. @@ -78,8 +77,16 @@ def forward(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) reshaped_x = self._reshape(x) reshaped_cross_attention = self._reshape(cross_attention) if cross_attention is not None else reshaped_x - y = self.mha(reshaped_cross_attention, reshaped_cross_attention, reshaped_x, need_weights=False)[0] - out: torch.Tensor = self.to_out(y) + q = rearrange(self.to_q(reshaped_x), '... L (heads dim) -> ... heads L dim ', heads=self.n_heads) + k, v = rearrange( + self.to_kv(reshaped_cross_attention), + '... S (kv heads dim) -> kv ... heads S dim ', + heads=self.n_heads, + kv=2, + ) + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.p_dropout, is_causal=False) + y = rearrange(y, '... heads L dim -> ... L (heads dim)') + out = self.to_out(y) if not self.features_last: out = out.moveaxis(-1, 1) diff --git a/src/mrpro/nn/NeighborhoodSelfAttention.py b/src/mrpro/nn/NeighborhoodSelfAttention.py index 91151e4bb..69f81bc02 100644 --- a/src/mrpro/nn/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/NeighborhoodSelfAttention.py @@ -166,7 +166,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Parameters ---------- x - The input tensor, with shape `batch, channels, *spatial_dims`. + The input tensor, with shape `(batch, channels, *spatial_dims)` + or `(batch, *spatial_dims, channels)` (if `features_last`). Returns ------- diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index e7a8f07bb..c46c405ec 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -154,11 +154,11 @@ def __init__(self, encoder: UNetEncoder, decoder: UNetDecoder, skip_blocks: Sequ self.skip_blocks = ModuleList() """Modifications of the skip connections.""" - if len(decoder) != len(encoder): - raise ValueError( - 'The number of resolutions in the encoder and decoder must be the same, ' - f'got {len(decoder)} and {len(encoder)}' - ) + # if len(decoder) != len(encoder): + # raise ValueError( + # 'The number of resolutions in the encoder and decoder must be the same, ' + # f'got {len(decoder)} and {len(encoder)}' + # ) if skip_blocks is None: self.skip_blocks.extend(Identity() for _ in range(len(decoder))) @@ -244,9 +244,8 @@ class UNet(UNetBase): Inspired by the OpenAi DDPM UNet/Latent Diffusion UNet [LDM]_, significant differences to the vanilla UNet [UNET]_ include: - Spatial transformer blocks - - Multiple skip connections per resolution - Convolutional downsampling, nearest neighbor upsampling - - Residual convolution blocks with group normalization and SiLU activation + - Residual convolution blocks with pre-act group normalization and SiLU activation References ---------- @@ -300,54 +299,38 @@ def attention_block(channels: int) -> Module: dim_groups = (tuple(range(-dim, 0)),) return SpatialTransformerBlock(dim_groups, channels, n_heads, cond_dim=cond_dim) - def block(channels_in: int, channels_out: int, attention: bool) -> Module: - if not attention: - return ResBlock(dim, channels_in, channels_out, cond_dim) - return Sequential(ResBlock(dim, channels_in, channels_out, cond_dim), attention_block(channels_out)) - - first_block = ConvND(dim)(channels_in, n_features[0], 3, padding=1) - encoder_blocks: list[Module] = [] - down_blocks: list[Module] = [] - skip_features = [] - n_feat_old = n_features[0] - for i_level, n_feat in enumerate(n_features): - encoder_blocks.append(Identity()) - skip_features.append(n_feat_old) + def blocks(channels_in: int, channels_out: int, attention: bool) -> Module: + blocks = Sequential() for _ in range(encoder_blocks_per_scale): - encoder_blocks.append(block(n_feat_old, n_feat, attention=i_level in attention_depths)) - n_feat_old = n_feat - down_blocks.append(Identity()) - skip_features.append(n_feat_old) - down_blocks.append(ConvND(dim)(n_feat, n_feat, 3, stride=2, padding=1)) - down_blocks[-1] = Identity() # no downsampling after the last resolution level + blocks.append(ResBlock(dim, channels_in, channels_out, cond_dim)) + if attention: + blocks.append(attention_block(channels_out)) + channels_in = channels_out + return blocks + + encoder_blocks: list[Module] = [ConvND(dim)(channels_in, n_features[0], 3, padding=1)] + down_blocks: list[Module] = [Identity()] + decoder_blocks: list[Module] = [] + up_blocks: list[Module] = [Identity()] + + for i_level, (n_feat, n_feat_next) in enumerate(pairwise(n_features)): + encoder_blocks.append(blocks(n_feat, n_feat, i_level in attention_depths)) + down_blocks.append(ConvND(dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) + decoder_blocks.append(blocks(n_feat_next + n_feat, n_feat, i_level in attention_depths)) + up_blocks.append(Upsample(tuple(range(-dim, 0)), scale_factor=2)) + middle_block = Sequential( - ResBlock(dim, n_features[-1], n_features[-1], cond_dim), - ResBlock(dim, n_features[-1], n_features[-1], cond_dim), + ResBlock(dim, n_feat_next, n_feat_next, cond_dim), + ResBlock(dim, n_feat_next, n_feat_next, cond_dim), ) if i_level in attention_depths: - middle_block.insert(1, attention_block(n_features[-1])) - encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) + middle_block.insert(1, attention_block(n_feat)) + encoder = UNetEncoder(Identity(), encoder_blocks, down_blocks, middle_block) - decoder_blocks: list[Module] = [] - up_blocks: list[Module] = [Identity()] - for i_level, n_feat in reversed(list(enumerate(n_features))): - decoder_blocks.append( - block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) - ) - n_feat_old = n_feat - for _ in range(encoder_blocks_per_scale): - decoder_blocks.append( - block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) - ) - up_blocks.append(Identity()) - up_blocks.append(Upsample(tuple(range(-dim, 0)), scale_factor=2)) - up_blocks.pop() # no upsampling after the last resolution level + decoder_blocks, up_blocks = decoder_blocks[::-1], up_blocks[::-1] + decoder_blocks.append(ResBlock(dim, 2 * n_features[0], n_features[0], cond_dim)) + last_block = ConvND(dim)(n_features[0], channels_out, 1) concat_blocks = [Concat() for _ in range(len(decoder_blocks))] - last_block = Sequential( - GroupNorm(n_features[0]), - SiLU(), - ConvND(dim)(n_features[0], channels_out, 3, padding=1), - ) decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) super().__init__(encoder, decoder) From 1dc3d9b6d7a2ae055bd119a7a3b4cc4daf8bc609 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sat, 12 Jul 2025 14:23:45 +0200 Subject: [PATCH 098/205] Squashed commit of the following: commit 66d412a9f1991b0e275b52f7918842ec499d6113 Author: Felix Zimmermann Date: Mon Jun 23 22:37:50 2025 +0200 Revert "separable" This reverts commit aaa68e97317797944ba353b8db1a0bab6a46f649. --- src/mrpro/nn/nets/UNet.py | 171 +++++++++++++++++++++++++++++--------- 1 file changed, 133 insertions(+), 38 deletions(-) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index e7a8f07bb..b48ee8c67 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -15,7 +15,9 @@ from mrpro.nn.ndmodules import ConvND, MaxPoolND from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.ResBlock import ResBlock -from mrpro.nn.SeparableResBlock import SeparableResBlock # Assuming SeparableResBlock is here +from mrpro.nn.SeparableResBlock import ( + SeparableResBlock, +) # Assuming SeparableResBlock is here from mrpro.nn.Sequential import Sequential from mrpro.nn.SpatialTransformerBlock import SpatialTransformerBlock from mrpro.nn.Upsample import Upsample @@ -49,7 +51,9 @@ def __len__(self): """Get the number of resolutions levels.""" return len(self.down_blocks) + 1 - def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple[torch.Tensor, ...]: + def forward( + self, x: torch.Tensor, *, cond: torch.Tensor | None = None + ) -> tuple[torch.Tensor, ...]: """Apply to Network.""" call = partial(call_with_cond, cond=cond) @@ -65,7 +69,9 @@ def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple return (*xs, x) - def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple[torch.Tensor, ...]: + def __call__( + self, x: torch.Tensor, *, cond: torch.Tensor | None = None + ) -> tuple[torch.Tensor, ...]: """Apply to Network. Parameters @@ -110,19 +116,25 @@ def __len__(self): """Get the number of resolutions levels.""" return len(self.up_blocks) + 1 - def forward(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None) -> torch.Tensor: + def forward( + self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None + ) -> torch.Tensor: """Apply to Network.""" call = partial(call_with_cond, cond=cond) x = hs[-1] # lowest resolution, from middle block - for block, up, concat, h in zip(self.blocks, self.up_blocks, self.concat_blocks, hs[-2::-1], strict=True): + for block, up, concat, h in zip( + self.blocks, self.up_blocks, self.concat_blocks, hs[-2::-1], strict=True + ): x = call(up, x) x = concat(h, x) x = call(block, x) x = call(self.last_block, x) return x - def __call__(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__( + self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None + ) -> torch.Tensor: """Apply to Network. Parameters @@ -142,7 +154,12 @@ def __call__(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = class UNetBase(Module): """Base class for U-shaped networks.""" - def __init__(self, encoder: UNetEncoder, decoder: UNetDecoder, skip_blocks: Sequence[Module] | None = None) -> None: + def __init__( + self, + encoder: UNetEncoder, + decoder: UNetDecoder, + skip_blocks: Sequence[Module] | None = None, + ) -> None: """Initialize the UNetBase.""" super().__init__() self.encoder = encoder @@ -156,31 +173,37 @@ def __init__(self, encoder: UNetEncoder, decoder: UNetDecoder, skip_blocks: Sequ if len(decoder) != len(encoder): raise ValueError( - 'The number of resolutions in the encoder and decoder must be the same, ' - f'got {len(decoder)} and {len(encoder)}' + "The number of resolutions in the encoder and decoder must be the same, " + f"got {len(decoder)} and {len(encoder)}" ) if skip_blocks is None: self.skip_blocks.extend(Identity() for _ in range(len(decoder))) elif len(skip_blocks) != len(decoder): raise ValueError( - f'The number of skip blocks must be the same as the number of resolutions, ' - f'got {len(skip_blocks)} and {len(encoder)}' + f"The number of skip blocks must be the same as the number of resolutions, " + f"got {len(skip_blocks)} and {len(encoder)}" ) else: self.skip_blocks.extend(skip_blocks) - def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, cond: torch.Tensor | None = None + ) -> torch.Tensor: """Apply to Network.""" xs = self.encoder(x, cond=cond) xs = tuple( - call_with_cond(self.skip_blocks[i], x, cond=cond) if i < len(self.skip_blocks) else x + call_with_cond(self.skip_blocks[i], x, cond=cond) + if i < len(self.skip_blocks) + else x for i, x in enumerate(xs) ) x = self.decoder(xs, cond=cond) return x - def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__( + self, x: torch.Tensor, cond: torch.Tensor | None = None + ) -> torch.Tensor: """Apply to Network. Parameters @@ -208,7 +231,14 @@ class BasicUNet(UNetBase): segmentation MICCAI 2015. https://arxiv.org/abs/1505.04597 """ - def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Sequence[int], cond_dim: int): + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + n_features: Sequence[int], + cond_dim: int, + ): """Initialize the BasicUNet.""" encoder_blocks: list[Module] = [] decoder_blocks: list[Module] = [] @@ -221,7 +251,8 @@ def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Se down_blocks.append(ConvND(dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) up_blocks.append( Sequential( - Upsample(tuple(range(-dim, 0)), scale_factor=2), ConvND(dim)(n_feat_next, n_feat, 3, padding=1) + Upsample(tuple(range(-dim, 0)), scale_factor=2), + ConvND(dim)(n_feat_next, n_feat, 3, padding=1), ) ) concat_blocks.append(Concat()) @@ -229,7 +260,9 @@ def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Se decoder_blocks = decoder_blocks[::-1] first_block = ConvND(dim)(channels_in, n_features[0], 3, padding=1) last_block = Sequential( - GroupNorm(n_features[0]), SiLU(), ConvND(dim)(n_features[0], channels_out, 3, padding=1) + GroupNorm(n_features[0]), + SiLU(), + ConvND(dim)(n_features[0], channels_out, 3, padding=1), ) middle_block = ResBlock(dim, n_features[-1], n_features[-1], cond_dim) encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) @@ -290,20 +323,27 @@ def __init__( depth = len(n_features) if not all(-depth <= d < depth for d in attention_depths): raise ValueError( - f'attention_depths must be in the range [-depth, depth], got {attention_depths=} for {depth=}' + f"attention_depths must be in the range [-depth, depth], got {attention_depths=} for {depth=}" ) attention_depths = tuple(d % depth for d in attention_depths) if len(attention_depths) != len(set(attention_depths)): - raise ValueError(f'attention_depths must be unique, got {attention_depths=}') + raise ValueError( + f"attention_depths must be unique, got {attention_depths=}" + ) def attention_block(channels: int) -> Module: dim_groups = (tuple(range(-dim, 0)),) - return SpatialTransformerBlock(dim_groups, channels, n_heads, cond_dim=cond_dim) + return SpatialTransformerBlock( + dim_groups, channels, n_heads, cond_dim=cond_dim + ) def block(channels_in: int, channels_out: int, attention: bool) -> Module: if not attention: return ResBlock(dim, channels_in, channels_out, cond_dim) - return Sequential(ResBlock(dim, channels_in, channels_out, cond_dim), attention_block(channels_out)) + return Sequential( + ResBlock(dim, channels_in, channels_out, cond_dim), + attention_block(channels_out), + ) first_block = ConvND(dim)(channels_in, n_features[0], 3, padding=1) encoder_blocks: list[Module] = [] @@ -314,7 +354,9 @@ def block(channels_in: int, channels_out: int, attention: bool) -> Module: encoder_blocks.append(Identity()) skip_features.append(n_feat_old) for _ in range(encoder_blocks_per_scale): - encoder_blocks.append(block(n_feat_old, n_feat, attention=i_level in attention_depths)) + encoder_blocks.append( + block(n_feat_old, n_feat, attention=i_level in attention_depths) + ) n_feat_old = n_feat down_blocks.append(Identity()) skip_features.append(n_feat_old) @@ -332,12 +374,20 @@ def block(channels_in: int, channels_out: int, attention: bool) -> Module: up_blocks: list[Module] = [Identity()] for i_level, n_feat in reversed(list(enumerate(n_features))): decoder_blocks.append( - block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) + block( + n_feat_old + skip_features.pop(), + n_feat, + attention=i_level in attention_depths, + ) ) n_feat_old = n_feat for _ in range(encoder_blocks_per_scale): decoder_blocks.append( - block(n_feat_old + skip_features.pop(), n_feat, attention=i_level in attention_depths) + block( + n_feat_old + skip_features.pop(), + n_feat, + attention=i_level in attention_depths, + ) ) up_blocks.append(Identity()) up_blocks.append(Upsample(tuple(range(-dim, 0)), scale_factor=2)) @@ -364,7 +414,14 @@ class AttentionGatedUNet(UNetBase): https://arxiv.org/abs/1804.03999 """ - def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Sequence[int], cond_dim: int = 0): + def __init__( + self, + dim: int, + channels_in: int, + channels_out: int, + n_features: Sequence[int], + cond_dim: int = 0, + ): """Initialize the AttentionGatedUNet. Parameters @@ -406,7 +463,9 @@ def block(channels_in: int, channels_out: int) -> Module: decoder_blocks: list[Module] = [] up_blocks: list[Module] = [] for n_feat, n_feat_skip in pairwise(n_features[::-1]): - concat_blocks.append(AttentionGate(dim, n_feat, n_feat_skip, n_feat_skip, concatenate=True)) + concat_blocks.append( + AttentionGate(dim, n_feat, n_feat_skip, n_feat_skip, concatenate=True) + ) decoder_blocks.append(block(n_feat + n_feat_skip, n_feat_skip)) up_blocks.append(Upsample(dim, scale_factor=2)) last_block = ConvND(dim)(n_features[0], channels_out, 1) @@ -470,13 +529,25 @@ def __init__( depth = len(n_features) for group in dim_groups: if len(group) > 3: - raise ValueError(f'dim_group {group} can at most contain 3 dimensions. Split it into multiple groups.') + raise ValueError( + f"dim_group {group} can at most contain 3 dimensions. Split it into multiple groups." + ) if any(d > dim + 2 or d < -dim for d in group): - raise ValueError(f'dim_group {group} contains dimensions that are out of range for dim={dim}') + raise ValueError( + f"dim_group {group} contains dimensions that are out of range for dim={dim}" + ) attention_depths = tuple(d % depth for d in attention_depths) if downsample_dims is None: - all_spatial_dims = tuple(sorted(set(d if d < 0 else d - dim - 2 for group in dim_groups for d in group))) + all_spatial_dims = tuple( + sorted( + set( + d if d < 0 else d - dim - 2 + for group in dim_groups + for d in group + ) + ) + ) downsample_dims = (all_spatial_dims,) * (depth - 1) def downsampler(level_dims, c_in, c_out) -> Module: @@ -485,7 +556,9 @@ def downsampler(level_dims, c_in, c_out) -> Module: for d in level_dims[1:]: sequence.append(downsampler(d, c_out, c_out)) return sequence - return PermutedBlock(level_dims, ConvND(len(level_dims))(c_in, c_out, 3, stride=2, padding=1)) + return PermutedBlock( + level_dims, ConvND(len(level_dims))(c_in, c_out, 3, stride=2, padding=1) + ) def upsampler(level_dims, c_in, c_out) -> Module: if len(level_dims) > 3: @@ -493,18 +566,23 @@ def upsampler(level_dims, c_in, c_out) -> Module: for d in level_dims[1:]: sequence.append(upsampler(d, c_out, c_out)) return sequence - return PermutedBlock(level_dims, Upsample(len(level_dims), scale_factor=2, mode='nearest')) + return PermutedBlock( + level_dims, Upsample(len(level_dims), scale_factor=2, mode="nearest") + ) def block(c_in: int, c_out: int, apply_attention: bool) -> Module: res_block = SeparableResBlock(dim_groups, c_in, c_out, cond_dim) if not apply_attention: return res_block - attn_block = SpatialTransformerBlock(dim_groups, c_out, n_heads, cond_dim=cond_dim) + attn_block = SpatialTransformerBlock( + dim_groups, c_out, n_heads, cond_dim=cond_dim + ) return Sequential(res_block, attn_block) # --- Module Construction --- first_block = PermutedBlock( - all_spatial_dims, ConvND(len(all_spatial_dims))(channels_in, n_features[0], 3, padding=1) + all_spatial_dims, + ConvND(len(all_spatial_dims))(channels_in, n_features[0], 3, padding=1), ) # -- Encoder -- @@ -512,12 +590,18 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: c_feat = n_features[0] for i_level, n_feat_level in enumerate(n_features): for _ in range(encoder_blocks_per_scale): - encoder_blocks.append(block(c_feat, n_feat_level, i_level in attention_depths)) + encoder_blocks.append( + block(c_feat, n_feat_level, i_level in attention_depths) + ) c_feat = n_feat_level skip_features.append(c_feat) if i_level < depth - 1: down_blocks.append( - _create_downsampler(downsample_dims_per_level[i_level], c_feat, n_features[i_level + 1]) + _create_downsampler( + downsample_dims_per_level[i_level], + c_feat, + n_features[i_level + 1], + ) ) c_feat = n_features[i_level + 1] @@ -533,10 +617,16 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: for i_level in reversed(range(depth)): n_feat_level = n_features[i_level] if i_level > 0: - up_blocks.append(_create_upsampler(downsample_dims_per_level[i_level - 1], c_feat, n_feat_level)) + up_blocks.append( + _create_upsampler( + downsample_dims_per_level[i_level - 1], c_feat, n_feat_level + ) + ) for _ in range(encoder_blocks_per_scale + 1): skip_c = skip_features.pop() - decoder_blocks.append(block(c_feat + skip_c, n_feat_level, i_level in attention_depths)) + decoder_blocks.append( + block(c_feat + skip_c, n_feat_level, i_level in attention_depths) + ) c_feat = n_feat_level decoder_blocks.reverse() @@ -547,7 +637,12 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: last_block = Sequential( GroupNorm(n_features[0]), SiLU(), - PermutedBlock(all_spatial_dims, ConvND(len(all_spatial_dims))(n_features[0], channels_out, 3, padding=1)), + PermutedBlock( + all_spatial_dims, + ConvND(len(all_spatial_dims))( + n_features[0], channels_out, 3, padding=1 + ), + ), ) decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) From 12f41b8985bdeb43e9c3378b1b149cc4a0bd3e4d Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 13 Jul 2025 01:01:21 +0200 Subject: [PATCH 099/205] simplidy unet --- src/mrpro/nn/nets/UNet.py | 127 +++++++++++--------------------------- 1 file changed, 37 insertions(+), 90 deletions(-) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 85da64b03..c149d2e62 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -15,9 +15,7 @@ from mrpro.nn.ndmodules import ConvND, MaxPoolND from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.ResBlock import ResBlock -from mrpro.nn.SeparableResBlock import ( - SeparableResBlock, -) # Assuming SeparableResBlock is here +from mrpro.nn.SeparableResBlock import SeparableResBlock # Assuming SeparableResBlock is here from mrpro.nn.Sequential import Sequential from mrpro.nn.SpatialTransformerBlock import SpatialTransformerBlock from mrpro.nn.Upsample import Upsample @@ -51,9 +49,7 @@ def __len__(self): """Get the number of resolutions levels.""" return len(self.down_blocks) + 1 - def forward( - self, x: torch.Tensor, *, cond: torch.Tensor | None = None - ) -> tuple[torch.Tensor, ...]: + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple[torch.Tensor, ...]: """Apply to Network.""" call = partial(call_with_cond, cond=cond) @@ -69,9 +65,7 @@ def forward( return (*xs, x) - def __call__( - self, x: torch.Tensor, *, cond: torch.Tensor | None = None - ) -> tuple[torch.Tensor, ...]: + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple[torch.Tensor, ...]: """Apply to Network. Parameters @@ -116,25 +110,19 @@ def __len__(self): """Get the number of resolutions levels.""" return len(self.up_blocks) + 1 - def forward( - self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None - ) -> torch.Tensor: + def forward(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply to Network.""" call = partial(call_with_cond, cond=cond) x = hs[-1] # lowest resolution, from middle block - for block, up, concat, h in zip( - self.blocks, self.up_blocks, self.concat_blocks, hs[-2::-1], strict=True - ): + for block, up, concat, h in zip(self.blocks, self.up_blocks, self.concat_blocks, hs[-2::-1], strict=True): x = call(up, x) x = concat(h, x) x = call(block, x) x = call(self.last_block, x) return x - def __call__( - self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None - ) -> torch.Tensor: + def __call__(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply to Network. Parameters @@ -181,29 +169,23 @@ def __init__( self.skip_blocks.extend(Identity() for _ in range(len(decoder))) elif len(skip_blocks) != len(decoder): raise ValueError( - f"The number of skip blocks must be the same as the number of resolutions, " - f"got {len(skip_blocks)} and {len(encoder)}" + f'The number of skip blocks must be the same as the number of resolutions, ' + f'got {len(skip_blocks)} and {len(encoder)}' ) else: self.skip_blocks.extend(skip_blocks) - def forward( - self, x: torch.Tensor, cond: torch.Tensor | None = None - ) -> torch.Tensor: + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply to Network.""" xs = self.encoder(x, cond=cond) xs = tuple( - call_with_cond(self.skip_blocks[i], x, cond=cond) - if i < len(self.skip_blocks) - else x + call_with_cond(self.skip_blocks[i], x, cond=cond) if i < len(self.skip_blocks) else x for i, x in enumerate(xs) ) x = self.decoder(xs, cond=cond) return x - def __call__( - self, x: torch.Tensor, cond: torch.Tensor | None = None - ) -> torch.Tensor: + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply to Network. Parameters @@ -322,19 +304,15 @@ def __init__( depth = len(n_features) if not all(-depth <= d < depth for d in attention_depths): raise ValueError( - f"attention_depths must be in the range [-depth, depth], got {attention_depths=} for {depth=}" + f'attention_depths must be in the range [-depth, depth], got {attention_depths=} for {depth=}' ) attention_depths = tuple(d % depth for d in attention_depths) if len(attention_depths) != len(set(attention_depths)): - raise ValueError( - f"attention_depths must be unique, got {attention_depths=}" - ) + raise ValueError(f'attention_depths must be unique, got {attention_depths=}') def attention_block(channels: int) -> Module: dim_groups = (tuple(range(-dim, 0)),) - return SpatialTransformerBlock( - dim_groups, channels, n_heads, cond_dim=cond_dim - ) + return SpatialTransformerBlock(dim_groups, channels, n_heads, cond_dim=cond_dim) def blocks(channels_in: int, channels_out: int, attention: bool) -> Module: blocks = Sequential() @@ -345,32 +323,31 @@ def blocks(channels_in: int, channels_out: int, attention: bool) -> Module: channels_in = channels_out return blocks - encoder_blocks: list[Module] = [ - ConvND(dim)(channels_in, n_features[0], 3, padding=1) - ] - down_blocks: list[Module] = [Identity()] + encoder_blocks: list[Module] = [] + down_blocks: list[Module] = [] decoder_blocks: list[Module] = [] - up_blocks: list[Module] = [Identity()] + up_blocks: list[Module] = [] for i_level, (n_feat, n_feat_next) in enumerate(pairwise(n_features)): encoder_blocks.append(blocks(n_feat, n_feat, i_level in attention_depths)) down_blocks.append(ConvND(dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) - decoder_blocks.append( - blocks(n_feat_next + n_feat, n_feat, i_level in attention_depths) - ) + decoder_blocks.append(blocks(n_feat_next + n_feat, n_feat, i_level in attention_depths)) up_blocks.append(Upsample(tuple(range(-dim, 0)), scale_factor=2)) middle_block = Sequential( ResBlock(dim, n_feat_next, n_feat_next, cond_dim), ResBlock(dim, n_feat_next, n_feat_next, cond_dim), ) - if i_level in attention_depths: - middle_block.insert(1, attention_block(n_feat)) - encoder = UNetEncoder(Identity(), encoder_blocks, down_blocks, middle_block) + if depth - 1 in attention_depths: + middle_block.insert(1, attention_block(n_feat_next)) + first_block = ConvND(dim)(channels_in, n_features[0], 3, padding=1) + encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) decoder_blocks, up_blocks = decoder_blocks[::-1], up_blocks[::-1] - decoder_blocks.append(ResBlock(dim, 2 * n_features[0], n_features[0], cond_dim)) - last_block = ConvND(dim)(n_features[0], channels_out, 1) + last_block = Sequential( + SiLU(), + ConvND(dim)(n_features[0], channels_out, 3, padding=1), + ) concat_blocks = [Concat() for _ in range(len(decoder_blocks))] decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) @@ -437,9 +414,7 @@ def block(channels_in: int, channels_out: int) -> Module: decoder_blocks: list[Module] = [] up_blocks: list[Module] = [] for n_feat, n_feat_skip in pairwise(n_features[::-1]): - concat_blocks.append( - AttentionGate(dim, n_feat, n_feat_skip, n_feat_skip, concatenate=True) - ) + concat_blocks.append(AttentionGate(dim, n_feat, n_feat_skip, n_feat_skip, concatenate=True)) decoder_blocks.append(block(n_feat + n_feat_skip, n_feat_skip)) up_blocks.append(Upsample(dim, scale_factor=2)) last_block = ConvND(dim)(n_features[0], channels_out, 1) @@ -503,25 +478,13 @@ def __init__( depth = len(n_features) for group in dim_groups: if len(group) > 3: - raise ValueError( - f"dim_group {group} can at most contain 3 dimensions. Split it into multiple groups." - ) + raise ValueError(f'dim_group {group} can at most contain 3 dimensions. Split it into multiple groups.') if any(d > dim + 2 or d < -dim for d in group): - raise ValueError( - f"dim_group {group} contains dimensions that are out of range for dim={dim}" - ) + raise ValueError(f'dim_group {group} contains dimensions that are out of range for dim={dim}') attention_depths = tuple(d % depth for d in attention_depths) if downsample_dims is None: - all_spatial_dims = tuple( - sorted( - set( - d if d < 0 else d - dim - 2 - for group in dim_groups - for d in group - ) - ) - ) + all_spatial_dims = tuple(sorted(set(d if d < 0 else d - dim - 2 for group in dim_groups for d in group))) downsample_dims = (all_spatial_dims,) * (depth - 1) def downsampler(level_dims, c_in, c_out) -> Module: @@ -530,9 +493,7 @@ def downsampler(level_dims, c_in, c_out) -> Module: for d in level_dims[1:]: sequence.append(downsampler(d, c_out, c_out)) return sequence - return PermutedBlock( - level_dims, ConvND(len(level_dims))(c_in, c_out, 3, stride=2, padding=1) - ) + return PermutedBlock(level_dims, ConvND(len(level_dims))(c_in, c_out, 3, stride=2, padding=1)) def upsampler(level_dims, c_in, c_out) -> Module: if len(level_dims) > 3: @@ -540,17 +501,13 @@ def upsampler(level_dims, c_in, c_out) -> Module: for d in level_dims[1:]: sequence.append(upsampler(d, c_out, c_out)) return sequence - return PermutedBlock( - level_dims, Upsample(len(level_dims), scale_factor=2, mode="nearest") - ) + return PermutedBlock(level_dims, Upsample(len(level_dims), scale_factor=2, mode='nearest')) def block(c_in: int, c_out: int, apply_attention: bool) -> Module: res_block = SeparableResBlock(dim_groups, c_in, c_out, cond_dim) if not apply_attention: return res_block - attn_block = SpatialTransformerBlock( - dim_groups, c_out, n_heads, cond_dim=cond_dim - ) + attn_block = SpatialTransformerBlock(dim_groups, c_out, n_heads, cond_dim=cond_dim) return Sequential(res_block, attn_block) # --- Module Construction --- @@ -564,9 +521,7 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: c_feat = n_features[0] for i_level, n_feat_level in enumerate(n_features): for _ in range(encoder_blocks_per_scale): - encoder_blocks.append( - block(c_feat, n_feat_level, i_level in attention_depths) - ) + encoder_blocks.append(block(c_feat, n_feat_level, i_level in attention_depths)) c_feat = n_feat_level skip_features.append(c_feat) if i_level < depth - 1: @@ -591,16 +546,10 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: for i_level in reversed(range(depth)): n_feat_level = n_features[i_level] if i_level > 0: - up_blocks.append( - _create_upsampler( - downsample_dims_per_level[i_level - 1], c_feat, n_feat_level - ) - ) + up_blocks.append(_create_upsampler(downsample_dims_per_level[i_level - 1], c_feat, n_feat_level)) for _ in range(encoder_blocks_per_scale + 1): skip_c = skip_features.pop() - decoder_blocks.append( - block(c_feat + skip_c, n_feat_level, i_level in attention_depths) - ) + decoder_blocks.append(block(c_feat + skip_c, n_feat_level, i_level in attention_depths)) c_feat = n_feat_level decoder_blocks.reverse() @@ -613,9 +562,7 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: SiLU(), PermutedBlock( all_spatial_dims, - ConvND(len(all_spatial_dims))( - n_features[0], channels_out, 3, padding=1 - ), + ConvND(len(all_spatial_dims))(n_features[0], channels_out, 3, padding=1), ), ) decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) From a17e5079bc7077d2e8fdf33c96510141e8ed2ef2 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 13 Jul 2025 01:01:33 +0200 Subject: [PATCH 100/205] fewer param in pinqi --- examples/scripts/train_pinqi.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/scripts/train_pinqi.py b/examples/scripts/train_pinqi.py index bbb88cdfe..a9956272c 100644 --- a/examples/scripts/train_pinqi.py +++ b/examples/scripts/train_pinqi.py @@ -344,9 +344,9 @@ def __init__( parameter_is_complex: Sequence[bool], n_images: int, n_iterations: int = 4, - n_features_parameter_net: Sequence[int] = (64, 128, 192, 224, 256), - n_features_image_net: Sequence[int] = (16, 32, 48, 64), - lr: float = 4e-4, # noqa: ARG002 + n_features_parameter_net: Sequence[int] = (64, 128, 192, 256), + n_features_image_net: Sequence[int] = (32, 48, 64, 96), + lr: float = 3e-4, # noqa: ARG002 weight_decay: float = 1e-3, # noqa: ARG002 loss_weights: Sequence[float] = (0.2, 0.1, 0.1, 0.1, 0.8), ): @@ -542,7 +542,7 @@ def configure_optimizers( max_lr=[self.hparams.lr, 10 * self.hparams.lr], total_steps=self.trainer.estimated_stepping_batches, pct_start=0.1, - div_factor=30, + div_factor=20, final_div_factor=300, ) return { @@ -606,6 +606,7 @@ def on_train_batch_end( if __name__ == '__main__': + torch.multiprocessing.set_sharing_strategy('file_system') torch.set_float32_matmul_precision('high') torch._inductor.config.compile_threads = 4 torch._inductor.config.worker_start_method = 'fork' From 132491ef6d7482bc2d31fd1d1bf3ccb0ab51bad1 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 13 Jul 2025 01:01:21 +0200 Subject: [PATCH 101/205] simplidy unet --- src/mrpro/nn/nets/UNet.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index c46c405ec..3c74292c2 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -311,7 +311,7 @@ def blocks(channels_in: int, channels_out: int, attention: bool) -> Module: encoder_blocks: list[Module] = [ConvND(dim)(channels_in, n_features[0], 3, padding=1)] down_blocks: list[Module] = [Identity()] decoder_blocks: list[Module] = [] - up_blocks: list[Module] = [Identity()] + up_blocks: list[Module] = [] for i_level, (n_feat, n_feat_next) in enumerate(pairwise(n_features)): encoder_blocks.append(blocks(n_feat, n_feat, i_level in attention_depths)) @@ -323,13 +323,16 @@ def blocks(channels_in: int, channels_out: int, attention: bool) -> Module: ResBlock(dim, n_feat_next, n_feat_next, cond_dim), ResBlock(dim, n_feat_next, n_feat_next, cond_dim), ) - if i_level in attention_depths: - middle_block.insert(1, attention_block(n_feat)) - encoder = UNetEncoder(Identity(), encoder_blocks, down_blocks, middle_block) + if depth - 1 in attention_depths: + middle_block.insert(1, attention_block(n_feat_next)) + first_block = ConvND(dim)(channels_in, n_features[0], 3, padding=1) + encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) decoder_blocks, up_blocks = decoder_blocks[::-1], up_blocks[::-1] - decoder_blocks.append(ResBlock(dim, 2 * n_features[0], n_features[0], cond_dim)) - last_block = ConvND(dim)(n_features[0], channels_out, 1) + last_block = Sequential( + SiLU(), + ConvND(dim)(n_features[0], channels_out, 3, padding=1), + ) concat_blocks = [Concat() for _ in range(len(decoder_blocks))] decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) @@ -530,7 +533,10 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: last_block = Sequential( GroupNorm(n_features[0]), SiLU(), - PermutedBlock(all_spatial_dims, ConvND(len(all_spatial_dims))(n_features[0], channels_out, 3, padding=1)), + PermutedBlock( + all_spatial_dims, + ConvND(len(all_spatial_dims))(n_features[0], channels_out, 3, padding=1), + ), ) decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) From a4e46f6c65ea9cc82a3a6548313ebbea1e47cec1 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 13 Jul 2025 13:25:55 +0200 Subject: [PATCH 102/205] update --- src/mrpro/nn/GroupNorm.py | 6 +- src/mrpro/nn/RoPE.py | 44 ++++++++----- src/mrpro/nn/SeparableResBlock.py | 87 +------------------------ src/mrpro/nn/SpatialTransformerBlock.py | 3 +- src/mrpro/nn/nets/BasicCNN.py | 58 ++++++++++++++--- src/mrpro/nn/nets/UNet.py | 17 ++--- 6 files changed, 90 insertions(+), 125 deletions(-) diff --git a/src/mrpro/nn/GroupNorm.py b/src/mrpro/nn/GroupNorm.py index 09e91cf11..78805c3ee 100644 --- a/src/mrpro/nn/GroupNorm.py +++ b/src/mrpro/nn/GroupNorm.py @@ -9,7 +9,7 @@ class GroupNorm(torch.nn.GroupNorm): Casts to float32 before calling the parent class to avoid instabilities in mixed precision training. """ - def __init__(self, channels: int, groups: int | None = None): + def __init__(self, channels: int, groups: int | None = None, affine: bool = False): """Initialize GroupNorm32. Parameters @@ -19,6 +19,8 @@ def __init__(self, channels: int, groups: int | None = None): groups The number of groups to use. If None, the number of groups is determined automatically as a power of 2 that is less than or equal to 32 and leaves at least 4 channels per group. + affine + Whether to use learnable affine parameters. """ if groups is None: groups_, candidate = 1, 2 @@ -26,7 +28,7 @@ def __init__(self, channels: int, groups: int | None = None): groups_, candidate = candidate, groups_ * 2 else: groups_ = groups - super().__init__(groups_, channels) + super().__init__(groups_, channels, affine=affine) def __call__(self, x: torch.Tensor) -> torch.Tensor: """Apply GroupNorm32. diff --git a/src/mrpro/nn/RoPE.py b/src/mrpro/nn/RoPE.py index 90ecb8739..7abfc9426 100644 --- a/src/mrpro/nn/RoPE.py +++ b/src/mrpro/nn/RoPE.py @@ -68,55 +68,67 @@ class AxialRoPE(Module): freqs: torch.Tensor - def __init__(self, dim: int, d_head: int, n_heads: int, headpos: int = -2, non_embed_fraction: float = 0.5): + def __init__( + self, + n_dim: int, + n_channels: int, + n_heads: int, + channels_last: bool = True, + non_embed_fraction: float = 0.5, + ): """Initialize AxialRoPE. Parameters ---------- - dim : int - Dimension of the input space - d_head : int - Dimension of each attention head - n_heads : int + n_dim + Number of (spatial-like) dimensions of the input + n_channels + Number of channels + n_heads Number of attention heads - headpos : int, optional - Position of the head dimension - non_embed_fraction : float, optional - Fraction of dimensions to not embed + channels_last + Whether the channels are the last dimension or dimension 1. + non_embed_fraction + Fraction of channels not used for embedding """ super().__init__() log_min = torch.log(torch.tensor(torch.pi)) log_max = torch.log(torch.tensor(10000.0)) - freqs = torch.exp(torch.linspace(log_min, log_max, d_head // 2)) + if n_channels % n_heads: + raise ValueError(f'Number of channels {n_channels} must be divisible by number of heads {n_heads}') + channels_per_head = n_channels // n_heads + freqs = torch.exp(torch.linspace(log_min, log_max, channels_per_head // 2)) self.register_buffer('freqs', freqs) - self.headpos = headpos + self.channels_last = channels_last + self.n_heads = n_heads def get_theta(self, pos: torch.Tensor) -> torch.Tensor: """Get rotation angles for given positions. Parameters ---------- - pos : torch.Tensor + pos Position tensor Returns ------- - torch.Tensor Rotation angles """ - return (self.freqs * pos[..., None, :, None]).flatten(start_dim=-2).movedim(-2, self.headpos) + return (self.freqs * pos[..., None, :, None]).flatten(start_dim=-2) def forward(self, pos: torch.Tensor, *tensors: torch.Tensor) -> None: """Apply rotary embeddings to input tensors. Parameters ---------- - pos : torch.Tensor + pos Position tensor *tensors : torch.Tensor Tensors to apply rotary embeddings to """ theta = self.get_theta(pos) + if not self.channels_last: + tensors = tuple(t.movedim(-1, 1) for t in tensors) tuple(RotaryEmbedding.apply(x, theta, False) for x in tensors) @staticmethod diff --git a/src/mrpro/nn/SeparableResBlock.py b/src/mrpro/nn/SeparableResBlock.py index c26a012cf..770293884 100644 --- a/src/mrpro/nn/SeparableResBlock.py +++ b/src/mrpro/nn/SeparableResBlock.py @@ -1,3 +1,5 @@ +"""Residual block with separable convolutions.""" + from collections.abc import Sequence import torch @@ -11,90 +13,7 @@ class SeparableResBlock(Module): - """Residual block with separable convolutions and ReZero.""" - - def __init__( - self, - dim_groups: Sequence[Sequence[int]], - channels_in: int, - channels_out: int, - cond_dim: int, - ) -> None: - """Initialize the SeparableResBlock. - - Applies convolutions as separable convolutions with SilU activation and group normalization. - For example, if ``dim_groups = ((-1,-2), (-3))`` then one 2D convolution is applied to the last two dimensions, - and one 1D convolution is applied to the last dimension. - The order within the block is Norm->Activation->Conv. - The whole sequence for all dimension groups is performed twice, with optional FiLM conditioning in between. - So for two `dim_groups`, a total of 4 convolutions are applied. - - Parameters - ---------- - dim_groups - Sequence of dimension groups to use in the convolutions. - channels_in - Number of input channels. - channels_out - Number of output channels. - cond_dim - Number of channels in the conditioning tensor. If 0, no conditioning is applied. - """ - super().__init__() - self.rezero = torch.nn.Parameter(torch.tensor(0.1)) - - def block(dims: Sequence[int], channels_in: int) -> Module: - return Sequential( - GroupNorm(channels_in), - SiLU(), - PermutedBlock(dims, ConvND(len(dims))(channels_in, channels_out, 3, padding=1)), - ) - - blocks = Sequential(*(block(d, channels_in if i == 0 else channels_out) for i, d in enumerate(dim_groups))) - if cond_dim > 0: - blocks.append(FiLM(channels_out, cond_dim)) - blocks.extend(block(d, channels_out) for d in dim_groups) - self.block = blocks - self.skip_connection = None - if channels_in != channels_out: - self.skip_connection = torch.nn.Linear(channels_in, channels_out) - - def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: - """Apply the SeparableResBlock. - - Parameters - ---------- - x - Input tensor. - cond - Conditioning tensor. - - Returns - ------- - Output tensor with the same number and order of dimensions as the input. - """ - return super().__call__(x, cond=cond) - - def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: - """Apply the SeparableResBlock.""" - h = self.block(x, cond=cond) - if self.skip_connection is None: - skip = x - else: - skip = torch.moveaxis(x, 1, -1) - skip = self.skip_connection(skip) - skip = torch.moveaxis(skip, -1, 1) - return skip + self.rezero * h - - -from collections.abc import Sequence - -import torch -from torch.nn import Module - - -class SeparableResBlock(Module): - """Residual block with separable convolutions and ReZero.""" + """Residual block with separable convolutions.""" def __init__( self, diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py index 906560c24..cb431b52d 100644 --- a/src/mrpro/nn/SpatialTransformerBlock.py +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -107,7 +107,8 @@ def __init__( dropout: float = 0.0, cond_dim: int = 0, ): - """ + """Initialize the spatial transformer block. + Parameters ---------- dim_groups diff --git a/src/mrpro/nn/nets/BasicCNN.py b/src/mrpro/nn/nets/BasicCNN.py index b2671c121..620e60038 100644 --- a/src/mrpro/nn/nets/BasicCNN.py +++ b/src/mrpro/nn/nets/BasicCNN.py @@ -1,25 +1,39 @@ +"""Basic CNN.""" + from collections.abc import Sequence from itertools import pairwise +from typing import Literal import torch -from torch.nn import ReLU +from torch.nn import LeakyReLU, ReLU, SiLU from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm from mrpro.nn.ndmodules import BatchNormND, ConvND from mrpro.nn.Sequential import Sequential class BasicCNN(Sequential): + """Basic CNN. + + A series of convolutions (window 3, stride 1, padding 1), normalization and activation. + Allows to use FiLM conditioning. + Order is Conv -> Norm (optional) -> FiLM (optional) -> Activation. + + If you need more flexibility, use `~mrpro.nn.Sequential` directly. + """ + def __init__( self, dim: int, channels_in: int, channels_out: int, - batch_norm: bool = True, + norm: Literal['batch', 'group', 'instance', 'none', 'layer'] = 'none', + activation: Literal['relu', 'silu', 'leaky_relu'] = 'relu', n_features: Sequence[int] = (64, 64, 64), cond_dim: int = 0, ): - """Initialize a basic CNN. + """Initialize a basic CNN. Parameters ---------- @@ -29,25 +43,49 @@ def __init__( The number of input channels. channels_out The number of output channels. - batch_norm - Whether to use batch normalization. + norm + The type of normalization to use. If 'batch', use batch normalization. If 'group', use group normalization, + if 'instance', use instance normalization, and if `layer`, use layer normalization. + If 'none', use no normalization. n_features - The number of features in the hidden layers. The length of this sequence determines the number of hidden layers. + The number of features in the hidden layers. The length of this sequence determines the number of hidden + layers. The total number of convolutions is `len(n_features) + 1`. cond_dim The dimension of the condition tensor. If 0, no FiLM conditioning is applied. + Otherwise, between convolutions, after normalization, FiLM conditioning is applied. """ super().__init__() use_film = cond_dim > 0 + self.append(ConvND(dim)(channels_in, n_features[0], kernel_size=3, padding='same')) + for c_in, c_out in pairwise((*n_features, channels_out)): - if batch_norm: + if norm.lower() == 'batch': self.append(BatchNormND(dim)(c_in, affine=not use_film)) + elif norm.lower() == 'group': + self.append(GroupNorm(c_in, affine=not use_film)) + elif norm.lower() == 'instance': + self.append(GroupNorm(c_in, groups=c_in, affine=not use_film)) # is instance norm + elif norm.lower() == 'layer': + self.append(GroupNorm(c_in, groups=1, affine=not use_film)) # is layer norm + elif norm.lower() != 'none': + raise ValueError(f'Invalid normalization type: {norm}') + if use_film: self.append(FiLM(c_in, cond_dim)) - self.append(ReLU(True)) + + if activation.lower() == 'relu': + self.append(ReLU(True)) + elif activation.lower() == 'silu': + self.append(SiLU(inplace=True)) + elif activation.lower() == 'leaky_relu': + self.append(LeakyReLU(inplace=True)) + else: + raise ValueError(f'Invalid activation type: {activation}') + self.append(ConvND(dim)(c_in, c_out, kernel_size=3, padding='same')) - def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: # type: ignore[override] """Apply the basic CNN to the input tensor. Parameters @@ -62,4 +100,4 @@ def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None) -> torch.Tenso ------- The output tensor. """ - return super().__call__(x, cond=cond) + return super().__call__(*x, cond=cond) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 3c74292c2..f0b8ffa41 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -394,7 +394,7 @@ def block(channels_in: int, channels_out: int) -> Module: for n_feat, n_feat_skip in pairwise(n_features[::-1]): concat_blocks.append(AttentionGate(dim, n_feat, n_feat_skip, n_feat_skip, concatenate=True)) decoder_blocks.append(block(n_feat + n_feat_skip, n_feat_skip)) - up_blocks.append(Upsample(dim, scale_factor=2)) + up_blocks.append(Upsample(range(-dim, 0), scale_factor=2)) last_block = ConvND(dim)(n_features[0], channels_out, 1) decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) @@ -467,19 +467,14 @@ def __init__( def downsampler(level_dims, c_in, c_out) -> Module: if len(level_dims) > 3: - sequence = Sequence(downsampler(d[0], c_in, c_out) for d in level_dims) + sequence = Sequential(*(downsampler(d[0], c_in, c_out) for d in level_dims)) for d in level_dims[1:]: sequence.append(downsampler(d, c_out, c_out)) return sequence return PermutedBlock(level_dims, ConvND(len(level_dims))(c_in, c_out, 3, stride=2, padding=1)) def upsampler(level_dims, c_in, c_out) -> Module: - if len(level_dims) > 3: - sequence = Sequence(upsampler(d[0], c_in, c_out) for d in level_dims) - for d in level_dims[1:]: - sequence.append(upsampler(d, c_out, c_out)) - return sequence - return PermutedBlock(level_dims, Upsample(len(level_dims), scale_factor=2, mode='nearest')) + return Upsample(level_dims, scale_factor=2) def block(c_in: int, c_out: int, apply_attention: bool) -> Module: res_block = SeparableResBlock(dim_groups, c_in, c_out, cond_dim) @@ -502,9 +497,7 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: c_feat = n_feat_level skip_features.append(c_feat) if i_level < depth - 1: - down_blocks.append( - _create_downsampler(downsample_dims_per_level[i_level], c_feat, n_features[i_level + 1]) - ) + down_blocks.append(downsampler(downsample_dims_per_level[i_level], c_feat, n_features[i_level + 1])) c_feat = n_features[i_level + 1] # -- Middle & Encoder Finalization -- @@ -519,7 +512,7 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: for i_level in reversed(range(depth)): n_feat_level = n_features[i_level] if i_level > 0: - up_blocks.append(_create_upsampler(downsample_dims_per_level[i_level - 1], c_feat, n_feat_level)) + up_blocks.append(upsampler(downsample_dims_per_level[i_level - 1], c_feat, n_feat_level)) for _ in range(encoder_blocks_per_scale + 1): skip_c = skip_features.pop() decoder_blocks.append(block(c_feat + skip_c, n_feat_level, i_level in attention_depths)) From c31a2eee828e6b9072b0769d5676d30b008e1f39 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 14 Jul 2025 02:14:55 +0200 Subject: [PATCH 103/205] update --- src/mrpro/nn/AttentionGate.py | 12 +-- src/mrpro/nn/ComplexAsChannel.py | 8 +- src/mrpro/nn/DropPath.py | 10 +-- src/mrpro/nn/GEGLU.py | 14 ++-- src/mrpro/nn/GluMBConvResBlock.py | 30 +++---- src/mrpro/nn/GroupNorm.py | 20 ++--- src/mrpro/nn/LayerNorm.py | 18 ++--- src/mrpro/nn/LinearSelfAttention.py | 27 ++----- src/mrpro/nn/MultiHeadAttention.py | 26 +++--- src/mrpro/nn/NeighborhoodSelfAttention.py | 14 ++-- src/mrpro/nn/PermutedBlock.py | 19 ++++- src/mrpro/nn/PixelShuffle.py | 99 ++++++++++++++--------- src/mrpro/nn/ShiftedWindowAttention.py | 37 +++++---- src/mrpro/nn/SpatialTransformerBlock.py | 4 +- src/mrpro/nn/SqueezeExcitation.py | 14 ++-- src/mrpro/nn/TransposedAttention.py | 18 ++--- src/mrpro/nn/convert_linear_conv.py | 16 ++-- src/mrpro/nn/ndmodules.py | 72 +++++++++-------- src/mrpro/nn/nets/BasicCNN.py | 4 +- src/mrpro/nn/nets/DCAE.py | 6 +- src/mrpro/nn/nets/Uformer.py | 6 +- tests/nn/test_attentiongate.py | 2 +- tests/nn/test_groupnorm32.py | 2 +- tests/nn/test_shiftedwindowattention.py | 7 +- tests/nn/test_sqeezeexcitation.py | 2 +- tests/nn/test_transposedattention.py | 2 +- 26 files changed, 263 insertions(+), 226 deletions(-) diff --git a/src/mrpro/nn/AttentionGate.py b/src/mrpro/nn/AttentionGate.py index 1d57fe5ee..682100650 100644 --- a/src/mrpro/nn/AttentionGate.py +++ b/src/mrpro/nn/AttentionGate.py @@ -17,12 +17,14 @@ class AttentionGate(Module): https://arxiv.org/abs/1804.03999 """ - def __init__(self, dim: int, channels_gate: int, channels_in: int, channels_hidden: int, concatenate: bool = False): + def __init__( + self, n_dim: int, channels_gate: int, channels_in: int, channels_hidden: int, concatenate: bool = False + ): """Initialize the attention gate. Parameters ---------- - dim + n_dim The dimension, i.e. 1, 2 or 3. channels_gate The number of channels in the gate tensor. @@ -34,11 +36,11 @@ def __init__(self, dim: int, channels_gate: int, channels_in: int, channels_hidd Whether to concatenate the gated signal with the gate signal in the channel dimension (1) """ super().__init__() - self.project_gate = ConvND(dim)(channels_gate, channels_hidden, kernel_size=1) - self.project_x = ConvND(dim)(channels_in, channels_hidden, kernel_size=1) + self.project_gate = ConvND(n_dim)(channels_gate, channels_hidden, kernel_size=1) + self.project_x = ConvND(n_dim)(channels_in, channels_hidden, kernel_size=1) self.psi = Sequential( ReLU(), - ConvND(dim)(channels_hidden, 1, kernel_size=1), + ConvND(n_dim)(channels_hidden, 1, kernel_size=1), Sigmoid(), ) self.concatenate = concatenate diff --git a/src/mrpro/nn/ComplexAsChannel.py b/src/mrpro/nn/ComplexAsChannel.py index 7c1bec0fd..22e13458e 100644 --- a/src/mrpro/nn/ComplexAsChannel.py +++ b/src/mrpro/nn/ComplexAsChannel.py @@ -20,9 +20,9 @@ def __init__(self, module: Module, convert_back: bool = True): Parameters ---------- - module : Module + module The module to wrap. Should output a single real tensor. - convert_back : bool + convert_back If True, the output is converted back to a complex tensor. The output should have a number of channels that is a multiple of 2. """ @@ -35,9 +35,9 @@ def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch. Parameters ---------- - x : torch.Tensor + x The input tensor. - cond : torch.Tensor | None + cond The conditioning tensor (if used by the wrapped module) """ return super().__call__(*x, cond=cond) diff --git a/src/mrpro/nn/DropPath.py b/src/mrpro/nn/DropPath.py index b1314904e..7262fd86c 100644 --- a/src/mrpro/nn/DropPath.py +++ b/src/mrpro/nn/DropPath.py @@ -21,10 +21,10 @@ def __init__(self, droprate: float = 0.0, scale_by_keep: bool = False): Parameters ---------- - droprate : float, optional + droprate Drop probability - scale_by_keep : bool, optional - If True, the kept samples are scaled by `1/(1-droprate)` + scale_by_keep + If True, the kept samples are scaled by :math:`1/(1-droprate)` """ super().__init__() self.droprate = droprate @@ -35,12 +35,12 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: Parameters ---------- - x : torch.Tensor + x Input tensor Returns ------- - Tensor with + Tensor with batch samples randomly dropped """ return super().__call__(x) diff --git a/src/mrpro/nn/GEGLU.py b/src/mrpro/nn/GEGLU.py index 0310c6a76..6151503d2 100644 --- a/src/mrpro/nn/GEGLU.py +++ b/src/mrpro/nn/GEGLU.py @@ -12,22 +12,22 @@ class GEGLU(Module): ..[GLU] Shazeer, N. (2020). GLU variants improve transformer. https://arxiv.org/abs/2002.05202 """ - def __init__(self, in_features: int, out_features: int | None = None, features_last: bool = False): + def __init__(self, n_channels_in: int, n_channels_out: int | None = None, features_last: bool = False): """Initialize the GEGLU activation function. Parameters ---------- - in_features - The number of input features. - out_features - The number of output features. If None, the number of + n_channels_in + The number of input features/channels. + n_channels_out + The number of output features/channels. If None, the number of output features is the same as the number of input features. features_last If True, the channel dimension is the last dimension, else in the second dimension. """ super().__init__() - out_features_ = in_features if out_features is None else out_features - self.proj = Linear(in_features, out_features_ * 2) # gate and output stacked + out_channels_ = n_channels_in if n_channels_out is None else n_channels_out + self.proj = Linear(n_channels_in, out_channels_ * 2) # gate and output stacked self.features_last = features_last def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/src/mrpro/nn/GluMBConvResBlock.py b/src/mrpro/nn/GluMBConvResBlock.py index 3eaf3b9d4..de3cae041 100644 --- a/src/mrpro/nn/GluMBConvResBlock.py +++ b/src/mrpro/nn/GluMBConvResBlock.py @@ -12,7 +12,7 @@ class GluMBConvResBlock(CondMixin, Module): """Gated MBConv residual block. - Gated variant [DCAE]_ of the MBConv block [EffNet]_ with a residual connection. + Gated variant [DCAE]_ of the MBConv block [EffNet]_ with a residual connection and (optional) conditioning. References ---------- @@ -24,9 +24,9 @@ class GluMBConvResBlock(CondMixin, Module): def __init__( self, - dim: int, - channels_in: int, - channels_out: int, + n_dim: int, + n_channels_in: int, + n_channels_out: int, expand_ratio: int = 6, stride: int = 1, kernel_size: int = 3, @@ -36,7 +36,7 @@ def __init__( Parameters ---------- - dim + n_dim Number of spatial dimensions. channels_in Number of input channels. @@ -52,21 +52,21 @@ def __init__( Dimension of the conditioning tensor used in a FiLM. If 0, no FiLM is used. """ super().__init__() - channels_mid = channels_in * expand_ratio - if stride == 1 and channels_in == channels_out: + channels_mid = n_channels_in * expand_ratio + if stride == 1 and n_channels_in == n_channels_out: self.skip: Module = Identity() else: - self.skip = ConvND(dim)(channels_in, channels_out, kernel_size=1, stride=stride) + self.skip = ConvND(n_dim)(n_channels_in, n_channels_out, kernel_size=1, stride=stride) self.inverted_conv = Sequential( - ConvND(dim)( - channels_in, + ConvND(n_dim)( + n_channels_in, channels_mid * 2, kernel_size=1, ), SiLU(), ) self.depth_conv = Sequential( - ConvND(dim)( + ConvND(n_dim)( channels_mid * 2, channels_mid * 2, kernel_size=kernel_size, @@ -77,12 +77,12 @@ def __init__( SiLU(), ) self.point_conv = Sequential( - ConvND(dim)( + ConvND(n_dim)( channels_mid, - channels_out, + n_channels_out, kernel_size=1, ), - RMSNorm(channels_out), + RMSNorm(n_channels_out), SiLU(), ) if cond_dim > 0: @@ -98,7 +98,7 @@ def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torc x Input tensor. cond - Conditioning tensor. If None, no conditioning is applied. + Conditioning tensor. If `None`, no conditioning is applied. Returns ------- diff --git a/src/mrpro/nn/GroupNorm.py b/src/mrpro/nn/GroupNorm.py index 78805c3ee..e0090d018 100644 --- a/src/mrpro/nn/GroupNorm.py +++ b/src/mrpro/nn/GroupNorm.py @@ -4,31 +4,31 @@ class GroupNorm(torch.nn.GroupNorm): - """A 32-bit GroupNorm. + """A 32-bit GroupNorm with (optional) automatic group size selection. Casts to float32 before calling the parent class to avoid instabilities in mixed precision training. """ - def __init__(self, channels: int, groups: int | None = None, affine: bool = False): - """Initialize GroupNorm32. + def __init__(self, n_channels: int, n_groups: int | None = None, affine: bool = False): + """Initialize GroupNorm. Parameters ---------- - channels + n_channels The number of channels in the input tensor. - groups + n_groups The number of groups to use. If None, the number of groups is determined automatically as a power of 2 that is less than or equal to 32 and leaves at least 4 channels per group. affine Whether to use learnable affine parameters. """ - if groups is None: + if n_groups is None: groups_, candidate = 1, 2 - while (candidate <= min(32, channels // 4)) and (channels % candidate == 0): + while (candidate <= min(32, n_channels // 4)) and (n_channels % candidate == 0): groups_, candidate = candidate, groups_ * 2 else: - groups_ = groups - super().__init__(groups_, channels, affine=affine) + groups_ = n_groups + super().__init__(groups_, n_channels, affine=affine) def __call__(self, x: torch.Tensor) -> torch.Tensor: """Apply GroupNorm32. @@ -45,5 +45,5 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: return super().__call__(x.float()).type(x.dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Apply GroupNorm32.""" + """Apply GroupNorm.""" return super().forward(x.float()).type(x.dtype) diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py index 84e1d56e4..7c35eee96 100644 --- a/src/mrpro/nn/LayerNorm.py +++ b/src/mrpro/nn/LayerNorm.py @@ -10,12 +10,12 @@ class LayerNorm(CondMixin, Module): """Layer normalization.""" - def __init__(self, channels: int | None, features_last: bool = False, cond_dim: int = 0) -> None: + def __init__(self, n_channels: int | None, features_last: bool = False, cond_dim: int = 0) -> None: """Initialize the layer normalization. Parameters ---------- - channels + n_channels Number of channels in the input tensor. If `None`, the layer normalization does not do an elementwise affine transformation. features_last @@ -24,20 +24,20 @@ def __init__(self, channels: int | None, features_last: bool = False, cond_dim: Number of channels in the conditioning tensor. If `0`, no adaptive scaling is applied. """ super().__init__() - if channels is None and cond_dim == 0: + if n_channels is None and cond_dim == 0: self.weight: Parameter | None = None self.bias: Parameter | None = None self.cond_proj: Linear | None = None - elif channels is None and cond_dim > 0: + elif n_channels is None and cond_dim > 0: raise ValueError('channels must be provided if cond_dim > 0') - elif channels is not None and cond_dim == 0: - self.weight = Parameter(torch.ones(channels)) - self.bias = Parameter(torch.zeros(channels)) + elif n_channels is not None and cond_dim == 0: + self.weight = Parameter(torch.ones(n_channels)) + self.bias = Parameter(torch.zeros(n_channels)) self.cond_proj = None - elif channels is not None: + elif n_channels is not None: self.weight = None self.bias = None - self.cond_proj = Linear(cond_dim, 2 * channels) + self.cond_proj = Linear(cond_dim, 2 * n_channels) else: raise ValueError('cond_dim must be zero or positive.') diff --git a/src/mrpro/nn/LinearSelfAttention.py b/src/mrpro/nn/LinearSelfAttention.py index 612f0ffe2..2bab08930 100644 --- a/src/mrpro/nn/LinearSelfAttention.py +++ b/src/mrpro/nn/LinearSelfAttention.py @@ -14,25 +14,14 @@ class LinearSelfAttention(Module): References ---------- - .. [KAT20] Katharopoulos, Angelos, et al. Transformers are rnns: Fast autoregressive transformers with linear + .. [KAT20] Katharopoulos, Angelos, et al. Transformers are RNNs: Fast autoregressive transformers with linear attention. ICML 2020. https://arxiv.org/abs/2006.16236 - - Parameters - ---------- - channels - Input and output channel dimension. - n_heads - Number of attention heads. - bias - Whether to use bias in the QKV projection. - eps - Small epsilon for numerical stability in normalization. """ def __init__( self, - channels_in: int, - channels_out: int, + n_channels_in: int, + n_channels_out: int, n_heads: int, eps: float = 1e-6, features_last: bool = False, @@ -41,9 +30,9 @@ def __init__( Parameters ---------- - channels_in + n_channels_in Input channel dimension. - channels_out + n_channels_out Output channel dimension. n_heads Number of attention heads. @@ -57,10 +46,10 @@ def __init__( self.features_last = features_last self.eps = eps self.n_heads = n_heads - channels_per_head = channels_in // n_heads - self.to_qkv = Linear(channels_in, 3 * channels_per_head * n_heads) + channels_per_head = n_channels_in // n_heads + self.to_qkv = Linear(n_channels_in, 3 * channels_per_head * n_heads) self.kernel_function = ReLU() - self.to_out = Linear(channels_per_head * n_heads, channels_out) + self.to_out = Linear(channels_per_head * n_heads, n_channels_out) def __call__(self, x: Tensor) -> Tensor: """Apply linear self-attention. diff --git a/src/mrpro/nn/MultiHeadAttention.py b/src/mrpro/nn/MultiHeadAttention.py index b79353192..18c34d446 100644 --- a/src/mrpro/nn/MultiHeadAttention.py +++ b/src/mrpro/nn/MultiHeadAttention.py @@ -14,22 +14,20 @@ class MultiHeadAttention(Module): def __init__( self, - channels_in: int, - channels_out: int, + n_channels_in: int, + n_channels_out: int, n_heads: int, features_last: bool = False, p_dropout: float = 0.0, - channels_cross: int | None = None, + n_channels_cross: int | None = None, ): """Initialize the Multi-head Attention. Parameters ---------- - dim - Number of spatial dimensions. - channels_in + n_channels_in Number of input channels. - channels_out + n_channels_out Number of output channels. n_heads number of attention heads @@ -38,17 +36,17 @@ def __init__( or the second dimension, as common in image models. p_dropout Dropout probability. - channels_cross - Number of channels for cross-attention. If `None`, use `channels_in`. + n_channels_cross + Number of channels for cross-attention. If `None`, use `n_channels_in`. """ super().__init__() - channels_per_head_q = channels_in // n_heads - channels_per_head_kv = channels_cross // n_heads if channels_cross is not None else channels_in // n_heads - self.to_q = Linear(channels_in, channels_per_head_q * n_heads) - self.to_kv = Linear(channels_in, channels_per_head_kv * n_heads * 2) + channels_per_head_q = n_channels_in // n_heads + channels_per_head_kv = n_channels_cross // n_heads if n_channels_cross is not None else n_channels_in // n_heads + self.to_q = Linear(n_channels_in, channels_per_head_q * n_heads) + self.to_kv = Linear(n_channels_in, channels_per_head_kv * n_heads * 2) self.p_dropout = p_dropout self.features_last = features_last - self.to_out = Linear(channels_in, channels_out) + self.to_out = Linear(n_channels_in, n_channels_out) self.n_heads = n_heads def __call__(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: diff --git a/src/mrpro/nn/NeighborhoodSelfAttention.py b/src/mrpro/nn/NeighborhoodSelfAttention.py index 69f81bc02..5523e081e 100644 --- a/src/mrpro/nn/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/NeighborhoodSelfAttention.py @@ -119,8 +119,8 @@ class NeighborhoodSelfAttention(Module): def __init__( self, - channels_in: int, - channels_out: int, + n_channels_in: int, + n_channels_out: int, n_heads: int, kernel_size: int | Sequence[int], dilation: int | Sequence[int] = 1, @@ -134,9 +134,9 @@ def __init__( Parameters ---------- - channels_in + n_channels_in The number of channels in the input tensor. - channels_out + n_channels_out The number of channels in the output tensor. n_heads The number of attention heads. @@ -156,9 +156,9 @@ def __init__( self.dilation = dilation if isinstance(dilation, int) else tuple(dilation) self.circular = circular if isinstance(circular, bool) else tuple(circular) self.features_last = features_last - channels_per_head = channels_in // n_heads - self.to_qkv = Linear(channels_in, 3 * channels_per_head * n_heads) - self.to_out = Linear(channels_per_head * n_heads, channels_out) + channels_per_head = n_channels_in // n_heads + self.to_qkv = Linear(n_channels_in, 3 * channels_per_head * n_heads) + self.to_out = Linear(channels_per_head * n_heads, n_channels_out) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply neighborhood attention to the input tensor. diff --git a/src/mrpro/nn/PermutedBlock.py b/src/mrpro/nn/PermutedBlock.py index 99a27f36a..935d114dc 100644 --- a/src/mrpro/nn/PermutedBlock.py +++ b/src/mrpro/nn/PermutedBlock.py @@ -32,7 +32,24 @@ def __init__(self, apply_along_dim: Sequence[int], module: nn.Module, features_l self.module = module self.features_last = features_last - def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module along the selected dimensions. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor, passed to the module if it supports conditioning + (that is, if it is a subclass of `~mrpro.nn.CondMixin`) + + Returns + ------- + Output tensor. + """ + return self.forward(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the module along the selected dimensions.""" keep = tuple(d % x.ndim for d in self.apply_along_dim) if 0 in keep: diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py index 70b0270df..6a1dc3a50 100644 --- a/src/mrpro/nn/PixelShuffle.py +++ b/src/mrpro/nn/PixelShuffle.py @@ -9,7 +9,7 @@ class PixelUnshuffle(Module): """ND-version of PixelUnshuffle downscaling.""" - def __init__(self, downscale_factor: int): + def __init__(self, downscale_factor: int, features_last: bool = False): """Initialize PixelUnshuffle. Reduces spatial dimensions and increases the channel number by reshaping. @@ -20,11 +20,15 @@ def __init__(self, downscale_factor: int): Parameters ---------- - downscale_factor : int + downscale_factor The factor by which to downscale the input tensor. + features_last + Whether the features/channels dimension is the last dimension as common in transformer models or the + second dimension as common in CNN models. """ super().__init__() self.downscale_factor = downscale_factor + self.features_last = features_last def __call__(self, x: torch.Tensor) -> torch.Tensor: """Downscale the input. @@ -32,32 +36,38 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: Parameters ---------- x - Tensor of shape `batch, channels, *spatial_dims` + Tensor of shape `batch, channels, *spatial_dims` or `batch, *spatial_dims, channels` (if `features_last`). Returns ------- - Tensor of shape `batch, channels * downscale_factor**dim, *spatial_dims/downscale_factor` + Tensor of shape `batch, channels * downscale_factor**dim, *spatial_dims/downscale_factor` or + `batch, *spatial_dims/downscale_factor, channels * downscale_factor**dim` (if `features_last`). """ return super().__call__(x) def forward(self, x: torch.Tensor) -> torch.Tensor: """Downscale the input.""" - dim = x.ndim - 2 - if dim == 2: # fast path for 2D + n_dim = x.ndim - 2 + if n_dim == 2 and not self.features_last: # fast path for 2D images return torch.nn.functional.pixel_unshuffle(x, self.downscale_factor) - new_shape = list(x.shape[:2]) + new_shape = list(x.shape[:1]) if self.features_last else list(x.shape[:2]) source_positions = [] - for i, old in enumerate(x.shape[2:]): + for i, old in enumerate(x.shape[1:-1] if self.features_last else x.shape[2:]): if old % self.downscale_factor: raise ValueError('Spatial size must be divisible by downscale_factor.') new_shape.append(old // self.downscale_factor) new_shape.append(self.downscale_factor) source_positions.append(2 + 2 * i) - + if self.features_last: + new_shape.append(x.shape[-1]) x = x.view(new_shape) - x = x.moveaxis(source_positions, tuple(range(-dim, 0))) - x = x.flatten(1, -dim - 1) + x = x.moveaxis(source_positions, tuple(range(-n_dim, 0))) + if self.features_last: + x = x.flatten(-n_dim - 1) + else: + x = x.flatten(1, -n_dim - 1) + return x @@ -73,29 +83,29 @@ class PixelUnshuffleDownsample(Module): """ def __init__( - self, dim: int, channels_in: int, channels_out: int, downscale_factor: int = 2, residual: bool = False + self, n_dim: int, n_channels_in: int, n_channels_out: int, downscale_factor: int = 2, residual: bool = False ): """Initialize a PixelUnshuffleDownsample layer. Parameters ---------- - dim : int + n_dim Dimension of the input tensor. - channels_in : int + channels_in Number of channels in the input tensor. - channels_out : int + channels_out Number of channels in the output tensor. - downscale_factor : int, optional + downscale_factor Factor by which to downscale the input tensor. - residual : bool, optional + residual Whether to use a residual connection as proposed in [DCAE]_. """ super().__init__() self.pixel_unshuffle = PixelUnshuffle(downscale_factor) - out_ratio = downscale_factor**dim - if channels_out % out_ratio != 0: - raise ValueError(f'channels_out must be divisible by downscale_factor**{dim}.') - self.conv = ConvND(dim)(channels_in, channels_out // out_ratio, kernel_size=3, padding='same') + out_ratio = downscale_factor**n_dim + if n_channels_out % out_ratio != 0: + raise ValueError(f'channels_out must be divisible by downscale_factor**{n_dim}.') + self.conv = ConvND(n_dim)(n_channels_in, n_channels_out // out_ratio, kernel_size=3, padding='same') self.residual = residual def __call__(self, x: torch.Tensor) -> torch.Tensor: @@ -134,24 +144,26 @@ class PixelShuffleUpsample(Module): https://arxiv.org/abs/2410.10733 """ - def __init__(self, dim: int, channels_in: int, channels_out: int, upscale_factor: int = 2, residual: bool = False): + def __init__( + self, n_dim: int, n_channels_in: int, n_channels_out: int, upscale_factor: int = 2, residual: bool = False + ): """Initialize a PixelShuffleUpsample layer. Parameters ---------- - dim : int + n_dim Dimension of the input tensor. - channels_in : int + n_channels_in Number of channels in the input tensor. - channels_out : int + n_channels_out Number of channels in the output tensor. - upscale_factor : int, optional + upscale_factor Factor by which to upscale the input tensor. - residual : bool, optional + residual Whether to use a residual connection as proposed in [DCAE]_. """ super().__init__() - self.conv = ConvND(dim)(channels_in, channels_out * upscale_factor**dim, kernel_size=3, padding='same') + self.conv = ConvND(n_dim)(n_channels_in, n_channels_out * upscale_factor**n_dim, kernel_size=3, padding='same') self.pixel_shuffle = PixelShuffle(upscale_factor) self.residual = residual @@ -165,7 +177,7 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: Returns ------- - Tensor of shape `batch, channels_out, *spatial_dims * upscale_factor` + Tensor of shape `batch, channels_out, *spatial_dims * upscale_factor` """ return super().__call__(x) @@ -181,7 +193,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class PixelShuffle(Module): """ND-version of PixelShuffle upscaling.""" - def __init__(self, upscale_factor: int): + def __init__(self, upscale_factor: int, features_last: bool = False): """Initialize PixelShuffle. Upscales spatial dimensions and decreases the channel number by reshaping. @@ -192,11 +204,15 @@ def __init__(self, upscale_factor: int): Parameters ---------- - upscale_factor : int + upscale_factor The factor by which to upscale the spatial dimensions. + features_last + Whether the features/channels dimension is the last dimension as common in transformer models or the + second dimension as common in CNN models. """ super().__init__() self.upscale_factor = upscale_factor + self.features_last = features_last def __call__(self, x: torch.Tensor) -> torch.Tensor: """Upscale the input. @@ -204,23 +220,28 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: Parameters ---------- x - Tensor of shape `batch, channels, *spatial_dims` + Tensor of shape `batch, channels, *spatial_dims` or `batch, *spatial_dims, channels` (if `features_last`). Returns ------- - Tensor of shape `batch, channels / upscale_factor**dim, *spatial_dims * upscale_factor` + Tensor of shape `batch, channels / upscale_factor**n_dim, *spatial_dims * upscale_factor` or + `batch, *spatial_dims * upscale_factor, channels / upscale_factor**n_dim` (if `features_last`). """ return super().__call__(x) def forward(self, x: torch.Tensor) -> torch.Tensor: """Upscale the input.""" - dim = x.ndim - 2 - if dim == 2: # fast path for 2D + n_dim = x.ndim - 2 + if n_dim == 2 and not self.features_last: # fast path for 2D return torch.nn.functional.pixel_shuffle(x, self.upscale_factor) - new_shape = (x.shape[0], -1, *(old * self.upscale_factor for old in x.shape[-dim:])) - - x = x.unflatten(1, (-1, *(self.upscale_factor,) * dim)) - x = x.moveaxis(tuple(range(2, 2 + dim)), tuple(range(-2 * dim + 1, 0, 2))) + if self.features_last: + new_shape = (x.shape[0], *(old * self.upscale_factor for old in x.shape[-n_dim - 1 : -1]), -1) + x = x.unflatten(-1, (-1, *(self.upscale_factor,) * n_dim)) + x = x.moveaxis(tuple(range(-n_dim, 0)), tuple(range(-2 * n_dim, 0, 2))) + else: + new_shape = (x.shape[0], -1, *(old * self.upscale_factor for old in x.shape[-n_dim:])) + x = x.unflatten(1, (-1, *(self.upscale_factor,) * n_dim)) + x = x.moveaxis(tuple(range(2, 2 + n_dim)), tuple(range(-2 * n_dim + 1, 0, 2))) x = x.reshape(new_shape) return x diff --git a/src/mrpro/nn/ShiftedWindowAttention.py b/src/mrpro/nn/ShiftedWindowAttention.py index f66e2277e..190418964 100644 --- a/src/mrpro/nn/ShiftedWindowAttention.py +++ b/src/mrpro/nn/ShiftedWindowAttention.py @@ -23,41 +23,47 @@ class ShiftedWindowAttention(Module): rel_position_index: torch.Tensor def __init__( - self, dim: int, channels_in: int, channels_out: int, n_heads: int, window_size: int = 7, shifted: bool = True + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + n_heads: int, + window_size: int = 7, + shifted: bool = True, ): """Initialize the ShiftedWindowAttention module. Parameters ---------- - dim : int + n_dim The dimension of the input. - channels_in : int + n_channels_in The number of channels in the input tensor. - channels_out : int + n_channels_out The number of channels in the output tensor. - n_heads : int + n_heads The number of attention heads. The number if channels per head is ``channels // n_heads``. - window_size : int + window_size The size of the window. - shifted : bool + shifted Whether to shift the window. """ super().__init__() self.n_heads = n_heads self.window_size = window_size self.shifted = shifted - channels_per_head = channels_in // n_heads - self.to_qkv = ConvND(dim)(channels_per_head * n_heads, 3 * channels_per_head * n_heads, 1) - self.to_out = ConvND(dim)(channels_per_head * n_heads, channels_out, 1) - self.dim = dim + channels_per_head = n_channels_in // n_heads + self.to_qkv = ConvND(n_dim)(channels_per_head * n_heads, 3 * channels_per_head * n_heads, 1) + self.to_out = ConvND(n_dim)(channels_per_head * n_heads, n_channels_out, 1) + self.dim = n_dim coords_1d = torch.arange(window_size) - coords_nd = torch.stack(torch.meshgrid(*([coords_1d] * dim), indexing='ij'), 0).flatten(1) + coords_nd = torch.stack(torch.meshgrid(*([coords_1d] * n_dim), indexing='ij'), 0).flatten(1) rel_coords = coords_nd[:, :, None] - coords_nd[:, None, :] # (dim, window_size**dim, window_size**dim) rel_coords += window_size - 1 # shift to >=0 - rel_position_index = ravel_multi_index(tuple(rel_coords), (2 * window_size - 1,) * dim) + rel_position_index = ravel_multi_index(tuple(rel_coords), (2 * window_size - 1,) * n_dim) self.register_buffer('rel_position_index', rel_position_index) - self.relative_position_bias_table = torch.nn.Parameter(torch.empty((2 * window_size - 1) ** dim, n_heads)) + self.relative_position_bias_table = torch.nn.Parameter(torch.empty((2 * window_size - 1) ** n_dim, n_heads)) torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02, a=-0.04, b=0.04) def __call__(self, x: torch.Tensor) -> torch.Tensor: @@ -98,6 +104,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: attention = torch.roll(attention, (self.window_size // 2,) * self.dim, dims=tuple(range(-self.dim, 0))) out = self.to_out(attention) return out - - -'' diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/SpatialTransformerBlock.py index cb431b52d..78be70d77 100644 --- a/src/mrpro/nn/SpatialTransformerBlock.py +++ b/src/mrpro/nn/SpatialTransformerBlock.py @@ -57,8 +57,8 @@ def __init__( self.selfattention = Sequential( LayerNorm(channels, features_last=True), MultiHeadAttention( - channels_in=channels, - channels_out=channels, + n_channels_in=channels, + n_channels_out=channels, n_heads=n_heads, p_dropout=p_dropout, features_last=True, diff --git a/src/mrpro/nn/SqueezeExcitation.py b/src/mrpro/nn/SqueezeExcitation.py index 787817173..bd0fab4e8 100644 --- a/src/mrpro/nn/SqueezeExcitation.py +++ b/src/mrpro/nn/SqueezeExcitation.py @@ -17,24 +17,24 @@ class SqueezeExcitation(Module): ..[SE] Hu, Jie, Li Shen, and Gang Sun. "Squeeze-and-excitation networks." CVPR 2018, https://arxiv.org/abs/1709.01507 """ - def __init__(self, dim: int, input_channels: int, squeeze_channels: int) -> None: + def __init__(self, n_dim: int, n_channels_input: int, n_channels_squeeze: int) -> None: """Initialize SqueezeExcitation. Parameters ---------- - dim + n_dim The dimension of the input tensor. - input_channels + n_channels_input The number of channels in the input tensor. - squeeze_channels + n_channels_squeeze The number of channels in the squeeze tensor. """ super().__init__() self.scale = Sequential( - AdaptiveAvgPoolND(dim)(1), - ConvND(dim)(input_channels, squeeze_channels, kernel_size=1), + AdaptiveAvgPoolND(n_dim)(1), + ConvND(n_dim)(n_channels_input, n_channels_squeeze, kernel_size=1), ReLU(), - ConvND(dim)(squeeze_channels, input_channels, kernel_size=1), + ConvND(n_dim)(n_channels_squeeze, n_channels_input, kernel_size=1), Sigmoid(), ) diff --git a/src/mrpro/nn/TransposedAttention.py b/src/mrpro/nn/TransposedAttention.py index 043afa750..1f99c0fe7 100644 --- a/src/mrpro/nn/TransposedAttention.py +++ b/src/mrpro/nn/TransposedAttention.py @@ -19,16 +19,16 @@ class TransposedAttention(Module): CVPR 2022, https://arxiv.org/pdf/2111.09881.pdf """ - def __init__(self, dim: int, channels_in: int, channels_out: int, n_heads: int): + def __init__(self, n_dim: int, n_channels_in: int, n_channels_out: int, n_heads: int): """Initialize a TransposedAttention layer. Parameters ---------- - dim + n_dim input dimension - channels_in + n_channels_in Number of channels in the input tensor. - channels_out + n_channels_out Number of channels in the output tensor. n_heads Number of attention heads. @@ -36,17 +36,17 @@ def __init__(self, dim: int, channels_in: int, channels_out: int, n_heads: int): super().__init__() self.n_heads = n_heads self.temperature = Parameter(torch.ones(n_heads, 1, 1)) - channels_per_head = channels_in // n_heads - self.to_qkv = ConvND(dim)(channels_in, channels_per_head * n_heads * 3, kernel_size=1) - self.qkv_dwconv = ConvND(dim)( + channels_per_head = n_channels_in // n_heads + self.to_qkv = ConvND(n_dim)(n_channels_in, channels_per_head * n_heads * 3, kernel_size=1) + self.qkv_dwconv = ConvND(n_dim)( channels_per_head * n_heads * 3, channels_per_head * n_heads * 3, kernel_size=3, - groups=channels_in * 3, + groups=n_channels_in * 3, padding=1, bias=False, ) - self.to_out = ConvND(dim)(channels_per_head * n_heads, channels_out, kernel_size=1) + self.to_out = ConvND(n_dim)(channels_per_head * n_heads, n_channels_out, kernel_size=1) def __call__(self, x: torch.Tensor) -> torch.Tensor: """Apply transposed attention. diff --git a/src/mrpro/nn/convert_linear_conv.py b/src/mrpro/nn/convert_linear_conv.py index a6dac5f33..beb09d4b0 100644 --- a/src/mrpro/nn/convert_linear_conv.py +++ b/src/mrpro/nn/convert_linear_conv.py @@ -9,22 +9,22 @@ @overload -def linear_to_conv(linear_layer: Linear, dim: Literal[1]) -> Conv1d: ... +def linear_to_conv(linear_layer: Linear, n_dim: Literal[1]) -> Conv1d: ... @overload -def linear_to_conv(linear_layer: Linear, dim: Literal[2]) -> Conv2d: ... +def linear_to_conv(linear_layer: Linear, n_dim: Literal[2]) -> Conv2d: ... @overload -def linear_to_conv(linear_layer: Linear, dim: Literal[3]) -> Conv3d: ... +def linear_to_conv(linear_layer: Linear, n_dim: Literal[3]) -> Conv3d: ... @overload -def linear_to_conv(linear_layer: Linear, dim: int) -> Conv1d | Conv2d | Conv3d: ... +def linear_to_conv(linear_layer: Linear, n_dim: int) -> Conv1d | Conv2d | Conv3d: ... -def linear_to_conv(linear_layer: Linear, dim: int) -> Conv1d | Conv2d | Conv3d: +def linear_to_conv(linear_layer: Linear, n_dim: int) -> Conv1d | Conv2d | Conv3d: """Convert a Linear layer to a ConvNd layer with kernel size 1. Rearranging the spatial dimensions to the batch dimension, @@ -39,16 +39,16 @@ def linear_to_conv(linear_layer: Linear, dim: int) -> Conv1d | Conv2d | Conv3d: Parameters ---------- - linear_layer : nn.Linear + linear_layer The linear layer to convert. - dim : int + n_dim The convolution dimension (1, 2, or 3). Returns ------- A Conv layer with equivalent weights and bias. """ - conv = ConvND(dim)( + conv = ConvND(n_dim)( in_channels=linear_layer.in_features, out_channels=linear_layer.out_features, kernel_size=1, diff --git a/src/mrpro/nn/ndmodules.py b/src/mrpro/nn/ndmodules.py index b4bf089ea..b7626ab5a 100644 --- a/src/mrpro/nn/ndmodules.py +++ b/src/mrpro/nn/ndmodules.py @@ -3,19 +3,19 @@ import torch -def ConvND(dim: int) -> type[torch.nn.Conv1d] | type[torch.nn.Conv2d] | type[torch.nn.Conv3d]: # noqa: N802 - """Get the `dim`-dimensional convolution class. +def ConvND(n_dim: int) -> type[torch.nn.Conv1d] | type[torch.nn.Conv2d] | type[torch.nn.Conv3d]: # noqa: N802 + """Get the `n_dim`-dimensional convolution class. Parameters ---------- - dim + n_dim The dimension of the convolution. Returns ------- The convolution class. """ - match dim: + match n_dim: case 1: return torch.nn.Conv1d case 2: @@ -23,24 +23,24 @@ def ConvND(dim: int) -> type[torch.nn.Conv1d] | type[torch.nn.Conv2d] | type[tor case 3: return torch.nn.Conv3d case _: - raise NotImplementedError(f'ConvND for dim {dim} not implemented. Raise an issue if you need this.') + raise NotImplementedError(f'ConvND for dim {n_dim} not implemented. Raise an issue if you need this.') def ConvTransposeND( # noqa: N802 - dim: int, + n_dim: int, ) -> type[torch.nn.ConvTranspose1d] | type[torch.nn.ConvTranspose2d] | type[torch.nn.ConvTranspose3d]: - """Get the `dim`-dimensional transposed convolution class. + """Get the `n_dim`-dimensional transposed convolution class. Parameters ---------- - dim + n_dim The dimension of the transposed convolution. Returns ------- The transposed convolution class. """ - match dim: + match n_dim: case 1: return torch.nn.ConvTranspose1d case 2: @@ -49,23 +49,23 @@ def ConvTransposeND( # noqa: N802 return torch.nn.ConvTranspose3d case _: raise NotImplementedError( - f'ConvTransposeND for dim {dim} not implemented. Raise an issue if you need this.' + f'ConvTransposeND for dim {n_dim} not implemented. Raise an issue if you need this.' ) -def MaxPoolND(dim: int) -> type[torch.nn.MaxPool1d] | type[torch.nn.MaxPool2d] | type[torch.nn.MaxPool3d]: # noqa: N802 - """Get the `dim`-dimensional max pooling class. +def MaxPoolND(n_dim: int) -> type[torch.nn.MaxPool1d] | type[torch.nn.MaxPool2d] | type[torch.nn.MaxPool3d]: # noqa: N802 + """Get the `n_dim`-dimensional max pooling class. Parameters ---------- - dim + n_dim The dimension of the max pooling. Returns ------- The max pooling class. """ - match dim: + match n_dim: case 1: return torch.nn.MaxPool1d case 2: @@ -73,22 +73,22 @@ def MaxPoolND(dim: int) -> type[torch.nn.MaxPool1d] | type[torch.nn.MaxPool2d] | case 3: return torch.nn.MaxPool3d case _: - raise NotImplementedError(f'MaxPoolNd for dim {dim} not implemented. Raise an issue if you need this.') + raise NotImplementedError(f'MaxPoolNd for dim {n_dim} not implemented. Raise an issue if you need this.') -def AvgPoolND(dim: int) -> type[torch.nn.AvgPool1d] | type[torch.nn.AvgPool2d] | type[torch.nn.AvgPool3d]: # noqa: N802 - """Get the `dim`-dimensional average pooling class. +def AvgPoolND(n_dim: int) -> type[torch.nn.AvgPool1d] | type[torch.nn.AvgPool2d] | type[torch.nn.AvgPool3d]: # noqa: N802 + """Get the `n_dim`-dimensional average pooling class. Parameters ---------- - dim + n_dim The dimension of the average pooling. Returns ------- The average pooling class. """ - match dim: + match n_dim: case 1: return torch.nn.AvgPool1d case 2: @@ -96,24 +96,24 @@ def AvgPoolND(dim: int) -> type[torch.nn.AvgPool1d] | type[torch.nn.AvgPool2d] | case 3: return torch.nn.AvgPool3d case _: - raise NotImplementedError(f'AvgPoolNd for dim {dim} not implemented. Raise an issue if you need this.') + raise NotImplementedError(f'AvgPoolNd for dim {n_dim} not implemented. Raise an issue if you need this.') def AdaptiveAvgPoolND( # noqa: N802 - dim: int, + n_dim: int, ) -> type[torch.nn.AdaptiveAvgPool1d] | type[torch.nn.AdaptiveAvgPool2d] | type[torch.nn.AdaptiveAvgPool3d]: - """Get the `dim`-dimensional adaptive average pooling class. + """Get the `n_dim`-dimensional adaptive average pooling class. Parameters ---------- - dim + n_dim The dimension of the adaptive average pooling. Returns ------- The adaptive average pooling class. """ - match dim: + match n_dim: case 1: return torch.nn.AdaptiveAvgPool1d case 2: @@ -122,25 +122,25 @@ def AdaptiveAvgPoolND( # noqa: N802 return torch.nn.AdaptiveAvgPool3d case _: raise NotImplementedError( - f'AdaptiveAvgPoolNd for dim {dim} not implemented. Raise an issue if you need this.' + f'AdaptiveAvgPoolNd for dim {n_dim} not implemented. Raise an issue if you need this.' ) def InstanceNormND( # noqa: N802 - dim: int, + n_dim: int, ) -> type[torch.nn.InstanceNorm1d] | type[torch.nn.InstanceNorm2d] | type[torch.nn.InstanceNorm3d]: - """Get the `dim`-dimensional instance normalization class. + """Get the `n_dim`-dimensional instance normalization class. Parameters ---------- - dim + n_dim The dimension of the instance normalization. Returns ------- The instance normalization class. """ - match dim: + match n_dim: case 1: return torch.nn.InstanceNorm1d case 2: @@ -148,24 +148,26 @@ def InstanceNormND( # noqa: N802 case 3: return torch.nn.InstanceNorm3d case _: - raise NotImplementedError(f'InstanceNormNd for dim {dim} not implemented. Raise an issue if you need this.') + raise NotImplementedError( + f'InstanceNormNd for dim {n_dim} not implemented. Raise an issue if you need this.' + ) def BatchNormND( # noqa: N802 - dim: int, + n_dim: int, ) -> type[torch.nn.BatchNorm1d] | type[torch.nn.BatchNorm2d] | type[torch.nn.BatchNorm3d]: - """Get the `dim`-dimensional batch normalization class. + """Get the `n_dim`-dimensional batch normalization class. Parameters ---------- - dim + n_dim The dimension of the batch normalization. Returns ------- The batch normalization class. """ - match dim: + match n_dim: case 1: return torch.nn.BatchNorm1d case 2: @@ -173,4 +175,4 @@ def BatchNormND( # noqa: N802 case 3: return torch.nn.BatchNorm3d case _: - raise NotImplementedError(f'BatchNormNd for dim {dim} not implemented. Raise an issue if you need this.') + raise NotImplementedError(f'BatchNormNd for dim {n_dim} not implemented. Raise an issue if you need this.') diff --git a/src/mrpro/nn/nets/BasicCNN.py b/src/mrpro/nn/nets/BasicCNN.py index 620e60038..280b79cd0 100644 --- a/src/mrpro/nn/nets/BasicCNN.py +++ b/src/mrpro/nn/nets/BasicCNN.py @@ -65,9 +65,9 @@ def __init__( elif norm.lower() == 'group': self.append(GroupNorm(c_in, affine=not use_film)) elif norm.lower() == 'instance': - self.append(GroupNorm(c_in, groups=c_in, affine=not use_film)) # is instance norm + self.append(GroupNorm(c_in, n_groups=c_in, affine=not use_film)) # is instance norm elif norm.lower() == 'layer': - self.append(GroupNorm(c_in, groups=1, affine=not use_film)) # is layer norm + self.append(GroupNorm(c_in, n_groups=1, affine=not use_film)) # is layer norm elif norm.lower() != 'none': raise ValueError(f'Invalid normalization type: {norm}') diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py index 1f49a0297..c5f4eaaa7 100644 --- a/src/mrpro/nn/nets/DCAE.py +++ b/src/mrpro/nn/nets/DCAE.py @@ -93,9 +93,9 @@ def __init__( attention = MultiHeadAttention(channels, channels, n_heads, features_last=False) self.context_module = Residual(Sequential(attention, RMSNorm(channels))) self.local_module = GluMBConvResBlock( - dim=dim, - channels_in=channels, - channels_out=channels, + n_dim=dim, + n_channels_in=channels, + n_channels_out=channels, expand_ratio=expand_ratio, ) diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index dd97efe59..83ec8c49b 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -59,9 +59,9 @@ def __init__( hidden_dim = int(channels * mlp_ratio) self.norm1 = InstanceNormND(dim)(channels) self.attn = ShiftedWindowAttention( - dim=dim, - channels_in=channels, - channels_out=channels, + n_dim=dim, + n_channels_in=channels, + n_channels_out=channels, n_heads=n_heads, window_size=window_size, shifted=shifted, diff --git a/tests/nn/test_attentiongate.py b/tests/nn/test_attentiongate.py index 4b470be1c..7a7bb18ec 100644 --- a/tests/nn/test_attentiongate.py +++ b/tests/nn/test_attentiongate.py @@ -25,7 +25,7 @@ def test_attention_gate(dim, channels_gate, channels_in, channels_hidden, input_ x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) gate = rng.float32_tensor(gate_shape).to(device).requires_grad_(True) attn = AttentionGate( - dim=dim, channels_gate=channels_gate, channels_in=channels_in, channels_hidden=channels_hidden + n_dim=dim, channels_gate=channels_gate, channels_in=channels_in, channels_hidden=channels_hidden ).to(device) output = attn(x, gate) assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' diff --git a/tests/nn/test_groupnorm32.py b/tests/nn/test_groupnorm32.py index 0c936dca7..468541aef 100644 --- a/tests/nn/test_groupnorm32.py +++ b/tests/nn/test_groupnorm32.py @@ -23,7 +23,7 @@ def test_groupnorm32(channels, groups, input_shape, device): """Test GroupNorm32 output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) - norm = GroupNorm(channels=channels, groups=groups).to(device) + norm = GroupNorm(n_channels=channels, n_groups=groups).to(device) output = norm(x) assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() diff --git a/tests/nn/test_shiftedwindowattention.py b/tests/nn/test_shiftedwindowattention.py index 7ea8a4175..3e30b7fde 100644 --- a/tests/nn/test_shiftedwindowattention.py +++ b/tests/nn/test_shiftedwindowattention.py @@ -25,7 +25,12 @@ def test_shifted_window_attentio(dim: int, window_size: int, shifted: bool, devi rng = RandomGenerator(13) x = rng.float32_tensor((batch, channels, *spatial_shape)).to(device).requires_grad_(True) swin = ShiftedWindowAttention( - dim=dim, channels_in=channels, channels_out=channels, n_heads=n_heads, window_size=window_size, shifted=shifted + n_dim=dim, + n_channels_in=channels, + n_channels_out=channels, + n_heads=n_heads, + window_size=window_size, + shifted=shifted, ).to(device) out = swin(x) assert out.shape == x.shape, f'Output shape {out.shape} != input shape {x.shape}' diff --git a/tests/nn/test_sqeezeexcitation.py b/tests/nn/test_sqeezeexcitation.py index 8929b9868..369ad0f3c 100644 --- a/tests/nn/test_sqeezeexcitation.py +++ b/tests/nn/test_sqeezeexcitation.py @@ -16,7 +16,7 @@ def test_squeeze_excitation(dim, input_shape, squeeze_channels): """Test SqueezeExcitation output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).requires_grad_(True) - se = SqueezeExcitation(dim=dim, input_channels=input_shape[1], squeeze_channels=squeeze_channels) + se = SqueezeExcitation(n_dim=dim, n_channels_input=input_shape[1], n_channels_squeeze=squeeze_channels) output = se(x) assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() diff --git a/tests/nn/test_transposedattention.py b/tests/nn/test_transposedattention.py index ea39781b3..361dc5799 100644 --- a/tests/nn/test_transposedattention.py +++ b/tests/nn/test_transposedattention.py @@ -23,7 +23,7 @@ def test_transposed_attention(dim, channels, num_heads, input_shape, device): """Test TransposedAttention output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) - attn = TransposedAttention(dim=dim, channels_in=channels, channels_out=channels, n_heads=num_heads).to(device) + attn = TransposedAttention(n_dim=dim, n_channels_in=channels, n_channels_out=channels, n_heads=num_heads).to(device) output = attn(x) assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() From da5baffa57fce6dd95b0f5fec70b88241d73558a Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 14 Jul 2025 11:16:41 +0200 Subject: [PATCH 104/205] Refactor variable names in GluMBConvResBlock and PixelShuffle classes for consistency; add unit tests for DropPath, GEGLU, PixelShuffle, and UNet functionality. --- src/mrpro/nn/GluMBConvResBlock.py | 4 +- src/mrpro/nn/PixelShuffle.py | 8 +-- tests/nn/nets/test_unet.py | 37 +++++++++++++ tests/nn/test_droppath.py | 22 ++++++++ tests/nn/test_geglu.py | 30 +++++++++++ tests/nn/test_pixelshuffle.py | 86 +++++++++++++++++++++++++++++++ 6 files changed, 182 insertions(+), 5 deletions(-) create mode 100644 tests/nn/nets/test_unet.py create mode 100644 tests/nn/test_droppath.py create mode 100644 tests/nn/test_geglu.py create mode 100644 tests/nn/test_pixelshuffle.py diff --git a/src/mrpro/nn/GluMBConvResBlock.py b/src/mrpro/nn/GluMBConvResBlock.py index de3cae041..0455cf118 100644 --- a/src/mrpro/nn/GluMBConvResBlock.py +++ b/src/mrpro/nn/GluMBConvResBlock.py @@ -38,9 +38,9 @@ def __init__( ---------- n_dim Number of spatial dimensions. - channels_in + n_channels_in Number of input channels. - channels_out + n_channels_out Number of output channels. expand_ratio Expansion ratio inside the block. diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py index 6a1dc3a50..b78853da7 100644 --- a/src/mrpro/nn/PixelShuffle.py +++ b/src/mrpro/nn/PixelShuffle.py @@ -1,5 +1,7 @@ """ND-version of PixelShuffle and PixelUnshuffle.""" +from math import ceil + import torch from torch.nn import Module @@ -91,9 +93,9 @@ def __init__( ---------- n_dim Dimension of the input tensor. - channels_in + n_channels_in Number of channels in the input tensor. - channels_out + n_channels_out Number of channels in the output tensor. downscale_factor Factor by which to downscale the input tensor. @@ -185,7 +187,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply upsampling.""" h = self.conv(x) if self.residual: - h = h + x.repeat_interleave(h.shape[1] // x.shape[1], dim=1) + h = h + x.repeat_interleave(ceil(h.shape[1] / x.shape[1]), dim=1)[:, : h.shape[1]] out = self.pixel_shuffle(h) return out diff --git a/tests/nn/nets/test_unet.py b/tests/nn/nets/test_unet.py new file mode 100644 index 000000000..5c831262d --- /dev/null +++ b/tests/nn/nets/test_unet.py @@ -0,0 +1,37 @@ +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import UNet + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_unet_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the UNet.""" + unet = UNet( + dim=2, + channels_in=1, + channels_out=1, + attention_depths=(-1,), + n_features=(4, 6, 8), + n_heads=2, + cond_dim=32, + encoder_blocks_per_scale=1, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + unet = unet.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + unet = cast(UNet, torch.compile(unet)) + y = unet(x, cond=cond) + assert y.shape == (1, 1, 16, 16) diff --git a/tests/nn/test_droppath.py b/tests/nn/test_droppath.py new file mode 100644 index 000000000..323fffd5b --- /dev/null +++ b/tests/nn/test_droppath.py @@ -0,0 +1,22 @@ +"""Test DropPath.""" + +from mrpro.nn.DropPath import DropPath +from mrpro.utils import RandomGenerator + + +def test_droppath_no_drop(): + """Test DropPath.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)) + droppath = DropPath(0) + y = droppath(x) + assert (y == x).all() + + +def test_droppath_drop_all(): + """Test DropPath.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)) + droppath = DropPath(1.0) + y = droppath(x) + assert (y == 0).all() diff --git a/tests/nn/test_geglu.py b/tests/nn/test_geglu.py new file mode 100644 index 000000000..c412b8779 --- /dev/null +++ b/tests/nn/test_geglu.py @@ -0,0 +1,30 @@ +"""Test GEGLU.""" + +import torch +from mrpro.nn.GEGLU import GEGLU +from mrpro.utils import RandomGenerator + + +def test_geglu(): + """Test GELU.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + gelu = GEGLU(3, 4) + y = gelu(x) + assert y.shape == (1, 4, 4, 5) + + y.sum().backward() + assert x.grad is not None + assert gelu.proj.weight.grad is not None + + +def test_geglu_features_last(): + """Test GELU with features last.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + gelu_last = GEGLU(3, 4, features_last=True) + gelu = GEGLU(3, 4, features_last=False) + gelu.proj = gelu_last.proj # need to set the same weights + y_last = gelu_last(x.moveaxis(1, -1)) + y = gelu(x) + torch.testing.assert_close(y, y_last.moveaxis(-1, 1)) diff --git a/tests/nn/test_pixelshuffle.py b/tests/nn/test_pixelshuffle.py new file mode 100644 index 000000000..9f098a4d3 --- /dev/null +++ b/tests/nn/test_pixelshuffle.py @@ -0,0 +1,86 @@ +"""Test PixelShuffle and PixelUnshuffle.""" + +import torch +from mrpro.nn.PixelShuffle import PixelShuffle, PixelShuffleUpsample, PixelUnshuffle, PixelUnshuffleDownsample +from mrpro.utils import RandomGenerator + + +def test_pixel_shuffle_2d(): + """Test PixelUnshuffle's fast path for 2D images.""" + x = torch.arange(3 * 4 * 8).reshape(1, 3, 4, 8) + pixel_unshuffle = PixelUnshuffle(2) + y = pixel_unshuffle(x) + assert y.shape == (1, 3 * 4, 4 // 2, 8 // 2) + + pixel_shuffle = PixelShuffle(2) + z = pixel_shuffle(y) + assert z.shape == (1, 3, 4, 8) + assert (x == z).all() + + +def test_pixel_unshuffle_4d(): + """Test PixelUnshuffle's general case.""" + x = torch.arange(3 * 4 * 8 * 10 * 12).reshape(1, 3, 4, 8, 10, 12) + pixel_unshuffle = PixelUnshuffle(2) + y = pixel_unshuffle(x) + assert y.shape == (1, 3 * 16, 4 // 2, 8 // 2, 10 // 2, 12 // 2) + + pixel_shuffle = PixelShuffle(2) + z = pixel_shuffle(y) + assert z.shape == (1, 3, 4, 8, 10, 12) + assert (x == z).all() + + +def test_pixelunshuffle_features_last(): + """Test PixelUnshuffle with features_last.""" + x = torch.arange(3 * 4 * 8 * 10 * 12).reshape(1, 3, 4, 8, 10, 12) + pixel_unshuffle_last = PixelUnshuffle(2, features_last=True) + pixel_unshuffle = PixelUnshuffle(2, features_last=False) + y_last = pixel_unshuffle_last(x.moveaxis(1, -1)).moveaxis(-1, 1) + y_normal = pixel_unshuffle(x) + assert (y_last == y_normal).all() + + +def test_pixelshuffle_features_last(): + """Test PixelS huffle with features_last.""" + x = torch.arange(3 * 4 * 8 * 10 * 12).reshape(1, -1, 2, 4, 5, 6) + pixel_shuffle_last = PixelShuffle(2, features_last=True) + pixel_shuffle = PixelShuffle(2, features_last=False) + y_last = pixel_shuffle_last(x.moveaxis(1, -1)).moveaxis(-1, 1) + y_normal = pixel_shuffle(x) + assert (y_last == y_normal).all() + + +def test_unpixelshuffledownsample_residual(): + """Test PixelUnshuffleDownsample with residual.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 2, 9, 12, 15)) + downsample = PixelUnshuffleDownsample(3, 2, 27, downscale_factor=3, residual=True) + y = downsample(x) + assert y.shape == (1, 27, 3, 4, 5) + + +def test_pixelshuffleupsample_residual(): + """Test PixelShuffleUpsample with residual.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 2, 3, 4, 5)) + upsample = PixelShuffleUpsample(3, 2, 1, upscale_factor=3, residual=True) + y = upsample(x) + assert y.shape == (1, 1, 9, 12, 15) + + +def test_pixelshuffleupsample_pixelunshuffledownsample(): + """Test if PixelUnshuffleDownsample is the inverse of PixelShuffleUpsample.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3**3, 3, 4, 5)) + # Only without residual, the upsample and downsample are inverses. + downsample = PixelUnshuffleDownsample(3, 1, 3**3, downscale_factor=3, residual=False) + upsample = PixelShuffleUpsample(3, 3**3, 1, upscale_factor=3, residual=False) + # Only if the convs are Identity, the upsample and downsample are inverses. + torch.nn.init.dirac_(downsample.conv.weight) + torch.nn.init.dirac_(upsample.conv.weight) + torch.nn.init.zeros_(downsample.conv.bias) # type: ignore[arg-type] + torch.nn.init.zeros_(upsample.conv.bias) # type: ignore[arg-type] + y = downsample(upsample(x)) + assert y.shape == (1, 3**3, 3, 4, 5) + torch.testing.assert_close(y, x, msg='Upsample and downsample are not inverses.') From 79124d1d9fedd4f370ffb52bc3c1a8297ab53263 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 14 Jul 2025 13:37:27 +0200 Subject: [PATCH 105/205] tests --- tests/nn/test_attentiongate.py | 18 ++- tests/nn/test_complexaschannel.py | 4 +- tests/nn/test_droppath.py | 14 +- tests/nn/test_film.py | 12 +- tests/nn/test_geglu.py | 16 ++- tests/nn/test_groupnorm.py | 39 ++++++ tests/nn/test_groupnorm32.py | 34 ----- tests/nn/test_layernorm.py | 175 ++++++++++++++++++++++++ tests/nn/test_resblock.py | 2 +- tests/nn/test_sequential.py | 17 ++- tests/nn/test_shiftedwindowattention.py | 12 +- tests/nn/test_sqeezeexcitation.py | 2 +- tests/nn/test_transposedattention.py | 2 +- 13 files changed, 279 insertions(+), 68 deletions(-) create mode 100644 tests/nn/test_groupnorm.py delete mode 100644 tests/nn/test_groupnorm32.py create mode 100644 tests/nn/test_layernorm.py diff --git a/tests/nn/test_attentiongate.py b/tests/nn/test_attentiongate.py index 7a7bb18ec..99d463a00 100644 --- a/tests/nn/test_attentiongate.py +++ b/tests/nn/test_attentiongate.py @@ -1,5 +1,7 @@ """Tests for AttentionGate module.""" +from collections.abc import Sequence + import pytest from mrpro.nn.AttentionGate import AttentionGate from mrpro.utils import RandomGenerator @@ -13,26 +15,34 @@ ], ) @pytest.mark.parametrize( - ('dim', 'channels_gate', 'channels_in', 'channels_hidden', 'input_shape', 'gate_shape'), + ('n_dim', 'n_channels_gate', 'n_channels_in', 'n_channels_hidden', 'input_shape', 'gate_shape'), [ (2, 32, 32, 16, (1, 32, 32, 32), (1, 32, 16, 16)), (3, 32, 4, 8, (2, 4, 16, 16, 16), (2, 32, 16, 16, 16)), ], ) -def test_attention_gate(dim, channels_gate, channels_in, channels_hidden, input_shape, gate_shape, device): +def test_attention_gate( + n_dim: int, + n_channels_gate: int, + n_channels_in: int, + n_channels_hidden: int, + input_shape: Sequence[int], + gate_shape: Sequence[int], + device: str, +) -> None: """Test AttentionGate output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) gate = rng.float32_tensor(gate_shape).to(device).requires_grad_(True) attn = AttentionGate( - n_dim=dim, channels_gate=channels_gate, channels_in=channels_in, channels_hidden=channels_hidden + n_dim=n_dim, channels_gate=n_channels_gate, channels_in=n_channels_in, channels_hidden=n_channels_hidden ).to(device) output = attn(x, gate) assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() assert x.grad is not None, 'No gradient computed for input' assert gate.grad is not None, 'No gradient computed for gate' - assert not x.isnan().any(), 'NaN values in input' + assert not output.isnan().any(), 'NaN values in output' assert not gate.isnan().any(), 'NaN values in gate' assert not x.grad.isnan().any(), 'NaN values in input gradients' assert not gate.grad.isnan().any(), 'NaN values in gate gradients' diff --git a/tests/nn/test_complexaschannel.py b/tests/nn/test_complexaschannel.py index 5eb1e87f8..37889f654 100644 --- a/tests/nn/test_complexaschannel.py +++ b/tests/nn/test_complexaschannel.py @@ -13,7 +13,7 @@ pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), ], ) -def test_complexaschannel(device): +def test_complexaschannel(device: str) -> None: """Test ComplexAsChannel output shape and backpropagation.""" rng = RandomGenerator(seed=42) input_shape = (1, 32) @@ -24,7 +24,7 @@ def test_complexaschannel(device): assert output.is_complex(), 'Output is not complex' output.sum().abs().backward() assert x.grad is not None, 'No gradient computed for input' - assert not x.isnan().any(), 'NaN values in input' + assert not output.isnan().any(), 'NaN values in output' assert not x.grad.isnan().any(), 'NaN values in input gradients' assert module.module.weight.grad is not None, 'No gradient computed for weight' assert module.module.bias.grad is not None, 'No gradient computed for bias' diff --git a/tests/nn/test_droppath.py b/tests/nn/test_droppath.py index 323fffd5b..ff66c69d5 100644 --- a/tests/nn/test_droppath.py +++ b/tests/nn/test_droppath.py @@ -1,14 +1,22 @@ """Test DropPath.""" +import pytest from mrpro.nn.DropPath import DropPath from mrpro.utils import RandomGenerator -def test_droppath_no_drop(): +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_droppath_no_drop(device): """Test DropPath.""" rng = RandomGenerator(seed=42) - x = rng.float32_tensor((1, 3, 4, 5)) - droppath = DropPath(0) + x = rng.float32_tensor((1, 3, 4, 5)).to(device) + droppath = DropPath(0).to(device) y = droppath(x) assert (y == x).all() diff --git a/tests/nn/test_film.py b/tests/nn/test_film.py index d49cd476b..40ada0a1b 100644 --- a/tests/nn/test_film.py +++ b/tests/nn/test_film.py @@ -1,5 +1,7 @@ """Tests for FiLM module.""" +from collections.abc import Sequence + import pytest from mrpro.nn.FiLM import FiLM from mrpro.utils import RandomGenerator @@ -13,24 +15,26 @@ ], ) @pytest.mark.parametrize( - ('channels', 'channels_cond', 'input_shape', 'cond_shape'), + ('n_channels', 'n_channels_cond', 'input_shape', 'cond_shape'), [ (64, 32, (1, 64, 32, 32), (1, 32)), (32, 16, (2, 32, 16, 16), (2, 16)), ], ) -def test_film(channels, channels_cond, input_shape, cond_shape, device): +def test_film( + n_channels: int, n_channels_cond: int, input_shape: Sequence[int], cond_shape: Sequence[int], device: str +) -> None: """Test FiLM output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) - film = FiLM(channels=channels, cond_dim=channels_cond).to(device) + film = FiLM(channels=n_channels, cond_dim=n_channels_cond).to(device) output = film(x, cond=cond) assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() assert x.grad is not None, 'No gradient computed for input' assert cond.grad is not None, 'No gradient computed for condedding' - assert not x.isnan().any(), 'NaN values in input' + assert not output.isnan().any(), 'NaN values in output' assert not cond.isnan().any(), 'NaN values in condedding' assert not x.grad.isnan().any(), 'NaN values in input gradients' assert not cond.grad.isnan().any(), 'NaN values in condedding gradients' diff --git a/tests/nn/test_geglu.py b/tests/nn/test_geglu.py index c412b8779..9de03103c 100644 --- a/tests/nn/test_geglu.py +++ b/tests/nn/test_geglu.py @@ -1,15 +1,23 @@ """Test GEGLU.""" +import pytest import torch from mrpro.nn.GEGLU import GEGLU from mrpro.utils import RandomGenerator -def test_geglu(): +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_geglu(device: str) -> None: """Test GELU.""" rng = RandomGenerator(seed=42) - x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) - gelu = GEGLU(3, 4) + x = rng.float32_tensor((1, 3, 4, 5)).to(device).requires_grad_(True) + gelu = GEGLU(3, 4).to(device) y = gelu(x) assert y.shape == (1, 4, 4, 5) @@ -18,7 +26,7 @@ def test_geglu(): assert gelu.proj.weight.grad is not None -def test_geglu_features_last(): +def test_geglu_features_last() -> None: """Test GELU with features last.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) diff --git a/tests/nn/test_groupnorm.py b/tests/nn/test_groupnorm.py new file mode 100644 index 000000000..945860bca --- /dev/null +++ b/tests/nn/test_groupnorm.py @@ -0,0 +1,39 @@ +"""Tests for GroupNorm module.""" + +from collections.abc import Sequence + +import pytest +from mrpro.nn import GroupNorm +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels', 'n_groups', 'input_shape', 'affine'), + [ + (32, None, (1, 32, 32, 32), True), + (64, 8, (2, 64, 16, 16, 16), False), + ], +) +def test_groupnorm(n_channels: int, n_groups: int, input_shape: Sequence[int], device: str, affine: bool) -> None: + """Test GroupNorm output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + norm = GroupNorm(n_channels=n_channels, n_groups=n_groups, affine=affine).to(device) + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + if affine: + assert norm.weight is not None, 'Weight should not be None when affine is True' + assert norm.weight.grad is not None, 'No gradient computed for weight' + assert norm.bias is not None, 'Bias should not be None when affine is True' + assert norm.bias.grad is not None, 'No gradient computed for bias' diff --git a/tests/nn/test_groupnorm32.py b/tests/nn/test_groupnorm32.py deleted file mode 100644 index 468541aef..000000000 --- a/tests/nn/test_groupnorm32.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Tests for GroupNorm32 module.""" - -import pytest -from mrpro.nn import GroupNorm -from mrpro.utils import RandomGenerator - - -@pytest.mark.parametrize( - 'device', - [ - pytest.param('cpu', id='cpu'), - pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), - ], -) -@pytest.mark.parametrize( - ('channels', 'groups', 'input_shape'), - [ - (32, None, (1, 32, 32, 32)), - (64, 8, (2, 64, 16, 16, 16)), - ], -) -def test_groupnorm32(channels, groups, input_shape, device): - """Test GroupNorm32 output shape and backpropagation.""" - rng = RandomGenerator(seed=42) - x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) - norm = GroupNorm(n_channels=channels, n_groups=groups).to(device) - output = norm(x) - assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' - output.sum().backward() - assert x.grad is not None, 'No gradient computed for input' - assert not x.isnan().any(), 'NaN values in input' - assert not x.grad.isnan().any(), 'NaN values in input gradients' - assert norm.weight.grad is not None, 'No gradient computed for weight' - assert norm.bias.grad is not None, 'No gradient computed for bias' diff --git a/tests/nn/test_layernorm.py b/tests/nn/test_layernorm.py new file mode 100644 index 000000000..85dc136a3 --- /dev/null +++ b/tests/nn/test_layernorm.py @@ -0,0 +1,175 @@ +"""Tests for LayerNorm module.""" + +import pytest +import torch +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels', 'features_last', 'input_shape'), + [ + (32, False, (1, 32, 32, 32)), + (64, True, (2, 16, 16, 64)), + (None, False, (1, 32, 32, 32)), + (None, True, (2, 16, 16, 64)), + ], +) +def test_layernorm_basic(n_channels, features_last, input_shape, device): + """Test LayerNorm basic functionality.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + norm = LayerNorm(n_channels=n_channels, features_last=features_last).to(device) + output = norm(x) + + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + if n_channels is not None: + assert norm.weight is not None, 'Weight should not be None when n_channels is provided' + assert norm.bias is not None, 'Bias should not be None when n_channels is provided' + assert norm.weight.grad is not None, 'No gradient computed for weight' + assert norm.bias.grad is not None, 'No gradient computed for bias' + + +@pytest.mark.parametrize( + ('n_channels', 'cond_dim', 'input_shape', 'cond_shape'), + [ + (32, 16, (1, 32, 32, 32), (1, 16)), + (64, 32, (2, 64, 16, 16), (2, 32)), + ], +) +def test_layernorm_with_conditioning(n_channels, cond_dim, input_shape, cond_shape): + """Test LayerNorm with conditioning.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).requires_grad_(True) + norm = LayerNorm(n_channels=n_channels, cond_dim=cond_dim) + + output = norm(x, cond=cond) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert cond.grad is not None, 'No gradient computed for conditioning' + assert norm.cond_proj is not None, 'cond_proj should not be None when cond_dim > 0' + assert norm.cond_proj.weight.grad is not None, 'No gradient computed for cond_proj' + + +def test_layernorm_features_last(): + """Test LayerNorm with features_last=True.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + + norm_standard = LayerNorm(n_channels=3, features_last=False) + y_standard = norm_standard(x) + + norm_last = LayerNorm(n_channels=3, features_last=True) + y_last = norm_last(x.moveaxis(1, -1)) + + torch.testing.assert_close(y_standard, y_last.moveaxis(-1, 1)) + + +def test_layernorm_no_channels(): + """Test LayerNorm without channels (pure normalization).""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 32, 32)).requires_grad_(True) + norm = LayerNorm(n_channels=None) + + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + # Check that normalization is applied (mean close to 0, std close to 1) + dims = tuple(range(1, x.ndim)) + mean = output.mean(dim=dims) + std = output.std(dim=dims) + + assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-6), 'Mean not close to 0' + assert torch.allclose(std, torch.ones_like(std), atol=1e-5), 'Std not close to 1' + + +def test_layernorm_conditioning_without_channels(): + """Test LayerNorm with conditioning but no channels (should raise error).""" + with pytest.raises(ValueError, match='channels must be provided if cond_dim > 0'): + LayerNorm(n_channels=None, cond_dim=16) + + +def test_layernorm_invalid_cond_dim(): + """Test LayerNorm with invalid cond_dim.""" + with pytest.raises(RuntimeError, match='Trying to create tensor with negative dimension'): + LayerNorm(n_channels=32, cond_dim=-1) + + +def test_layernorm_3d_input(): + """Test LayerNorm with 3D input.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((2, 64, 128)).requires_grad_(True) + norm = LayerNorm(n_channels=64) + + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + + +def test_layernorm_5d_input(): + """Test LayerNorm with 5D input.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 16, 16, 16)).requires_grad_(True) + norm = LayerNorm(n_channels=32) + + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + + +def test_layernorm_conditioning_features_last(): + """Test LayerNorm with conditioning and features_last=True.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + cond = rng.float32_tensor((1, 8)).requires_grad_(True) + + norm = LayerNorm(n_channels=3, features_last=True, cond_dim=8) + output = norm(x.moveaxis(1, -1), cond=cond) + + assert output.shape == x.moveaxis(1, -1).shape, f'Output shape {output.shape} != expected shape' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert cond.grad is not None, 'No gradient computed for conditioning' + + +def test_layernorm_gradient_flow(): + """Test that gradients flow properly through LayerNorm.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 32, 32)).requires_grad_(True) + norm = LayerNorm(n_channels=32) + + output = norm(x) + loss = output.sum() + loss.backward() + + # Check that gradients are computed for all learnable parameters + assert x.grad is not None, 'Input gradients not computed' + assert norm.weight is not None, 'Weight should not be None when n_channels is provided' + assert norm.bias is not None, 'Bias should not be None when n_channels is provided' + assert norm.weight.grad is not None, 'Weight gradients not computed' + assert norm.bias.grad is not None, 'Bias gradients not computed' + + # Check that gradients are finite + assert torch.isfinite(x.grad).all(), 'Input gradients contain non-finite values' + assert torch.isfinite(norm.weight.grad).all(), 'Weight gradients contain non-finite values' + assert torch.isfinite(norm.bias.grad).all(), 'Bias gradients contain non-finite values' diff --git a/tests/nn/test_resblock.py b/tests/nn/test_resblock.py index 6df1fce7f..dfbfc8a9e 100644 --- a/tests/nn/test_resblock.py +++ b/tests/nn/test_resblock.py @@ -31,7 +31,7 @@ def test_resblock(dim, channels_in, channels_out, cond_dim, input_shape, cond_sh ) output.sum().backward() assert x.grad is not None, 'No gradient computed for input' - assert not x.isnan().any(), 'NaN values in input' + assert not output.isnan().any(), 'NaN values in output' assert not x.grad.isnan().any(), 'NaN values in input gradients' assert res.block[2].weight.grad is not None, 'No gradient computed for first Conv' if cond is not None: diff --git a/tests/nn/test_sequential.py b/tests/nn/test_sequential.py index 83e585498..9d382a6a0 100644 --- a/tests/nn/test_sequential.py +++ b/tests/nn/test_sequential.py @@ -2,7 +2,7 @@ import pytest from mrpro.nn import FiLM, Sequential -from mrpro.operators import FastFourierOp +from mrpro.operators import FastFourierOp, MagnitudeOp from mrpro.utils import RandomGenerator from torch.nn import Linear @@ -18,7 +18,7 @@ ('input_shape', 'cond_dim'), [ ((1, 32), (1, 16)), - ((2, 64), None), + ((2, 32), None), ], ) def test_sequential(input_shape, cond_dim, device): @@ -28,14 +28,17 @@ def test_sequential(input_shape, cond_dim, device): cond = rng.float32_tensor(cond_dim).to(device).requires_grad_(True) if cond_dim else None seq = Sequential( Linear(input_shape[1], 64), - FastFourierOp(), + FastFourierOp(dim=(-1,)), FiLM(channels=64, cond_dim=16), + MagnitudeOp(), ).to(device) output = seq(x, cond=cond) - assert output.shape == (input_shape[0], 32), f'Output shape {output.shape} != expected {(input_shape[0], 32)}' + assert output.shape == (input_shape[0], 64) output.sum().backward() assert x.grad is not None, 'No gradient computed for input' - assert not x.isnan().any(), 'NaN values in input' + assert not output.isnan().any(), 'NaN values in output' assert not x.grad.isnan().any(), 'NaN values in input gradients' - assert seq[0].weight.grad is not None, 'No gradient computed for first Linear' - assert seq[2].weight.grad is not None, 'No gradient computed for second Linear' + if cond is not None: + assert cond.grad is not None, 'No gradient computed for cond' + assert not cond.grad.isnan().any(), 'NaN values in cond gradients' + assert seq[0].weight.grad is not None, 'No gradient computed for Linear' diff --git a/tests/nn/test_shiftedwindowattention.py b/tests/nn/test_shiftedwindowattention.py index 3e30b7fde..9ccd4f5d0 100644 --- a/tests/nn/test_shiftedwindowattention.py +++ b/tests/nn/test_shiftedwindowattention.py @@ -17,17 +17,15 @@ (4, 4, True), ], ) -def test_shifted_window_attentio(dim: int, window_size: int, shifted: bool, device: str) -> None: - batch = 2 - channels = 8 - n_heads = 2 +def test_shifted_window_attention(dim: int, window_size: int, shifted: bool, device: str) -> None: + n_batch, n_channels, n_heads = 2, 8, 2 spatial_shape = (window_size * 4,) * dim rng = RandomGenerator(13) - x = rng.float32_tensor((batch, channels, *spatial_shape)).to(device).requires_grad_(True) + x = rng.float32_tensor((n_batch, n_channels, *spatial_shape)).to(device).requires_grad_(True) swin = ShiftedWindowAttention( n_dim=dim, - n_channels_in=channels, - n_channels_out=channels, + n_channels_in=n_channels, + n_channels_out=n_channels, n_heads=n_heads, window_size=window_size, shifted=shifted, diff --git a/tests/nn/test_sqeezeexcitation.py b/tests/nn/test_sqeezeexcitation.py index 369ad0f3c..b0ddf7050 100644 --- a/tests/nn/test_sqeezeexcitation.py +++ b/tests/nn/test_sqeezeexcitation.py @@ -21,6 +21,6 @@ def test_squeeze_excitation(dim, input_shape, squeeze_channels): assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() assert x.grad is not None, 'No gradient computed for input' - assert not x.isnan().any(), 'NaN values in input' + assert not output.isnan().any(), 'NaN values in output' assert not x.grad.isnan().any(), 'NaN values in input gradients' assert se.scale[1].weight.grad is not None, 'No gradient computed for Conv' diff --git a/tests/nn/test_transposedattention.py b/tests/nn/test_transposedattention.py index 361dc5799..8b72b071f 100644 --- a/tests/nn/test_transposedattention.py +++ b/tests/nn/test_transposedattention.py @@ -28,7 +28,7 @@ def test_transposed_attention(dim, channels, num_heads, input_shape, device): assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() assert x.grad is not None, 'No gradient computed for input' - assert not x.isnan().any(), 'NaN values in input' + assert not output.isnan().any(), 'NaN values in output' assert not x.grad.isnan().any(), 'NaN values in input gradients' assert attn.to_qkv.weight.grad is not None, 'No gradient computed for qkv' assert attn.qkv_dwconv.weight.grad is not None, 'No gradient computed for qkv_dwconv' From b3c8d4081d21f7ad7e8ee0d7c56ea8d34c1a798a Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 14 Jul 2025 13:37:38 +0200 Subject: [PATCH 106/205] fix swin attention --- src/mrpro/nn/ShiftedWindowAttention.py | 38 ++++++++++++++++---------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/src/mrpro/nn/ShiftedWindowAttention.py b/src/mrpro/nn/ShiftedWindowAttention.py index 190418964..5c960dbbb 100644 --- a/src/mrpro/nn/ShiftedWindowAttention.py +++ b/src/mrpro/nn/ShiftedWindowAttention.py @@ -2,9 +2,8 @@ import torch from einops import rearrange -from torch.nn import Module +from torch.nn import Linear, Module -from mrpro.nn.ndmodules import ConvND from mrpro.utils.reshape import ravel_multi_index from mrpro.utils.sliding_window import sliding_window @@ -30,6 +29,7 @@ def __init__( n_heads: int, window_size: int = 7, shifted: bool = True, + features_last: bool = False, ): """Initialize the ShiftedWindowAttention module. @@ -47,15 +47,18 @@ def __init__( The size of the window. shifted Whether to shift the window. + features_last + Whether the features are last in the input tensor or in the second dimension. """ super().__init__() self.n_heads = n_heads self.window_size = window_size self.shifted = shifted + self.features_last = features_last channels_per_head = n_channels_in // n_heads - self.to_qkv = ConvND(n_dim)(channels_per_head * n_heads, 3 * channels_per_head * n_heads, 1) - self.to_out = ConvND(n_dim)(channels_per_head * n_heads, n_channels_out, 1) - self.dim = n_dim + self.to_qkv = Linear(channels_per_head * n_heads, 3 * channels_per_head * n_heads) + self.to_out = Linear(channels_per_head * n_heads, n_channels_out) + self.n_dim = n_dim coords_1d = torch.arange(window_size) coords_nd = torch.stack(torch.meshgrid(*([coords_1d] * n_dim), indexing='ij'), 0).flatten(1) rel_coords = coords_nd[:, :, None] - coords_nd[:, None, :] # (dim, window_size**dim, window_size**dim) @@ -82,25 +85,32 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the ShiftedWindowAttention.""" + if not self.features_last: + x = x.moveaxis(1, -1) # now it is features last if self.shifted: - x = torch.roll(x, (-(self.window_size // 2),) * self.dim, dims=tuple(range(-self.dim, 0))) + x = torch.roll(x, (-(self.window_size // 2),) * self.n_dim, dims=tuple(range(-self.n_dim - 1, -1))) qkv = self.to_qkv(x) - windowed = sliding_window(qkv, window_shape=self.window_size, stride=self.window_size, dim=range(-self.dim, 0)) - flat = windowed.flatten(0, self.dim - 1).flatten(-self.dim) + windowed = sliding_window( + qkv, window_shape=self.window_size, stride=self.window_size, dim=range(-self.n_dim - 1, -1) + ) q, k, v = rearrange( - flat, - 'spatial batch (qkv heads channels) window->qkv spatial batch heads window channels', + windowed.flatten(-self.n_dim - 1, -2), + '... sequence (qkv heads channels)->qkv ... heads sequence channels', heads=self.n_heads, qkv=3, ) bias = rearrange(self.relative_position_bias_table[self.rel_position_index], 'wd1 wd2 heads -> 1 heads wd1 wd2') attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=bias) - attention = rearrange(attention, 'spatial batch head window channels->batch (head channels) spatial window') - attention = attention.unflatten(-2, windowed.shape[: self.dim]).unflatten(-1, (self.window_size,) * self.dim) + attention = rearrange(attention, '... head sequence channels->... sequence (head channels)') + attention = attention.unflatten(-2, windowed.shape[-self.n_dim - 1 : -1]) # permute (in 3d) batch channels z y x wz wy wx -> batch channels wz z wy y wx x - attention = attention.moveaxis(list(range(-self.dim, 0)), list(range(3, 3 + 2 * self.dim, 2))) + attention = attention.moveaxis(list(range(self.n_dim)), list(range(2, 2 + 2 * self.n_dim, 2))) attention = attention.reshape(x.shape) if self.shifted: - attention = torch.roll(attention, (self.window_size // 2,) * self.dim, dims=tuple(range(-self.dim, 0))) + attention = torch.roll( + attention, (self.window_size // 2,) * self.n_dim, dims=tuple(range(-self.n_dim - 1, -1)) + ) out = self.to_out(attention) + if not self.features_last: + out = out.moveaxis(-1, 1) return out From a748ceda7514300e8ad43ed13d167553db1189c1 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 14 Jul 2025 13:54:50 +0200 Subject: [PATCH 107/205] allow reflection padding --- src/mrpro/operators/ZeroPadOp.py | 2 +- src/mrpro/utils/pad_or_crop.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/mrpro/operators/ZeroPadOp.py b/src/mrpro/operators/ZeroPadOp.py index c4adfc831..19f19b23e 100644 --- a/src/mrpro/operators/ZeroPadOp.py +++ b/src/mrpro/operators/ZeroPadOp.py @@ -5,7 +5,7 @@ import torch from mrpro.operators.LinearOperator import LinearOperator -from mrpro.utils import pad_or_crop +from mrpro.utils.pad_or_crop import pad_or_crop class ZeroPadOp(LinearOperator): diff --git a/src/mrpro/utils/pad_or_crop.py b/src/mrpro/utils/pad_or_crop.py index af61e7354..5bb82f21d 100644 --- a/src/mrpro/utils/pad_or_crop.py +++ b/src/mrpro/utils/pad_or_crop.py @@ -5,7 +5,6 @@ from typing import Literal import torch -import torch.nn.functional as F # noqa: N812 def normalize_index(ndim: int, index: int) -> int: @@ -35,7 +34,7 @@ def pad_or_crop( data: torch.Tensor, new_shape: Sequence[int] | torch.Size, dim: None | Sequence[int] = None, - mode: Literal['constant', 'replicate', 'circular'] = 'constant', + mode: Literal['constant', 'reflect', 'replicate', 'circular'] = 'constant', value: float = 0.0, ) -> torch.Tensor: """Change shape of data by center cropping or symmetric padding. @@ -50,9 +49,9 @@ def pad_or_crop( Dimensions the `new_shape` corresponds to. `None` is interpreted as last ``len(new_shape)`` dimensions. mode - Mode for padding. + Mode to use for padding. value - Value to use for padding. + Value to use for constant padding. Returns ------- @@ -82,5 +81,5 @@ def pad_or_crop( if any(npad): # F.pad expects paddings in reversed order - data = F.pad(data, npad[::-1], mode=mode, value=value) + data = torch.nn.functional.pad(data, npad[::-1], value=value, mode=mode) return data From 3e55d01d12165b9d38128393129f7693fbb1f57f Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 14 Jul 2025 17:11:35 +0200 Subject: [PATCH 108/205] Add RMSNorm module, update NeighborhoodSelfAttention to accept device parameter, and introduce unit tests for RMSNorm and NeighborhoodSelfAttention. Update BasicCNN to include activation type in parameters. --- src/mrpro/nn/NeighborhoodSelfAttention.py | 14 +++- src/mrpro/nn/RMSNorm.py | 50 ++++++++---- src/mrpro/nn/__init__.py | 2 + src/mrpro/nn/nets/BasicCNN.py | 2 + tests/nn/test_ndmodules.py | 76 ++++++++++++++++++ tests/nn/test_neighborhoodselfattention.py | 89 ++++++++++++++++++++++ tests/nn/test_rmsnorm.py | 58 ++++++++++++++ 7 files changed, 273 insertions(+), 18 deletions(-) create mode 100644 tests/nn/test_ndmodules.py create mode 100644 tests/nn/test_neighborhoodselfattention.py create mode 100644 tests/nn/test_rmsnorm.py diff --git a/src/mrpro/nn/NeighborhoodSelfAttention.py b/src/mrpro/nn/NeighborhoodSelfAttention.py index 5523e081e..71b4aeb06 100644 --- a/src/mrpro/nn/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/NeighborhoodSelfAttention.py @@ -16,6 +16,7 @@ @cache def neighborhood_mask( + device: str, input_size: torch.Size, kernel_size: int | tuple[int, ...], # tuples instead of Sequence for cache dilation: int | tuple[int, ...] = 1, @@ -42,6 +43,8 @@ def neighborhood_mask( circular Whether the neighborhood wraps around the edges (circular padding). Can be a single boolean or a sequence of booleans. + device + The device to create the mask on. Returns ------- @@ -98,7 +101,7 @@ def mask( return reduce(lambda x, y: x & y, masks) qkv_len = input_size.numel() - return create_block_mask(mask, B=None, H=None, Q_LEN=qkv_len, KV_LEN=qkv_len, _compile=True) + return torch.compile(create_block_mask)(mask, B=None, H=None, Q_LEN=qkv_len, KV_LEN=qkv_len, device=device) class NeighborhoodSelfAttention(Module): @@ -178,14 +181,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: spatial_shape = x.shape[1:-1] qkv = self.to_qkv(x) query, key, value = rearrange( - qkv, 'batch ... (qkv head channels) -> qkv batch head (...) channel', qkv=3, head=self.n_head + qkv, 'batch ... (qkv head channels) -> qkv batch head (...) channels', qkv=3, head=self.n_head ) # the mask depends on the input size. To be more flexible if used within CNNs, we compute it here. # The computation is cached.. mask = neighborhood_mask( - input_size=spatial_shape, kernel_size=self.kernel_size, dilation=self.dilation, circular=self.circular + device=str(qkv.device), + input_size=spatial_shape, + kernel_size=self.kernel_size, + dilation=self.dilation, + circular=self.circular, ) out: torch.Tensor = flex_attention(query.contiguous(), key.contiguous(), value.contiguous(), block_mask=mask) # type: ignore[assignment] # wrong type hints + out = rearrange(out, 'batch head sequence channels -> batch sequence(head channels)') out = self.to_out(out) out = out.unflatten(-2, spatial_shape) if not self.features_last: diff --git a/src/mrpro/nn/RMSNorm.py b/src/mrpro/nn/RMSNorm.py index 28cecbf9f..b97641545 100644 --- a/src/mrpro/nn/RMSNorm.py +++ b/src/mrpro/nn/RMSNorm.py @@ -1,29 +1,46 @@ -"""RMSNorm module for root mean square normalization.""" +"""RMSNorm over the channel dimension.""" import torch from torch.nn import Module, Parameter class RMSNorm(Module): - """RMSNorm over the channel dimension.""" + """RMSNorm over the channel dimension. - def __init__(self, channels: int, eps: float = 1e-8, features_last: bool = False): + As used in the DCAE [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_channels: int | None = None, + eps: float = 1e-8, + features_last: bool = False, + ): """Initialize RMSNorm. - Includes a learnable weight and bias. + Includes a learnable weight and bias if n_channels is provided. Parameters ---------- - channels - Number of channels. + n_channels + Number of channels. If `None`, no learnable weight and bias are included. eps Epsilon value to avoid division by zero. features_last If True, the channel dimension is the last dimension. """ super().__init__() - self.weight = Parameter(torch.zeros(channels)) - self.bias = Parameter(torch.zeros(channels)) + if n_channels is not None: + self.weight: Parameter | None = Parameter(torch.zeros(n_channels)) + self.bias: Parameter | None = Parameter(torch.zeros(n_channels)) + else: + self.weight = None + self.bias = None self.eps = eps self.channel_dim = -1 if features_last else 1 @@ -43,11 +60,14 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply RMSNorm over the channel dimension.""" - mean_square = x.square().mean(dim=self.channel_dim, keepdim=True) + x32 = x.to(torch.float32) # normalization in float32 to stabilize mixed precision training + mean_square = x32.square().mean(dim=self.channel_dim, keepdim=True) scale = (mean_square + self.eps).rsqrt() - x = x * scale - shape = [1] * x.ndim - shape[self.channel_dim] = -1 - weight = (self.weight + 1).view(shape) - bias = self.bias.view(shape) - return x * weight + bias + x32 = x32 * scale + if self.weight is not None and self.bias is not None: + shape = [1] * x.ndim + shape[self.channel_dim] = -1 + weight = (self.weight.to(x32.dtype) + 1).view(shape) + bias = self.bias.view(shape) + x32 = x32 * weight + bias + return x32.to(x.dtype) diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index 9d2b1dff4..b5a67e17e 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -24,6 +24,7 @@ from mrpro.nn.ComplexAsChannel import ComplexAsChannel from mrpro.nn import nets from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.RMSNorm import RMSNorm __all__ = [ "AdaptiveAvgPoolND", @@ -41,6 +42,7 @@ "MaxPoolND", "NeighborhoodSelfAttention", "PermutedBlock", + "RMSNorm", "ResBlock", "Residual", "Sequential", diff --git a/src/mrpro/nn/nets/BasicCNN.py b/src/mrpro/nn/nets/BasicCNN.py index 280b79cd0..f294de715 100644 --- a/src/mrpro/nn/nets/BasicCNN.py +++ b/src/mrpro/nn/nets/BasicCNN.py @@ -47,6 +47,8 @@ def __init__( The type of normalization to use. If 'batch', use batch normalization. If 'group', use group normalization, if 'instance', use instance normalization, and if `layer`, use layer normalization. If 'none', use no normalization. + activation + The type of activation to use. If 'relu', use ReLU. If 'silu', use SiLU. If 'leaky_relu', use LeakyReLU. n_features The number of features in the hidden layers. The length of this sequence determines the number of hidden layers. The total number of convolutions is `len(n_features) + 1`. diff --git a/tests/nn/test_ndmodules.py b/tests/nn/test_ndmodules.py new file mode 100644 index 000000000..b34999170 --- /dev/null +++ b/tests/nn/test_ndmodules.py @@ -0,0 +1,76 @@ +"""Tests for the ndmodules module.""" + +import pytest +import torch +from mrpro.nn.ndmodules import ( + AdaptiveAvgPoolND, + AvgPoolND, + BatchNormND, + ConvND, + ConvTransposeND, + InstanceNormND, + MaxPoolND, +) + + +def test_convnd() -> None: + """Test ConvND.""" + assert ConvND(1) is torch.nn.Conv1d + assert ConvND(2) is torch.nn.Conv2d + assert ConvND(3) is torch.nn.Conv3d + with pytest.raises(NotImplementedError): + ConvND(4) + + +def test_convtransposend() -> None: + """Test ConvTransposeND.""" + assert ConvTransposeND(1) is torch.nn.ConvTranspose1d + assert ConvTransposeND(2) is torch.nn.ConvTranspose2d + assert ConvTransposeND(3) is torch.nn.ConvTranspose3d + with pytest.raises(NotImplementedError): + ConvTransposeND(4) + + +def test_maxpoolnd() -> None: + """Test MaxPoolND.""" + assert MaxPoolND(1) is torch.nn.MaxPool1d + assert MaxPoolND(2) is torch.nn.MaxPool2d + assert MaxPoolND(3) is torch.nn.MaxPool3d + with pytest.raises(NotImplementedError): + MaxPoolND(4) + + +def test_avgpoolnd() -> None: + """Test AvgPoolND.""" + assert AvgPoolND(1) is torch.nn.AvgPool1d + assert AvgPoolND(2) is torch.nn.AvgPool2d + assert AvgPoolND(3) is torch.nn.AvgPool3d + with pytest.raises(NotImplementedError): + AvgPoolND(4) + + +def test_adaptiveavgpoolnd() -> None: + """Test AdaptiveAvgPoolND.""" + assert AdaptiveAvgPoolND(1) is torch.nn.AdaptiveAvgPool1d + assert AdaptiveAvgPoolND(2) is torch.nn.AdaptiveAvgPool2d + assert AdaptiveAvgPoolND(3) is torch.nn.AdaptiveAvgPool3d + with pytest.raises(NotImplementedError): + AdaptiveAvgPoolND(4) + + +def test_instancenormnd() -> None: + """Test InstanceNormND.""" + assert InstanceNormND(1) is torch.nn.InstanceNorm1d + assert InstanceNormND(2) is torch.nn.InstanceNorm2d + assert InstanceNormND(3) is torch.nn.InstanceNorm3d + with pytest.raises(NotImplementedError): + InstanceNormND(4) + + +def test_batchnormnd() -> None: + """Test BatchNormND.""" + assert BatchNormND(1) is torch.nn.BatchNorm1d + assert BatchNormND(2) is torch.nn.BatchNorm2d + assert BatchNormND(3) is torch.nn.BatchNorm3d + with pytest.raises(NotImplementedError): + BatchNormND(4) diff --git a/tests/nn/test_neighborhoodselfattention.py b/tests/nn/test_neighborhoodselfattention.py new file mode 100644 index 000000000..5e66f5be9 --- /dev/null +++ b/tests/nn/test_neighborhoodselfattention.py @@ -0,0 +1,89 @@ +"""Tests for NeighborhoodSelfAttention module.""" + +import pytest +from mrpro.nn import NeighborhoodSelfAttention +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels_in', 'n_channels_out', 'n_heads', 'kernel_size', 'input_shape', 'features_last'), + [ + (2, 3, 1, 2, (1, 2, 16, 16), False), + (3, 2, 2, 4, (1, 3, 8, 8, 8, 8), True), + ], +) +def test_neighborhood_self_attention( + n_channels_in: int, + n_channels_out: int, + n_heads: int, + kernel_size: int, + input_shape: tuple[int, ...], + features_last: bool, + device: str, +) -> None: + """Test NeighborhoodSelfAttention output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + + if features_last: + x = x.moveaxis(1, -1) + + attn = NeighborhoodSelfAttention( + n_channels_in=n_channels_in, + n_channels_out=n_channels_out, + n_heads=n_heads, + kernel_size=kernel_size, + features_last=features_last, + ).to(device) + + output = attn(x) + + expected_shape = (x.shape[0], n_channels_out, *x.shape[2:]) + assert output.shape == expected_shape + assert not output.isnan().any(), 'NaN values in output' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + assert attn.to_qkv.weight.grad is not None, 'No gradient computed for to_qkv.weight' + assert attn.to_qkv.bias.grad is not None, 'No gradient computed for to_qkv.bias' + assert attn.to_out.weight.grad is not None, 'No gradient computed for to_out.weight' + assert attn.to_out.bias.grad is not None, 'No gradient computed for to_out.bias' + + +@pytest.mark.parametrize( + ('kernel_size', 'dilation', 'circular'), + [ + (3, 1, False), + (5, 2, True), + (7, 1, False), + ], +) +def test_neighborhood_attention_variants(kernel_size: int, dilation: int, circular: bool) -> None: + """Test NeighborhoodSelfAttention with different neighborhood configurations.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 16, 16)).requires_grad_(True) + + attn = NeighborhoodSelfAttention( + n_channels_in=32, + n_channels_out=32, + n_heads=4, + kernel_size=kernel_size, + dilation=dilation, + circular=circular, + ) + + output = attn(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' diff --git a/tests/nn/test_rmsnorm.py b/tests/nn/test_rmsnorm.py new file mode 100644 index 000000000..aab133da0 --- /dev/null +++ b/tests/nn/test_rmsnorm.py @@ -0,0 +1,58 @@ +"""Tests for RMSNorm module.""" + +from collections.abc import Sequence + +import pytest +import torch +from mrpro.nn import RMSNorm +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels', 'features_last', 'input_shape'), + [ + (32, False, (1, 32, 32, 32)), + (64, True, (2, 16, 16, 64)), + (None, False, (1, 32, 32, 32)), + (None, True, (2, 16, 16, 64)), + ], +) +def test_rmsnorm_basic(n_channels: int | None, features_last: bool, input_shape: Sequence[int], device: str) -> None: + """Test RMSNorm basic functionality.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + norm = RMSNorm(n_channels=n_channels, features_last=features_last).to(device) + output = norm(x) + + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + if n_channels is not None: + assert norm.weight is not None, 'Weight should not be None when n_channels is provided' + assert norm.bias is not None, 'Bias should not be None when n_channels is provided' + assert norm.weight.grad is not None, 'No gradient computed for weight' + assert norm.bias.grad is not None, 'No gradient computed for bias' + + +def test_rmsnorm_features_last(): + """Test RMSNorm with features_last=True.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + + norm_standard = RMSNorm(n_channels=3, features_last=False) + y_standard = norm_standard(x) + + norm_last = RMSNorm(n_channels=3, features_last=True) + y_last = norm_last(x.moveaxis(1, -1)) + + torch.testing.assert_close(y_standard, y_last.moveaxis(-1, 1)) From f734e3f12c31ab2ce6a0a5e1fc2fa71644fc0856 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 14 Jul 2025 23:55:14 +0200 Subject: [PATCH 109/205] Add LinearSelfAttention module, update join.py to accept string mode, and refactor padding logic in pad_or_crop.py. Remove unused CNN.py file and add unit tests for join and LinearSelfAttention modules. --- src/mrpro/nn/__init__.py | 2 + src/mrpro/nn/join.py | 6 +- src/mrpro/nn/nets/CNN.py | 60 ---------- src/mrpro/nn/nets/UNet.py | 16 ++- src/mrpro/utils/pad_or_crop.py | 14 ++- tests/nn/test_join.py | 160 +++++++++++++++++++++++++++ tests/nn/test_linearselfattention.py | 58 ++++++++++ 7 files changed, 245 insertions(+), 71 deletions(-) delete mode 100644 src/mrpro/nn/nets/CNN.py create mode 100644 tests/nn/test_join.py create mode 100644 tests/nn/test_linearselfattention.py diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index b5a67e17e..ab5c13513 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -19,6 +19,7 @@ from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention from mrpro.nn.SqueezeExcitation import SqueezeExcitation from mrpro.nn.TransposedAttention import TransposedAttention +from mrpro.nn.LinearSelfAttention import LinearSelfAttention from mrpro.nn.DropPath import DropPath from mrpro.nn.Residual import Residual from mrpro.nn.ComplexAsChannel import ComplexAsChannel @@ -39,6 +40,7 @@ "FiLM", "GroupNorm", "InstanceNormND", + "LinearSelfAttention", "MaxPoolND", "NeighborhoodSelfAttention", "PermutedBlock", diff --git a/src/mrpro/nn/join.py b/src/mrpro/nn/join.py index 0aed41b8d..204f301a8 100644 --- a/src/mrpro/nn/join.py +++ b/src/mrpro/nn/join.py @@ -12,7 +12,7 @@ def _fix_shapes( xs: Sequence[torch.Tensor], - mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'], + mode: str, dim: Sequence[int], ) -> tuple[torch.Tensor, ...]: """Fix shapes of input tensors by padding or cropping.""" @@ -29,7 +29,7 @@ def _fix_shapes( if mode == 'zero' or mode == 'crop': return tuple(pad_or_crop(x, target, dim=dim, mode='constant', value=0.0) for x in xs) else: - return tuple(pad_or_crop(x, target, dim=dim, mode=mode) for x in xs) + return tuple(pad_or_crop(x, target, dim=dim, mode=mode) for x in xs) # type: ignore class Concat(Module): @@ -55,7 +55,7 @@ def __init__( Dimension along which to concatenate. """ super().__init__() - modes = {'fail', 'crop', 'zero', 'replicate', 'circular', 'interpolate'} + modes = {'fail', 'crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'} if mode not in modes: raise ValueError(f'mode must be one of {modes}') self.mode = mode diff --git a/src/mrpro/nn/nets/CNN.py b/src/mrpro/nn/nets/CNN.py deleted file mode 100644 index 3fbabcc4f..000000000 --- a/src/mrpro/nn/nets/CNN.py +++ /dev/null @@ -1,60 +0,0 @@ -"""Simple Convolutional Neural Network.""" - -from collections.abc import Sequence -from itertools import pairwise - -from torch.nn import ReLU - -from mrpro.nn.FiLM import FiLM -from mrpro.nn.GroupNorm import GroupNorm -from mrpro.nn.ndmodules import ConvND -from mrpro.nn.Residual import Residual -from mrpro.nn.Sequential import Sequential - - -class CNN(Sequential): - """A simple CNN network.""" - - def __init__( - self, - dim: int, - channels_in: int, - channels_out: int, - features: Sequence[int], - norm: bool = True, - residual: bool = True, - cond_dim: int = 0, - ): - """Initialize the CNN. - - Parameters - ---------- - dim - The number of spatial dimensions. - channels_in - The number of input channels. - channels_out - The number of output channels. - features - The number of features in each layer. The length of the list is the number of hidden layers. - norm - Whether to use layer normalization. - residual - Whether to use residual connections. - cond_dim - The dimension of the conditioning tensor. If 0, no FiLM is used. - """ - super().__init__() - channels = [channels_in, *features] - for i, (channels_current, channels_next) in enumerate(pairwise(channels)): - block = Sequential(ConvND(dim)(channels_current, channels_next, 3, padding=1), ReLU(True)) - if norm: - block.append(GroupNorm(1)) - if cond_dim > 0 and i % 2 == 0: - block.append(FiLM(channels_next, cond_dim)) - if residual: - self.append(Residual(block)) - else: - self.append(block) - - self.append(ConvND(dim)(channels_next, channels_out, 3, padding=1)) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index f0b8ffa41..14db1d3a0 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -154,11 +154,11 @@ def __init__(self, encoder: UNetEncoder, decoder: UNetDecoder, skip_blocks: Sequ self.skip_blocks = ModuleList() """Modifications of the skip connections.""" - # if len(decoder) != len(encoder): - # raise ValueError( - # 'The number of resolutions in the encoder and decoder must be the same, ' - # f'got {len(decoder)} and {len(encoder)}' - # ) + if len(decoder) != len(encoder): + raise ValueError( + 'The number of resolutions in the encoder and decoder must be the same, ' + f'got {len(decoder)} and {len(encoder)}' + ) if skip_blocks is None: self.skip_blocks.extend(Identity() for _ in range(len(decoder))) @@ -241,17 +241,21 @@ class UNet(UNetBase): """UNet. U-shaped convolutional network with optional patch attention. - Inspired by the OpenAi DDPM UNet/Latent Diffusion UNet [LDM]_, + Inspired by [NOSENSE_] and the OpenAi DDPM UNet/Latent Diffusion UNet [LDM]_. significant differences to the vanilla UNet [UNET]_ include: - Spatial transformer blocks - Convolutional downsampling, nearest neighbor upsampling - Residual convolution blocks with pre-act group normalization and SiLU activation + References ---------- .. [UNET] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation MICCAI 2015. https://arxiv.org/abs/1505.04597 .. [LDM] https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py + .. [NOSENSE] Zimmermann, FF, and Kofler, Andreas. "NoSENSE: Learned unrolled cardiac MRI reconstruction without + explicit sensitivity maps." STACOM 2023. https://github.com/fzimmermann89/CMRxRecon/blob/master/src/cmrxrecon/nets/unet.py + """ def __init__( diff --git a/src/mrpro/utils/pad_or_crop.py b/src/mrpro/utils/pad_or_crop.py index 5bb82f21d..f9c120d3e 100644 --- a/src/mrpro/utils/pad_or_crop.py +++ b/src/mrpro/utils/pad_or_crop.py @@ -76,8 +76,18 @@ def pad_or_crop( diff = new - old after = math.trunc(diff / 2) before = diff - after - npad.append(before) - npad.append(after) + if before != 0 or after != 0: + npad.append(before) + npad.append(after) + + if mode != 'constant': + while len(npad) // 2 < data.ndim - 2: + npad = [0, 0, *npad] + if len(npad) // 2 > data.ndim - 2: + raise ValueError( + 'replicate and circular padding are only supported for up to the last 3 dimensions of 4D/5D data, ' + 'last 2 dimensions of 3D/4D data and last dimension of 2D/3D data.' + ) if any(npad): # F.pad expects paddings in reversed order diff --git a/tests/nn/test_join.py b/tests/nn/test_join.py new file mode 100644 index 000000000..f86647ac4 --- /dev/null +++ b/tests/nn/test_join.py @@ -0,0 +1,160 @@ +"""Tests for join modules.""" + +from typing import Literal + +import pytest +import torch +from mrpro.nn.join import Add, Concat +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('mode', 'input_shapes', 'expected_shape'), + [ + ('crop', [(1, 3, 32, 32), (1, 5, 30, 30)], (1, 8, 30, 30)), + ('zero', [(1, 3, 32, 32), (1, 5, 34, 34)], (1, 8, 34, 34)), + ('linear', [(1, 3, 32, 32), (1, 5, 34, 34)], (1, 8, 34, 34)), + ('nearest', [(1, 3, 32, 32), (1, 5, 34, 34)], (1, 8, 34, 34)), + ], +) +def test_concat_basic( + mode: Literal['crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'], + input_shapes: list[tuple[int, ...]], + expected_shape: tuple[int, ...], + device: str, +) -> None: + """Test Concat basic functionality.""" + rng = RandomGenerator(seed=42) + xs = [rng.float32_tensor(shape).to(device).requires_grad_(True) for shape in input_shapes] + concat = Concat(mode=mode).to(device) + + output = concat(*xs) + assert output.shape == expected_shape + assert not output.isnan().any(), 'NaN values in output' + + output.sum().backward() + for x in xs: + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('mode', 'input_shapes', 'expected_shape'), + [ + ('crop', [(1, 3, 32, 32), (1, 3, 30, 30)], (1, 3, 30, 30)), + ('zero', [(1, 3, 32, 32), (1, 3, 34, 34)], (1, 3, 34, 34)), + ('replicate', [(1, 1, 1, 2), (1, 1, 1, 3)], (1, 1, 1, 3)), + ('circular', [(1, 1, 1, 2), (1, 1, 1, 4)], (1, 1, 1, 4)), + ], +) +def test_add_basic( + mode: Literal['crop', 'zero', 'replicate', 'circular'], + input_shapes: list[tuple[int, ...]], + expected_shape: tuple[int, ...], + device: str, +) -> None: + """Test Add basic functionality.""" + rng = RandomGenerator(seed=42) + xs = [rng.float32_tensor(shape).to(device).requires_grad_(True) for shape in input_shapes] + add = Add(mode=mode).to(device) + + output = add(*xs) + assert output.shape == expected_shape + assert not output.isnan().any(), 'NaN values in output' + + output.sum().backward() + for x in xs: + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + +@pytest.mark.parametrize( + ('dim', 'input_shapes', 'expected_shape'), + [ + (0, [(1, 3, 32, 32), (1, 3, 32, 32)], (2, 3, 32, 32)), + (1, [(1, 3, 32, 32), (1, 5, 32, 32)], (1, 8, 32, 32)), + (2, [(1, 3, 32, 32), (1, 3, 32, 32)], (1, 3, 64, 32)), + ], +) +def test_concat_dimensions(dim: int, input_shapes: list[tuple[int, ...]], expected_shape: tuple[int, ...]) -> None: + """Test Concat with different concatenation dimensions.""" + rng = RandomGenerator(seed=42) + xs = [rng.float32_tensor(shape).requires_grad_(True) for shape in input_shapes] + concat = Concat(mode='fail', dim=dim) + output = concat(*xs) + assert output.shape == expected_shape + + +def test_concat_values() -> None: + """Test that Concat preserves input values correctly.""" + x1 = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]).requires_grad_(True) + x2 = torch.tensor([[[[5.0, 6.0], [7.0, 8.0]]]]).requires_grad_(True) + + concat = Concat(mode='fail') + output = concat(x1, x2) + + expected = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]) + torch.testing.assert_close(output, expected) + + +def test_add_values() -> None: + """Test that Add correctly sums input values.""" + x1 = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]).requires_grad_(True) + x2 = torch.tensor([[[[5.0, 6.0], [7.0, 8.0]]]]).requires_grad_(True) + + add = Add(mode='fail') + output = add(x1, x2) + + expected = torch.tensor([[[[6.0, 8.0], [10.0, 12.0]]]]) + torch.testing.assert_close(output, expected) + + +def test_concat_mode_fail() -> None: + """Test Concat with mode='fail'.""" + rng = RandomGenerator(seed=42) + + x1 = rng.float32_tensor((1, 3, 32, 32)) + x2 = rng.float32_tensor((1, 5, 32, 32)) + concat = Concat(mode='fail') + output = concat(x1, x2) + assert output.shape == (1, 8, 32, 32) + + x3 = rng.float32_tensor((1, 3, 30, 30)) + with pytest.raises(RuntimeError): + concat(x1, x3) + + +def test_add_mode_fail() -> None: + """Test Add with mode='fail'.""" + rng = RandomGenerator(seed=42) + + x1 = rng.float32_tensor((1, 3, 32, 32)) + x2 = rng.float32_tensor((1, 3, 32, 32)) + add = Add(mode='fail') + output = add(x1, x2) + assert output.shape == (1, 3, 32, 32) + + x3 = rng.float32_tensor((1, 3, 30, 30)) + with pytest.raises(RuntimeError): + add(x1, x3) + + +@pytest.mark.parametrize('module_class', [Concat, Add]) +def test_invalid_mode(module_class: type) -> None: + """Test modules with invalid mode.""" + with pytest.raises(ValueError, match='mode must be one of'): + module_class(mode='invalid_mode') diff --git a/tests/nn/test_linearselfattention.py b/tests/nn/test_linearselfattention.py new file mode 100644 index 000000000..dc42fb197 --- /dev/null +++ b/tests/nn/test_linearselfattention.py @@ -0,0 +1,58 @@ +"""Tests for LinearSelfAttention module.""" + +import pytest +from mrpro.nn import LinearSelfAttention +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels_in', 'n_channels_out', 'n_heads', 'input_shape', 'features_last'), + [ + (32, 32, 4, (1, 32, 32, 32), False), + (64, 64, 8, (2, 64, 16, 16), False), + (16, 16, 2, (1, 16, 16, 16), True), + ], +) +def test_linear_self_attention( + n_channels_in: int, + n_channels_out: int, + n_heads: int, + input_shape: tuple[int, ...], + features_last: bool, + device: str, +) -> None: + """Test LinearSelfAttention output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + + attn = LinearSelfAttention( + n_channels_in=n_channels_in, + n_channels_out=n_channels_out, + n_heads=n_heads, + features_last=features_last, + ).to(device) + + if features_last: + output = attn(x.moveaxis(1, -1)).moveaxis(-1, 1) + else: + output = attn(x) + + expected_shape = (x.shape[0], n_channels_out, *x.shape[2:]) + assert output.shape == expected_shape, f'Output shape {output.shape} != expected shape {expected_shape}' + assert not output.isnan().any(), 'NaN values in output' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + assert attn.to_qkv.weight.grad is not None, 'No gradient computed for to_qkv.weight' + assert attn.to_qkv.bias.grad is not None, 'No gradient computed for to_qkv.bias' + assert attn.to_out.weight.grad is not None, 'No gradient computed for to_out.weight' + assert attn.to_out.bias.grad is not None, 'No gradient computed for to_out.bias' From e91b6dc84a008f5a9c686fc31269fdaae2e1d35e Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 15 Jul 2025 00:06:09 +0200 Subject: [PATCH 110/205] bump torch version --- docker/minimal_requirements.txt | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/minimal_requirements.txt b/docker/minimal_requirements.txt index c9be333c9..8385cb69c 100644 --- a/docker/minimal_requirements.txt +++ b/docker/minimal_requirements.txt @@ -1,4 +1,4 @@ -torch==2.3.1+cpu +torch==2.5.0+cpu torchvision==0.18.1+cpu numpy==1.23 ismrmrd==1.14.1 diff --git a/pyproject.toml b/pyproject.toml index c8f90da87..4132279eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ classifiers = [ ] dependencies = [ "numpy>=1.23, <3.0", - "torch>=2.3.1", + "torch>=2.5", "ismrmrd>=1.14.1", "einops>=0.7.0", "pydicom>=3.0.1", From b387ca517643766f6d6294ac4f67652bf7d148e2 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 15 Jul 2025 00:22:26 +0200 Subject: [PATCH 111/205] fix pad --- src/mrpro/utils/pad_or_crop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/utils/pad_or_crop.py b/src/mrpro/utils/pad_or_crop.py index f9c120d3e..f8c5c3897 100644 --- a/src/mrpro/utils/pad_or_crop.py +++ b/src/mrpro/utils/pad_or_crop.py @@ -76,7 +76,7 @@ def pad_or_crop( diff = new - old after = math.trunc(diff / 2) before = diff - after - if before != 0 or after != 0: + if before or after or npad: npad.append(before) npad.append(after) From 233ecf5a8b7434256d0900afb491c2f83ab6e14d Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 15 Jul 2025 10:18:10 +0200 Subject: [PATCH 112/205] include build essentials --- docker/install_system.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/install_system.sh b/docker/install_system.sh index 82faef8fd..ccc9f3bea 100644 --- a/docker/install_system.sh +++ b/docker/install_system.sh @@ -8,13 +8,13 @@ apt-get update -qq ${APT_GET_INSTALL} --reinstall ca-certificates # base utilities -${APT_GET_INSTALL} git software-properties-common gpg-agent curl jq +${APT_GET_INSTALL} git software-properties-common gpg-agent curl jq build-essential # add repo for python installation add-apt-repository ppa:deadsnakes/ppa apt update -qq -${APT_GET_INSTALL} $PYTHON-full +${APT_GET_INSTALL} $PYTHON-dev # pip if [[ "$PYTHON" == "python3.10" ]]; then From 0edd9e7f8768defaf1f0f1b00e15e830f0c8f2f3 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 15 Jul 2025 12:30:16 +0200 Subject: [PATCH 113/205] dev and full --- docker/install_system.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/install_system.sh b/docker/install_system.sh index ccc9f3bea..858d268fa 100644 --- a/docker/install_system.sh +++ b/docker/install_system.sh @@ -14,7 +14,7 @@ ${APT_GET_INSTALL} git software-properties-common gpg-agent curl jq build-essent add-apt-repository ppa:deadsnakes/ppa apt update -qq -${APT_GET_INSTALL} $PYTHON-dev +${APT_GET_INSTALL} $PYTHON-dev $PYTHON-full # pip if [[ "$PYTHON" == "python3.10" ]]; then From be7c9a184093a5094fed791f6586b4d361ecadd6 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 15 Jul 2025 23:42:24 +0200 Subject: [PATCH 114/205] fix --- src/mrpro/utils/pad_or_crop.py | 26 +++++++++++++++++++++++--- tests/utils/test_pad_or_crop.py | 25 ++++++++++++++++++++----- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/src/mrpro/utils/pad_or_crop.py b/src/mrpro/utils/pad_or_crop.py index 5bb82f21d..00f86f947 100644 --- a/src/mrpro/utils/pad_or_crop.py +++ b/src/mrpro/utils/pad_or_crop.py @@ -6,6 +6,8 @@ import torch +from mrpro.utils.reshape import unsqueeze_left + def normalize_index(ndim: int, index: int) -> int: """Normalize possibly negative indices. @@ -71,15 +73,33 @@ def pad_or_crop( # Update elements in data.shape at indices specified in dim with corresponding elements from new_shape new_shape = tuple(new_shape[dim.index(i)] if i in dim else s for i, s in enumerate(data.shape)) - npad = [] + npad: list[int] = [] for old, new in zip(data.shape, new_shape, strict=True): diff = new - old after = math.trunc(diff / 2) before = diff - after - npad.append(before) - npad.append(after) + if before or after or npad: + npad.append(before) + npad.append(after) + + n_extended_dims = 0 + if mode != 'constant': + # See https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.pad.html for the supported shapes + while len(npad) // 2 < data.ndim - 2: + npad = [0, 0, *npad] + + n_extended_dims = max(0, len(npad) // 2 - (data.ndim - 2)) + if n_extended_dims: # We need to extend data such that the padding is supported. + data = unsqueeze_left(data, n_extended_dims) + + if len(npad) > 6: # TODO: reshape and call multiple times + raise ValueError('replicate and circular padding are only supported for up to the last 3 dimensions.') if any(npad): # F.pad expects paddings in reversed order data = torch.nn.functional.pad(data, npad[::-1], value=value, mode=mode) + + if n_extended_dims: + idx = n_extended_dims * (0,) + data = data[idx] return data diff --git a/tests/utils/test_pad_or_crop.py b/tests/utils/test_pad_or_crop.py index 15a678419..3781c15f7 100644 --- a/tests/utils/test_pad_or_crop.py +++ b/tests/utils/test_pad_or_crop.py @@ -1,18 +1,33 @@ """Tests for padding and cropping of data tensors.""" +from typing import Literal + +import pytest import torch from mrpro.utils import RandomGenerator from mrpro.utils.pad_or_crop import pad_or_crop -def test_pad_or_crop_content(): +@pytest.mark.parametrize('mode', ['constant', 'reflect', 'replicate', 'circular']) +def test_pad_or_crop_content(mode: Literal['constant', 'reflect', 'replicate', 'circular']): """Test changing data by cropping and padding.""" generator = RandomGenerator(seed=0) original_data_shape = (100, 200, 50) - new_data_shape = (80, 100, 240) + new_data_shape = (80, 100, 70) original_data = generator.complex64_tensor(original_data_shape) - new_data = pad_or_crop(original_data, new_data_shape, dim=(-3, -2, -1), value=123) + new_data = pad_or_crop( + original_data, new_data_shape, dim=(-3, -2, -1), value=123 if mode == 'constant' else 0, mode=mode + ) # Compare overlapping region - torch.testing.assert_close(original_data[10:90, 50:150, :], new_data[:, :, 95:145]) - assert new_data[0, 0, 0] == 123 + torch.testing.assert_close(original_data[10:90, 50:150, :], new_data[:, :, 10:60]) + # ... and padded region + match mode: + case 'constant': + assert new_data[0, 0, 0] == 123 + case 'reflect': + assert new_data[0, 0, 9] == original_data[10, 50, 1] + case 'replicate': + assert new_data[0, 0, 9] == original_data[10, 50, 0] + case 'circular': + assert new_data[0, 0, 9] == original_data[10, 50, -1] From ef42f4f7e8eadeeeb2b10f4e1b10ca4dbc6c9ae2 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 16 Jul 2025 23:19:21 +0200 Subject: [PATCH 115/205] update dependencies --- docker/minimal_requirements.txt | 6 +++--- pyproject.toml | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docker/minimal_requirements.txt b/docker/minimal_requirements.txt index 8385cb69c..2d5b7e24a 100644 --- a/docker/minimal_requirements.txt +++ b/docker/minimal_requirements.txt @@ -1,12 +1,12 @@ -torch==2.5.0+cpu -torchvision==0.18.1+cpu +torch==2.5.1+cpu +torchvision==0.20.1+cpu numpy==1.23 ismrmrd==1.14.1 einops==0.7.0 pydicom==3.0.1 pypulseq==1.4.2 pytorch-finufft==0.1.0 -cufinufft==2.3.1 +cufinufft==2.4.1 scipy==1.12 ptwt==0.1.8 tqdm==4.60.0 diff --git a/pyproject.toml b/pyproject.toml index 4132279eb..8d4e7fa3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,16 +61,16 @@ classifiers = [ ] dependencies = [ "numpy>=1.23, <3.0", - "torch>=2.5", + "torch>=2.5.1", "ismrmrd>=1.14.1", "einops>=0.7.0", "pydicom>=3.0.1", "pypulseq>=1.4.2", "pytorch-finufft>=0.1.0", - "cufinufft==2.3.1; platform_system=='Linux'", + "cufinufft==2.4.1; platform_system=='Linux'", "scipy>=1.12", "ptwt>=0.1.8", - "torchvision>=0.18.1", + "torchvision>=0.20.1", "tqdm>=4.60.0", "typing-extensions>=4.12", "platformdirs>=4.0", From 3bbd66ee187b14333744adeb2ac9649b49804cad Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 16 Jul 2025 23:54:31 +0200 Subject: [PATCH 116/205] Add tests for NeighborhoodSelfAttention module --- tests/nn/test_neighborhoodselfattention.py | 64 ++++++++++++++++++++-- 1 file changed, 59 insertions(+), 5 deletions(-) diff --git a/tests/nn/test_neighborhoodselfattention.py b/tests/nn/test_neighborhoodselfattention.py index 5e66f5be9..925dd4693 100644 --- a/tests/nn/test_neighborhoodselfattention.py +++ b/tests/nn/test_neighborhoodselfattention.py @@ -1,6 +1,7 @@ """Tests for NeighborhoodSelfAttention module.""" import pytest +import torch from mrpro.nn import NeighborhoodSelfAttention from mrpro.utils import RandomGenerator @@ -18,6 +19,7 @@ (2, 3, 1, 2, (1, 2, 16, 16), False), (3, 2, 2, 4, (1, 3, 8, 8, 8, 8), True), ], + ids=['2d_kernel2', '4d_features-last_kernel4'], ) def test_neighborhood_self_attention( n_channels_in: int, @@ -32,9 +34,6 @@ def test_neighborhood_self_attention( rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) - if features_last: - x = x.moveaxis(1, -1) - attn = NeighborhoodSelfAttention( n_channels_in=n_channels_in, n_channels_out=n_channels_out, @@ -43,9 +42,12 @@ def test_neighborhood_self_attention( features_last=features_last, ).to(device) - output = attn(x) + if features_last: + output = attn(x.moveaxis(1, -1)).moveaxis(-1, 1) + else: + output = attn(x) - expected_shape = (x.shape[0], n_channels_out, *x.shape[2:]) + expected_shape = (input_shape[0], n_channels_out, *input_shape[2:]) assert output.shape == expected_shape assert not output.isnan().any(), 'NaN values in output' @@ -87,3 +89,55 @@ def test_neighborhood_attention_variants(kernel_size: int, dilation: int, circul output.sum().backward() assert x.grad is not None, 'No gradient computed for input' assert not output.isnan().any(), 'NaN values in output' + + +@pytest.mark.parametrize( + ('kernel_size', 'circular', 'input_shape'), + [ + (11, False, (1, 8, 32, 32)), + (7, True, (1, 8, 16, 16)), + ], + ids=['regular', 'circular'], +) +def test_neighborhood_constraint(kernel_size: int, circular: bool, input_shape: tuple[int, int, int, int]) -> None: + """Test that neighborhood attention only affects pixels within the kernel window.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).requires_grad_(True) + + attn = NeighborhoodSelfAttention( + n_channels_in=8, + n_channels_out=8, + n_heads=2, + kernel_size=kernel_size, + dilation=1, + circular=circular, + ) + + output_original = attn(x) + x_modified = x.clone() + test_point = (input_shape[-2] - 2, input_shape[-1] - 2) + x_modified[..., test_point[0], test_point[1]] += 1.0 + output_modified = attn(x_modified) + + diff = output_modified - output_original + changed_pixels = torch.abs(diff).sum(dim=(0, 1)) > 1e-6 + + half_kernel = kernel_size // 2 + h, w = input_shape[2], input_shape[3] + + i_coords, j_coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') + + if circular: + h_dist = torch.minimum((i_coords - test_point[0]) % h, (test_point[0] - i_coords) % h) + w_dist = torch.minimum((j_coords - test_point[1]) % w, (test_point[1] - j_coords) % w) + in_neighborhood = (h_dist <= half_kernel) & (w_dist <= half_kernel) + else: + h_min, h_max = max(0, test_point[0] - half_kernel), min(h, test_point[0] + half_kernel + 1) + w_min, w_max = max(0, test_point[1] - half_kernel), min(w, test_point[1] + half_kernel + 1) + in_neighborhood = (i_coords >= h_min) & (i_coords < h_max) & (j_coords >= w_min) & (j_coords < w_max) + + neighborhood_changed = changed_pixels[in_neighborhood].all() + outside_changed = changed_pixels[~in_neighborhood].any() + + assert neighborhood_changed, 'Not all pixels in the neighborhood changed, which indicates a problem' + assert not outside_changed, 'Pixels outside the neighborhood changed, which violates the constraint' From a2f7eb6a3fa31c030a2d2b2d8dd3d1f131aaa868 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 17 Jul 2025 02:04:47 +0200 Subject: [PATCH 117/205] move attention --- src/mrpro/nn/__init__.py | 15 +++------------ src/mrpro/nn/{ => attention}/AttentionGate.py | 0 .../nn/{ => attention}/LinearSelfAttention.py | 0 .../nn/{ => attention}/MultiHeadAttention.py | 0 .../{ => attention}/NeighborhoodSelfAttention.py | 0 .../nn/{ => attention}/ShiftedWindowAttention.py | 0 .../nn/{ => attention}/SpatialTransformerBlock.py | 2 +- src/mrpro/nn/{ => attention}/SqueezeExcitation.py | 0 .../nn/{ => attention}/TransposedAttention.py | 0 src/mrpro/nn/attention/__init__,py | 15 +++++++++++++++ src/mrpro/nn/nets/DCAE.py | 4 ++-- src/mrpro/nn/nets/Restormer.py | 2 +- src/mrpro/nn/nets/SwinIR.py | 2 +- src/mrpro/nn/nets/UNet.py | 4 ++-- src/mrpro/nn/nets/Uformer.py | 2 +- tests/nn/test_attentiongate.py | 2 +- tests/nn/test_linearselfattention.py | 2 +- tests/nn/test_neighborhoodselfattention.py | 2 +- tests/nn/test_sqeezeexcitation.py | 2 +- tests/nn/test_transposedattention.py | 2 +- 20 files changed, 31 insertions(+), 25 deletions(-) rename src/mrpro/nn/{ => attention}/AttentionGate.py (100%) rename src/mrpro/nn/{ => attention}/LinearSelfAttention.py (100%) rename src/mrpro/nn/{ => attention}/MultiHeadAttention.py (100%) rename src/mrpro/nn/{ => attention}/NeighborhoodSelfAttention.py (100%) rename src/mrpro/nn/{ => attention}/ShiftedWindowAttention.py (100%) rename src/mrpro/nn/{ => attention}/SpatialTransformerBlock.py (98%) rename src/mrpro/nn/{ => attention}/SqueezeExcitation.py (100%) rename src/mrpro/nn/{ => attention}/TransposedAttention.py (100%) create mode 100644 src/mrpro/nn/attention/__init__,py diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index ab5c13513..5dc019ebc 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -1,6 +1,5 @@ """Neural network modules and utilities.""" -from mrpro.nn.AttentionGate import AttentionGate from mrpro.nn.CondMixin import CondMixin from mrpro.nn.FiLM import FiLM from mrpro.nn.GroupNorm import GroupNorm @@ -13,23 +12,19 @@ InstanceNormND, MaxPoolND, ) -from mrpro.nn.NeighborhoodSelfAttention import NeighborhoodSelfAttention from mrpro.nn.ResBlock import ResBlock from mrpro.nn.Sequential import Sequential -from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention -from mrpro.nn.SqueezeExcitation import SqueezeExcitation -from mrpro.nn.TransposedAttention import TransposedAttention -from mrpro.nn.LinearSelfAttention import LinearSelfAttention + from mrpro.nn.DropPath import DropPath from mrpro.nn.Residual import Residual from mrpro.nn.ComplexAsChannel import ComplexAsChannel from mrpro.nn import nets +from mrpro.nn import attention from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.RMSNorm import RMSNorm __all__ = [ "AdaptiveAvgPoolND", - "AttentionGate", "AvgPoolND", "BatchNormND", "ComplexAsChannel", @@ -40,16 +35,12 @@ "FiLM", "GroupNorm", "InstanceNormND", - "LinearSelfAttention", "MaxPoolND", - "NeighborhoodSelfAttention", "PermutedBlock", "RMSNorm", "ResBlock", "Residual", "Sequential", - "ShiftedWindowAttention", - "SqueezeExcitation", - "TransposedAttention", + "attention", "nets" ] \ No newline at end of file diff --git a/src/mrpro/nn/AttentionGate.py b/src/mrpro/nn/attention/AttentionGate.py similarity index 100% rename from src/mrpro/nn/AttentionGate.py rename to src/mrpro/nn/attention/AttentionGate.py diff --git a/src/mrpro/nn/LinearSelfAttention.py b/src/mrpro/nn/attention/LinearSelfAttention.py similarity index 100% rename from src/mrpro/nn/LinearSelfAttention.py rename to src/mrpro/nn/attention/LinearSelfAttention.py diff --git a/src/mrpro/nn/MultiHeadAttention.py b/src/mrpro/nn/attention/MultiHeadAttention.py similarity index 100% rename from src/mrpro/nn/MultiHeadAttention.py rename to src/mrpro/nn/attention/MultiHeadAttention.py diff --git a/src/mrpro/nn/NeighborhoodSelfAttention.py b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py similarity index 100% rename from src/mrpro/nn/NeighborhoodSelfAttention.py rename to src/mrpro/nn/attention/NeighborhoodSelfAttention.py diff --git a/src/mrpro/nn/ShiftedWindowAttention.py b/src/mrpro/nn/attention/ShiftedWindowAttention.py similarity index 100% rename from src/mrpro/nn/ShiftedWindowAttention.py rename to src/mrpro/nn/attention/ShiftedWindowAttention.py diff --git a/src/mrpro/nn/SpatialTransformerBlock.py b/src/mrpro/nn/attention/SpatialTransformerBlock.py similarity index 98% rename from src/mrpro/nn/SpatialTransformerBlock.py rename to src/mrpro/nn/attention/SpatialTransformerBlock.py index 78be70d77..ac38dd030 100644 --- a/src/mrpro/nn/SpatialTransformerBlock.py +++ b/src/mrpro/nn/attention/SpatialTransformerBlock.py @@ -9,7 +9,7 @@ from mrpro.nn.GEGLU import GEGLU from mrpro.nn.GroupNorm import GroupNorm from mrpro.nn.LayerNorm import LayerNorm -from mrpro.nn.MultiHeadAttention import MultiHeadAttention +from mrpro.nn.attention.MultiHeadAttention import MultiHeadAttention from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.Sequential import Sequential diff --git a/src/mrpro/nn/SqueezeExcitation.py b/src/mrpro/nn/attention/SqueezeExcitation.py similarity index 100% rename from src/mrpro/nn/SqueezeExcitation.py rename to src/mrpro/nn/attention/SqueezeExcitation.py diff --git a/src/mrpro/nn/TransposedAttention.py b/src/mrpro/nn/attention/TransposedAttention.py similarity index 100% rename from src/mrpro/nn/TransposedAttention.py rename to src/mrpro/nn/attention/TransposedAttention.py diff --git a/src/mrpro/nn/attention/__init__,py b/src/mrpro/nn/attention/__init__,py new file mode 100644 index 000000000..7b3d24115 --- /dev/null +++ b/src/mrpro/nn/attention/__init__,py @@ -0,0 +1,15 @@ +from mrpro.nn.attention.AttentionGate import AttentionGate +from mrpro.nn.attention.LinearSelfAttention import LinearSelfAttention +from mrpro.nn.attention.NeighborhoodSelfAttention import NeighborhoodSelfAttention +from mrpro.nn.attention.ShiftedWindowAttention import ShiftedWindowAttention +from mrpro.nn.attention.SqueezeExcitation import SqueezeExcitation +from mrpro.nn.attention.TransposedAttention import TransposedAttention + +__all__ = [ + 'AttentionGate', + 'LinearSelfAttention', + 'NeighborhoodSelfAttention', + 'ShiftedWindowAttention', + 'SqueezeExcitation', + 'TransposedAttention', +] diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py index c5f4eaaa7..a83407132 100644 --- a/src/mrpro/nn/nets/DCAE.py +++ b/src/mrpro/nn/nets/DCAE.py @@ -6,9 +6,9 @@ import torch from torch.nn import Module, ReLU, SiLU +from mrpro.nn.attention.LinearSelfAttention import LinearSelfAttention +from mrpro.nn.attention.MultiHeadAttention import MultiHeadAttention from mrpro.nn.GluMBConvResBlock import GluMBConvResBlock -from mrpro.nn.LinearSelfAttention import LinearSelfAttention -from mrpro.nn.MultiHeadAttention import MultiHeadAttention from mrpro.nn.ndmodules import ConvND from mrpro.nn.nets.VAE import VAE from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index f95fd2f98..3cba23ecc 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -6,6 +6,7 @@ import torch from torch.nn import Module +from mrpro.nn.attention.TransposedAttention import TransposedAttention from mrpro.nn.CondMixin import CondMixin from mrpro.nn.FiLM import FiLM from mrpro.nn.join import Concat @@ -13,7 +14,6 @@ from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample from mrpro.nn.Sequential import Sequential -from mrpro.nn.TransposedAttention import TransposedAttention class GDFN(Module): diff --git a/src/mrpro/nn/nets/SwinIR.py b/src/mrpro/nn/nets/SwinIR.py index e3e8a440a..16c19a9fc 100644 --- a/src/mrpro/nn/nets/SwinIR.py +++ b/src/mrpro/nn/nets/SwinIR.py @@ -3,11 +3,11 @@ import torch from torch.nn import GELU, Module +from mrpro.nn.attention.ShiftedWindowAttention import ShiftedWindowAttention from mrpro.nn.DropPath import DropPath from mrpro.nn.FiLM import FiLM from mrpro.nn.ndmodules import ConvND, InstanceNormND from mrpro.nn.Sequential import Sequential -from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention class SwinTransformerLayer(Module): diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 14db1d3a0..6f21971a0 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -7,7 +7,8 @@ import torch from torch.nn import Identity, Module, ModuleList, ReLU, SiLU -from mrpro.nn.AttentionGate import AttentionGate +from mrpro.nn.attention.AttentionGate import AttentionGate +from mrpro.nn.attention.SpatialTransformerBlock import SpatialTransformerBlock from mrpro.nn.CondMixin import call_with_cond from mrpro.nn.FiLM import FiLM from mrpro.nn.GroupNorm import GroupNorm @@ -17,7 +18,6 @@ from mrpro.nn.ResBlock import ResBlock from mrpro.nn.SeparableResBlock import SeparableResBlock # Assuming SeparableResBlock is here from mrpro.nn.Sequential import Sequential -from mrpro.nn.SpatialTransformerBlock import SpatialTransformerBlock from mrpro.nn.Upsample import Upsample diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index 83ec8c49b..4424ce91c 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -6,6 +6,7 @@ import torch from torch.nn import GELU, LeakyReLU, Module +from mrpro.nn.attention.ShiftedWindowAttention import ShiftedWindowAttention from mrpro.nn.CondMixin import CondMixin from mrpro.nn.DropPath import DropPath from mrpro.nn.FiLM import FiLM @@ -13,7 +14,6 @@ from mrpro.nn.ndmodules import ConvND, ConvTransposeND, InstanceNormND from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder from mrpro.nn.Sequential import Sequential -from mrpro.nn.ShiftedWindowAttention import ShiftedWindowAttention class LeWinTransformerBlock(CondMixin, Module): diff --git a/tests/nn/test_attentiongate.py b/tests/nn/test_attentiongate.py index 99d463a00..10d30cb07 100644 --- a/tests/nn/test_attentiongate.py +++ b/tests/nn/test_attentiongate.py @@ -3,7 +3,7 @@ from collections.abc import Sequence import pytest -from mrpro.nn.AttentionGate import AttentionGate +from mrpro.nn.attention import AttentionGate from mrpro.utils import RandomGenerator diff --git a/tests/nn/test_linearselfattention.py b/tests/nn/test_linearselfattention.py index dc42fb197..d7649a3df 100644 --- a/tests/nn/test_linearselfattention.py +++ b/tests/nn/test_linearselfattention.py @@ -1,7 +1,7 @@ """Tests for LinearSelfAttention module.""" import pytest -from mrpro.nn import LinearSelfAttention +from mrpro.nn.attention.LinearSelfAttention import LinearSelfAttention from mrpro.utils import RandomGenerator diff --git a/tests/nn/test_neighborhoodselfattention.py b/tests/nn/test_neighborhoodselfattention.py index 925dd4693..0dbf1ccb5 100644 --- a/tests/nn/test_neighborhoodselfattention.py +++ b/tests/nn/test_neighborhoodselfattention.py @@ -2,7 +2,7 @@ import pytest import torch -from mrpro.nn import NeighborhoodSelfAttention +from mrpro.nn.attention.NeighborhoodSelfAttention import NeighborhoodSelfAttention from mrpro.utils import RandomGenerator diff --git a/tests/nn/test_sqeezeexcitation.py b/tests/nn/test_sqeezeexcitation.py index b0ddf7050..7b7509f1d 100644 --- a/tests/nn/test_sqeezeexcitation.py +++ b/tests/nn/test_sqeezeexcitation.py @@ -1,7 +1,7 @@ """Tests for SqueezeExcitation module.""" import pytest -from mrpro.nn import SqueezeExcitation +from mrpro.nn.attention.SqueezeExcitation import SqueezeExcitation from mrpro.utils import RandomGenerator diff --git a/tests/nn/test_transposedattention.py b/tests/nn/test_transposedattention.py index 8b72b071f..a2688f36d 100644 --- a/tests/nn/test_transposedattention.py +++ b/tests/nn/test_transposedattention.py @@ -1,7 +1,7 @@ """Tests for TransposedAttention module.""" import pytest -from mrpro.nn import TransposedAttention +from mrpro.nn.attention.TransposedAttention import TransposedAttention from mrpro.utils import RandomGenerator From 2bfba604110d5958af506c1ce96dd94a3f688c0a Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 17 Jul 2025 02:05:01 +0200 Subject: [PATCH 118/205] dc --- .../data_consistency/AnalyticCartesianDC.py | 51 +++++++++++++++++++ .../data_consistency/ConjugateGradientDC.py | 13 +++++ .../nn/data_consistency/GradientDescentDC.py | 37 ++++++++++++++ src/mrpro/nn/data_consistency/__init__.py | 5 ++ 4 files changed, 106 insertions(+) create mode 100644 src/mrpro/nn/data_consistency/AnalyticCartesianDC.py create mode 100644 src/mrpro/nn/data_consistency/ConjugateGradientDC.py create mode 100644 src/mrpro/nn/data_consistency/GradientDescentDC.py create mode 100644 src/mrpro/nn/data_consistency/__init__.py diff --git a/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py new file mode 100644 index 000000000..53394784d --- /dev/null +++ b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py @@ -0,0 +1,51 @@ +import torch +from torch.nn import Module, Parameter + +from mrpro.data.KData import KData +from mrpro.operators.FourierOp import FourierOp +from mrpro.operators.IdentityOp import IdentityOp + + +class AnalyticCartesianDC(Module): + r"""Analytic Cartesian data consistency. + + Solves the following problem: + :math:`\min_x \|Ax - k\|_2^2 + \lambda \|x-p\|_2^2` + where :math:`A` is the acquistion operator and :math:`k` is the data, :math:`\lambda` is the regularization parameter, + and :math:`p` is the regularization image/prior analytically. :math:`A^H A` has to be diagonal. This is a special case + for a Cartesian acquisition without coil sensitivity weighting. This can be used for either single-coil data or + to apply data consistancy to each coil image [NOSENSE]_ + + References + ---------- + .. [NOSENSE] Zimmermann, FF, and Kofler, Andreas. "NoSENSE: Learned unrolled cardiac MRI reconstruction without + explicit sensitivity maps." STACOM@MICCAI 2023. https://arxiv.org/abs/2309.15608 + + Parameters + ---------- + initial_regularization_weight + Initial regularization weight. The regularization weight is a trainable parameter. + + + """ + + def __init__(self, initial_regularization_weight: torch.Tensor | float): + super().__init__() + self.regularization_weight = Parameter(torch.as_tensor(initial_regularization_weight)) + + def forward( + self, + x: torch.Tensor, + data: KData | torch.Tensor, + fourier_op: FourierOp, + ): + if not isinstance(fourier_op, FourierOp) or fourier_op._nufft_dims or fourier_op._fast_fourier_op is None: + raise ValueError('Only Cartesian acquisitions are supported') + + data_ = data.data if isinstance(data, KData) else data + fft_op = fourier_op._fast_fourier_op + sampling_op = fourier_op._cart_sampling_op if fourier_op._cart_sampling_op is not None else IdentityOp() + (k_pred,) = fft_op(x) + (k,) = sampling_op.gram((data_ - k_pred) / (1 + self.regularization_weight)) + (delta,) = fft_op.H(k) + return x + delta diff --git a/src/mrpro/nn/data_consistency/ConjugateGradientDC.py b/src/mrpro/nn/data_consistency/ConjugateGradientDC.py new file mode 100644 index 000000000..e9bf2b635 --- /dev/null +++ b/src/mrpro/nn/data_consistency/ConjugateGradientDC.py @@ -0,0 +1,13 @@ +import torch +from torch.nn import Module + +from mrpro.operators.ConjugateGradientOp import ConjugateGradientOp +from mrpro.operators.FourierOp import FourierOp + + +class ConjugateGradientDC(Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, data: torch.Tensor, fourier_op: FourierOp): + cg_op = ConjugateGradientOp(fourier_op) diff --git a/src/mrpro/nn/data_consistency/GradientDescentDC.py b/src/mrpro/nn/data_consistency/GradientDescentDC.py new file mode 100644 index 000000000..ecf9c0111 --- /dev/null +++ b/src/mrpro/nn/data_consistency/GradientDescentDC.py @@ -0,0 +1,37 @@ +import torch +from torch.nn import Module, Parameter + +from mrpro.data.KData import KData +from mrpro.operators.LinearOperator import LinearOperator + + +class GradientDescentDC(Module): + r"""Gradient descent data consistency. + + Performs gradient descent steps on + :math:`\|Ax - k\|_2^2` where :math:`A` is the acquistion operator and :math:`k` is the data. + + Parameters + ---------- + initial_stepsize + Initial stepsize. The stepsize is a trainable parameter. + n_steps + Number of gradient descent steps. + + Returns + ------- + The updated image. + """ + + def __init__(self, initial_stepsize: float, n_steps: int = 1) -> None: + super().__init__() + self.stepsize = Parameter(torch.tensor(initial_stepsize)) + self.n_steps = n_steps + + def forward(self, x: torch.Tensor, data: KData | torch.Tensor, acquistion_operator: LinearOperator) -> torch.Tensor: + """Forward pass.""" + data_ = data.data if isinstance(data, KData) else data + for _ in range(self.n_steps): + residual = acquistion_operator(x)[0] - data_ + x = x - self.stepsize * acquistion_operator.adjoint(residual)[0] + return x diff --git a/src/mrpro/nn/data_consistency/__init__.py b/src/mrpro/nn/data_consistency/__init__.py new file mode 100644 index 000000000..b881fc2fc --- /dev/null +++ b/src/mrpro/nn/data_consistency/__init__.py @@ -0,0 +1,5 @@ +from mrpro.nn.data_consistency.AnalyticCartesianDC import AnalyticCartesianDC +from mrpro.nn.data_consistency.GradientDescentDC import GradientDescentDC +from mrpro.nn.data_consistency.ConjugateGradientDC import ConjugateGradientDC + +__all__ = ['AnalyticCartesianDC', 'GradientDescentDC', 'ConjugateGradientDC'] \ No newline at end of file From 396d4b6a99426d18bf83a28fffa99d7f2446cca2 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 17 Jul 2025 02:52:21 +0200 Subject: [PATCH 119/205] fix --- src/mrpro/nn/attention/SpatialTransformerBlock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/nn/attention/SpatialTransformerBlock.py b/src/mrpro/nn/attention/SpatialTransformerBlock.py index ac38dd030..50b5bbfa7 100644 --- a/src/mrpro/nn/attention/SpatialTransformerBlock.py +++ b/src/mrpro/nn/attention/SpatialTransformerBlock.py @@ -5,11 +5,11 @@ import torch from torch.nn import Dropout, Linear, Module +from mrpro.nn.attention.MultiHeadAttention import MultiHeadAttention from mrpro.nn.CondMixin import CondMixin from mrpro.nn.GEGLU import GEGLU from mrpro.nn.GroupNorm import GroupNorm from mrpro.nn.LayerNorm import LayerNorm -from mrpro.nn.attention.MultiHeadAttention import MultiHeadAttention from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.Sequential import Sequential From 8311bc5b9e405e77049c57d5ff855a7a6fb3d5ca Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 17 Jul 2025 02:54:08 +0200 Subject: [PATCH 120/205] dc --- .../data_consistency/AnalyticCartesianDC.py | 4 +- .../data_consistency/ConjugateGradientDC.py | 46 +++++++++++++++++-- .../nn/data_consistency/GradientDescentDC.py | 10 ++-- src/mrpro/nn/data_consistency/__init__.py | 2 +- 4 files changed, 51 insertions(+), 11 deletions(-) diff --git a/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py index 53394784d..65af18cdd 100644 --- a/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py +++ b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py @@ -11,10 +11,10 @@ class AnalyticCartesianDC(Module): Solves the following problem: :math:`\min_x \|Ax - k\|_2^2 + \lambda \|x-p\|_2^2` - where :math:`A` is the acquistion operator and :math:`k` is the data, :math:`\lambda` is the regularization parameter, + where :math:`A` is the acquisition operator and :math:`k` is the data, :math:`\lambda` is the regularization parameter, and :math:`p` is the regularization image/prior analytically. :math:`A^H A` has to be diagonal. This is a special case for a Cartesian acquisition without coil sensitivity weighting. This can be used for either single-coil data or - to apply data consistancy to each coil image [NOSENSE]_ + to apply data consistency to each coil image [NOSENSE]_ References ---------- diff --git a/src/mrpro/nn/data_consistency/ConjugateGradientDC.py b/src/mrpro/nn/data_consistency/ConjugateGradientDC.py index e9bf2b635..646527d0a 100644 --- a/src/mrpro/nn/data_consistency/ConjugateGradientDC.py +++ b/src/mrpro/nn/data_consistency/ConjugateGradientDC.py @@ -1,13 +1,51 @@ +from inspect import Parameter + import torch -from torch.nn import Module +from torch.nn import Module, Parameter +from mrpro.data.CsmData import CsmData +from mrpro.data.KData import KData from mrpro.operators.ConjugateGradientOp import ConjugateGradientOp from mrpro.operators.FourierOp import FourierOp +from mrpro.operators.IdentityOp import IdentityOp +from mrpro.operators.SensitivityOp import SensitivityOp class ConjugateGradientDC(Module): - def __init__(self): + """Conjugate gradient data consistency.""" + + def __init__(self, initial_regularization_weight: torch.Tensor | float): + """Initialize the conjugate gradient data consistency. + + Parameters + ---------- + initial_regularization_weight + Initial regularization weight. + """ super().__init__() + self.regularization_weight = Parameter(torch.as_tensor(initial_regularization_weight)) + + def operator_factory( + fourier_op: FourierOp, csm: torch.Tensor | CsmData | None, regularization_weight: torch.Tensor | float, *_ + ): + csm_op = SensitivityOp(csm) if csm is not None else IdentityOp() + op = csm_op.H @ fourier_op.gram @ csm_op + regularization_weight + return op + + self.cg_op = ConjugateGradientOp( + operator_factory=operator_factory, + rhs_factory=lambda _fourier, _csm, regularization_weight, zero_filled, regularization: zero_filled + + regularization_weight * regularization, + ) - def forward(self, x: torch.Tensor, data: torch.Tensor, fourier_op: FourierOp): - cg_op = ConjugateGradientOp(fourier_op) + def forward( + self, + x: torch.Tensor, + data: torch.Tensor | KData, + fourier_op: FourierOp, + csm: torch.Tensor | CSMData | None, + ): + data_ = data.data if isinstance(data, KData) else data + zero_filled = fourier_op.adjoint(data_) + x = self.cg_op(fourier_op, csm, self.regularization_weight, zero_filled, x) + return x diff --git a/src/mrpro/nn/data_consistency/GradientDescentDC.py b/src/mrpro/nn/data_consistency/GradientDescentDC.py index ecf9c0111..ee05ca8db 100644 --- a/src/mrpro/nn/data_consistency/GradientDescentDC.py +++ b/src/mrpro/nn/data_consistency/GradientDescentDC.py @@ -9,7 +9,7 @@ class GradientDescentDC(Module): r"""Gradient descent data consistency. Performs gradient descent steps on - :math:`\|Ax - k\|_2^2` where :math:`A` is the acquistion operator and :math:`k` is the data. + :math:`\|Ax - k\|_2^2` where :math:`A` is the acquisition operator and :math:`k` is the data. Parameters ---------- @@ -28,10 +28,12 @@ def __init__(self, initial_stepsize: float, n_steps: int = 1) -> None: self.stepsize = Parameter(torch.tensor(initial_stepsize)) self.n_steps = n_steps - def forward(self, x: torch.Tensor, data: KData | torch.Tensor, acquistion_operator: LinearOperator) -> torch.Tensor: + def forward( + self, x: torch.Tensor, data: KData | torch.Tensor, acquisition_operator: LinearOperator + ) -> torch.Tensor: """Forward pass.""" data_ = data.data if isinstance(data, KData) else data for _ in range(self.n_steps): - residual = acquistion_operator(x)[0] - data_ - x = x - self.stepsize * acquistion_operator.adjoint(residual)[0] + residual = acquisition_operator(x)[0] - data_ + x = x - self.stepsize * acquisition_operator.adjoint(residual)[0] return x diff --git a/src/mrpro/nn/data_consistency/__init__.py b/src/mrpro/nn/data_consistency/__init__.py index b881fc2fc..cea955c85 100644 --- a/src/mrpro/nn/data_consistency/__init__.py +++ b/src/mrpro/nn/data_consistency/__init__.py @@ -2,4 +2,4 @@ from mrpro.nn.data_consistency.GradientDescentDC import GradientDescentDC from mrpro.nn.data_consistency.ConjugateGradientDC import ConjugateGradientDC -__all__ = ['AnalyticCartesianDC', 'GradientDescentDC', 'ConjugateGradientDC'] \ No newline at end of file +__all__ = ["AnalyticCartesianDC", "ConjugateGradientDC", "GradientDescentDC"] \ No newline at end of file From 5c3195209f7c820e7be16fc6b81fa6c5ae7f11cc Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 17 Jul 2025 15:38:54 +0200 Subject: [PATCH 121/205] fix --- src/mrpro/nn/data_consistency/AnalyticCartesianDC.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py index 65af18cdd..742490cd6 100644 --- a/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py +++ b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py @@ -45,7 +45,8 @@ def forward( data_ = data.data if isinstance(data, KData) else data fft_op = fourier_op._fast_fourier_op sampling_op = fourier_op._cart_sampling_op if fourier_op._cart_sampling_op is not None else IdentityOp() + (zero_filled,) = sampling_op.adjoint(data_) (k_pred,) = fft_op(x) - (k,) = sampling_op.gram((data_ - k_pred) / (1 + self.regularization_weight)) + (k,) = sampling_op.gram((zero_filled - k_pred) / (1 + self.regularization_weight)) (delta,) = fft_op.H(k) return x + delta From 6495eb0ed84477fb0059dc4cc6086b89b033ad11 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 17 Jul 2025 16:33:31 +0200 Subject: [PATCH 122/205] Add fully sampled Cartesian trajectory generation and improve error handling in EllipsePhantom - Implemented `fullysampled` class method in `KTrajectoryCartesian` to generate fully sampled Cartesian trajectories based on the encoding matrix. - Refactored error handling in `EllipsePhantom` to ensure kx and ky shapes are broadcastable and devices match. - Updated tests to validate the new functionality in `KTrajectoryCartesian` and `EllipsePhantom` for k-space data generation. --- .../traj_calculators/KTrajectoryCartesian.py | 18 ++++++++ src/mrpro/phantoms/EllipsePhantom.py | 45 +++++++++++++++---- src/mrpro/phantoms/coils.py | 19 ++++---- tests/data/test_traj_calculators.py | 14 +++++- tests/phantoms/test_ellipse_phantom.py | 25 ++++++++--- 5 files changed, 98 insertions(+), 23 deletions(-) diff --git a/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py b/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py index 633958595..0d5a9a682 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py @@ -59,6 +59,24 @@ def __call__( kz, ky, kx = unsqueeze_tensors_left(kz, ky, kx, ndim=5) return KTrajectory(kz, ky, kx) + @classmethod + def fullysampled(cls, encoding_matrix: SpatialDimension[int]) -> KTrajectory: + """Generate fully sampled Cartesian trajectory. + + Parameters + ---------- + encoding_matrix + Encoded K-space size. + """ + return cls()( + n_k0=encoding_matrix.x, + k0_center=encoding_matrix.x // 2, + k1_idx=torch.arange(encoding_matrix.y)[:, None], + k1_center=encoding_matrix.y // 2, + k2_idx=torch.arange(encoding_matrix.z)[:, None, None], + k2_center=encoding_matrix.z // 2, + ) + @classmethod def gaussian_variable_density( cls, diff --git a/src/mrpro/phantoms/EllipsePhantom.py b/src/mrpro/phantoms/EllipsePhantom.py index 696dbd1bf..3f715c155 100644 --- a/src/mrpro/phantoms/EllipsePhantom.py +++ b/src/mrpro/phantoms/EllipsePhantom.py @@ -2,10 +2,12 @@ from collections.abc import Sequence -import numpy as np import torch from einops import repeat +from mrpro.data.KData import KData +from mrpro.data.KHeader import KHeader +from mrpro.data.KTrajectory import KTrajectory from mrpro.data.SpatialDimension import SpatialDimension from mrpro.phantoms.phantom_elements import EllipseParameters @@ -51,17 +53,21 @@ def kspace(self, ky: torch.Tensor, kx: torch.Tensor) -> torch.Tensor: phantom in the Fourier domain. MRM 58(2) https://doi.org/10.1002/mrm.21292 .. """ - # kx and ky have to be of same shape - if kx.shape != ky.shape: - raise ValueError(f'shape mismatch between kx {kx.shape} and ky {ky.shape}') - - kdata = torch.zeros_like(kx, dtype=torch.complex64) + # kx and ky should be broadcastable + try: + shape = torch.broadcast_shapes(kx.shape, ky.shape) + except RuntimeError: + raise ValueError(f'shape mismatch between kx {kx.shape} and ky {ky.shape}') from None + if kx.device != ky.device: + raise ValueError(f'device mismatch between kx {kx.device} and ky {ky.device}') + + kdata = torch.zeros(shape, dtype=torch.complex64, device=kx.device) for ellipse in self.ellipses: arg = torch.sqrt((ellipse.radius_x * 2) ** 2 * kx**2 + (ellipse.radius_y * 2) ** 2 * ky**2) arg[arg < 1e-6] = 1e-6 # avoid zeros cdata = 2 * 2 * ellipse.radius_x * ellipse.radius_y * 0.5 * torch.special.bessel_j1(torch.pi * arg) / arg - kdata += ( + kdata = kdata + ( torch.exp(-1j * 2 * torch.pi * (ellipse.center_x * kx + ellipse.center_y * ky)) * cdata * ellipse.intensity @@ -69,7 +75,7 @@ def kspace(self, ky: torch.Tensor, kx: torch.Tensor) -> torch.Tensor: # Scale k-space data by factor sqrt(number of points) to ensure correct scaling after FFT with # normalization "ortho". See e.g. https://docs.scipy.org/doc/scipy/tutorial/fft.html - kdata *= np.sqrt(torch.numel(kdata)) + kdata = kdata * kdata.numel() ** 0.5 return kdata def image_space(self, image_dimensions: SpatialDimension[int]) -> torch.Tensor: @@ -98,3 +104,26 @@ def image_space(self, image_dimensions: SpatialDimension[int]) -> torch.Tensor: idata += ellipse.intensity * in_ellipse return repeat(idata, 'y x->other coils z y x', other=1, coils=1, z=1) + + def kdata(self, trajectory: KTrajectory, encoding_matrix: SpatialDimension[int]) -> KData: + """Create k-space data for the phantom. + + Parameters + ---------- + trajectory + Trajectory. + encoding_matrix + Encoding matrix. + """ + if (trajectory.kz != 0).any(): + raise ValueError('Only 2D k-space data is supported') + + data = self.kspace(trajectory.ky, trajectory.kx) + header = KHeader( + recon_fov=SpatialDimension(z=0.01, y=1, x=1), + encoding_fov=SpatialDimension(z=0.01, y=1, x=1), + encoding_matrix=encoding_matrix, + recon_matrix=encoding_matrix, + ) + kdata = KData(data=data, header=header, traj=trajectory) + return kdata diff --git a/src/mrpro/phantoms/coils.py b/src/mrpro/phantoms/coils.py index dd9c208fe..3125d5875 100644 --- a/src/mrpro/phantoms/coils.py +++ b/src/mrpro/phantoms/coils.py @@ -7,7 +7,7 @@ def birdcage_2d( - number_of_coils: int, + n_coils: int, image_dimensions: SpatialDimension[int], relative_radius: float = 1.5, normalize_with_rss: bool = True, @@ -19,21 +19,25 @@ def birdcage_2d( Parameters ---------- - number_of_coils + n_coils number of coil elements image_dimensions - number of voxels in the image - This is a 2D simulation so the output will be (1 number_of_coils 1 image_dimensions.y image_dimensions.x) + number of pixels in the image in y and x direction relative_radius relative radius of birdcage normalize_with_rss If set to true, the calculated sensitivities are normalized by the root-sum-of-squares + Returns + ------- + Coil sensitivities. + Shape: `(1, n_coils, 1, image_dimensions.y, image_dimensions.x)` + References ---------- .. [ISMc] ISMRMRD Python tools https://github.com/ismrmrd/ismrmrd-python-tools """ - dim = [number_of_coils, image_dimensions.y, image_dimensions.x] + dim = [n_coils, image_dimensions.y, image_dimensions.x] x_co, y_co = torch.meshgrid( torch.linspace(-dim[2] // 2, dim[2] // 2 - 1, dim[2]), torch.linspace(-dim[1] // 2, dim[1] // 2 - 1, dim[1]), @@ -53,9 +57,8 @@ def birdcage_2d( sensitivities = (1 / rr) * torch.exp(1j * phi) if normalize_with_rss: - rss = sensitivities.abs().square().sum(0).sqrt() - # Normalize only where rss is > 0 - sensitivities[:, rss > 0] /= rss[None, rss > 0] + rss = sensitivities.abs().square().sum(0, keepdim=True).sqrt() + sensitivities /= rss + 1e-8 return repeat(sensitivities, 'coils y x->other coils z y x', other=1, z=1) diff --git a/tests/data/test_traj_calculators.py b/tests/data/test_traj_calculators.py index 9f2fcb477..66d3ae6b6 100644 --- a/tests/data/test_traj_calculators.py +++ b/tests/data/test_traj_calculators.py @@ -217,6 +217,18 @@ def test_KTrajectoryCartesian_random(acceleration: int = 2, n_k: int = 64) -> No assert center_idx in lines1 +def test_KTrajectoryCartesian_fullysampled() -> None: + """Test the generation of a fully sampled Cartesian trajectory""" + traj = KTrajectoryCartesian.fullysampled(SpatialDimension(10, 64, 64)) + assert traj.kx.shape == (1, 1, 1, 1, 1, 64) + assert traj.ky.shape == (1, 1, 1, 1, 64, 1) + assert traj.kz.shape == (1, 1, 1, 10, 1, 1) + assert len(traj.kx.unique()) == 64 + assert len(traj.ky.unique()) == 64 + assert len(traj.kz.unique()) == 10 + assert traj.kx.diff().unique() == 1 + + @pytest.mark.parametrize('acceleration', [1, 16]) def test_KTrajectoryCartesian_random_edgecases(acceleration: int, n_k=128) -> None: """Test the generation of a 2D gaussian variable density pattern""" @@ -229,7 +241,7 @@ def test_KTrajectoryCartesian_random_edgecases(acceleration: int, n_k=128) -> No assert center_idx in traj.ky.ravel() -def test_KTrajectorySpiral(): +def test_KTrajectorySpiral() -> None: """Test the generation of a 2D spiral trajectory""" trajectory_calculator = KTrajectorySpiral2D() trajectory = trajectory_calculator( diff --git a/tests/phantoms/test_ellipse_phantom.py b/tests/phantoms/test_ellipse_phantom.py index c7ab58850..0946d94dd 100644 --- a/tests/phantoms/test_ellipse_phantom.py +++ b/tests/phantoms/test_ellipse_phantom.py @@ -3,25 +3,27 @@ import pytest import torch from mrpro.data import SpatialDimension -from mrpro.operators import FastFourierOp +from mrpro.data.traj_calculators import KTrajectoryCartesian +from mrpro.operators import FastFourierOp, FourierOp from tests import relative_image_difference +from tests.phantoms import EllipsePhantomTestData -def test_image_space(ellipse_phantom): +def test_image_space(ellipse_phantom: EllipsePhantomTestData) -> None: """Check if image space has correct shape.""" img_dimension = SpatialDimension(z=1, y=ellipse_phantom.n_y, x=ellipse_phantom.n_x) img = ellipse_phantom.phantom.image_space(img_dimension) assert img.shape[-2:] == (ellipse_phantom.n_y, ellipse_phantom.n_x) -def test_kspace_correct_shape(ellipse_phantom): +def test_kspace_correct_shape(ellipse_phantom: EllipsePhantomTestData) -> None: """Check if kspace has correct shape.""" kdata = ellipse_phantom.phantom.kspace(ellipse_phantom.ky, ellipse_phantom.kx) assert kdata.shape == (ellipse_phantom.n_y, ellipse_phantom.n_x) -def test_kspace_raises_error(ellipse_phantom): +def test_kspace_raises_error(ellipse_phantom: EllipsePhantomTestData) -> None: """Check if kspace raises error if kx and ky have different shapes.""" [kx_, _] = torch.meshgrid( torch.linspace(-ellipse_phantom.n_x // 2, ellipse_phantom.n_x // 2, ellipse_phantom.n_x + 1), @@ -32,7 +34,7 @@ def test_kspace_raises_error(ellipse_phantom): ellipse_phantom.phantom.kspace(ellipse_phantom.ky, kx_) -def test_kspace_image_match(ellipse_phantom): +def test_kspace_image_match(ellipse_phantom: EllipsePhantomTestData) -> None: """Check if fft of kspace matches image.""" img_dimension = SpatialDimension(z=1, y=ellipse_phantom.n_y, x=ellipse_phantom.n_x) img = ellipse_phantom.phantom.image_space(img_dimension) @@ -42,4 +44,15 @@ def test_kspace_image_match(ellipse_phantom): # Due to discretization artifacts the reconstructed image will be different to the reference image. Using standard # testing functions such as numpy.testing.assert_almost_equal fails because there are few voxels with high # differences along the edges of the elliptic objects. - assert relative_image_difference(reconstructed_img, img[0, 0, 0, :, :]) <= 0.05 + assert relative_image_difference(reconstructed_img, img[0, 0, 0]) <= 0.05 + + +def test_kspace_fullysampled(ellipse_phantom: EllipsePhantomTestData) -> None: + """Check if kspace has correct shape.""" + matrix = SpatialDimension(z=1, y=ellipse_phantom.n_y, x=ellipse_phantom.n_x) + traj = KTrajectoryCartesian.fullysampled(matrix) + kdata = ellipse_phantom.phantom.kdata(traj, matrix) + fourier_op = FourierOp.from_kdata(kdata) + (reconstructed_img,) = fourier_op.adjoint(kdata.data) + img = ellipse_phantom.phantom.image_space(matrix) + assert relative_image_difference(reconstructed_img[0, 0, 0], img[0, 0, 0]) <= 0.05 From 089d312c33d2ccd01ac1067d44f6982ef14ab6ee Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 17 Jul 2025 16:37:00 +0200 Subject: [PATCH 123/205] docstring --- .../data/traj_calculators/KTrajectoryCartesian.py | 4 ++++ src/mrpro/phantoms/EllipsePhantom.py | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py b/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py index 0d5a9a682..f75009127 100644 --- a/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py +++ b/src/mrpro/data/traj_calculators/KTrajectoryCartesian.py @@ -67,6 +67,10 @@ def fullysampled(cls, encoding_matrix: SpatialDimension[int]) -> KTrajectory: ---------- encoding_matrix Encoded K-space size. + + Returns + ------- + Cartesian trajectory. """ return cls()( n_k0=encoding_matrix.x, diff --git a/src/mrpro/phantoms/EllipsePhantom.py b/src/mrpro/phantoms/EllipsePhantom.py index 3f715c155..cd05c7929 100644 --- a/src/mrpro/phantoms/EllipsePhantom.py +++ b/src/mrpro/phantoms/EllipsePhantom.py @@ -47,6 +47,10 @@ def kspace(self, ky: torch.Tensor, kx: torch.Tensor) -> torch.Tensor: kx k-space locations in kx (frequency encoding direction). + Returns + ------- + K-space data. + References ---------- .. [KOA2007] Koay C, Sarlls J, Oezarslan E (2007) Three-dimensional analytical magnetic resonance imaging @@ -86,6 +90,10 @@ def image_space(self, image_dimensions: SpatialDimension[int]) -> torch.Tensor: image_dimensions Number of voxels in the image. This is a 2D simulation, so the output will be of shape `(1 1 1 image_dimensions.y image_dimensions.x)`. + + Returns + ------- + Image representation of phantom """ # Calculate image representation of phantom ny, nx = image_dimensions.y, image_dimensions.x @@ -114,6 +122,10 @@ def kdata(self, trajectory: KTrajectory, encoding_matrix: SpatialDimension[int]) Trajectory. encoding_matrix Encoding matrix. + + Returns + ------- + K-space data with header and trajectory. """ if (trajectory.kz != 0).any(): raise ValueError('Only 2D k-space data is supported') From 1959fda7f7644184b8ebf7da9228858c6faa335c Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 17 Jul 2025 16:44:26 +0200 Subject: [PATCH 124/205] fix test --- tests/data/test_traj_calculators.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/data/test_traj_calculators.py b/tests/data/test_traj_calculators.py index 66d3ae6b6..c2585be3b 100644 --- a/tests/data/test_traj_calculators.py +++ b/tests/data/test_traj_calculators.py @@ -220,9 +220,9 @@ def test_KTrajectoryCartesian_random(acceleration: int = 2, n_k: int = 64) -> No def test_KTrajectoryCartesian_fullysampled() -> None: """Test the generation of a fully sampled Cartesian trajectory""" traj = KTrajectoryCartesian.fullysampled(SpatialDimension(10, 64, 64)) - assert traj.kx.shape == (1, 1, 1, 1, 1, 64) - assert traj.ky.shape == (1, 1, 1, 1, 64, 1) - assert traj.kz.shape == (1, 1, 1, 10, 1, 1) + assert traj.kx.shape == (1, 1, 1, 1, 64) + assert traj.ky.shape == (1, 1, 1, 64, 1) + assert traj.kz.shape == (1, 1, 10, 1, 1) assert len(traj.kx.unique()) == 64 assert len(traj.ky.unique()) == 64 assert len(traj.kz.unique()) == 10 From 956fc2faa689c84a0491f504329abf20f07c1b39 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 17 Jul 2025 17:54:06 +0200 Subject: [PATCH 125/205] Refactor data consistency modules and enhance attention imports - Added `data_consistency` to the `__init__.py` of the `nn` module. - Created `attention/__init__.py` to organize attention-related classes. - Updated `AnalyticCartesianDC` to improve parameter validation and logging. - Refactored `ConjugateGradientDC` to correct type hint for `csm`. - Adjusted import statements in test files for `LinearSelfAttention`, `ShiftedWindowAttention`, `SqueezeExcitation`, and `TransposedAttention` to reflect the new structure. --- src/mrpro/nn/__init__.py | 2 + .../nn/attention/{__init__,py => __init__.py} | 0 .../data_consistency/AnalyticCartesianDC.py | 65 ++++++++++++++++--- .../data_consistency/ConjugateGradientDC.py | 4 +- tests/nn/test_linearselfattention.py | 2 +- tests/nn/test_shiftedwindowattention.py | 2 +- tests/nn/test_sqeezeexcitation.py | 2 +- tests/nn/test_transposedattention.py | 2 +- 8 files changed, 64 insertions(+), 15 deletions(-) rename src/mrpro/nn/attention/{__init__,py => __init__.py} (100%) diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index 5dc019ebc..16842a302 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -20,6 +20,7 @@ from mrpro.nn.ComplexAsChannel import ComplexAsChannel from mrpro.nn import nets from mrpro.nn import attention +from mrpro.nn import data_consistency from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.RMSNorm import RMSNorm @@ -42,5 +43,6 @@ "Residual", "Sequential", "attention", + "data_consistency", "nets" ] \ No newline at end of file diff --git a/src/mrpro/nn/attention/__init__,py b/src/mrpro/nn/attention/__init__.py similarity index 100% rename from src/mrpro/nn/attention/__init__,py rename to src/mrpro/nn/attention/__init__.py diff --git a/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py index 742490cd6..3bb658d82 100644 --- a/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py +++ b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py @@ -1,3 +1,5 @@ +from typing import overload + import torch from torch.nn import Module, Parameter @@ -30,15 +32,59 @@ class AnalyticCartesianDC(Module): """ def __init__(self, initial_regularization_weight: torch.Tensor | float): + """Initialize the data consistency. + + Parameters + ---------- + initial_regularization_weight + Initial regularization weight. The regularization weight is a trainable parameter. + Must be a positive scalar. + """ super().__init__() - self.regularization_weight = Parameter(torch.as_tensor(initial_regularization_weight)) + weight = torch.as_tensor(initial_regularization_weight) + if weight.ndim != 0: + raise ValueError('Regularization weight must be a scalar') + if weight.item() <= 0: + raise ValueError('Regularization weight must be positive') + self.log_weight = Parameter(weight.log()) + + @overload + def __call__(self, image: torch.Tensor, data: KData, fourier_op: FourierOp | None = None) -> torch.Tensor: ... + + @overload + def __call__(self, image: torch.Tensor, data: torch.Tensor, fourier_op: FourierOp) -> torch.Tensor: ... + + def __call__( + self, image: torch.Tensor, data: KData | torch.Tensor, fourier_op: FourierOp | None = None + ) -> torch.Tensor: + """Apply the data consistency. + + Parameters + ---------- + image + Current image estimate. + data + k-space data. + fourier_op + Fourier operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` object, the Fourier operator is + automatically created from the data. + + Returns + ------- + Updated image estimate. + """ + return super().__call__(image, data, fourier_op) def forward( - self, - x: torch.Tensor, - data: KData | torch.Tensor, - fourier_op: FourierOp, - ): + self, image: torch.Tensor, data: KData | torch.Tensor, fourier_op: FourierOp | None = None + ) -> torch.Tensor: + """Apply the data consistency.""" + if fourier_op is None: + if isinstance(data, KData): + fourier_op = FourierOp.from_kdata(data) + else: + raise ValueError('Either a KData or a FourierOp is required') + if not isinstance(fourier_op, FourierOp) or fourier_op._nufft_dims or fourier_op._fast_fourier_op is None: raise ValueError('Only Cartesian acquisitions are supported') @@ -46,7 +92,8 @@ def forward( fft_op = fourier_op._fast_fourier_op sampling_op = fourier_op._cart_sampling_op if fourier_op._cart_sampling_op is not None else IdentityOp() (zero_filled,) = sampling_op.adjoint(data_) - (k_pred,) = fft_op(x) - (k,) = sampling_op.gram((zero_filled - k_pred) / (1 + self.regularization_weight)) + (k_pred,) = fft_op(image) + regularization_weight = self.log_weight.exp() + (k,) = sampling_op.gram((zero_filled - k_pred) / (1 + regularization_weight)) (delta,) = fft_op.H(k) - return x + delta + return image + delta diff --git a/src/mrpro/nn/data_consistency/ConjugateGradientDC.py b/src/mrpro/nn/data_consistency/ConjugateGradientDC.py index 646527d0a..8243775eb 100644 --- a/src/mrpro/nn/data_consistency/ConjugateGradientDC.py +++ b/src/mrpro/nn/data_consistency/ConjugateGradientDC.py @@ -1,4 +1,4 @@ -from inspect import Parameter +"""Conjugate gradient data consistency.""" import torch from torch.nn import Module, Parameter @@ -43,7 +43,7 @@ def forward( x: torch.Tensor, data: torch.Tensor | KData, fourier_op: FourierOp, - csm: torch.Tensor | CSMData | None, + csm: torch.Tensor | CsmData | None, ): data_ = data.data if isinstance(data, KData) else data zero_filled = fourier_op.adjoint(data_) diff --git a/tests/nn/test_linearselfattention.py b/tests/nn/test_linearselfattention.py index d7649a3df..11d17b301 100644 --- a/tests/nn/test_linearselfattention.py +++ b/tests/nn/test_linearselfattention.py @@ -1,7 +1,7 @@ """Tests for LinearSelfAttention module.""" import pytest -from mrpro.nn.attention.LinearSelfAttention import LinearSelfAttention +from mrpro.nn.attention import LinearSelfAttention from mrpro.utils import RandomGenerator diff --git a/tests/nn/test_shiftedwindowattention.py b/tests/nn/test_shiftedwindowattention.py index 9ccd4f5d0..c863a680a 100644 --- a/tests/nn/test_shiftedwindowattention.py +++ b/tests/nn/test_shiftedwindowattention.py @@ -1,5 +1,5 @@ import pytest -from mrpro.nn import ShiftedWindowAttention +from mrpro.nn.attention import ShiftedWindowAttention from mrpro.utils import RandomGenerator diff --git a/tests/nn/test_sqeezeexcitation.py b/tests/nn/test_sqeezeexcitation.py index 7b7509f1d..8b2a9720e 100644 --- a/tests/nn/test_sqeezeexcitation.py +++ b/tests/nn/test_sqeezeexcitation.py @@ -1,7 +1,7 @@ """Tests for SqueezeExcitation module.""" import pytest -from mrpro.nn.attention.SqueezeExcitation import SqueezeExcitation +from mrpro.nn.attention import SqueezeExcitation from mrpro.utils import RandomGenerator diff --git a/tests/nn/test_transposedattention.py b/tests/nn/test_transposedattention.py index a2688f36d..b2c27e8cf 100644 --- a/tests/nn/test_transposedattention.py +++ b/tests/nn/test_transposedattention.py @@ -1,7 +1,7 @@ """Tests for TransposedAttention module.""" import pytest -from mrpro.nn.attention.TransposedAttention import TransposedAttention +from mrpro.nn.attention import TransposedAttention from mrpro.utils import RandomGenerator From fd670244ef3219a4a9fd8387be9356e4569ed74c Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 17 Jul 2025 22:47:19 +0200 Subject: [PATCH 126/205] dc --- .../data_consistency/ConjugateGradientDC.py | 108 ++++++++++++++---- .../nn/data_consistency/GradientDescentDC.py | 68 ++++++++++- tests/nn/data_consistency/conftest.py | 44 +++++++ .../test_analyticcertesiandc.py | 17 +++ .../test_conjugategradientdc.py | 17 +++ .../test_gradientdescentdc.py | 17 +++ 6 files changed, 245 insertions(+), 26 deletions(-) create mode 100644 tests/nn/data_consistency/conftest.py create mode 100644 tests/nn/data_consistency/test_analyticcertesiandc.py create mode 100644 tests/nn/data_consistency/test_conjugategradientdc.py create mode 100644 tests/nn/data_consistency/test_gradientdescentdc.py diff --git a/src/mrpro/nn/data_consistency/ConjugateGradientDC.py b/src/mrpro/nn/data_consistency/ConjugateGradientDC.py index 8243775eb..f81b24f98 100644 --- a/src/mrpro/nn/data_consistency/ConjugateGradientDC.py +++ b/src/mrpro/nn/data_consistency/ConjugateGradientDC.py @@ -1,5 +1,7 @@ """Conjugate gradient data consistency.""" +from typing import overload + import torch from torch.nn import Module, Parameter @@ -7,7 +9,7 @@ from mrpro.data.KData import KData from mrpro.operators.ConjugateGradientOp import ConjugateGradientOp from mrpro.operators.FourierOp import FourierOp -from mrpro.operators.IdentityOp import IdentityOp +from mrpro.operators.LinearOperator import LinearOperator from mrpro.operators.SensitivityOp import SensitivityOp @@ -20,32 +22,96 @@ def __init__(self, initial_regularization_weight: torch.Tensor | float): Parameters ---------- initial_regularization_weight - Initial regularization weight. + Initial regularization weight. The regularization weight is a trainable parameter. + Must be a positive scalar. """ super().__init__() - self.regularization_weight = Parameter(torch.as_tensor(initial_regularization_weight)) + weight = torch.as_tensor(initial_regularization_weight) + if weight.ndim != 0: + raise ValueError('Regularization weight must be a scalar') + if weight.item() <= 0: + raise ValueError('Regularization weight must be positive') + self.log_weight = Parameter(weight.log()) - def operator_factory( - fourier_op: FourierOp, csm: torch.Tensor | CsmData | None, regularization_weight: torch.Tensor | float, *_ - ): - csm_op = SensitivityOp(csm) if csm is not None else IdentityOp() - op = csm_op.H @ fourier_op.gram @ csm_op + regularization_weight - return op + @overload + def __call__( + self, + image: torch.Tensor, + data: KData, + fourier_op: FourierOp | None = None, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: ... + + @overload + def __call__( + self, + image: torch.Tensor, + data: torch.Tensor, + fourier_op: FourierOp, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: ... + + def __call__( + self, + image: torch.Tensor, + data: KData | torch.Tensor, + fourier_op: LinearOperator | None = None, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: + """Apply the data consistency. - self.cg_op = ConjugateGradientOp( - operator_factory=operator_factory, - rhs_factory=lambda _fourier, _csm, regularization_weight, zero_filled, regularization: zero_filled - + regularization_weight * regularization, - ) + Parameters + ---------- + image + Current image estimate. + data + k-space data. + fourier_op + Fourier operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` object, + the Fourier operator is automatically created from the data. + This operator can already include the coil sensitivity weighting, if gradients wrt the coil sensitivity maps + NOT required. Otherwise, they should be given as an additional argument. + csm + Coil sensitivity maps. If None, no coil sensitivity weighting is applied. + + Returns + ------- + Updated image estimate. + """ + return super().__call__(image, data, fourier_op, csm) def forward( self, - x: torch.Tensor, + image: torch.Tensor, data: torch.Tensor | KData, - fourier_op: FourierOp, - csm: torch.Tensor | CsmData | None, - ): + fourier_op: FourierOp | None = None, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: + """Apply the data consistency.""" + if fourier_op is None: + if isinstance(data, KData): + fourier_op = FourierOp.from_kdata(data) + else: + raise ValueError('Either a KData or a FourierOp is required') + data_ = data.data if isinstance(data, KData) else data - zero_filled = fourier_op.adjoint(data_) - x = self.cg_op(fourier_op, csm, self.regularization_weight, zero_filled, x) - return x + + if csm is None: + csm = torch.tensor(()) + elif isinstance(csm, CsmData): + csm = csm.data + + def operator_factory(csm: torch.Tensor, weight: torch.Tensor, *_): + op = fourier_op.gram + if csm.numel(): + csm_op = SensitivityOp(csm) + op = csm_op.H @ op @ csm_op + op = op + weight + return op + + def rhs_factory(_csm: torch.Tensor, weight: torch.Tensor, zero_filled: torch.Tensor, image: torch.Tensor): + return (zero_filled + weight * image,) + + cg_op = ConjugateGradientOp(operator_factory=operator_factory, rhs_factory=rhs_factory) + (result,) = cg_op(csm, self.log_weight.exp(), fourier_op.adjoint(data_)[0], image) + return result diff --git a/src/mrpro/nn/data_consistency/GradientDescentDC.py b/src/mrpro/nn/data_consistency/GradientDescentDC.py index ee05ca8db..38eb97f6c 100644 --- a/src/mrpro/nn/data_consistency/GradientDescentDC.py +++ b/src/mrpro/nn/data_consistency/GradientDescentDC.py @@ -1,7 +1,10 @@ +from typing import overload + import torch from torch.nn import Module, Parameter from mrpro.data.KData import KData +from mrpro.operators.FourierOp import FourierOp from mrpro.operators.LinearOperator import LinearOperator @@ -15,6 +18,7 @@ class GradientDescentDC(Module): ---------- initial_stepsize Initial stepsize. The stepsize is a trainable parameter. + Must be a positive scalar. n_steps Number of gradient descent steps. @@ -23,17 +27,71 @@ class GradientDescentDC(Module): The updated image. """ - def __init__(self, initial_stepsize: float, n_steps: int = 1) -> None: + def __init__(self, initial_stepsize: float | torch.Tensor, n_steps: int = 1) -> None: + """Initialize the gradient descent data consistency. + + Parameters + ---------- + initial_stepsize + Initial stepsize. The stepsize is a trainable parameter. + Must be a positive scalar. + n_steps + Number of gradient descent steps. + """ super().__init__() - self.stepsize = Parameter(torch.tensor(initial_stepsize)) + stepsize = torch.as_tensor(initial_stepsize) + if stepsize.ndim != 0: + raise ValueError('Stepsize must be a scalar') + if stepsize.item() <= 0: + raise ValueError('Stepsize must be positive') + self.log_stepsize = Parameter(stepsize.log()) self.n_steps = n_steps + @overload + def __call__( + self, image: torch.Tensor, data: KData, acquisition_operator: LinearOperator | None = None + ) -> torch.Tensor: ... + + @overload + def __call__( + self, image: torch.Tensor, data: torch.Tensor, acquisition_operator: LinearOperator + ) -> torch.Tensor: ... + + def __call__( + self, image: torch.Tensor, data: KData | torch.Tensor, acquisition_operator: LinearOperator | None = None + ) -> torch.Tensor: + """Apply the data consistency. + + Parameters + ---------- + image + Current image estimate. + data + k-space data. + acquisition_operator + Acquisition operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` object, + the Fourier operator is automatically created from the data. + + Returns + ------- + Updated image estimate. + """ + return super().__call__(image, data, acquisition_operator) + def forward( - self, x: torch.Tensor, data: KData | torch.Tensor, acquisition_operator: LinearOperator + self, image: torch.Tensor, data: KData | torch.Tensor, acquisition_operator: LinearOperator | None = None ) -> torch.Tensor: - """Forward pass.""" + """Apply the data consistency.""" + if acquisition_operator is None: + if isinstance(data, KData): + acquisition_operator = FourierOp.from_kdata(data) + else: + raise ValueError('Either a KData or an acquisition operator is required') + data_ = data.data if isinstance(data, KData) else data + stepsize = self.log_stepsize.exp() + x = image for _ in range(self.n_steps): residual = acquisition_operator(x)[0] - data_ - x = x - self.stepsize * acquisition_operator.adjoint(residual)[0] + x = x - stepsize * acquisition_operator.adjoint(residual)[0] return x diff --git a/tests/nn/data_consistency/conftest.py b/tests/nn/data_consistency/conftest.py new file mode 100644 index 000000000..e15087e04 --- /dev/null +++ b/tests/nn/data_consistency/conftest.py @@ -0,0 +1,44 @@ +import pytest +from mrpro.data.KData import KData +from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.data.traj_calculators.KTrajectoryCartesian import KTrajectoryCartesian +from mrpro.operators.FourierOp import FourierOp +from mrpro.phantoms.EllipsePhantom import EllipsePhantom +from mrpro.utils import RandomGenerator + + +@pytest.fixture +def kdata(): + matrix = SpatialDimension(x=128, y=128, z=1) + kdata = EllipsePhantom().kdata(KTrajectoryCartesian.fullysampled(matrix), matrix) + return kdata + + +@pytest.fixture +def kdata_noisy(kdata: KData): + kdata_noisy = kdata.clone() + kdata_noisy.data += 0.1 * RandomGenerator(123).randn_like(kdata_noisy.data) + return kdata_noisy + + +@pytest.fixture +def kdata_us(kdata: KData): + return kdata[..., ::2, :].clone() + + +@pytest.fixture +def image_noisy(kdata_noisy: KData): + fourier_op = FourierOp.from_kdata(kdata_noisy) + return fourier_op.adjoint(kdata_noisy.data)[0] + + +@pytest.fixture +def image(kdata: KData): + fourier_op = FourierOp.from_kdata(kdata) + return fourier_op.adjoint(kdata.data)[0] + + +@pytest.fixture +def image_us(kdata_us: KData): + fourier_op = FourierOp.from_kdata(kdata_us) + return fourier_op.adjoint(kdata_us.data)[0] diff --git a/tests/nn/data_consistency/test_analyticcertesiandc.py b/tests/nn/data_consistency/test_analyticcertesiandc.py new file mode 100644 index 000000000..dc200e8b7 --- /dev/null +++ b/tests/nn/data_consistency/test_analyticcertesiandc.py @@ -0,0 +1,17 @@ +import torch +from mrpro.data.KData import KData +from mrpro.nn.data_consistency.AnalyticCartesianDC import AnalyticCartesianDC + + +def test_analytic_cartesian_dc(image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor): + image_noisy = image_noisy.clone().requires_grad_(True) + dc = AnalyticCartesianDC(initial_regularization_weight=1e-6) + result = dc(image_noisy, kdata_us) + loss = (result - image).abs().mean() + assert loss < (image_noisy - image).abs().mean() + assert loss < (image_us - image).abs().mean() + loss.backward() + assert image_noisy.grad is not None + assert dc.log_weight.grad is not None + assert not dc.log_weight.grad.isnan() + assert not image_noisy.grad.isnan().any() diff --git a/tests/nn/data_consistency/test_conjugategradientdc.py b/tests/nn/data_consistency/test_conjugategradientdc.py new file mode 100644 index 000000000..2527259f4 --- /dev/null +++ b/tests/nn/data_consistency/test_conjugategradientdc.py @@ -0,0 +1,17 @@ +import torch +from mrpro.data.KData import KData +from mrpro.nn.data_consistency.ConjugateGradientDC import ConjugateGradientDC + + +def test_conjugate_gradient_dc(image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor): + image_noisy = image_noisy.clone().requires_grad_(True) + dc = ConjugateGradientDC(initial_regularization_weight=1.0) + result = dc(image_noisy, kdata_us) + loss = (result - image).abs().mean() + assert loss < (image_noisy - image).abs().mean() + assert loss < (image_us - image).abs().mean() + loss.backward() + assert image_noisy.grad is not None + assert dc.log_weight.grad is not None + assert not dc.log_weight.grad.isnan() + assert not image_noisy.grad.isnan().any() diff --git a/tests/nn/data_consistency/test_gradientdescentdc.py b/tests/nn/data_consistency/test_gradientdescentdc.py new file mode 100644 index 000000000..fae284b8e --- /dev/null +++ b/tests/nn/data_consistency/test_gradientdescentdc.py @@ -0,0 +1,17 @@ +import torch +from mrpro.data.KData import KData +from mrpro.nn.data_consistency.GradientDescentDC import GradientDescentDC + + +def test_gradient_descent_dc(image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor): + image_noisy = image_noisy.clone().requires_grad_(True) + dc = GradientDescentDC(initial_stepsize=1.0) + result = dc(image_noisy, kdata_us) + loss = (result - image).abs().mean() + assert loss < (image_noisy - image).abs().mean() + assert loss < (image_us - image).abs().mean() + loss.backward() + assert image_noisy.grad is not None + assert dc.log_stepsize.grad is not None + assert not dc.log_stepsize.grad.isnan() + assert not image_noisy.grad.isnan().any() From e629a55f768f385a153f8cae923c171e9f06cb74 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 18 Jul 2025 00:52:56 +0200 Subject: [PATCH 127/205] update --- src/mrpro/nn/ResBlock.py | 24 +-- .../nn/attention/SpatialTransformerBlock.py | 13 ++ .../data_consistency/AnalyticCartesianDC.py | 8 +- .../nn/data_consistency/GradientDescentDC.py | 6 +- src/mrpro/nn/encoding.py | 18 +-- src/mrpro/nn/join.py | 2 +- src/mrpro/nn/nets/BasicCNN.py | 20 +-- src/mrpro/nn/nets/DCAE.py | 138 ++++++++++-------- src/mrpro/nn/nets/Restormer.py | 72 ++++----- src/mrpro/nn/nets/SwinIR.py | 134 ++++++++--------- src/mrpro/nn/nets/UNet.py | 95 ++++++------ src/mrpro/nn/nets/Uformer.py | 74 +++++----- tests/nn/nets/test_unet.py | 6 +- tests/nn/test_resblock.py | 2 +- 14 files changed, 325 insertions(+), 287 deletions(-) diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py index 8f61e6022..f115e205f 100644 --- a/src/mrpro/nn/ResBlock.py +++ b/src/mrpro/nn/ResBlock.py @@ -13,16 +13,16 @@ class ResBlock(CondMixin, Module): """Residual convolution block with two convolutions.""" - def __init__(self, dim: int, channels_in: int, channels_out: int, cond_dim: int) -> None: + def __init__(self, n_dim: int, n_channels_in: int, n_channels_out: int, cond_dim: int) -> None: """Initialize the ResBlock. Parameters ---------- - dim - The dimension, i.e. 1, 2 or 3. - channels_in + n_dim + The number of dimensions, i.e. 1, 2 or 3. + n_channels_in The number of channels in the input tensor. - channels_out + n_channels_out The number of channels in the output tensor. cond_dim The number of features in the conditioning tensor used in a FiLM. @@ -32,20 +32,20 @@ def __init__(self, dim: int, channels_in: int, channels_out: int, cond_dim: int) super().__init__() self.rezero = torch.nn.Parameter(torch.tensor(0.1)) self.block = Sequential( - GroupNorm(channels_in), + GroupNorm(n_channels_in), SiLU(), - ConvND(dim)(channels_in, channels_out, kernel_size=3, padding=1), - GroupNorm(channels_out), + ConvND(n_dim)(n_channels_in, n_channels_out, kernel_size=3, padding=1), + GroupNorm(n_channels_out), SiLU(), - ConvND(dim)(channels_out, channels_out, kernel_size=3, padding=1), + ConvND(n_dim)(n_channels_out, n_channels_out, kernel_size=3, padding=1), ) if cond_dim > 0: - self.block.insert(-3, FiLM(channels_out, cond_dim)) + self.block.insert(-3, FiLM(n_channels_out, cond_dim)) - if channels_out == channels_in: + if n_channels_out == n_channels_in: self.skip_connection: Module = Identity() else: - self.skip_connection = ConvND(dim)(channels_in, channels_out, kernel_size=1) + self.skip_connection = ConvND(n_dim)(n_channels_in, n_channels_out, kernel_size=1) def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the ResBlock. diff --git a/src/mrpro/nn/attention/SpatialTransformerBlock.py b/src/mrpro/nn/attention/SpatialTransformerBlock.py index 50b5bbfa7..d278c232a 100644 --- a/src/mrpro/nn/attention/SpatialTransformerBlock.py +++ b/src/mrpro/nn/attention/SpatialTransformerBlock.py @@ -147,4 +147,17 @@ def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch return skip + h def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the spatial transformer block. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. If None, no conditioning is applied. + + Returns + ------- + Output tensor. + """ return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py index 3bb658d82..ed6b812b1 100644 --- a/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py +++ b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py @@ -1,3 +1,5 @@ +"""Analytic Cartesian data consistency.""" + from typing import overload import torch @@ -13,8 +15,8 @@ class AnalyticCartesianDC(Module): Solves the following problem: :math:`\min_x \|Ax - k\|_2^2 + \lambda \|x-p\|_2^2` - where :math:`A` is the acquisition operator and :math:`k` is the data, :math:`\lambda` is the regularization parameter, - and :math:`p` is the regularization image/prior analytically. :math:`A^H A` has to be diagonal. This is a special case + where :math:`A` is the acquisition operator and :math:`k` is the data, :math:`\lambda` is the regularization + parameter and :math:`p` is the regularization image/prior analytically. :math:`A^H A` has to be diagonal. This is a special case for a Cartesian acquisition without coil sensitivity weighting. This can be used for either single-coil data or to apply data consistency to each coil image [NOSENSE]_ @@ -62,7 +64,7 @@ def __call__( Parameters ---------- image - Current image estimate. + Current image estimate, i.e. the regularized image. data k-space data. fourier_op diff --git a/src/mrpro/nn/data_consistency/GradientDescentDC.py b/src/mrpro/nn/data_consistency/GradientDescentDC.py index 38eb97f6c..579e5181a 100644 --- a/src/mrpro/nn/data_consistency/GradientDescentDC.py +++ b/src/mrpro/nn/data_consistency/GradientDescentDC.py @@ -1,3 +1,5 @@ +"""Gradient descent data consistency.""" + from typing import overload import torch @@ -69,8 +71,8 @@ def __call__( data k-space data. acquisition_operator - Acquisition operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` object, - the Fourier operator is automatically created from the data. + Acquisition operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` + object, the Fourier operator is automatically created from the data. Returns ------- diff --git a/src/mrpro/nn/encoding.py b/src/mrpro/nn/encoding.py index 39f48c51e..94d527b21 100644 --- a/src/mrpro/nn/encoding.py +++ b/src/mrpro/nn/encoding.py @@ -63,13 +63,13 @@ class AbsolutePositionEncoding(Module): encoding: torch.Tensor - def __init__(self, dim: int, features: int, include_radii: bool = True, base_resolution: int = 128): + def __init__(self, n_dim: int, features: int, include_radii: bool = True, base_resolution: int = 128): """Initialize absolute position encoding layer. Parameters ---------- - dim - Dimension of the input space (1, 2, or 3) + n_dim + Dimensions of the input space (1, 2, or 3) features Number of output features include_radii @@ -79,19 +79,19 @@ def __init__(self, dim: int, features: int, include_radii: bool = True, base_res """ super().__init__() - coords = [unsqueeze_right(torch.linspace(-1, 1, base_resolution), i) for i in range(dim)] + coords = [unsqueeze_right(torch.linspace(-1, 1, base_resolution), i) for i in range(n_dim)] if include_radii: - for n in range(2, dim + 1): + for n in range(2, n_dim + 1): for combination in combinations(coords, n): coords.append((2 * sum([c**2 for c in combination])) ** 0.5 - 1) n_freqs = ceil(features / len(coords) / 2) - freqs = unsqueeze_right((base_resolution) ** torch.linspace(0, 1, n_freqs), dim) + freqs = unsqueeze_right((base_resolution) ** torch.linspace(0, 1, n_freqs), n_dim) encoding = [] for coord in coords: - encoding.append(torch.sin(coord * freqs).broadcast_to(1, -1, *((base_resolution,) * dim))) - encoding.append(torch.cos(coord * freqs).broadcast_to(1, -1, *((base_resolution,) * dim))) + encoding.append(torch.sin(coord * freqs).broadcast_to(1, -1, *((base_resolution,) * n_dim))) + encoding.append(torch.cos(coord * freqs).broadcast_to(1, -1, *((base_resolution,) * n_dim))) self.register_buffer('encoding', torch.cat(encoding, dim=1)[:, :features]) - self.interpolation_mode = ['linear', 'bilinear', 'trilinear'][dim - 1] + self.interpolation_mode = ['linear', 'bilinear', 'trilinear'][n_dim - 1] def __call__(self, x: torch.Tensor) -> torch.Tensor: """ diff --git a/src/mrpro/nn/join.py b/src/mrpro/nn/join.py index 204f301a8..b6a749c47 100644 --- a/src/mrpro/nn/join.py +++ b/src/mrpro/nn/join.py @@ -29,7 +29,7 @@ def _fix_shapes( if mode == 'zero' or mode == 'crop': return tuple(pad_or_crop(x, target, dim=dim, mode='constant', value=0.0) for x in xs) else: - return tuple(pad_or_crop(x, target, dim=dim, mode=mode) for x in xs) # type: ignore + return tuple(pad_or_crop(x, target, dim=dim, mode=mode) for x in xs) # type: ignore[arg-type] class Concat(Module): diff --git a/src/mrpro/nn/nets/BasicCNN.py b/src/mrpro/nn/nets/BasicCNN.py index f294de715..adca090b6 100644 --- a/src/mrpro/nn/nets/BasicCNN.py +++ b/src/mrpro/nn/nets/BasicCNN.py @@ -25,9 +25,9 @@ class BasicCNN(Sequential): def __init__( self, - dim: int, - channels_in: int, - channels_out: int, + n_dim: int, + n_channels_in: int, + n_channels_out: int, norm: Literal['batch', 'group', 'instance', 'none', 'layer'] = 'none', activation: Literal['relu', 'silu', 'leaky_relu'] = 'relu', n_features: Sequence[int] = (64, 64, 64), @@ -37,11 +37,11 @@ def __init__( Parameters ---------- - dim + n_dim The number of spatial dimensions of the input tensor. - channels_in + n_channels_in The number of input channels. - channels_out + n_channels_out The number of output channels. norm The type of normalization to use. If 'batch', use batch normalization. If 'group', use group normalization, @@ -59,11 +59,11 @@ def __init__( super().__init__() use_film = cond_dim > 0 - self.append(ConvND(dim)(channels_in, n_features[0], kernel_size=3, padding='same')) + self.append(ConvND(n_dim)(n_channels_in, n_features[0], kernel_size=3, padding='same')) - for c_in, c_out in pairwise((*n_features, channels_out)): + for c_in, c_out in pairwise((*n_features, n_channels_out)): if norm.lower() == 'batch': - self.append(BatchNormND(dim)(c_in, affine=not use_film)) + self.append(BatchNormND(n_dim)(c_in, affine=not use_film)) elif norm.lower() == 'group': self.append(GroupNorm(c_in, affine=not use_film)) elif norm.lower() == 'instance': @@ -85,7 +85,7 @@ def __init__( else: raise ValueError(f'Invalid activation type: {activation}') - self.append(ConvND(dim)(c_in, c_out, kernel_size=3, padding='same')) + self.append(ConvND(n_dim)(c_in, c_out, kernel_size=3, padding='same')) def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: # type: ignore[override] """Apply the basic CNN to the input tensor. diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCAE.py index a83407132..7a5362514 100644 --- a/src/mrpro/nn/nets/DCAE.py +++ b/src/mrpro/nn/nets/DCAE.py @@ -30,24 +30,24 @@ class CNNBlock(Residual): def __init__( self, - dim: int, - channels: int, + n_dim: int, + n_channels: int, ): """Initialize the CNNBlock. Parameters ---------- - dim : int - The spatial dimension of the input tensor. - channels : int + n_dim + The number of spatial dimensions of the input tensor. + n_channels The number of channels in the input tensor. """ super().__init__( Sequential( - ConvND(dim)(channels, channels, kernel_size=3, padding=1), + ConvND(n_dim)(n_channels, n_channels, kernel_size=3, padding=1), SiLU(True), - ConvND(dim)(channels, channels, kernel_size=3, padding=1, bias=False), - RMSNorm(channels), + ConvND(n_dim)(n_channels, n_channels, kernel_size=3, padding=1, bias=False), + RMSNorm(n_channels), ) ) @@ -65,8 +65,8 @@ class EfficientViTBlock(Module): def __init__( self, - dim: int, - channels: int, + n_dim: int, + n_channels: int, n_heads: int, expand_ratio: int = 4, linear_attn: bool = False, @@ -75,27 +75,27 @@ def __init__( Parameters ---------- - dim : int - The spatial dimension of the input tensor. - channels : int + n_dim + The number of spatial dimensions of the input tensor. + n_channels The number of channels in the input tensor. - n_heads : int + n_heads The number of attention heads. - expand_ratio : int + expand_ratio The expansion ratio of the GluMBConvResBlock. - linear_attn : bool + linear_attn Whether to use linear attention instead of softmax attention with quadratic complexity. """ super().__init__() if linear_attn: - attention: Module = LinearSelfAttention(channels, channels, n_heads) + attention: Module = LinearSelfAttention(n_channels, n_channels, n_heads) else: - attention = MultiHeadAttention(channels, channels, n_heads, features_last=False) - self.context_module = Residual(Sequential(attention, RMSNorm(channels))) + attention = MultiHeadAttention(n_channels, n_channels, n_heads, features_last=False) + self.context_module = Residual(Sequential(attention, RMSNorm(n_channels))) self.local_module = GluMBConvResBlock( - n_dim=dim, - n_channels_in=channels, - n_channels_out=channels, + n_dim=n_dim, + n_channels_in=n_channels, + n_channels_out=n_channels, expand_ratio=expand_ratio, ) @@ -133,9 +133,9 @@ class Encoder(Sequential): def __init__( self, - dim: int = 2, - channels_in: int = 3, - channels_out: int = 32, + n_dim: int = 2, + n_channels_in: int = 3, + n_channels_out: int = 32, block_types: Sequence[Literal['CNN', 'LinearViT', 'ViT']] = ('CNN', 'CNN', 'LinearViT', 'LinearViT', 'ViT'), widths: Sequence[int] = (256, 512, 512, 1024, 1024), depths: Sequence[int] = (4, 6, 2, 2, 2), @@ -147,41 +147,43 @@ def __init__( Parameters ---------- - dim : int - The spatial dimension of the input tensor. - channels_in : int + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in The number of channels in the input tensor, i.e. the latent space - channels_out : int + n_channels_out The number of channels in the output tensor, i.e. the original space - block_types : Sequence[str] + block_types The types of blocks to use in the decoder. - widths : Sequence[int] + widths The widths of the blocks in the decoder, i.e. the number of channels in the blocks - depths : Sequence[int] + depths The depths of the blocks in the decoder, i.e. the number blocks in the stage """ super().__init__() - self.append(PixelUnshuffleDownsample(dim, channels_in, widths[0], downscale_factor=2, residual=False)) + self.append(PixelUnshuffleDownsample(n_dim, n_channels_in, widths[0], downscale_factor=2, residual=False)) if len(block_types) != len(widths) or len(block_types) != len(depths): raise ValueError('block_types, widths, and depths must have the same length') for block_type, width, next_width, depth in zip(block_types, widths, (*widths[1:], None), depths, strict=False): match block_type: case 'CNN': - stage: list[Module] = [CNNBlock(dim, width) for _ in range(depth)] + stage: list[Module] = [CNNBlock(n_dim, width) for _ in range(depth)] case 'LinearViT': - stage = [EfficientViTBlock(dim, width, max(1, width // 32), linear_attn=True) for _ in range(depth)] + stage = [ + EfficientViTBlock(n_dim, width, max(1, width // 32), linear_attn=True) for _ in range(depth) + ] case 'ViT': - stage = [EfficientViTBlock(dim, width, max(1, width // 32)) for _ in range(depth)] + stage = [EfficientViTBlock(n_dim, width, max(1, width // 32)) for _ in range(depth)] case _: raise ValueError(f'Block type {block_type} not supported') self.append(Sequential(*stage)) if next_width: - self.append(PixelUnshuffleDownsample(dim, width, next_width, downscale_factor=2, residual=True)) + self.append(PixelUnshuffleDownsample(n_dim, width, next_width, downscale_factor=2, residual=True)) self.append( Sequential( RMSNorm(widths[-1]), ReLU(), - PixelUnshuffleDownsample(dim, widths[-1], channels_out, downscale_factor=1, residual=True), + PixelUnshuffleDownsample(n_dim, widths[-1], n_channels_out, downscale_factor=1, residual=True), ) ) @@ -199,9 +201,9 @@ class Decoder(Sequential): def __init__( self, - dim: int = 2, - channels_in: int = 32, - channels_out: int = 3, + n_dim: int = 2, + n_channels_in: int = 32, + n_channels_out: int = 3, block_types: Sequence[Literal['ViT', 'LinearViT', 'CNN']] = ('ViT', 'LinearViT', 'LinearViT', 'CNN', 'CNN'), widths: Sequence[int] = (1024, 1024, 512, 512, 256), depths: Sequence[int] = (2, 2, 2, 6, 4), @@ -213,45 +215,47 @@ def __init__( Parameters ---------- - dim : int - The spatial dimension of the input tensor. - channels_in : int + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in The number of channels in the input tensor, i.e. the latent space - channels_out : int + n_channels_out The number of channels in the output tensor, i.e. the original space - block_types : Sequence[str] + block_types The types of blocks to use in the decoder. - widths : Sequence[int] + widths The widths of the blocks in the decoder, i.e. the number of channels in the blocks - depths : Sequence[int] + depths The depths of the blocks in the decoder, i.e. the number blocks in the stage """ super().__init__() if not (len(block_types) == len(widths) == len(depths)): raise ValueError('block_types, widths, and depths must have the same length') - self.append(PixelShuffleUpsample(dim, channels_in, widths[0], upscale_factor=1, residual=True)) + self.append(PixelShuffleUpsample(n_dim, n_channels_in, widths[0], upscale_factor=1, residual=True)) for block_type, width, next_width, depth in zip(block_types, widths, (*widths[1:], None), depths, strict=False): match block_type: case 'CNN': - stage: list[Module] = [CNNBlock(dim, width) for _ in range(depth)] + stage: list[Module] = [CNNBlock(n_dim, width) for _ in range(depth)] case 'LinearViT': - stage = [EfficientViTBlock(dim, width, n_heads=width // 32, linear_attn=True) for _ in range(depth)] + stage = [ + EfficientViTBlock(n_dim, width, n_heads=width // 32, linear_attn=True) for _ in range(depth) + ] case 'ViT': stage = [ - EfficientViTBlock(dim, width, n_heads=width // 32, linear_attn=False) for _ in range(depth) + EfficientViTBlock(n_dim, width, n_heads=width // 32, linear_attn=False) for _ in range(depth) ] case _: raise ValueError(f'Block type {block_type} not supported') self.append(Sequential(*stage)) if next_width: - self.append(PixelShuffleUpsample(dim, width, next_width, upscale_factor=2, residual=True)) + self.append(PixelShuffleUpsample(n_dim, width, next_width, upscale_factor=2, residual=True)) self.append( Sequential( RMSNorm(widths[-1]), ReLU(), - PixelShuffleUpsample(dim, widths[-1], channels_out, upscale_factor=2), + PixelShuffleUpsample(n_dim, widths[-1], n_channels_out, upscale_factor=2), ) ) @@ -267,14 +271,30 @@ class DCVAE(VAE): def __init__( self, - dim: int, - channels: int, + n_dim: int, + n_channels: int, latent_dim: int = 32, block_types: Sequence[Literal['CNN', 'LinearViT', 'ViT']] = ('CNN', 'CNN', 'LinearViT', 'LinearViT', 'ViT'), widths: Sequence[int] = (256, 512, 512, 1024, 1024), depths: Sequence[int] = (4, 6, 2, 2, 2), ): - """Initialize the DCVAE.""" - encoder = Encoder(dim, channels, latent_dim * 2, block_types, widths, depths) - decoder = Decoder(dim, latent_dim, channels, block_types[::-1], widths[::-1], depths[::-1]) + """Initialize the DCVAE. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + latent_dim + The number of channels in the latent space. + block_types + The types of blocks to use in the encoder and decoder. + widths + The widths of the blocks in the encoder and decoder. + depths + The depths of the blocks in the encoder and decoder. + """ + encoder = Encoder(n_dim, n_channels, latent_dim * 2, block_types, widths, depths) + decoder = Decoder(n_dim, latent_dim, n_channels, block_types[::-1], widths[::-1], depths[::-1]) super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index 3cba23ecc..11ce8f95a 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -22,23 +22,23 @@ class GDFN(Module): As used in the Restormer architecture. """ - def __init__(self, dim: int, channels: int, mlp_ratio: float): + def __init__(self, n_dim: int, n_channels: int, mlp_ratio: float): """Initialize GDFN. Parameters ---------- - dim : int - Dimension of the input space - channels : int - Number of input/output channels - mlp_ratio : float + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + mlp_ratio Ratio for hidden dimension expansion """ super().__init__() - hidden_features = int(channels * mlp_ratio) - self.project_in = ConvND(dim)(channels, hidden_features * 2, kernel_size=1) - self.depthwise_conv = ConvND(dim)( + hidden_features = int(n_channels * mlp_ratio) + self.project_in = ConvND(n_dim)(n_channels, hidden_features * 2, kernel_size=1) + self.depthwise_conv = ConvND(n_dim)( hidden_features * 2, hidden_features * 2, kernel_size=3, @@ -46,7 +46,7 @@ def __init__(self, dim: int, channels: int, mlp_ratio: float): padding=1, groups=hidden_features * 2, ) - self.project_out = ConvND(dim)(hidden_features, channels, kernel_size=1) + self.project_out = ConvND(n_dim)(hidden_features, n_channels, kernel_size=1) def __call__(self, x: torch.Tensor) -> torch.Tensor: """Apply the gated depthwise feed forward network. @@ -58,7 +58,7 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: Returns ------- - Output tensor + Output tensor """ x = self.project_in(x) x1, x2 = self.depthwise_conv(x).chunk(2, dim=1) @@ -70,15 +70,15 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: class RestormerBlock(CondMixin, Module): """Transformer block with transposed attention and gated depthwise feed forward network.""" - def __init__(self, dim: int, channels: int, n_heads: int, mlp_ratio: float, cond_dim: int = 0): + def __init__(self, n_dim: int, n_channels: int, n_heads: int, mlp_ratio: float, cond_dim: int = 0): """Initialize RestormerBlock. Parameters ---------- - dim - Dimension of the input space - channels : int - Number of input/output channels + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. n_heads Number of attention heads mlp_ratio @@ -87,12 +87,12 @@ def __init__(self, dim: int, channels: int, n_heads: int, mlp_ratio: float, cond Dimension of conditioning input. If 0, no conditioning is applied. """ super().__init__() - self.norm1 = Sequential(InstanceNormND(dim)(channels)) - self.attn = TransposedAttention(dim, channels, channels, n_heads) - self.norm2 = Sequential(InstanceNormND(dim)(channels)) - self.ffn = GDFN(dim, channels, mlp_ratio) + self.norm1 = Sequential(InstanceNormND(n_dim)(n_channels)) + self.attn = TransposedAttention(n_dim, n_channels, n_channels, n_heads) + self.norm2 = Sequential(InstanceNormND(n_dim)(n_channels)) + self.ffn = GDFN(n_dim, n_channels, mlp_ratio) if cond_dim > 0: - self.norm2.append(FiLM(channels=channels, cond_dim=cond_dim)) + self.norm2.append(FiLM(channels=n_channels, cond_dim=cond_dim)) def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply Restormer block. @@ -131,9 +131,9 @@ class Restormer(UNetBase): def __init__( self, - dim: int, - channels_in: int, - channels_out: int, + n_dim: int, + n_channels_in: int, + n_channels_out: int, n_blocks: Sequence[int] = (4, 6, 6, 8), n_refinement_blocks: int = 4, n_heads: Sequence[int] = (1, 2, 4, 8), @@ -145,12 +145,12 @@ def __init__( Parameters ---------- - dim - Dimension of the input space - channels_in - Number of input channels - channels_out - Number of output channels + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of input channels. + n_channels_out + The number of output channels. n_blocks Number of blocks in each stage n_refinement_blocks @@ -167,17 +167,17 @@ def __init__( def blocks(n_heads: int, n_blocks: int): layers = Sequential( - *(RestormerBlock(dim, n_channels_per_head * n_heads, n_heads, mlp_ratio) for _ in range(n_blocks)) + *(RestormerBlock(n_dim, n_channels_per_head * n_heads, n_heads, mlp_ratio) for _ in range(n_blocks)) ) if cond_dim > 0 and n_blocks > 1: layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, cond_dim=cond_dim)) return layers - first_block = ConvND(dim)(channels_in, n_channels_per_head, kernel_size=3, stride=1, padding=1, bias=False) + first_block = ConvND(n_dim)(n_channels_in, n_channels_per_head, kernel_size=3, stride=1, padding=1, bias=False) encoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)] down_blocks = [ - PixelUnshuffleDownsample(dim, n_channels_per_head * head_current, n_channels_per_head * head_next) + PixelUnshuffleDownsample(n_dim, n_channels_per_head * head_current, n_channels_per_head * head_next) for head_current, head_next in pairwise(n_heads) ] middle_block = blocks(n_heads[-1], n_blocks[-1]) @@ -189,14 +189,14 @@ def blocks(n_heads: int, n_blocks: int): ) up_blocks = [ - PixelShuffleUpsample(dim, n_channels_per_head * head_next, n_channels_per_head * head_current) + PixelShuffleUpsample(n_dim, n_channels_per_head * head_next, n_channels_per_head * head_current) for head_current, head_next in pairwise(n_heads) ][::-1] concat_blocks = [Concat() for _ in range(len(encoder_blocks))] decoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)][::-1] last_block = Sequential( - *(RestormerBlock(dim, n_channels_per_head, n_heads[0], mlp_ratio) for _ in range(n_refinement_blocks)), - ConvND(dim)(n_channels_per_head, channels_out, kernel_size=3, stride=1, padding=1), + *(RestormerBlock(n_dim, n_channels_per_head, n_heads[0], mlp_ratio) for _ in range(n_refinement_blocks)), + ConvND(n_dim)(n_channels_per_head, n_channels_out, kernel_size=3, stride=1, padding=1), ) decoder = UNetDecoder( blocks=decoder_blocks, diff --git a/src/mrpro/nn/nets/SwinIR.py b/src/mrpro/nn/nets/SwinIR.py index 16c19a9fc..af944a753 100644 --- a/src/mrpro/nn/nets/SwinIR.py +++ b/src/mrpro/nn/nets/SwinIR.py @@ -18,8 +18,8 @@ class SwinTransformerLayer(Module): def __init__( self, - dim: int, - channels: int, + n_dim: int, + n_channels: int, n_heads: int, window_size: int, mlp_ratio: int = 4, @@ -30,31 +30,31 @@ def __init__( Parameters ---------- - dim : int - Dimension of the input space - channels : int - Number of input/output channels - n_heads : int + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + n_heads Number of attention heads - window_size : int + window_size Size of the attention window - mlp_ratio : int + mlp_ratio Ratio for hidden dimension expansion in MLP - emb_dim : int + emb_dim Dimension of conditioning input. If 0, no FiLM conditioning is used. - p_droppath : float + p_droppath Droppath probability for MLP """ super().__init__() - self.norm1 = InstanceNormND(dim)(channels) - self.attn = ShiftedWindowAttention(dim, channels, channels, n_heads, window_size) - self.norm2 = Sequential(InstanceNormND(dim)(channels)) + self.norm1 = InstanceNormND(n_dim)(n_channels) + self.attn = ShiftedWindowAttention(n_dim, n_channels, n_channels, n_heads, window_size) + self.norm2 = Sequential(InstanceNormND(n_dim)(n_channels)) if emb_dim > 0: - self.norm2.append(FiLM(channels=channels, cond_dim=emb_dim)) + self.norm2.append(FiLM(channels=n_channels, cond_dim=emb_dim)) self.mlp = Sequential( - ConvND(dim)(channels, channels * mlp_ratio, 1), + ConvND(n_dim)(n_channels, n_channels * mlp_ratio, 1), GELU('tanh'), - ConvND(dim)(channels * mlp_ratio, channels, 1), + ConvND(n_dim)(n_channels * mlp_ratio, n_channels, 1), DropPath(p_droppath), ) @@ -63,9 +63,9 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T Parameters ---------- - x : torch.Tensor + x Input tensor - cond : torch.Tensor | None, optional + cond Conditioning input Returns @@ -91,8 +91,8 @@ class ResidualSwinTransformerBlock(Module): def __init__( self, - dim: int, - channels: int, + n_dim: int, + n_channels: int, n_heads: int, window_size: int, depth: int, @@ -104,42 +104,42 @@ def __init__( Parameters ---------- - dim : int - Dimension of the input space - channels : int - Number of input/output channels - n_heads : int + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + n_heads Number of attention heads - window_size : int + window_size Size of the attention window - depth : int + depth Number of Swin Transformer layers - emb_dim : int, optional + emb_dim Dimension of conditioning input. If 0, no FiLM conditioning is used. - p_droppath : float, optional + p_droppath Droppath probability for MLP. - mlp_ratio : int, optional + mlp_ratio Ratio for hidden dimension expansion in MLP """ super().__init__() self.layers = Sequential( *( SwinTransformerLayer( - dim, channels, n_heads, window_size, emb_dim=emb_dim, p_droppath=p_droppath, mlp_ratio=mlp_ratio + n_dim, n_channels, n_heads, window_size, emb_dim=emb_dim, p_droppath=p_droppath, mlp_ratio=mlp_ratio ) for _ in range(depth) ) ) - self.conv = ConvND(dim)(channels, channels, 3, padding=1) + self.conv = ConvND(n_dim)(n_channels, n_channels, 3, padding=1) def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the residual Swin Transformer block. Parameters ---------- - x : torch.Tensor + x Input tensor - cond : torch.Tensor | None, optional + cond Conditioning input. If None, no FiLM conditioning is used. Returns @@ -168,10 +168,10 @@ class SwinIR(Module): def __init__( self, - dim: int, - channels_in: int, - channels_out: int, - channels_per_head: int = 16, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + n_channels_per_head: int = 16, n_heads: int = 6, window_size: int = 64, n_blocks: int = 6, @@ -184,36 +184,36 @@ def __init__( Parameters ---------- - dim : int - Dimension of the input space - channels_in : int - Number of input channels - channels_out : int - Number of output channels - channels_per_head : int, optional - Number of channels per attention head - n_heads : int, optional - Number of attention heads - window_size : int - Size of the attention window. Inputs sizes must be divisible by this value. - n_blocks : int - Number of residual blocks - n_attn_per_block : int - Number of attention layers per block - emb_dim : int, optional - Dimension of conditioning input. If 0, no FiLM conditioning is used. - p_droppath : float, optional - Droppath probability for MLP. - mlp_ratio : int, optional - Ratio for hidden dimension expansion in MLP. + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of input channels. + n_channels_out + The number of output channels. + n_channels_per_head + The number of channels per attention head. + n_heads + The number of attention heads. + window_size + The size of the attention window. Inputs sizes must be divisible by this value. + n_blocks + The number of residual blocks. + n_attn_per_block + The number of attention layers per block. + emb_dim + The dimension of the conditioning input. If 0, no FiLM conditioning is used. + p_droppath + The droppath probability for MLP. + mlp_ratio + The ratio for hidden dimension expansion in MLP. """ super().__init__() - self.first = ConvND(dim)(channels_in, channels_per_head * n_heads, kernel_size=3, padding=1) + self.first = ConvND(n_dim)(n_channels_in, n_channels_per_head * n_heads, kernel_size=3, padding=1) self.blocks = Sequential( *( ResidualSwinTransformerBlock( - dim, - channels_per_head * n_heads, + n_dim, + n_channels_per_head * n_heads, n_heads, window_size, n_attn_per_block, @@ -224,16 +224,16 @@ def __init__( for _ in range(n_blocks) ) ) - self.last = ConvND(dim)(channels_per_head * n_heads, channels_out, kernel_size=3, padding=1) + self.last = ConvND(n_dim)(n_channels_per_head * n_heads, n_channels_out, kernel_size=3, padding=1) def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply SwinIR. Parameters ---------- - x : torch.Tensor + x Input tensor - cond : torch.Tensor | None, optional + cond Conditioning input. If None, no FiLM conditioning is used. Returns diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 6f21971a0..fdcda6e2a 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -260,9 +260,9 @@ class UNet(UNetBase): def __init__( self, - dim: int, - channels_in: int, - channels_out: int, + n_dim: int, + n_channels_in: int, + n_channels_out: int, attention_depths: Sequence[int] = (-1,), n_features: Sequence[int] = (64, 128, 192, 256), n_heads: int = 8, @@ -273,12 +273,12 @@ def __init__( Parameters ---------- - dim - Spatial dimension of the input tensor. - channels_in - Number of channels in the input tensor. - channels_out - Number of channels in the output tensor. + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. attention_depths The depths at which to apply attention. n_features @@ -300,42 +300,42 @@ def __init__( raise ValueError(f'attention_depths must be unique, got {attention_depths=}') def attention_block(channels: int) -> Module: - dim_groups = (tuple(range(-dim, 0)),) + dim_groups = (tuple(range(-n_dim, 0)),) return SpatialTransformerBlock(dim_groups, channels, n_heads, cond_dim=cond_dim) def blocks(channels_in: int, channels_out: int, attention: bool) -> Module: blocks = Sequential() for _ in range(encoder_blocks_per_scale): - blocks.append(ResBlock(dim, channels_in, channels_out, cond_dim)) + blocks.append(ResBlock(n_dim, channels_in, channels_out, cond_dim)) if attention: blocks.append(attention_block(channels_out)) channels_in = channels_out return blocks - encoder_blocks: list[Module] = [ConvND(dim)(channels_in, n_features[0], 3, padding=1)] + encoder_blocks: list[Module] = [ConvND(n_dim)(n_channels_in, n_features[0], 3, padding=1)] down_blocks: list[Module] = [Identity()] decoder_blocks: list[Module] = [] up_blocks: list[Module] = [] for i_level, (n_feat, n_feat_next) in enumerate(pairwise(n_features)): encoder_blocks.append(blocks(n_feat, n_feat, i_level in attention_depths)) - down_blocks.append(ConvND(dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) + down_blocks.append(ConvND(n_dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) decoder_blocks.append(blocks(n_feat_next + n_feat, n_feat, i_level in attention_depths)) - up_blocks.append(Upsample(tuple(range(-dim, 0)), scale_factor=2)) + up_blocks.append(Upsample(tuple(range(-n_dim, 0)), scale_factor=2)) middle_block = Sequential( - ResBlock(dim, n_feat_next, n_feat_next, cond_dim), - ResBlock(dim, n_feat_next, n_feat_next, cond_dim), + ResBlock(n_dim, n_feat_next, n_feat_next, cond_dim), + ResBlock(n_dim, n_feat_next, n_feat_next, cond_dim), ) if depth - 1 in attention_depths: middle_block.insert(1, attention_block(n_feat_next)) - first_block = ConvND(dim)(channels_in, n_features[0], 3, padding=1) + first_block = ConvND(n_dim)(n_channels_in, n_features[0], 3, padding=1) encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) decoder_blocks, up_blocks = decoder_blocks[::-1], up_blocks[::-1] last_block = Sequential( SiLU(), - ConvND(dim)(n_features[0], channels_out, 3, padding=1), + ConvND(n_dim)(n_features[0], n_channels_out, 3, padding=1), ) concat_blocks = [Concat() for _ in range(len(decoder_blocks))] decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) @@ -354,17 +354,19 @@ class AttentionGatedUNet(UNetBase): https://arxiv.org/abs/1804.03999 """ - def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Sequence[int], cond_dim: int = 0): + def __init__( + self, n_dim: int, n_channels_in: int, n_channels_out: int, n_features: Sequence[int], cond_dim: int = 0 + ): """Initialize the AttentionGatedUNet. Parameters ---------- - dim - Spatial dimension of the input tensor. - channels_in - Number of channels in the input tensor. - channels_out - Number of channels in the output tensor. + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. n_features Number of features at each resolution level. The length determines the number of resolution levels. cond_dim @@ -373,9 +375,9 @@ def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Se def block(channels_in: int, channels_out: int) -> Module: block = Sequential( - ConvND(dim)(channels_in, channels_out, 3, padding=1), + ConvND(n_dim)(channels_in, channels_out, 3, padding=1), ReLU(True), - ConvND(dim)(channels_out, channels_out, 3, padding=1), + ConvND(n_dim)(channels_out, channels_out, 3, padding=1), ReLU(True), ) if cond_dim > 0: @@ -384,10 +386,10 @@ def block(channels_in: int, channels_out: int) -> Module: encoder_blocks: list[Module] = [] down_blocks: list[Module] = [] - n_feat_old = channels_in + n_feat_old = n_channels_in for n_feat in n_features[:-1]: encoder_blocks.append(block(n_feat_old, n_feat)) - down_blocks.append(MaxPoolND(dim)(2)) + down_blocks.append(MaxPoolND(n_dim)(2)) n_feat_old = n_feat middle_block = block(n_features[-2], n_features[-1]) encoder = UNetEncoder(Identity(), encoder_blocks, down_blocks, middle_block) @@ -396,10 +398,10 @@ def block(channels_in: int, channels_out: int) -> Module: decoder_blocks: list[Module] = [] up_blocks: list[Module] = [] for n_feat, n_feat_skip in pairwise(n_features[::-1]): - concat_blocks.append(AttentionGate(dim, n_feat, n_feat_skip, n_feat_skip, concatenate=True)) + concat_blocks.append(AttentionGate(n_dim, n_feat, n_feat_skip, n_feat_skip, concatenate=True)) decoder_blocks.append(block(n_feat + n_feat_skip, n_feat_skip)) - up_blocks.append(Upsample(range(-dim, 0), scale_factor=2)) - last_block = ConvND(dim)(n_features[0], channels_out, 1) + up_blocks.append(Upsample(range(-n_dim, 0), scale_factor=2)) + last_block = ConvND(n_dim)(n_features[0], n_channels_out, 1) decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) super().__init__(encoder, decoder) @@ -410,10 +412,10 @@ class SeparableUNet(UNetBase): def __init__( self, - dim: int, + n_dim: int, dim_groups: Sequence[tuple[int, ...]], - channels_in: int, - channels_out: int, + n_channels_in: int, + n_channels_out: int, n_features: Sequence[int] = (64, 128, 256, 512), cond_dim: int = 0, encoder_blocks_per_scale: int = 2, @@ -426,16 +428,15 @@ def __init__( Parameters ---------- - dim - Total number of non batch, non channel dimensions. - E.g., 2 for 2D images, 3 for 3D volumes or 2D+time for 2D+time images. + n_dim + The number of spatial dimensions of the input tensor. dim_groups A list of tuples, where each tuple contains the spatial dimension indices for one separable convolution. Each group must contain fewer than 3 dimensions. - channels_in - Number of channels in the input tensor. - channels_out - Number of channels in the output tensor. + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. n_features Number of features at each resolution level. cond_dim @@ -461,12 +462,12 @@ def __init__( for group in dim_groups: if len(group) > 3: raise ValueError(f'dim_group {group} can at most contain 3 dimensions. Split it into multiple groups.') - if any(d > dim + 2 or d < -dim for d in group): - raise ValueError(f'dim_group {group} contains dimensions that are out of range for dim={dim}') + if any(d > n_dim + 2 or d < -n_dim for d in group): + raise ValueError(f'dim_group {group} contains dimensions that are out of range for dim={n_dim}') attention_depths = tuple(d % depth for d in attention_depths) if downsample_dims is None: - all_spatial_dims = tuple(sorted(set(d if d < 0 else d - dim - 2 for group in dim_groups for d in group))) + all_spatial_dims = tuple(sorted(set(d if d < 0 else d - n_dim - 2 for group in dim_groups for d in group))) downsample_dims = (all_spatial_dims,) * (depth - 1) def downsampler(level_dims, c_in, c_out) -> Module: @@ -489,7 +490,7 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: # --- Module Construction --- first_block = PermutedBlock( - all_spatial_dims, ConvND(len(all_spatial_dims))(channels_in, n_features[0], 3, padding=1) + all_spatial_dims, ConvND(len(all_spatial_dims))(n_channels_in, n_features[0], 3, padding=1) ) # -- Encoder -- @@ -532,7 +533,7 @@ def block(c_in: int, c_out: int, apply_attention: bool) -> Module: SiLU(), PermutedBlock( all_spatial_dims, - ConvND(len(all_spatial_dims))(n_features[0], channels_out, 3, padding=1), + ConvND(len(all_spatial_dims))(n_features[0], n_channels_out, 3, padding=1), ), ) decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index 4424ce91c..960d14e17 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -24,7 +24,7 @@ class LeWinTransformerBlock(CondMixin, Module): def __init__( self, - dim: int, + n_dim: int, n_channels_per_head: int, n_heads: int, window_size: int = 8, @@ -37,10 +37,10 @@ def __init__( Parameters ---------- - dim - Dimension of the input, e.g. 2 or 3 + n_dim + The number of spatial dimensions of the input tensor. n_channels_per_head - Number of features per head + The number of features per head. n_heads Number of attention heads window_size @@ -57,26 +57,26 @@ def __init__( super().__init__() channels = n_channels_per_head * n_heads hidden_dim = int(channels * mlp_ratio) - self.norm1 = InstanceNormND(dim)(channels) + self.norm1 = InstanceNormND(n_dim)(channels) self.attn = ShiftedWindowAttention( - n_dim=dim, + n_dim=n_dim, n_channels_in=channels, n_channels_out=channels, n_heads=n_heads, window_size=window_size, shifted=shifted, ) - self.norm2 = InstanceNormND(dim)(channels) + self.norm2 = InstanceNormND(n_dim)(channels) self.ff = Sequential( - ConvND(dim)(channels, hidden_dim, 1), + ConvND(n_dim)(channels, hidden_dim, 1), GELU(), - ConvND(dim)(hidden_dim, hidden_dim, kernel_size=3, groups=hidden_dim, stride=1, padding=1), + ConvND(n_dim)(hidden_dim, hidden_dim, kernel_size=3, groups=hidden_dim, stride=1, padding=1), GELU(), - ConvND(dim)(hidden_dim, channels, 1), + ConvND(n_dim)(hidden_dim, channels, 1), ) if cond_dim > 0: self.ff.append(FiLM(channels, cond_dim)) - self.modulator = torch.nn.Parameter(torch.empty(channels, *((window_size,) * dim))) + self.modulator = torch.nn.Parameter(torch.empty(channels, *((window_size,) * n_dim))) torch.nn.init.trunc_normal_(self.modulator) self.drop_path = DropPath(droprate=p_droppath) @@ -120,9 +120,9 @@ class Uformer(UNetBase): def __init__( self, - dim: int, - channels_in: int, - channels_out: int, + n_dim: int, + n_channels_in: int, + n_channels_out: int, n_channels_per_head: int = 32, n_heads: Sequence[int] = (1, 2, 4, 8), n_blocks: int = 2, @@ -135,26 +135,26 @@ def __init__( Parameters ---------- - dim : int - Dimension of the input, e.g. 2 or 3 - channels_in : int - Number of input channels - channels_out : int - Number of output channels - n_channels_per_head : int, optional - Number of features per head. The number of features at a resolution level is given by + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of input channels. + n_channels_out + The number of output channels. + n_channels_per_head + The number of features per head. The number of features at a resolution level is given by `n_channels_per_head * n_heads`. - n_heads : Sequence[int], optional + n_heads Number of attention heads at each resolution level. - n_blocks : int, optional - Number of transformer blocks at each resolution level in the input and output path - cond_dim : int, optional + n_blocks + The number of transformer blocks at each resolution level in the input and output path + cond_dim Dimension of a conditioning tensor. If `0`, no FiLM layers are added. - window_size : int, optional - Size of the attention windows in the (shifted) window attention layers. - mlp_ratio : float, optional + window_size + The size of the attention windows in the (shifted) window attention layers. + mlp_ratio Ratio of the hidden dimension to the input dimension in the feed-forward blocks - max_droppath_rate : float, optional + max_droppath_rate Maximum drop path rate. As in the original implementation, the drop path rate in the input path is linearly increased from `0` to `max_droppath_rate` with decreasing resolution. The rate in output blocks is fixed to `max_droppath_rate`. @@ -164,7 +164,7 @@ def blocks(n_heads: int, p_droppath: float = 0.0): return Sequential( *( LeWinTransformerBlock( - dim=dim, + n_dim=n_dim, n_heads=n_heads, n_channels_per_head=n_channels_per_head, window_size=window_size, @@ -178,7 +178,7 @@ def blocks(n_heads: int, p_droppath: float = 0.0): ) first_block = torch.nn.Sequential( - ConvND(dim)(channels_in, n_channels_per_head * n_heads[0], kernel_size=3, stride=1, padding='same'), + ConvND(n_dim)(n_channels_in, n_channels_per_head * n_heads[0], kernel_size=3, stride=1, padding='same'), LeakyReLU(), ) drop_path_rates = torch.linspace(0, max_droppath_rate, len(n_heads)).tolist() @@ -187,7 +187,7 @@ def blocks(n_heads: int, p_droppath: float = 0.0): for n_head, p_droppath_input in zip(n_heads[:-1], drop_path_rates[:-1], strict=True) ] down_blocks = [ - ConvND(dim)( + ConvND(n_dim)( n_channels_per_head * n_head_current, n_channels_per_head * n_head_next, kernel_size=4, @@ -207,18 +207,18 @@ def blocks(n_heads: int, p_droppath: float = 0.0): decoder_blocks = [blocks(n_heads=2 * n_head, p_droppath=max_droppath_rate) for n_head in reversed(n_heads[:-1])] concat_blocks = [Concat() for _ in range(len(decoder_blocks))] up_blocks = [ - ConvTransposeND(dim)( + ConvTransposeND(n_dim)( n_channels_per_head * n_heads[-1], n_channels_per_head * n_heads[-2], kernel_size=2, stride=2 ) ] for n_head_current, n_head_next in pairwise(reversed(n_heads[:-1])): up_blocks.append( - ConvTransposeND(dim)( + ConvTransposeND(n_dim)( 2 * n_channels_per_head * n_head_current, n_channels_per_head * n_head_next, kernel_size=2, stride=2 ) ) - last_block = ConvND(dim)( - 2 * n_channels_per_head * n_heads[0], channels_out, kernel_size=3, stride=1, padding='same' + last_block = ConvND(n_dim)( + 2 * n_channels_per_head * n_heads[0], n_channels_out, kernel_size=3, stride=1, padding='same' ) decoder = UNetDecoder( blocks=decoder_blocks, diff --git a/tests/nn/nets/test_unet.py b/tests/nn/nets/test_unet.py index 5c831262d..798be6a85 100644 --- a/tests/nn/nets/test_unet.py +++ b/tests/nn/nets/test_unet.py @@ -16,9 +16,9 @@ def test_unet_forward(torch_compile: bool, device: str) -> None: """Test the forward pass of the UNet.""" unet = UNet( - dim=2, - channels_in=1, - channels_out=1, + n_dim=2, + n_channels_in=1, + n_channels_out=1, attention_depths=(-1,), n_features=(4, 6, 8), n_heads=2, diff --git a/tests/nn/test_resblock.py b/tests/nn/test_resblock.py index dfbfc8a9e..195217638 100644 --- a/tests/nn/test_resblock.py +++ b/tests/nn/test_resblock.py @@ -24,7 +24,7 @@ def test_resblock(dim, channels_in, channels_out, cond_dim, input_shape, cond_sh rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) if cond_shape else None - res = ResBlock(dim=dim, channels_in=channels_in, channels_out=channels_out, cond_dim=cond_dim).to(device) + res = ResBlock(n_dim=dim, n_channels_in=channels_in, n_channels_out=channels_out, cond_dim=cond_dim).to(device) output = res(x, cond=cond) assert output.shape == (input_shape[0], channels_out, *input_shape[2:]), ( f'Output shape {output.shape} != expected {(input_shape[0], channels_out, *input_shape[2:])}' From 3b92c6798c3b0654b6b58fe502d2a07f1b06f097 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 21 Jul 2025 16:22:31 +0200 Subject: [PATCH 128/205] rope --- src/mrpro/nn/AxialRoPE.py | 104 ++++++++++++++++++++++++++ src/mrpro/nn/RoPE.py | 153 -------------------------------------- tests/nn/test_rope.py | 40 ++++++++++ 3 files changed, 144 insertions(+), 153 deletions(-) create mode 100644 src/mrpro/nn/AxialRoPE.py delete mode 100644 src/mrpro/nn/RoPE.py create mode 100644 tests/nn/test_rope.py diff --git a/src/mrpro/nn/AxialRoPE.py b/src/mrpro/nn/AxialRoPE.py new file mode 100644 index 000000000..71a0b5c83 --- /dev/null +++ b/src/mrpro/nn/AxialRoPE.py @@ -0,0 +1,104 @@ +"""Rotary Position Embedding (RoPE).""" + +from collections.abc import Sequence + +import torch +from einops import rearrange +from torch.nn import Module + + +@torch.compile +def get_theta(shape: Sequence[int], n_embedding_channels: int, device: torch.device) -> torch.Tensor: + """Get rotation angles. + + Parameters + ---------- + shape + Spatial shape of the input tensor to use for the position embedding, + i.e. the shape excluding batch and channel dimensions. + n_embedding_channels + Number of embedding channels per head + """ + position = torch.stack( + torch.meshgrid([torch.arange(s, device=device) - s // 2 for s in shape], indexing='ij'), dim=-1 + ) + log_min = torch.log(torch.tensor(torch.pi)) + log_max = torch.log(torch.tensor(10000.0)) + freqs = torch.exp(torch.linspace(log_min, log_max, n_embedding_channels // (2 * position.shape[-1]), device=device)) + return rearrange(freqs * position[..., None], '... dim freqs ->... 1 (dim freqs)') + + +class AxialRoPE(Module): + """Axial Rotary Position Embedding. + + Applies rotary position embeddings along each axis independently. + """ + + freqs: torch.Tensor + + def __init__( + self, + n_heads: int, + non_embed_fraction: float = 0.0, + ): + """Initialize AxialRoPE. + + Parameters + ---------- + n_heads + Number of attention heads + non_embed_fraction + Fraction of channels not used for embedding + """ + super().__init__() + self.non_embed_fraction = non_embed_fraction + if non_embed_fraction < 0 or non_embed_fraction > 1: + raise ValueError('non_embed_fraction must be between 0 and 1') + self.n_heads = n_heads + + def forward(self, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Apply rotary embeddings to input tensors. + + Parameters + ---------- + *tensors + Tensors to apply rotary embeddings to + """ + if self.non_embed_fraction == 1.0: + return tensors + + shape = tensors[0].shape + if not all(t.shape == shape for t in tensors): + raise ValueError('All tensors must have the same shape') + device = tensors[0].device + if not all(t.device == device for t in tensors): + raise ValueError('All tensors must be on the same device') + + shape, n_channels = shape[1:-1], shape[-1] + if n_channels % self.n_heads: + raise ValueError(f'Number of channels {n_channels} must be divisible by number of heads {self.n_heads}') + n_channels_per_head = n_channels // self.n_heads + tensors = tuple(t.unflatten(-1, (self.n_heads, -1)) for t in tensors) + n_embedding_channels = int(n_channels_per_head * (1 - self.non_embed_fraction)) + theta = get_theta(shape, n_embedding_channels, device) + return tuple(self.apply_rotary_emb(t, theta).flatten(-2) for t in tensors) + + @staticmethod + def apply_rotary_emb(x: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + """Add rotary embedding to the input tensor. + + Parameters + ---------- + x + Input tensor to modify + theta + Rotation angles + """ + n_emb = theta.shape[-1] * 2 + if n_emb > x.shape[-1]: + raise ValueError(f'Embedding dimension {n_emb} is larger than input dimension {x.shape[-1]}') + (x1, x2), x_unembed = x[..., :n_emb].chunk(2, dim=-1), x[..., n_emb:] + result = torch.cat( + [x1 * theta.cos() - x2 * theta.sin(), x2 * theta.cos() + x1 * theta.sin(), x_unembed], dim=-1 + ) + return result diff --git a/src/mrpro/nn/RoPE.py b/src/mrpro/nn/RoPE.py deleted file mode 100644 index 7abfc9426..000000000 --- a/src/mrpro/nn/RoPE.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Rotary Position Embedding (RoPE).""" - -import torch -from torch.nn import Module - - -@torch.compile -def apply_rotary_emb_(x: torch.Tensor, theta: torch.Tensor, conjugated: bool) -> None: - """Add rotary embedding to the input tensor (inplace). - - This is a helper function for the `AxialRoPE` class. - - Parameters - ---------- - x : torch.Tensor - Input tensor to modify - theta : torch.Tensor - Rotation angles - conjugated : bool - Whether to use conjugated rotation - """ - n_emb = theta.shape[-1] * 2 - if n_emb > x.shape[-1]: - raise ValueError(f'Embedding dimension {n_emb} is larger than input dimension {x.shape[-1]}') - x1, x2 = x[..., :n_emb].chunk(2, dim=-1) - if conjugated: - x1, x2 = x2, x1 - x[..., :n_emb] = torch.cat([x1 * theta.cos() - x2 * theta.sin(), x2 * theta.cos() + x1 * theta.sin()], dim=-1) - - -class RotaryEmbedding(torch.autograd.Function): - """Custom autograd function for rotary embeddings.""" - - @staticmethod - def forward( - x: torch.Tensor, - theta: torch.Tensor, - conjugated: bool, - ) -> torch.Tensor: - """Apply rotary embedding in forward pass.""" - apply_rotary_emb_(x, theta, conjugated) - return x - - @staticmethod - def setup_context( - ctx: torch.autograd.function.FunctionCtx, inputs: tuple[torch.Tensor, torch.Tensor, bool], _output: torch.Tensor - ) -> None: - """Save tensors for backward pass.""" - _, theta, conjugated = inputs - ctx.save_for_backward(theta) - ctx.conjugated = conjugated # type: ignore[attr-defined] - - @staticmethod - def backward( # type: ignore[override] - ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor - ) -> tuple[torch.Tensor, None, None]: - """Apply backward pass.""" - (theta,) = ctx.saved_tensors # type: ignore[attr-defined] - apply_rotary_emb_(grad_output, theta, ctx.conjugated) # type: ignore[attr-defined] - return grad_output, None, None - - -class AxialRoPE(Module): - """Axial Rotary Position Embedding. - - Applies rotary position embeddings along each axis independently. - """ - - freqs: torch.Tensor - - def __init__( - self, - n_dim: int, - n_channels: int, - n_heads: int, - channels_last: bool = True, - non_embed_fraction: float = 0.5, - ): - """Initialize AxialRoPE. - - Parameters - ---------- - n_dim - Number of (spatial-like) dimensions of the input - n_channels - Number of channels - n_heads - Number of attention heads - channels_last - Whether the channels are the last dimension or dimension 1. - non_embed_fraction - Fraction of channels not used for embedding - """ - super().__init__() - log_min = torch.log(torch.tensor(torch.pi)) - log_max = torch.log(torch.tensor(10000.0)) - if n_channels % n_heads: - raise ValueError(f'Number of channels {n_channels} must be divisible by number of heads {n_heads}') - channels_per_head = n_channels // n_heads - freqs = torch.exp(torch.linspace(log_min, log_max, channels_per_head // 2)) - self.register_buffer('freqs', freqs) - self.channels_last = channels_last - self.n_heads = n_heads - - def get_theta(self, pos: torch.Tensor) -> torch.Tensor: - """Get rotation angles for given positions. - - Parameters - ---------- - pos - Position tensor - - Returns - ------- - Rotation angles - """ - return (self.freqs * pos[..., None, :, None]).flatten(start_dim=-2) - - def forward(self, pos: torch.Tensor, *tensors: torch.Tensor) -> None: - """Apply rotary embeddings to input tensors. - - Parameters - ---------- - pos - Position tensor - *tensors : torch.Tensor - Tensors to apply rotary embeddings to - """ - theta = self.get_theta(pos) - if not self.channels_last: - tensors = tuple(t.movedim(-1, 1) for t in tensors) - tuple(RotaryEmbedding.apply(x, theta, False) for x in tensors) - - @staticmethod - def make_axial_positions(*shape: int) -> torch.Tensor: - """Create axial position tensors. - - Parameters - ---------- - *shape : int - Shape of the position tensor - - Returns - ------- - torch.Tensor - Position tensor - """ - m = torch.as_tensor(shape).max() - pos = torch.stack( - [torch.arange(s, device=m.device) - s // 2 for s in shape], - dim=-1, - ) - return pos diff --git a/tests/nn/test_rope.py b/tests/nn/test_rope.py new file mode 100644 index 000000000..4b479e8a1 --- /dev/null +++ b/tests/nn/test_rope.py @@ -0,0 +1,40 @@ +import pytest +import torch +from mrpro.nn import AxialRoPE +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_rope(device: torch.device): + shape = (10, 10) + n_heads = 2 + n_channels = 64 + n_embed = int(0.5 * n_channels // n_heads) + q, k = RandomGenerator(seed=42).float32_tensor((2, 1, *shape, n_channels), low=0.5).to(device) + rope = AxialRoPE(2, non_embed_fraction=0.5) + (q_rope, k_rope) = rope(q, k) + assert q_rope.shape == q.shape + assert k_rope.shape == k.shape + + # non embedded channels should be the same + torch.testing.assert_close( + q.unflatten(-1, (n_heads, -1))[..., n_embed:], q_rope.unflatten(-1, (n_heads, -1))[..., n_embed:] + ) + torch.testing.assert_close( + k.unflatten(-1, (n_heads, -1))[..., n_embed:], k_rope.unflatten(-1, (n_heads, -1))[..., n_embed:] + ) + + # other should change + q_emb = q_rope.unflatten(-1, (n_heads, -1))[..., :n_embed] + q_orig = q.unflatten(-1, (n_heads, -1))[..., :n_embed] + k_emb = k_rope.unflatten(-1, (n_heads, -1))[..., :n_embed] + k_orig = k.unflatten(-1, (n_heads, -1))[..., :n_embed] + assert not torch.isclose(q_emb, q_orig).all() + assert not torch.isclose(k_emb, k_orig).all() + assert not torch.isclose(q_emb, k_emb).all() From cb08cafbc31b333cf4837364cae2001a99ab1333 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 21 Jul 2025 16:22:46 +0200 Subject: [PATCH 129/205] rope --- src/mrpro/nn/attention/MultiHeadAttention.py | 15 ++++-- .../nn/attention/NeighborhoodSelfAttention.py | 7 +++ .../nn/attention/SpatialTransformerBlock.py | 51 ++++++++++++++++--- tests/nn/test_neighborhoodselfattention.py | 11 ++-- 4 files changed, 68 insertions(+), 16 deletions(-) diff --git a/src/mrpro/nn/attention/MultiHeadAttention.py b/src/mrpro/nn/attention/MultiHeadAttention.py index 18c34d446..9daab4385 100644 --- a/src/mrpro/nn/attention/MultiHeadAttention.py +++ b/src/mrpro/nn/attention/MultiHeadAttention.py @@ -4,6 +4,8 @@ from einops import rearrange from torch.nn import Linear, Module +from mrpro.nn.AxialRoPE import AxialRoPE + class MultiHeadAttention(Module): """Multi-head Attention. @@ -20,6 +22,7 @@ def __init__( features_last: bool = False, p_dropout: float = 0.0, n_channels_cross: int | None = None, + rope_embed_fraction: float = 0.0, ): """Initialize the Multi-head Attention. @@ -38,6 +41,8 @@ def __init__( Dropout probability. n_channels_cross Number of channels for cross-attention. If `None`, use `n_channels_in`. + rope_embed_fraction + Fraction of channels to embed with RoPE. """ super().__init__() channels_per_head_q = n_channels_in // n_heads @@ -48,6 +53,7 @@ def __init__( self.features_last = features_last self.to_out = Linear(n_channels_in, n_channels_out) self.n_heads = n_heads + self.rope = AxialRoPE(n_heads, non_embed_fraction=1 - rope_embed_fraction) def __call__(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: """Apply multi-head attention. @@ -75,14 +81,17 @@ def forward(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) reshaped_x = self._reshape(x) reshaped_cross_attention = self._reshape(cross_attention) if cross_attention is not None else reshaped_x - q = rearrange(self.to_q(reshaped_x), '... L (heads dim) -> ... heads L dim ', heads=self.n_heads) - k, v = rearrange( + query = rearrange(self.to_q(reshaped_x), '... L (heads dim) -> ... heads L dim ', heads=self.n_heads) + key, value = rearrange( self.to_kv(reshaped_cross_attention), '... S (kv heads dim) -> kv ... heads S dim ', heads=self.n_heads, kv=2, ) - y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.p_dropout, is_causal=False) + query, key = self.rope(query, key) # NO-OP if rope_embed_fraction is 0.0 + y = torch.nn.functional.scaled_dot_product_attention( + query, key, value, dropout_p=self.p_dropout, is_causal=False + ) y = rearrange(y, '... heads L dim -> ... L (heads dim)') out = self.to_out(y) diff --git a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py index 71b4aeb06..232f153a8 100644 --- a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py @@ -9,6 +9,7 @@ from torch.nn import Linear, Module from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention +from mrpro.nn.AxialRoPE import AxialRoPE from mrpro.utils.to_tuple import to_tuple T = TypeVar('T') @@ -129,6 +130,7 @@ def __init__( dilation: int | Sequence[int] = 1, circular: bool | Sequence[bool] = False, features_last: bool = False, + rope_embed_fraction: float = 0.0, ) -> None: """Initialize a neighborhood attention module. @@ -152,6 +154,9 @@ def __init__( features_last Whether the channels are in the last dimension of the tensor, as common in visíon transformers. Otherwise, assume the channels are in the second dimension, as common in CNN models. + rope_embed_fraction + Fraction of channels to embed with RoPE. + """ super().__init__() self.n_head = n_heads @@ -162,6 +167,7 @@ def __init__( channels_per_head = n_channels_in // n_heads self.to_qkv = Linear(n_channels_in, 3 * channels_per_head * n_heads) self.to_out = Linear(channels_per_head * n_heads, n_channels_out) + self.rope = AxialRoPE(n_heads, rope_embed_fraction) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply neighborhood attention to the input tensor. @@ -183,6 +189,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: query, key, value = rearrange( qkv, 'batch ... (qkv head channels) -> qkv batch head (...) channels', qkv=3, head=self.n_head ) + query, key = self.rope(query, key) # NO-OP if rope_embed_fraction is 0.0 # the mask depends on the input size. To be more flexible if used within CNNs, we compute it here. # The computation is cached.. mask = neighborhood_mask( diff --git a/src/mrpro/nn/attention/SpatialTransformerBlock.py b/src/mrpro/nn/attention/SpatialTransformerBlock.py index d278c232a..f765770ac 100644 --- a/src/mrpro/nn/attention/SpatialTransformerBlock.py +++ b/src/mrpro/nn/attention/SpatialTransformerBlock.py @@ -6,6 +6,7 @@ from torch.nn import Dropout, Linear, Module from mrpro.nn.attention.MultiHeadAttention import MultiHeadAttention +from mrpro.nn.attention.NeighborhoodSelfAttention import NeighborhoodSelfAttention from mrpro.nn.CondMixin import CondMixin from mrpro.nn.GEGLU import GEGLU from mrpro.nn.GroupNorm import GroupNorm @@ -34,6 +35,8 @@ def __init__( cond_dim: int = 0, mlp_ratio: float = 4, features_last: bool = False, + rope_embed_fraction: float = 0.0, + attention_neighborhood: int | None = None, ): """Initialize the basic transformer block. @@ -51,19 +54,37 @@ def __init__( Ratio of the hidden dimension to the input dimension. features_last Whether the features are last in the input tensor. + rope_embed_fraction + Fraction of channels to embed with RoPE. + attention_neighborhood + If not None, use neighborhood self attention with the given neighborhood size instead + of global self attention. """ super().__init__() self.features_last = features_last - self.selfattention = Sequential( - LayerNorm(channels, features_last=True), - MultiHeadAttention( + + if attention_neighborhood is None: + attention: Module = MultiHeadAttention( n_channels_in=channels, n_channels_out=channels, n_heads=n_heads, p_dropout=p_dropout, features_last=True, - ), - ) + rope_embed_fraction=rope_embed_fraction, + ) + else: + if p_dropout > 0: + raise ValueError('p_dropout > 0 is not supported for neighborhood self attention') + attention = NeighborhoodSelfAttention( + n_channels_in=channels, + n_channels_out=channels, + n_heads=n_heads, + features_last=True, + kernel_size=attention_neighborhood, + circular=True, + rope_embed_fraction=rope_embed_fraction, + ) + self.selfattention = Sequential(LayerNorm(channels, features_last=True), attention) hidden_dim = int(channels * mlp_ratio) self.ff = Sequential( LayerNorm(channels, features_last=True, cond_dim=cond_dim), @@ -104,8 +125,10 @@ def __init__( channels: int, n_heads: int, depth: int = 1, - dropout: float = 0.0, + p_dropout: float = 0.0, cond_dim: int = 0, + rope_embed_fraction: float = 0.0, + attention_neighborhood: int | None = None, ): """Initialize the spatial transformer block. @@ -119,10 +142,14 @@ def __init__( Number of attention heads for each group. depth Number of transformer blocks for each group. - dropout + p_dropout Dropout probability. cond_dim Dimension of the conditioning tensor. + rope_embed_fraction + Fraction of channels to embed with RoPE. + attention_neighborhood + If not None, use NeighborhoodSelfAttention with the given neighborhood size instead of MultiHeadAttention. """ super().__init__() hidden_dim = n_heads * (channels // n_heads) @@ -131,7 +158,15 @@ def __init__( self.transformer_blocks = Sequential() for group in (g for _ in range(depth) for g in dim_groups): group = tuple(g - 1 if g < 0 else g for g in group) - block = BasicTransformerBlock(hidden_dim, n_heads, p_dropout=dropout, cond_dim=cond_dim, features_last=True) + block = BasicTransformerBlock( + hidden_dim, + n_heads, + p_dropout=p_dropout, + cond_dim=cond_dim, + features_last=True, + rope_embed_fraction=rope_embed_fraction, + attention_neighborhood=attention_neighborhood, + ) self.transformer_blocks.append(PermutedBlock(group, block, features_last=True)) self.proj_out = Linear(hidden_dim, channels) diff --git a/tests/nn/test_neighborhoodselfattention.py b/tests/nn/test_neighborhoodselfattention.py index 0dbf1ccb5..43f866683 100644 --- a/tests/nn/test_neighborhoodselfattention.py +++ b/tests/nn/test_neighborhoodselfattention.py @@ -62,14 +62,14 @@ def test_neighborhood_self_attention( @pytest.mark.parametrize( - ('kernel_size', 'dilation', 'circular'), + ('kernel_size', 'dilation', 'circular', 'rope'), [ - (3, 1, False), - (5, 2, True), - (7, 1, False), + (3, 1, False, True), + (5, 2, True, False), + (7, 1, False, True), ], ) -def test_neighborhood_attention_variants(kernel_size: int, dilation: int, circular: bool) -> None: +def test_neighborhood_attention_variants(kernel_size: int, dilation: int, circular: bool, rope: bool) -> None: """Test NeighborhoodSelfAttention with different neighborhood configurations.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((1, 32, 16, 16)).requires_grad_(True) @@ -81,6 +81,7 @@ def test_neighborhood_attention_variants(kernel_size: int, dilation: int, circul kernel_size=kernel_size, dilation=dilation, circular=circular, + rope_embed_fraction=1.0 if rope else 0.0, ) output = attn(x) From bf63a4b7db4415c0e2b7b571542c454b055db3e7 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 21 Jul 2025 16:23:01 +0200 Subject: [PATCH 130/205] hourglass v1 --- src/mrpro/nn/nets/HourglassTransformer.py | 94 +++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 src/mrpro/nn/nets/HourglassTransformer.py diff --git a/src/mrpro/nn/nets/HourglassTransformer.py b/src/mrpro/nn/nets/HourglassTransformer.py new file mode 100644 index 000000000..134e71bfe --- /dev/null +++ b/src/mrpro/nn/nets/HourglassTransformer.py @@ -0,0 +1,94 @@ +from collections.abc import Sequence + +from torch.nn import Module + +from mrpro.nn.attention.SpatialTransformerBlock import SpatialTransformerBlock +from mrpro.nn.join import Interpolate +from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder +from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample +from mrpro.nn.Sequential import Sequential +from mrpro.operators.RearrangeOp import RearrangeOp +from mrpro.utils import to_tuple + + +class HourglassTransformer(UNetBase): + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + n_features: Sequence[int] | int, + depths: Sequence[int] | int = 3, + attention_neighborhood: Sequence[None | int] | int | None = 11, + n_heads: int | Sequence[int] = 4, + cond_dim: int = 0, + ): + n_layers_ = [ + len(x) + for x in (n_features, depths, attention_neighborhood, n_heads) + if (x is not None and not isinstance(x, int)) + ] + n_layers = n_layers_[0] + + if any(x != n_layers_[0] for x in n_layers_): + raise ValueError('All arguments must have the same length or be scalars') + + n_features_ = to_tuple(n_layers, n_features) + depths_ = to_tuple(n_layers, depths) + attention_neighborhood_ = to_tuple(n_layers, attention_neighborhood) + n_heads_ = to_tuple(n_layers, n_heads) + + move_channels_last = RearrangeOp('batch ... channels -> batch ... channels') + first_block = Sequential( + PixelUnshuffleDownsample(n_dim, n_channels_in, n_features_[0], downscale_factor=2), + move_channels_last, + ) + dim = (tuple(range(-n_dim, 0)),) # TODO: allow arbitrary dimensions. + encoder_blocks: list[Module] = [] + decoder_blocks: list[Module] = [] + merge_blocks: list[Module] = [] + down_blocks: list[Module] = [] + up_blocks: list[Module] = [] + for channels, depth, neighborhood, head in zip( + n_features_, depths_, attention_neighborhood_, n_heads_, strict=True + ): + encoder_blocks.append( + SpatialTransformerBlock( + dim_groups=dim, + channels=channels, + depth=depth, + attention_neighborhood=neighborhood, + n_heads=head, + rope_embed_fraction=1.0, + cond_dim=cond_dim, + ) + ) + decoder_blocks.append( + SpatialTransformerBlock( + dim_groups=dim, + channels=channels, + depth=depth, + attention_neighborhood=neighborhood, + n_heads=head, + rope_embed_fraction=1.0, + cond_dim=cond_dim, + ) + ) + merge_blocks.append(Interpolate()) + + last_block = Sequential( + move_channels_last.H, PixelShuffleUpsample(n_dim, n_features_[-1], n_channels_out, upscale_factor=2) + ) + middle_block = SpatialTransformerBlock( + dim_groups=dim, + channels=n_features_[-1], + depth=depths_[-1], + attention_neighborhood=attention_neighborhood_[-1], + n_heads=n_heads_[-1], + rope_embed_fraction=1.0, + cond_dim=cond_dim, + ) + encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) + decoder = UNetDecoder(decoder_blocks, up_blocks, merge_blocks, last_block) + + super().__init__(encoder, decoder) From 62577236c7138c266b73cfd362463e6a74b2a71a Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 21 Jul 2025 16:23:15 +0200 Subject: [PATCH 131/205] Add AxialRoPE to nn module and introduce Interpolate class - Included `AxialRoPE` in the `__init__.py` of the `nn` module for better accessibility. - Updated the `join.py` file to add type hints for the `interpolate` function. - Refined parameter descriptions in the `Concat` and `Add` classes for clarity. - Introduced a new `Interpolate` class for linear interpolation between tensors, enhancing functionality in the `nn` module. - Removed the `BasicUNet` and `SeparableUNet` classes from `UNet.py` to streamline the codebase and avoid duplication. --- src/mrpro/nn/__init__.py | 2 + .../data_consistency/AnalyticCartesianDC.py | 10 +- src/mrpro/nn/join.py | 59 +++++- src/mrpro/nn/nets/UNet.py | 177 ------------------ src/mrpro/nn/nets/__init__.py | 6 +- 5 files changed, 65 insertions(+), 189 deletions(-) diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index 16842a302..a5b64bafc 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -23,10 +23,12 @@ from mrpro.nn import data_consistency from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.RMSNorm import RMSNorm +from mrpro.nn.AxialRoPE import AxialRoPE __all__ = [ "AdaptiveAvgPoolND", "AvgPoolND", + "AxialRoPE", "BatchNormND", "ComplexAsChannel", "CondMixin", diff --git a/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py index ed6b812b1..5b2848fca 100644 --- a/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py +++ b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py @@ -16,9 +16,9 @@ class AnalyticCartesianDC(Module): Solves the following problem: :math:`\min_x \|Ax - k\|_2^2 + \lambda \|x-p\|_2^2` where :math:`A` is the acquisition operator and :math:`k` is the data, :math:`\lambda` is the regularization - parameter and :math:`p` is the regularization image/prior analytically. :math:`A^H A` has to be diagonal. This is a special case - for a Cartesian acquisition without coil sensitivity weighting. This can be used for either single-coil data or - to apply data consistency to each coil image [NOSENSE]_ + parameter and :math:`p` is the regularization image/prior analytically. :math:`A^H A` has to be diagonal. This is a + special case for a Cartesian acquisition without coil sensitivity weighting. This can be used for either single-coil + data or to apply data consistency to each coil image [NOSENSE]_. References ---------- @@ -68,8 +68,8 @@ def __call__( data k-space data. fourier_op - Fourier operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` object, the Fourier operator is - automatically created from the data. + Fourier operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` object, + the Fourier operator is automatically created from the data. Returns ------- diff --git a/src/mrpro/nn/join.py b/src/mrpro/nn/join.py index b6a749c47..d98aeb7b2 100644 --- a/src/mrpro/nn/join.py +++ b/src/mrpro/nn/join.py @@ -25,7 +25,7 @@ def _fix_shapes( else: # largest as target target = tuple(max(s) for s in zip(*shapes, strict=True)) if mode == 'linear' or mode == 'nearest': - return tuple(interpolate(x, target, dim=dim, mode=mode) for x in xs) + return tuple(interpolate(x, target, dim=dim, mode=mode) for x in xs) # type: ignore[arg-type] if mode == 'zero' or mode == 'crop': return tuple(pad_or_crop(x, target, dim=dim, mode='constant', value=0.0) for x in xs) else: @@ -43,7 +43,7 @@ def __init__( Parameters ---------- mode - How to handle mismatched spatial dimensions: + How to handle mismatched dimensions: - 'fail': do not align, raise error if shapes mismatch - 'crop': center-crop to smallest spatial size - 'zero': zero-pad to largest spatial size @@ -90,8 +90,8 @@ def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular' Parameters ---------- - mode : {'fail', 'crop', 'zero', 'replicate', 'circular'}, default='zero' - How to handle mismatched spatial dimensions: + mode + How to handle mismatched dimensions: - 'fail': do not align, raise error if shapes mismatch - 'crop': center-crop to smallest spatial size - 'zero': zero-pad to largest spatial size @@ -123,3 +123,54 @@ def __call__(self, *xs: torch.Tensor) -> torch.Tensor: Summed tensor """ return super().__call__(*xs) + + +class Interpolate(Module): + """Linear interpolate between two tensors. + + As suggestions for the Hourglass Transformer [CR]_ + + References + ---------- + .. [CK] Crowson, Katherine, et al. "Scalable high-resolution pixel-space image synthesis with + hourglass diffusion transformers." ICML 2024, https://arxiv.org/abs/2401.11605 + """ + + def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'] = 'fail') -> None: + """Initialize learned linear interpolation. + + Parameters + ---------- + mode + How to handle mismatched dimensions: + - 'fail': do not align, raise error if shapes mismatch + - 'crop': center-crop to smallest spatial size + - 'zero': zero-pad to largest spatial size + - 'replicate': pad by edge value replication + - 'circular': circular padding + """ + super().__init__() + modes = {'fail', 'crop', 'zero', 'replicate', 'circular'} + if mode not in modes: + raise ValueError(f'mode must be one of {modes}') + self.mode = mode + self.weight = torch.nn.Parameter(torch.tensor(0.5)) + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """Linear interpolate between two tensors.""" + x1, x2 = _fix_shapes((x1, x2), self.mode, dim=range(max(x.ndim for x in (x1, x2)))) + return x1 * self.weight + x2 * (1 - self.weight) + + def __call__(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """Linear interpolate between two tensors. + + Parameters + ---------- + x1, x2 + Input tensors + + Returns + ------- + Interpolated tensor + """ + return super().__call__(x1, x2) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index fdcda6e2a..3bca93bc8 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -11,12 +11,9 @@ from mrpro.nn.attention.SpatialTransformerBlock import SpatialTransformerBlock from mrpro.nn.CondMixin import call_with_cond from mrpro.nn.FiLM import FiLM -from mrpro.nn.GroupNorm import GroupNorm from mrpro.nn.join import Concat from mrpro.nn.ndmodules import ConvND, MaxPoolND -from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.ResBlock import ResBlock -from mrpro.nn.SeparableResBlock import SeparableResBlock # Assuming SeparableResBlock is here from mrpro.nn.Sequential import Sequential from mrpro.nn.Upsample import Upsample @@ -197,46 +194,6 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T return super().__call__(x, cond=cond) -class BasicUNet(UNetBase): - """Basic UNet. - - A Basic UNet with residual blocks, convolutional downsampling, and nearest neighbor upsampling. - - References - ---------- - .. [UNET] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image - segmentation MICCAI 2015. https://arxiv.org/abs/1505.04597 - """ - - def __init__(self, dim: int, channels_in: int, channels_out: int, n_features: Sequence[int], cond_dim: int): - """Initialize the BasicUNet.""" - encoder_blocks: list[Module] = [] - decoder_blocks: list[Module] = [] - down_blocks: list[Module] = [] - up_blocks: list[Module] = [] - concat_blocks: list[Module] = [] - for n_feat, n_feat_next in pairwise(n_features): - encoder_blocks.append(ResBlock(dim, n_feat, n_feat, cond_dim)) - decoder_blocks.append(ResBlock(dim, 2 * n_feat, n_feat, cond_dim)) - down_blocks.append(ConvND(dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) - up_blocks.append( - Sequential( - Upsample(tuple(range(-dim, 0)), scale_factor=2), ConvND(dim)(n_feat_next, n_feat, 3, padding=1) - ) - ) - concat_blocks.append(Concat()) - up_blocks = up_blocks[::-1] - decoder_blocks = decoder_blocks[::-1] - first_block = ConvND(dim)(channels_in, n_features[0], 3, padding=1) - last_block = Sequential( - GroupNorm(n_features[0]), SiLU(), ConvND(dim)(n_features[0], channels_out, 3, padding=1) - ) - middle_block = ResBlock(dim, n_features[-1], n_features[-1], cond_dim) - encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) - decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) - super().__init__(encoder, decoder) - - class UNet(UNetBase): """UNet. @@ -405,137 +362,3 @@ def block(channels_in: int, channels_out: int) -> Module: decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) super().__init__(encoder, decoder) - - -class SeparableUNet(UNetBase): - """UNet with separable convolutions and attention, and grouped downsampling.""" - - def __init__( - self, - n_dim: int, - dim_groups: Sequence[tuple[int, ...]], - n_channels_in: int, - n_channels_out: int, - n_features: Sequence[int] = (64, 128, 256, 512), - cond_dim: int = 0, - encoder_blocks_per_scale: int = 2, - attention_depths: Sequence[int] = (-1,), - n_heads: int = 8, - downsample_dims: Sequence[Sequence[int]] | None = None, - ) -> None: - """ - Initialize the SeparableUNet. - - Parameters - ---------- - n_dim - The number of spatial dimensions of the input tensor. - dim_groups - A list of tuples, where each tuple contains the spatial dimension - indices for one separable convolution. Each group must contain fewer than 3 dimensions. - n_channels_in - The number of channels in the input tensor. - n_channels_out - The number of channels in the output tensor. - n_features - Number of features at each resolution level. - cond_dim - Number of channels in the conditioning tensor. - encoder_blocks_per_scale - Number of encoder blocks per resolution level. - attention_depths - The depths at which to apply attention. - n_heads - Number of attention heads. - downsample_dims - Sequence specifying which absolute spatial dimensions to downsample - at each encoder level. If None, all dimensions in `dim_groups` are combined - and downsampled at each level. - If a downsampling step contains more than 3 dimensions, downsampling is performed separately for each - dimension. If the length of the sequence is less than the number of resolution levels, the sequence is - repeated. E.g., ``((-1,-2), (-1,-2,-3))`` for 3D data: first level downsamples x,y; second level x,y,z; - third level x,y. - - - """ - depth = len(n_features) - for group in dim_groups: - if len(group) > 3: - raise ValueError(f'dim_group {group} can at most contain 3 dimensions. Split it into multiple groups.') - if any(d > n_dim + 2 or d < -n_dim for d in group): - raise ValueError(f'dim_group {group} contains dimensions that are out of range for dim={n_dim}') - - attention_depths = tuple(d % depth for d in attention_depths) - if downsample_dims is None: - all_spatial_dims = tuple(sorted(set(d if d < 0 else d - n_dim - 2 for group in dim_groups for d in group))) - downsample_dims = (all_spatial_dims,) * (depth - 1) - - def downsampler(level_dims, c_in, c_out) -> Module: - if len(level_dims) > 3: - sequence = Sequential(*(downsampler(d[0], c_in, c_out) for d in level_dims)) - for d in level_dims[1:]: - sequence.append(downsampler(d, c_out, c_out)) - return sequence - return PermutedBlock(level_dims, ConvND(len(level_dims))(c_in, c_out, 3, stride=2, padding=1)) - - def upsampler(level_dims, c_in, c_out) -> Module: - return Upsample(level_dims, scale_factor=2) - - def block(c_in: int, c_out: int, apply_attention: bool) -> Module: - res_block = SeparableResBlock(dim_groups, c_in, c_out, cond_dim) - if not apply_attention: - return res_block - attn_block = SpatialTransformerBlock(dim_groups, c_out, n_heads, cond_dim=cond_dim) - return Sequential(res_block, attn_block) - - # --- Module Construction --- - first_block = PermutedBlock( - all_spatial_dims, ConvND(len(all_spatial_dims))(n_channels_in, n_features[0], 3, padding=1) - ) - - # -- Encoder -- - encoder_blocks, down_blocks, skip_features = [], [], [] - c_feat = n_features[0] - for i_level, n_feat_level in enumerate(n_features): - for _ in range(encoder_blocks_per_scale): - encoder_blocks.append(block(c_feat, n_feat_level, i_level in attention_depths)) - c_feat = n_feat_level - skip_features.append(c_feat) - if i_level < depth - 1: - down_blocks.append(downsampler(downsample_dims_per_level[i_level], c_feat, n_features[i_level + 1])) - c_feat = n_features[i_level + 1] - - # -- Middle & Encoder Finalization -- - middle_block = Sequential( - block(c_feat, c_feat, depth - 1 in attention_depths), - block(c_feat, c_feat, depth - 1 in attention_depths), - ) - encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) - - # -- Decoder -- - decoder_blocks, up_blocks = [], [] - for i_level in reversed(range(depth)): - n_feat_level = n_features[i_level] - if i_level > 0: - up_blocks.append(upsampler(downsample_dims_per_level[i_level - 1], c_feat, n_feat_level)) - for _ in range(encoder_blocks_per_scale + 1): - skip_c = skip_features.pop() - decoder_blocks.append(block(c_feat + skip_c, n_feat_level, i_level in attention_depths)) - c_feat = n_feat_level - - decoder_blocks.reverse() - up_blocks.reverse() - - # -- Decoder Finalization -- - concat_blocks = [Concat()] * len(decoder_blocks) - last_block = Sequential( - GroupNorm(n_features[0]), - SiLU(), - PermutedBlock( - all_spatial_dims, - ConvND(len(all_spatial_dims))(n_features[0], n_channels_out, 3, padding=1), - ), - ) - decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) - - super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 228596dc8..a26bf253d 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -2,17 +2,17 @@ from mrpro.nn.nets.Uformer import Uformer from mrpro.nn.nets.DCAE import DCVAE from mrpro.nn.nets.VAE import VAE -from mrpro.nn.nets.UNet import UNet, AttentionGatedUNet, BasicUNet, SeparableUNet +from mrpro.nn.nets.UNet import UNet, AttentionGatedUNet from mrpro.nn.nets.SwinIR import SwinIR from mrpro.nn.nets.BasicCNN import BasicCNN +from mrpro.nn.nets.HourglassTransformer import HourglassTransformer __all__ = [ "AttentionGatedUNet", "BasicCNN", - "BasicUNet", "DCVAE", + "HourglassTransformer", "Restormer", - "SeparableUNet", "SwinIR", "UNet", "Uformer", From 93666bbea2914404fe97473a955631c301df5e01 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 21 Jul 2025 16:36:02 +0200 Subject: [PATCH 132/205] docstrings --- src/mrpro/nn/AxialRoPE.py | 6 ++++ src/mrpro/nn/nets/HourglassTransformer.py | 36 +++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/src/mrpro/nn/AxialRoPE.py b/src/mrpro/nn/AxialRoPE.py index 71a0b5c83..2eeb37d32 100644 --- a/src/mrpro/nn/AxialRoPE.py +++ b/src/mrpro/nn/AxialRoPE.py @@ -18,6 +18,12 @@ def get_theta(shape: Sequence[int], n_embedding_channels: int, device: torch.dev i.e. the shape excluding batch and channel dimensions. n_embedding_channels Number of embedding channels per head + device + Device to create the rotation angles on + + Returns + ------- + Rotation angles """ position = torch.stack( torch.meshgrid([torch.arange(s, device=device) - s // 2 for s in shape], indexing='ij'), dim=-1 diff --git a/src/mrpro/nn/nets/HourglassTransformer.py b/src/mrpro/nn/nets/HourglassTransformer.py index 134e71bfe..594e2c855 100644 --- a/src/mrpro/nn/nets/HourglassTransformer.py +++ b/src/mrpro/nn/nets/HourglassTransformer.py @@ -1,3 +1,5 @@ +"""Hourglass Transformer.""" + from collections.abc import Sequence from torch.nn import Module @@ -12,6 +14,18 @@ class HourglassTransformer(UNetBase): + """Hourglass Transformer. + + A U-shaped transformer [CK]_ with neighborhood self-attention [NAT]_. + + References + ---------- + .. [CK] Crowson, Katherine, et al. "Scalable high-resolution pixel-space image synthesis with + hourglass diffusion transformers." ICML 2024, https://arxiv.org/abs/2401.11605 + .. [NAT] Hassani, A. et al. "Neighborhood Attention Transformer" CVPR, 2023, https://arxiv.org/abs/2204.07143 + + """ + def __init__( self, n_dim: int, @@ -23,6 +37,28 @@ def __init__( n_heads: int | Sequence[int] = 4, cond_dim: int = 0, ): + """Initialize the Hourglass Transformer. + + Parameters + ---------- + n_dim + Number of (spatial)dimensions of the input data. + n_channels_in + Number of channels in the input data. + n_channels_out + Number of channels in the output data. + n_features + Number of features in each stage. + depths + Number of layers in each stage. + attention_neighborhood + Neighborhood size for the neighborhood self-attention. If None, use global attention + for that stage. + n_heads + Number of heads in each stage. + cond_dim + Number of dimensions of the conditioning tensor. + """ n_layers_ = [ len(x) for x in (n_features, depths, attention_neighborhood, n_heads) From 08e20b7b719dc3434c588fcc11d28e1121f994d6 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 21 Jul 2025 17:38:14 +0200 Subject: [PATCH 133/205] fix unet --- src/mrpro/nn/nets/UNet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 3bca93bc8..9a2e6db8f 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -269,8 +269,8 @@ def blocks(channels_in: int, channels_out: int, attention: bool) -> Module: channels_in = channels_out return blocks - encoder_blocks: list[Module] = [ConvND(n_dim)(n_channels_in, n_features[0], 3, padding=1)] - down_blocks: list[Module] = [Identity()] + encoder_blocks: list[Module] = [] + down_blocks: list[Module] = [] decoder_blocks: list[Module] = [] up_blocks: list[Module] = [] From 88fe7a27aa0b286135cbd985bb17cdb603755f54 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 21 Jul 2025 17:38:30 +0200 Subject: [PATCH 134/205] cahnge rope shape --- src/mrpro/nn/AxialRoPE.py | 33 ++++++++----------- src/mrpro/nn/attention/MultiHeadAttention.py | 18 ++++++---- .../nn/attention/NeighborhoodSelfAttention.py | 8 +++-- tests/nn/test_rope.py | 23 +++++-------- 4 files changed, 37 insertions(+), 45 deletions(-) diff --git a/src/mrpro/nn/AxialRoPE.py b/src/mrpro/nn/AxialRoPE.py index 2eeb37d32..60ee2327e 100644 --- a/src/mrpro/nn/AxialRoPE.py +++ b/src/mrpro/nn/AxialRoPE.py @@ -31,7 +31,7 @@ def get_theta(shape: Sequence[int], n_embedding_channels: int, device: torch.dev log_min = torch.log(torch.tensor(torch.pi)) log_max = torch.log(torch.tensor(10000.0)) freqs = torch.exp(torch.linspace(log_min, log_max, n_embedding_channels // (2 * position.shape[-1]), device=device)) - return rearrange(freqs * position[..., None], '... dim freqs ->... 1 (dim freqs)') + return rearrange(freqs * position[..., None], '... dim freqs ->... (dim freqs)') class AxialRoPE(Module): @@ -44,23 +44,19 @@ class AxialRoPE(Module): def __init__( self, - n_heads: int, - non_embed_fraction: float = 0.0, + embed_fraction: float = 1.0, ): """Initialize AxialRoPE. Parameters ---------- - n_heads - Number of attention heads - non_embed_fraction - Fraction of channels not used for embedding + embed_fraction + Fraction of channels used for embedding """ super().__init__() - self.non_embed_fraction = non_embed_fraction - if non_embed_fraction < 0 or non_embed_fraction > 1: - raise ValueError('non_embed_fraction must be between 0 and 1') - self.n_heads = n_heads + self.embed_fraction = embed_fraction + if embed_fraction < 0 or embed_fraction > 1: + raise ValueError('embed_fraction must be between 0 and 1') def forward(self, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: """Apply rotary embeddings to input tensors. @@ -68,9 +64,10 @@ def forward(self, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: Parameters ---------- *tensors - Tensors to apply rotary embeddings to + Tensors to apply rotary embeddings to. + Shape must be `(batch, heads, *spatial_dims, channels)`. """ - if self.non_embed_fraction == 1.0: + if self.embed_fraction == 1.0: return tensors shape = tensors[0].shape @@ -80,14 +77,10 @@ def forward(self, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: if not all(t.device == device for t in tensors): raise ValueError('All tensors must be on the same device') - shape, n_channels = shape[1:-1], shape[-1] - if n_channels % self.n_heads: - raise ValueError(f'Number of channels {n_channels} must be divisible by number of heads {self.n_heads}') - n_channels_per_head = n_channels // self.n_heads - tensors = tuple(t.unflatten(-1, (self.n_heads, -1)) for t in tensors) - n_embedding_channels = int(n_channels_per_head * (1 - self.non_embed_fraction)) + shape, n_channels_per_head = shape[2:-1], shape[-1] + n_embedding_channels = int(n_channels_per_head * self.embed_fraction) theta = get_theta(shape, n_embedding_channels, device) - return tuple(self.apply_rotary_emb(t, theta).flatten(-2) for t in tensors) + return tuple(self.apply_rotary_emb(t, theta) for t in tensors) @staticmethod def apply_rotary_emb(x: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: diff --git a/src/mrpro/nn/attention/MultiHeadAttention.py b/src/mrpro/nn/attention/MultiHeadAttention.py index 9daab4385..0ff503639 100644 --- a/src/mrpro/nn/attention/MultiHeadAttention.py +++ b/src/mrpro/nn/attention/MultiHeadAttention.py @@ -53,7 +53,7 @@ def __init__( self.features_last = features_last self.to_out = Linear(n_channels_in, n_channels_out) self.n_heads = n_heads - self.rope = AxialRoPE(n_heads, non_embed_fraction=1 - rope_embed_fraction) + self.rope = AxialRoPE(rope_embed_fraction) def __call__(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: """Apply multi-head attention. @@ -78,21 +78,25 @@ def _reshape(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: """Apply multi-head attention.""" - reshaped_x = self._reshape(x) - reshaped_cross_attention = self._reshape(cross_attention) if cross_attention is not None else reshaped_x + if cross_attention is None: + cross_attention = x + if not self.features_last: + x = x.moveaxis(1, -1) + cross_attention = cross_attention.moveaxis(1, -1) - query = rearrange(self.to_q(reshaped_x), '... L (heads dim) -> ... heads L dim ', heads=self.n_heads) + query = rearrange(self.to_q(x), 'batch ... (heads channels) -> batch heads ... channels ', heads=self.n_heads) key, value = rearrange( - self.to_kv(reshaped_cross_attention), - '... S (kv heads dim) -> kv ... heads S dim ', + self.to_kv(cross_attention), + 'batch ... (kv heads channels) -> kv batch heads ... channels ', heads=self.n_heads, kv=2, ) query, key = self.rope(query, key) # NO-OP if rope_embed_fraction is 0.0 + query, key, value = query.flatten(2, -2), key.flatten(2, -2), value.flatten(2, -2) y = torch.nn.functional.scaled_dot_product_attention( query, key, value, dropout_p=self.p_dropout, is_causal=False ) - y = rearrange(y, '... heads L dim -> ... L (heads dim)') + y = rearrange(y, '... heads L channels -> ... L (heads channels)') out = self.to_out(y) if not self.features_last: diff --git a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py index 232f153a8..d389ecd67 100644 --- a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py @@ -15,6 +15,8 @@ T = TypeVar('T') +# coverage does not pick up the use via flex_attention, as the code gets compiled. +# pragma: no cover @cache def neighborhood_mask( device: str, @@ -130,7 +132,7 @@ def __init__( dilation: int | Sequence[int] = 1, circular: bool | Sequence[bool] = False, features_last: bool = False, - rope_embed_fraction: float = 0.0, + rope_embed_fraction: float = 1.0, ) -> None: """Initialize a neighborhood attention module. @@ -167,7 +169,7 @@ def __init__( channels_per_head = n_channels_in // n_heads self.to_qkv = Linear(n_channels_in, 3 * channels_per_head * n_heads) self.to_out = Linear(channels_per_head * n_heads, n_channels_out) - self.rope = AxialRoPE(n_heads, rope_embed_fraction) + self.rope = AxialRoPE(rope_embed_fraction) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply neighborhood attention to the input tensor. @@ -187,7 +189,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: spatial_shape = x.shape[1:-1] qkv = self.to_qkv(x) query, key, value = rearrange( - qkv, 'batch ... (qkv head channels) -> qkv batch head (...) channels', qkv=3, head=self.n_head + qkv, 'batch ... (qkv heads channels) -> qkv batch heads (...) channels', qkv=3, heads=self.n_head ) query, key = self.rope(query, key) # NO-OP if rope_embed_fraction is 0.0 # the mask depends on the input size. To be more flexible if used within CNNs, we compute it here. diff --git a/tests/nn/test_rope.py b/tests/nn/test_rope.py index 4b479e8a1..3b484f733 100644 --- a/tests/nn/test_rope.py +++ b/tests/nn/test_rope.py @@ -16,25 +16,18 @@ def test_rope(device: torch.device): n_heads = 2 n_channels = 64 n_embed = int(0.5 * n_channels // n_heads) - q, k = RandomGenerator(seed=42).float32_tensor((2, 1, *shape, n_channels), low=0.5).to(device) - rope = AxialRoPE(2, non_embed_fraction=0.5) + q, k = RandomGenerator(seed=42).float32_tensor((2, 1, n_heads, *shape, n_channels), low=0.5).to(device) + + rope = AxialRoPE(embed_fraction=0.5) (q_rope, k_rope) = rope(q, k) + assert q_rope.shape == q.shape assert k_rope.shape == k.shape # non embedded channels should be the same - torch.testing.assert_close( - q.unflatten(-1, (n_heads, -1))[..., n_embed:], q_rope.unflatten(-1, (n_heads, -1))[..., n_embed:] - ) - torch.testing.assert_close( - k.unflatten(-1, (n_heads, -1))[..., n_embed:], k_rope.unflatten(-1, (n_heads, -1))[..., n_embed:] - ) + torch.testing.assert_close(q[..., n_embed:], q_rope[..., n_embed:]) + torch.testing.assert_close(k[..., n_embed:], k_rope[..., n_embed:]) # other should change - q_emb = q_rope.unflatten(-1, (n_heads, -1))[..., :n_embed] - q_orig = q.unflatten(-1, (n_heads, -1))[..., :n_embed] - k_emb = k_rope.unflatten(-1, (n_heads, -1))[..., :n_embed] - k_orig = k.unflatten(-1, (n_heads, -1))[..., :n_embed] - assert not torch.isclose(q_emb, q_orig).all() - assert not torch.isclose(k_emb, k_orig).all() - assert not torch.isclose(q_emb, k_emb).all() + assert not torch.isclose(q_rope[..., :n_embed], q[..., :n_embed]).all() + assert not torch.isclose(k_rope[..., :n_embed], k[..., :n_embed]).all() From d5895ce4a5f9534c30e0ad06ca39a7a0121e483d Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 21 Jul 2025 22:47:39 +0200 Subject: [PATCH 135/205] fix restormer --- src/mrpro/nn/PixelShuffle.py | 2 +- src/mrpro/nn/Sequential.py | 2 +- src/mrpro/nn/nets/Restormer.py | 10 +++++++++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py index b78853da7..a39bd7731 100644 --- a/src/mrpro/nn/PixelShuffle.py +++ b/src/mrpro/nn/PixelShuffle.py @@ -103,12 +103,12 @@ def __init__( Whether to use a residual connection as proposed in [DCAE]_. """ super().__init__() - self.pixel_unshuffle = PixelUnshuffle(downscale_factor) out_ratio = downscale_factor**n_dim if n_channels_out % out_ratio != 0: raise ValueError(f'channels_out must be divisible by downscale_factor**{n_dim}.') self.conv = ConvND(n_dim)(n_channels_in, n_channels_out // out_ratio, kernel_size=3, padding='same') self.residual = residual + self.pixel_unshuffle = PixelUnshuffle(downscale_factor) def __call__(self, x: torch.Tensor) -> torch.Tensor: """Apply downsampling. diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py index fb56bd43f..e0c29f37d 100644 --- a/src/mrpro/nn/Sequential.py +++ b/src/mrpro/nn/Sequential.py @@ -30,7 +30,7 @@ def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch. ------- The output tensor. """ - return super().__call__(*x, cond=cond) + return torch.nn.Sequential.__call__(self, *x, cond=cond) def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply all modules in series to the input.""" diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index 11ce8f95a..ae1204748 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -164,6 +164,8 @@ def __init__( cond_dim Dimension of conditioning input. If 0, no conditioning is applied. """ + if len(n_blocks) != len(n_heads): + raise ValueError('n_blocks and n_heads must have the same length.') def blocks(n_heads: int, n_blocks: int): layers = Sequential( @@ -192,7 +194,13 @@ def blocks(n_heads: int, n_blocks: int): PixelShuffleUpsample(n_dim, n_channels_per_head * head_next, n_channels_per_head * head_current) for head_current, head_next in pairwise(n_heads) ][::-1] - concat_blocks = [Concat() for _ in range(len(encoder_blocks))] + concat_blocks = [ + Sequential( + Concat(), + ConvND(n_dim)(2 * n_channels_per_head * head, n_channels_per_head * head, kernel_size=1), + ) + for head in n_heads[1::-1] + ] decoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)][::-1] last_block = Sequential( *(RestormerBlock(n_dim, n_channels_per_head, n_heads[0], mlp_ratio) for _ in range(n_refinement_blocks)), From 5526bac27829cb85bd22b82532873e85751ec31f Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 21 Jul 2025 22:47:55 +0200 Subject: [PATCH 136/205] tests --- tests/nn/nets/test_cnn.py | 58 +++++++++++++++++++++++++++++++ tests/nn/nets/test_restormer.py | 61 +++++++++++++++++++++++++++++++++ tests/nn/nets/test_uformer.py | 59 +++++++++++++++++++++++++++++++ tests/nn/nets/test_unet.py | 26 ++++++++++++++ 4 files changed, 204 insertions(+) create mode 100644 tests/nn/nets/test_cnn.py create mode 100644 tests/nn/nets/test_restormer.py create mode 100644 tests/nn/nets/test_uformer.py diff --git a/tests/nn/nets/test_cnn.py b/tests/nn/nets/test_cnn.py new file mode 100644 index 000000000..6ba3df0ad --- /dev/null +++ b/tests/nn/nets/test_cnn.py @@ -0,0 +1,58 @@ +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import BasicCNN + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_cnn_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the cnn.""" + cnn = BasicCNN( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + norm='layer', + n_features=(8, 8), + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + cnn = cnn.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + cnn = cast(BasicCNN, torch.compile(cnn)) + y = cnn(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + assert y.mean().abs() < 0.1 + + +def test_cnn_backward(): + cnn = BasicCNN( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + norm='instance', + activation='silu', + n_features=(8, 8), + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = cnn(x, cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in cnn.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' diff --git a/tests/nn/nets/test_restormer.py b/tests/nn/nets/test_restormer.py new file mode 100644 index 000000000..821ba7acb --- /dev/null +++ b/tests/nn/nets/test_restormer.py @@ -0,0 +1,61 @@ +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import Restormer + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_restormer_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the restormer.""" + restormer = Restormer( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_heads=(1, 2, 4), + n_blocks=(2, 1, 1), + cond_dim=32, + n_channels_per_head=2, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + restormer = restormer.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + restormer = cast(Restormer, torch.compile(restormer)) + y = restormer(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + assert y.mean().abs() < 0.2 + + +def test_restormer_backward(): + restormer = Restormer( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_heads=(1, 2), + n_blocks=(2, 2), + cond_dim=32, + n_channels_per_head=4, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = restormer(x, cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in restormer.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' diff --git a/tests/nn/nets/test_uformer.py b/tests/nn/nets/test_uformer.py new file mode 100644 index 000000000..2985aeca2 --- /dev/null +++ b/tests/nn/nets/test_uformer.py @@ -0,0 +1,59 @@ +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import Uformer + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_uformer_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the uformer.""" + uformer = Uformer( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_heads=(1, 2, 4), + cond_dim=32, + n_channels_per_head=8, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + uformer = uformer.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + uformer = cast(Uformer, torch.compile(uformer)) + y = uformer(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + assert y.mean().abs() < 0.1 + + +def test_uformer_backward(): + uformer = Uformer( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_heads=(1, 2, 4), + cond_dim=32, + n_channels_per_head=8, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = uformer(x, cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in uformer.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' diff --git a/tests/nn/nets/test_unet.py b/tests/nn/nets/test_unet.py index 798be6a85..5511722a0 100644 --- a/tests/nn/nets/test_unet.py +++ b/tests/nn/nets/test_unet.py @@ -35,3 +35,29 @@ def test_unet_forward(torch_compile: bool, device: str) -> None: unet = cast(UNet, torch.compile(unet)) y = unet(x, cond=cond) assert y.shape == (1, 1, 16, 16) + assert y.mean().abs() < 0.1 + + +def test_unet_backward(): + unet = UNet( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + attention_depths=(-1,), + n_features=(4, 6, 8), + n_heads=2, + cond_dim=32, + encoder_blocks_per_scale=1, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = unet(x, cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in unet.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From f63e059f16f9b2d8553601496b0ebfe3c55eb2ad Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 21 Jul 2025 23:42:17 +0200 Subject: [PATCH 137/205] update --- src/mrpro/nn/attention/MultiHeadAttention.py | 4 +- src/mrpro/nn/nets/{DCAE.py => DCVAE.py} | 6 +- src/mrpro/nn/nets/VAE.py | 6 +- src/mrpro/nn/nets/__init__.py | 2 +- tests/nn/nets/test_dcvae.py | 59 ++++++++++++++++++++ tests/nn/nets/test_unet.py | 55 +++++++++++++++++- 6 files changed, 123 insertions(+), 9 deletions(-) rename src/mrpro/nn/nets/{DCAE.py => DCVAE.py} (98%) create mode 100644 tests/nn/nets/test_dcvae.py diff --git a/src/mrpro/nn/attention/MultiHeadAttention.py b/src/mrpro/nn/attention/MultiHeadAttention.py index 0ff503639..212bc68eb 100644 --- a/src/mrpro/nn/attention/MultiHeadAttention.py +++ b/src/mrpro/nn/attention/MultiHeadAttention.py @@ -97,9 +97,9 @@ def forward(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) query, key, value, dropout_p=self.p_dropout, is_causal=False ) y = rearrange(y, '... heads L channels -> ... L (heads channels)') - out = self.to_out(y) + out = self.to_out(y).reshape(x.shape) if not self.features_last: out = out.moveaxis(-1, 1) - return out.reshape(x.shape) + return out diff --git a/src/mrpro/nn/nets/DCAE.py b/src/mrpro/nn/nets/DCVAE.py similarity index 98% rename from src/mrpro/nn/nets/DCAE.py rename to src/mrpro/nn/nets/DCVAE.py index 7a5362514..5b3da8a01 100644 --- a/src/mrpro/nn/nets/DCAE.py +++ b/src/mrpro/nn/nets/DCVAE.py @@ -239,11 +239,13 @@ def __init__( stage: list[Module] = [CNNBlock(n_dim, width) for _ in range(depth)] case 'LinearViT': stage = [ - EfficientViTBlock(n_dim, width, n_heads=width // 32, linear_attn=True) for _ in range(depth) + EfficientViTBlock(n_dim, width, n_heads=max(1, width // 32), linear_attn=True) + for _ in range(depth) ] case 'ViT': stage = [ - EfficientViTBlock(n_dim, width, n_heads=width // 32, linear_attn=False) for _ in range(depth) + EfficientViTBlock(n_dim, width, n_heads=max(1, width // 32), linear_attn=False) + for _ in range(depth) ] case _: raise ValueError(f'Block type {block_type} not supported') diff --git a/src/mrpro/nn/nets/VAE.py b/src/mrpro/nn/nets/VAE.py index cd4a1260a..e0b1bfc58 100644 --- a/src/mrpro/nn/nets/VAE.py +++ b/src/mrpro/nn/nets/VAE.py @@ -20,9 +20,9 @@ def __init__(self, encoder: Module, decoder: Module): Parameters ---------- - encoder : Module + encoder Encoder module. Should return double the number of channels of the latent space. - decoder : Module + decoder Decoder module """ super().__init__() @@ -37,7 +37,7 @@ def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: Parameters ---------- - x : torch.Tensor + x Input tensor Returns diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index a26bf253d..4a7464198 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -1,6 +1,6 @@ from mrpro.nn.nets.Restormer import Restormer from mrpro.nn.nets.Uformer import Uformer -from mrpro.nn.nets.DCAE import DCVAE +from mrpro.nn.nets.DCVAE import DCVAE from mrpro.nn.nets.VAE import VAE from mrpro.nn.nets.UNet import UNet, AttentionGatedUNet from mrpro.nn.nets.SwinIR import SwinIR diff --git a/tests/nn/nets/test_dcvae.py b/tests/nn/nets/test_dcvae.py new file mode 100644 index 000000000..f5b190546 --- /dev/null +++ b/tests/nn/nets/test_dcvae.py @@ -0,0 +1,59 @@ +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import DCVAE + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_dcvae_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the DCVAE.""" + dcvae = DCVAE( + n_dim=2, + n_channels=1, + latent_dim=4, + block_types=('CNN', 'LinearViT', 'ViT'), + widths=(32, 64, 32), + depths=(2, 2, 3), + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + dcvae = dcvae.to(device) + x = x.to(device) + if torch_compile: + dcvae = cast(DCVAE, torch.compile(dcvae)) + y, kl = dcvae(x) + assert y.shape == (1, 1, 16, 16) + assert kl.shape == () + latent = dcvae.encoder(x) + assert latent.shape == (1, 2 * 4, 2, 2) # 2 because of mean and logvar + + +def test_dcvae_backward(): + """Test the backward pass of the DCVAE.""" + dcvae = DCVAE( + n_dim=1, + n_channels=1, + latent_dim=4, + block_types=('CNN', 'LinearViT', 'ViT'), + widths=(8, 12, 16), + depths=(2, 2, 3), + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + + y, kl = dcvae(x) + y.sum().backward() + kl.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + for name, parameter in dcvae.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' diff --git a/tests/nn/nets/test_unet.py b/tests/nn/nets/test_unet.py index 5511722a0..20e7659ab 100644 --- a/tests/nn/nets/test_unet.py +++ b/tests/nn/nets/test_unet.py @@ -2,7 +2,7 @@ import pytest import torch -from mrpro.nn.nets import UNet +from mrpro.nn.nets import AttentionGatedUNet, UNet @pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) @@ -61,3 +61,56 @@ def test_unet_backward(): for name, parameter in unet.named_parameters(): assert parameter.grad is not None, f'{name}.grad is None' assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_gated_unet_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the AttentionGatedUNet.""" + unet = AttentionGatedUNet( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_features=(4, 6, 8), + cond_dim=32, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + unet = unet.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + unet = cast(AttentionGatedUNet, torch.compile(unet)) + y = unet(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + assert y.mean().abs() < 0.1 + + +def test_gated_unet_backward() -> None: + """Test the backward pass of the AttentionGatedUNet.""" + unet = AttentionGatedUNet( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_features=(4, 6, 8), + cond_dim=32, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = unet(x, cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in unet.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From b9a3d4e29dc78596af1bd443ddbca6834ac83469 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 22 Jul 2025 01:31:20 +0200 Subject: [PATCH 138/205] fixes --- src/mrpro/nn/nets/BasicCNN.py | 2 +- tests/nn/nets/test_cnn.py | 6 ++---- tests/nn/nets/test_restormer.py | 3 +-- tests/nn/nets/test_uformer.py | 9 ++------- tests/nn/nets/test_unet.py | 2 -- tests/nn/test_rope.py | 2 +- 6 files changed, 7 insertions(+), 17 deletions(-) diff --git a/src/mrpro/nn/nets/BasicCNN.py b/src/mrpro/nn/nets/BasicCNN.py index adca090b6..b39741f85 100644 --- a/src/mrpro/nn/nets/BasicCNN.py +++ b/src/mrpro/nn/nets/BasicCNN.py @@ -102,4 +102,4 @@ def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.T ------- The output tensor. """ - return super().__call__(*x, cond=cond) + return super().__call__(x, cond=cond) diff --git a/tests/nn/nets/test_cnn.py b/tests/nn/nets/test_cnn.py index 6ba3df0ad..db8bf8ffc 100644 --- a/tests/nn/nets/test_cnn.py +++ b/tests/nn/nets/test_cnn.py @@ -24,15 +24,12 @@ def test_cnn_forward(torch_compile: bool, device: str) -> None: ) x = torch.zeros(1, 1, 16, 16, device=device) - cond = torch.zeros(1, 32, device=device) cnn = cnn.to(device) x = x.to(device) - cond = cond.to(device) if torch_compile: cnn = cast(BasicCNN, torch.compile(cnn)) - y = cnn(x, cond=cond) + y = cnn(x) assert y.shape == (1, 1, 16, 16) - assert y.mean().abs() < 0.1 def test_cnn_backward(): @@ -43,6 +40,7 @@ def test_cnn_backward(): norm='instance', activation='silu', n_features=(8, 8), + cond_dim=32, ) x = torch.zeros(1, 1, 16, requires_grad=True) diff --git a/tests/nn/nets/test_restormer.py b/tests/nn/nets/test_restormer.py index 821ba7acb..ca5f04785 100644 --- a/tests/nn/nets/test_restormer.py +++ b/tests/nn/nets/test_restormer.py @@ -34,12 +34,11 @@ def test_restormer_forward(torch_compile: bool, device: str) -> None: restormer = cast(Restormer, torch.compile(restormer)) y = restormer(x, cond=cond) assert y.shape == (1, 1, 16, 16) - assert y.mean().abs() < 0.2 def test_restormer_backward(): restormer = Restormer( - n_dim=2, + n_dim=1, n_channels_in=1, n_channels_out=1, n_heads=(1, 2), diff --git a/tests/nn/nets/test_uformer.py b/tests/nn/nets/test_uformer.py index 2985aeca2..4a7858f65 100644 --- a/tests/nn/nets/test_uformer.py +++ b/tests/nn/nets/test_uformer.py @@ -16,12 +16,7 @@ def test_uformer_forward(torch_compile: bool, device: str) -> None: """Test the forward pass of the uformer.""" uformer = Uformer( - n_dim=2, - n_channels_in=1, - n_channels_out=1, - n_heads=(1, 2, 4), - cond_dim=32, - n_channels_per_head=8, + n_dim=2, n_channels_in=1, n_channels_out=1, n_heads=(1, 2, 4), cond_dim=32, n_channels_per_head=8, window_size=2 ) x = torch.zeros(1, 1, 16, 16, device=device) @@ -33,7 +28,6 @@ def test_uformer_forward(torch_compile: bool, device: str) -> None: uformer = cast(Uformer, torch.compile(uformer)) y = uformer(x, cond=cond) assert y.shape == (1, 1, 16, 16) - assert y.mean().abs() < 0.1 def test_uformer_backward(): @@ -44,6 +38,7 @@ def test_uformer_backward(): n_heads=(1, 2, 4), cond_dim=32, n_channels_per_head=8, + window_size=2, ) x = torch.zeros(1, 1, 16, requires_grad=True) diff --git a/tests/nn/nets/test_unet.py b/tests/nn/nets/test_unet.py index 20e7659ab..4ef8694c7 100644 --- a/tests/nn/nets/test_unet.py +++ b/tests/nn/nets/test_unet.py @@ -35,7 +35,6 @@ def test_unet_forward(torch_compile: bool, device: str) -> None: unet = cast(UNet, torch.compile(unet)) y = unet(x, cond=cond) assert y.shape == (1, 1, 16, 16) - assert y.mean().abs() < 0.1 def test_unet_backward(): @@ -90,7 +89,6 @@ def test_gated_unet_forward(torch_compile: bool, device: str) -> None: unet = cast(AttentionGatedUNet, torch.compile(unet)) y = unet(x, cond=cond) assert y.shape == (1, 1, 16, 16) - assert y.mean().abs() < 0.1 def test_gated_unet_backward() -> None: diff --git a/tests/nn/test_rope.py b/tests/nn/test_rope.py index 3b484f733..b19dc6c01 100644 --- a/tests/nn/test_rope.py +++ b/tests/nn/test_rope.py @@ -15,7 +15,7 @@ def test_rope(device: torch.device): shape = (10, 10) n_heads = 2 n_channels = 64 - n_embed = int(0.5 * n_channels // n_heads) + n_embed = int(0.5 * n_channels) q, k = RandomGenerator(seed=42).float32_tensor((2, 1, n_heads, *shape, n_channels), low=0.5).to(device) rope = AxialRoPE(embed_fraction=0.5) From da5cc632974bc5d8914afebe0cf3700c54fbe9e9 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 22 Jul 2025 02:21:40 +0200 Subject: [PATCH 139/205] update test --- src/mrpro/nn/AxialRoPE.py | 2 +- src/mrpro/nn/PixelShuffle.py | 3 +- src/mrpro/nn/nets/Restormer.py | 2 +- tests/nn/nets/test_dcvae.py | 32 ++++++++++++++++---- tests/nn/nets/test_swinir.py | 54 ++++++++++++++++++++++++++++++++++ 5 files changed, 85 insertions(+), 8 deletions(-) create mode 100644 tests/nn/nets/test_swinir.py diff --git a/src/mrpro/nn/AxialRoPE.py b/src/mrpro/nn/AxialRoPE.py index 60ee2327e..90580db9a 100644 --- a/src/mrpro/nn/AxialRoPE.py +++ b/src/mrpro/nn/AxialRoPE.py @@ -67,7 +67,7 @@ def forward(self, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: Tensors to apply rotary embeddings to. Shape must be `(batch, heads, *spatial_dims, channels)`. """ - if self.embed_fraction == 1.0: + if self.embed_fraction == 0.0: return tensors shape = tensors[0].shape diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py index a39bd7731..6ba74e2eb 100644 --- a/src/mrpro/nn/PixelShuffle.py +++ b/src/mrpro/nn/PixelShuffle.py @@ -131,7 +131,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.residual: x = self.pixel_unshuffle(x) - h = h + x.unflatten(1, (h.shape[1], -1)).mean(2) + n = (x.shape[1] // h.shape[1]) * h.shape[1] + h = h + x[:, :n].unflatten(1, (h.shape[1], -1)).mean(2) return h diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index ae1204748..9f85c48ba 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -199,7 +199,7 @@ def blocks(n_heads: int, n_blocks: int): Concat(), ConvND(n_dim)(2 * n_channels_per_head * head, n_channels_per_head * head, kernel_size=1), ) - for head in n_heads[1::-1] + for head in n_heads[-2::-1] ] decoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)][::-1] last_block = Sequential( diff --git a/tests/nn/nets/test_dcvae.py b/tests/nn/nets/test_dcvae.py index f5b190546..502e1dcda 100644 --- a/tests/nn/nets/test_dcvae.py +++ b/tests/nn/nets/test_dcvae.py @@ -21,7 +21,7 @@ def test_dcvae_forward(torch_compile: bool, device: str) -> None: latent_dim=4, block_types=('CNN', 'LinearViT', 'ViT'), widths=(32, 64, 32), - depths=(2, 2, 3), + depths=(1, 2, 2), ) x = torch.zeros(1, 1, 16, 16, device=device) @@ -36,8 +36,8 @@ def test_dcvae_forward(torch_compile: bool, device: str) -> None: assert latent.shape == (1, 2 * 4, 2, 2) # 2 because of mean and logvar -def test_dcvae_backward(): - """Test the backward pass of the DCVAE.""" +def test_dcvae_backward_kl(): + """Test the backward pass of the DCVAE wrt kl.""" dcvae = DCVAE( n_dim=1, n_channels=1, @@ -49,11 +49,33 @@ def test_dcvae_backward(): x = torch.zeros(1, 1, 16, requires_grad=True) - y, kl = dcvae(x) - y.sum().backward() + _, kl = dcvae(x) kl.sum().backward() assert x.grad is not None, 'x.grad is None' assert not x.grad.isnan().any(), 'x.grad is NaN' for name, parameter in dcvae.named_parameters(): assert parameter.grad is not None, f'{name}.grad is None' assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' + + +def test_dcvae_backward_y(): + """Test the backward pass of the DCVAE wrt y.""" + dcvae = DCVAE( + n_dim=1, + n_channels=1, + latent_dim=4, + block_types=('CNN', 'LinearViT', 'ViT'), + widths=(8, 12, 16), + depths=(2, 2, 3), + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + + y, _ = dcvae(x) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + # only the encoder parameters can influence kl + for name, parameter in dcvae.encoder.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' diff --git a/tests/nn/nets/test_swinir.py b/tests/nn/nets/test_swinir.py new file mode 100644 index 000000000..2bef91d13 --- /dev/null +++ b/tests/nn/nets/test_swinir.py @@ -0,0 +1,54 @@ +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import SwinIR + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_swinir_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the UNet.""" + swinir = SwinIR( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_heads=2, + n_channels_per_head=4, + n_blocks=2, + window_size=4, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + swinir = swinir.to(device) + if torch_compile: + swinir = cast(SwinIR, torch.compile(swinir)) + y = swinir(x) + assert y.shape == (1, 1, 16, 16) + + +def test_swinir_backward(): + swinir = SwinIR( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_heads=2, + n_channels_per_head=4, + n_blocks=2, + window_size=4, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + y = swinir(x) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + for name, parameter in swinir.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From fb0e0a485cfb2c9867807096c7bba4aec3cffa41 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 22 Jul 2025 02:30:41 +0200 Subject: [PATCH 140/205] fix --- tests/nn/nets/test_dcvae.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/nn/nets/test_dcvae.py b/tests/nn/nets/test_dcvae.py index 502e1dcda..335595e3d 100644 --- a/tests/nn/nets/test_dcvae.py +++ b/tests/nn/nets/test_dcvae.py @@ -53,7 +53,7 @@ def test_dcvae_backward_kl(): kl.sum().backward() assert x.grad is not None, 'x.grad is None' assert not x.grad.isnan().any(), 'x.grad is NaN' - for name, parameter in dcvae.named_parameters(): + for name, parameter in dcvae.encoder.named_parameters(): # only the encoder parameters can influence kl assert parameter.grad is not None, f'{name}.grad is None' assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' @@ -75,7 +75,6 @@ def test_dcvae_backward_y(): y.sum().backward() assert x.grad is not None, 'x.grad is None' assert not x.grad.isnan().any(), 'x.grad is NaN' - # only the encoder parameters can influence kl - for name, parameter in dcvae.encoder.named_parameters(): + for name, parameter in dcvae.named_parameters(): assert parameter.grad is not None, f'{name}.grad is None' assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From 6882b6f60fee6b0b0284b12f9b74ac28a72caf8f Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 22 Jul 2025 11:09:45 +0200 Subject: [PATCH 141/205] encodings --- ...ncoding.py => AbsolutePositionEncoding.py} | 73 ++++--------------- src/mrpro/nn/FourierFeatures.py | 50 +++++++++++++ src/mrpro/nn/__init__.py | 4 + tests/nn/test_ape.py | 27 +++++++ tests/nn/test_fourierfeatures.py | 24 ++++++ 5 files changed, 121 insertions(+), 57 deletions(-) rename src/mrpro/nn/{encoding.py => AbsolutePositionEncoding.py} (53%) create mode 100644 src/mrpro/nn/FourierFeatures.py create mode 100644 tests/nn/test_ape.py create mode 100644 tests/nn/test_fourierfeatures.py diff --git a/src/mrpro/nn/encoding.py b/src/mrpro/nn/AbsolutePositionEncoding.py similarity index 53% rename from src/mrpro/nn/encoding.py rename to src/mrpro/nn/AbsolutePositionEncoding.py index 94d527b21..f6093d6ba 100644 --- a/src/mrpro/nn/encoding.py +++ b/src/mrpro/nn/AbsolutePositionEncoding.py @@ -1,4 +1,4 @@ -"""Encoding modules for neural networks.""" +"""Absolute position encoding (APE).""" from itertools import combinations from math import ceil @@ -9,73 +9,29 @@ from mrpro.utils.reshape import unsqueeze_right -class FourierFeatures(Module): - """Fourier feature encoding layer. - - Projects input features into a higher dimensional space using random Fourier features. - This is useful for encoding positional information in neural networks. - """ - - weight: torch.Tensor - - def __init__(self, in_features: int, out_features: int, std: float = 1.0): - """Initialize Fourier feature encoding layer. - - Parameters - ---------- - in_features - Number of input features - out_features - Number of output features (must be even) - std - Standard deviation for random initialization - """ - if out_features % 2 != 0: - raise ValueError('out_features must be even.') - super().__init__() - self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std) - - def __call__(self, x: torch.Tensor) -> torch.Tensor: - """Apply Fourier feature encoding. - - Parameters - ---------- - x - Input tensor of shape (..., in_features) - - Returns - ------- - Encoded features of shape (..., out_features) - """ - return super().__call__(x) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Apply Fourier feature encoding.""" - f = 2 * torch.pi * x @ self.weight.T - return torch.cat([f.cos(), f.sin()], dim=-1) - - class AbsolutePositionEncoding(Module): """Absolute position encoding layer. - Encodes absolute positions in a grid using learned embeddings. + Encodes absolute positions in a grid. Has no learnable parameters. """ encoding: torch.Tensor - def __init__(self, n_dim: int, features: int, include_radii: bool = True, base_resolution: int = 128): + def __init__(self, n_dim: int, n_features: int, include_radii: bool = True, base_resolution: int = 128): """Initialize absolute position encoding layer. Parameters ---------- n_dim Dimensions of the input space (1, 2, or 3) - features - Number of output features + n_features + Number of features to encode. The input to the forward pass needs to have at least + this many features/channels. include_radii Whether to include radius features base_resolution - Base resolution for position encoding + Base resolution for position encoding. + Encodings are generated at this resolution and interpolated to the input shape in the forward pass. """ super().__init__() @@ -84,18 +40,17 @@ def __init__(self, n_dim: int, features: int, include_radii: bool = True, base_r for n in range(2, n_dim + 1): for combination in combinations(coords, n): coords.append((2 * sum([c**2 for c in combination])) ** 0.5 - 1) - n_freqs = ceil(features / len(coords) / 2) + n_freqs = ceil(n_features / len(coords) / 2) freqs = unsqueeze_right((base_resolution) ** torch.linspace(0, 1, n_freqs), n_dim) encoding = [] for coord in coords: encoding.append(torch.sin(coord * freqs).broadcast_to(1, -1, *((base_resolution,) * n_dim))) encoding.append(torch.cos(coord * freqs).broadcast_to(1, -1, *((base_resolution,) * n_dim))) - self.register_buffer('encoding', torch.cat(encoding, dim=1)[:, :features]) + self.register_buffer('encoding', torch.cat(encoding, dim=1)[:, :n_features]) self.interpolation_mode = ['linear', 'bilinear', 'trilinear'][n_dim - 1] def __call__(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass for encoding. + """Apply absolute position encoding to a tensor. Parameters ---------- @@ -104,8 +59,12 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: Returns ------- - Encoded tensor with absolute position information + Encoded tensor with absolute position information """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply absolute position encoding to a tensor.""" features = self.encoding.shape[1] if features > x.shape[1]: raise ValueError(f'x has {x.shape[1]} features, but {features} are required') diff --git a/src/mrpro/nn/FourierFeatures.py b/src/mrpro/nn/FourierFeatures.py new file mode 100644 index 000000000..847ae3c60 --- /dev/null +++ b/src/mrpro/nn/FourierFeatures.py @@ -0,0 +1,50 @@ +"""Random Fourier feature embedding.""" + +import torch +from torch.nn import Module + + +class FourierFeatures(Module): + """Fourier feature encoding layer. + + Projects input features into a higher dimensional space using random Fourier features. + Used in INRs and to embed the time or other continuous variables. + """ + + weight: torch.Tensor + + def __init__(self, n_features_in: int, n_features_out: int, std: float = 1.0): + """Initialize Fourier feature encoding layer. + + Parameters + ---------- + n_features_in + Number of input features + n_features_out + Number of output features (must be even) + std + Standard deviation for random initialization + """ + if n_features_out % 2 != 0: + raise ValueError('n_features_out must be even.') + super().__init__() + self.register_buffer('weight', torch.randn([n_features_out // 2, n_features_in]) * std) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply Fourier feature encoding. + + Parameters + ---------- + x + Input tensor of shape (..., in_features) + + Returns + ------- + Encoded features of shape (..., out_features) + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply Fourier feature encoding.""" + f = 2 * torch.pi * x @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index a5b64bafc..ee5032f51 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -24,8 +24,11 @@ from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.RMSNorm import RMSNorm from mrpro.nn.AxialRoPE import AxialRoPE +from mrpro.nn.AbsolutePositionEncoding import AbsolutePositionEncoding +from mrpro.nn.FourierFeatures import FourierFeatures __all__ = [ + "AbsolutePositionEncoding", "AdaptiveAvgPoolND", "AvgPoolND", "AxialRoPE", @@ -36,6 +39,7 @@ "ConvTransposeND", "DropPath", "FiLM", + "FourierFeatures", "GroupNorm", "InstanceNormND", "MaxPoolND", diff --git a/tests/nn/test_ape.py b/tests/nn/test_ape.py new file mode 100644 index 000000000..c4fbdc634 --- /dev/null +++ b/tests/nn/test_ape.py @@ -0,0 +1,27 @@ +"""Tests for absolute position encoding""" + +import pytest +import torch +from mrpro.nn import AbsolutePositionEncoding +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_absolute_position_encodings(device) -> None: + """Test absolute position encoding.""" + n_features = 32 + shape = (1, 2 * n_features, 32, 32) + ape = AbsolutePositionEncoding(2, n_features, True, 128).to(device) + rng = RandomGenerator(444) + x1 = rng.float32_tensor(shape).to(device) + x2 = rng.float32_tensor(shape).to(device) + y1, y2 = ape(x1), ape(x2) + assert y1.shape == x1.shape + torch.testing.assert_close(y1 - x1, y2 - x2) + assert x1[:, n_features:] == y1[:, n_features:] # unembedded features diff --git a/tests/nn/test_fourierfeatures.py b/tests/nn/test_fourierfeatures.py new file mode 100644 index 000000000..30f91d785 --- /dev/null +++ b/tests/nn/test_fourierfeatures.py @@ -0,0 +1,24 @@ +"""Test for random fourier features""" + +import pytest +from mrpro.nn import FourierFeatures +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_fourierfeatures(device) -> None: + """Test fourier features""" + n_features_in = 1 + n_features_out = 16 + std = 1.0 + rng = RandomGenerator(444) + x = rng.float32_tensor((1, n_features_in)).to(device) + ff = FourierFeatures(n_features_in, n_features_out, std).to(device) + y = ff(x) + assert y.shape == (1, n_features_out) From dc911f5f7ba2a10a561fb9416685701953d752db Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 22 Jul 2025 11:11:56 +0200 Subject: [PATCH 142/205] nocover --- src/mrpro/nn/AxialRoPE.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mrpro/nn/AxialRoPE.py b/src/mrpro/nn/AxialRoPE.py index 90580db9a..03bf55c16 100644 --- a/src/mrpro/nn/AxialRoPE.py +++ b/src/mrpro/nn/AxialRoPE.py @@ -7,6 +7,7 @@ from torch.nn import Module +# pragma: no cover @torch.compile def get_theta(shape: Sequence[int], n_embedding_channels: int, device: torch.device) -> torch.Tensor: """Get rotation angles. From de43dceac694deca516dc6ff317cc93e7284daaa Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 22 Jul 2025 23:40:07 +0200 Subject: [PATCH 143/205] hourglass --- src/mrpro/nn/GroupNorm.py | 13 +++- src/mrpro/nn/PixelShuffle.py | 59 +++++++++++++++---- .../nn/attention/SpatialTransformerBlock.py | 26 ++++++-- src/mrpro/nn/nets/HourglassTransformer.py | 33 ++++++++--- tests/nn/test_ape.py | 2 +- tests/nn/test_pixelshuffle.py | 8 +-- 6 files changed, 110 insertions(+), 31 deletions(-) diff --git a/src/mrpro/nn/GroupNorm.py b/src/mrpro/nn/GroupNorm.py index e0090d018..7c647cee8 100644 --- a/src/mrpro/nn/GroupNorm.py +++ b/src/mrpro/nn/GroupNorm.py @@ -9,7 +9,7 @@ class GroupNorm(torch.nn.GroupNorm): Casts to float32 before calling the parent class to avoid instabilities in mixed precision training. """ - def __init__(self, n_channels: int, n_groups: int | None = None, affine: bool = False): + def __init__(self, n_channels: int, n_groups: int | None = None, affine: bool = False, features_last: bool = False): """Initialize GroupNorm. Parameters @@ -21,6 +21,9 @@ def __init__(self, n_channels: int, n_groups: int | None = None, affine: bool = a power of 2 that is less than or equal to 32 and leaves at least 4 channels per group. affine Whether to use learnable affine parameters. + features_last + Whether the features are last in the input tensor, as common in transformer models, + or in the second dimension, as common in CNNs. """ if n_groups is None: groups_, candidate = 1, 2 @@ -28,6 +31,7 @@ def __init__(self, n_channels: int, n_groups: int | None = None, affine: bool = groups_, candidate = candidate, groups_ * 2 else: groups_ = n_groups + self.features_last = features_last super().__init__(groups_, n_channels, affine=affine) def __call__(self, x: torch.Tensor) -> torch.Tensor: @@ -46,4 +50,9 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply GroupNorm.""" - return super().forward(x.float()).type(x.dtype) + if self.features_last: + x = x.moveaxis(-1, 1) + result = super().forward(x.float()).type(x.dtype) + if self.features_last: + result = result.moveaxis(1, -1) + return result diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py index 6ba74e2eb..9e5da35e8 100644 --- a/src/mrpro/nn/PixelShuffle.py +++ b/src/mrpro/nn/PixelShuffle.py @@ -3,7 +3,7 @@ from math import ceil import torch -from torch.nn import Module +from torch.nn import Linear, Module from mrpro.nn.ndmodules import ConvND @@ -85,7 +85,13 @@ class PixelUnshuffleDownsample(Module): """ def __init__( - self, n_dim: int, n_channels_in: int, n_channels_out: int, downscale_factor: int = 2, residual: bool = False + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + downscale_factor: int = 2, + residual: bool = False, + features_last: bool = False, ): """Initialize a PixelUnshuffleDownsample layer. @@ -101,14 +107,21 @@ def __init__( Factor by which to downscale the input tensor. residual Whether to use a residual connection as proposed in [DCAE]_. + features_last + Whether the features are last in the input tensor, as common in transformer models, + or in the second dimension, as common in CNNs. """ super().__init__() out_ratio = downscale_factor**n_dim if n_channels_out % out_ratio != 0: raise ValueError(f'channels_out must be divisible by downscale_factor**{n_dim}.') - self.conv = ConvND(n_dim)(n_channels_in, n_channels_out // out_ratio, kernel_size=3, padding='same') + if features_last: + self.projection: Module = Linear(n_channels_in, n_channels_out // out_ratio) + else: + self.projection = ConvND(n_dim)(n_channels_in, n_channels_out // out_ratio, kernel_size=3, padding='same') + self.features_last = features_last self.residual = residual - self.pixel_unshuffle = PixelUnshuffle(downscale_factor) + self.pixel_unshuffle = PixelUnshuffle(downscale_factor, features_last) def __call__(self, x: torch.Tensor) -> torch.Tensor: """Apply downsampling. @@ -126,13 +139,17 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply downsampling.""" - h = self.conv(x) + h = self.projection(x) h = self.pixel_unshuffle(h) if self.residual: x = self.pixel_unshuffle(x) - n = (x.shape[1] // h.shape[1]) * h.shape[1] - h = h + x[:, :n].unflatten(1, (h.shape[1], -1)).mean(2) + if self.features_last: + n = (x.shape[-1] // h.shape[-1]) * h.shape[-1] + h = h + x[..., :n].unflatten(-1, (h.shape[-1], -1)).mean(-1) + else: + n = (x.shape[1] // h.shape[1]) * h.shape[1] + h = h + x[:, :n].unflatten(1, (h.shape[1], -1)).mean(2) return h @@ -148,7 +165,13 @@ class PixelShuffleUpsample(Module): """ def __init__( - self, n_dim: int, n_channels_in: int, n_channels_out: int, upscale_factor: int = 2, residual: bool = False + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + upscale_factor: int = 2, + residual: bool = False, + features_last: bool = False, ): """Initialize a PixelShuffleUpsample layer. @@ -164,10 +187,19 @@ def __init__( Factor by which to upscale the input tensor. residual Whether to use a residual connection as proposed in [DCAE]_. + features_last + Whether the features are last in the input tensor, as common in transformer models, + or in the second dimension, as common in CNNs. """ super().__init__() - self.conv = ConvND(n_dim)(n_channels_in, n_channels_out * upscale_factor**n_dim, kernel_size=3, padding='same') - self.pixel_shuffle = PixelShuffle(upscale_factor) + if features_last: + self.projection: Module = Linear(n_channels_in, n_channels_out * upscale_factor**n_dim) + else: + self.projection = ConvND(n_dim)( + n_channels_in, n_channels_out * upscale_factor**n_dim, kernel_size=3, padding='same' + ) + self.features_last = features_last + self.pixel_shuffle = PixelShuffle(upscale_factor, features_last) self.residual = residual def __call__(self, x: torch.Tensor) -> torch.Tensor: @@ -186,9 +218,12 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply upsampling.""" - h = self.conv(x) + h = self.projection(x) if self.residual: - h = h + x.repeat_interleave(ceil(h.shape[1] / x.shape[1]), dim=1)[:, : h.shape[1]] + if self.features_last: + h = h + x.repeat_interleave(ceil(h.shape[-1] / x.shape[-1]), dim=-1)[..., : h.shape[-1]] + else: + h = h + x.repeat_interleave(ceil(h.shape[1] / x.shape[1]), dim=1)[:, : h.shape[1]] out = self.pixel_shuffle(h) return out diff --git a/src/mrpro/nn/attention/SpatialTransformerBlock.py b/src/mrpro/nn/attention/SpatialTransformerBlock.py index f765770ac..18817b26f 100644 --- a/src/mrpro/nn/attention/SpatialTransformerBlock.py +++ b/src/mrpro/nn/attention/SpatialTransformerBlock.py @@ -1,6 +1,7 @@ """Spatial transformer block.""" from collections.abc import Sequence +from typing import Literal import torch from torch.nn import Dropout, Linear, Module @@ -12,6 +13,7 @@ from mrpro.nn.GroupNorm import GroupNorm from mrpro.nn.LayerNorm import LayerNorm from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.RMSNorm import RMSNorm from mrpro.nn.Sequential import Sequential @@ -129,6 +131,8 @@ def __init__( cond_dim: int = 0, rope_embed_fraction: float = 0.0, attention_neighborhood: int | None = None, + features_last: bool = False, + norm: Literal['group', 'rms'] = 'group', ): """Initialize the spatial transformer block. @@ -150,14 +154,26 @@ def __init__( Fraction of channels to embed with RoPE. attention_neighborhood If not None, use NeighborhoodSelfAttention with the given neighborhood size instead of MultiHeadAttention. + features_last + Whether the features are last in the input tensor, as common in transformer models. + norm + Whether to use GroupNorm or RMSNorm. """ super().__init__() hidden_dim = n_heads * (channels // n_heads) - self.norm = GroupNorm(channels) + match norm: + case 'group': + self.norm: Module = GroupNorm(channels, features_last=features_last) + case 'rms': + self.norm = RMSNorm(channels, features_last=features_last) + case _: + raise ValueError(f'Invalid norm: {norm}') + self.features_last = features_last self.proj_in = Linear(channels, hidden_dim) self.transformer_blocks = Sequential() for group in (g for _ in range(depth) for g in dim_groups): - group = tuple(g - 1 if g < 0 else g for g in group) + if not self.features_last: + group = tuple(g - 1 if g < 0 else g for g in group) block = BasicTransformerBlock( hidden_dim, n_heads, @@ -174,11 +190,13 @@ def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch """Apply the spatial transformer block.""" skip = x h = self.norm(x) - h = h.movedim(1, -1) + if not self.features_last: + h = h.movedim(1, -1) h = self.proj_in(h) h = self.transformer_blocks(h, cond=cond) h = self.proj_out(h) - h = h.movedim(-1, 1) + if not self.features_last: + h = h.movedim(-1, 1) return skip + h def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: diff --git a/src/mrpro/nn/nets/HourglassTransformer.py b/src/mrpro/nn/nets/HourglassTransformer.py index 594e2c855..6736d5c35 100644 --- a/src/mrpro/nn/nets/HourglassTransformer.py +++ b/src/mrpro/nn/nets/HourglassTransformer.py @@ -1,6 +1,7 @@ """Hourglass Transformer.""" from collections.abc import Sequence +from itertools import pairwise from torch.nn import Module @@ -74,55 +75,71 @@ def __init__( attention_neighborhood_ = to_tuple(n_layers, attention_neighborhood) n_heads_ = to_tuple(n_layers, n_heads) - move_channels_last = RearrangeOp('batch ... channels -> batch ... channels') + move_channels_last = RearrangeOp('batch channels ... -> batch ... channels') first_block = Sequential( - PixelUnshuffleDownsample(n_dim, n_channels_in, n_features_[0], downscale_factor=2), move_channels_last, + PixelUnshuffleDownsample(n_dim, n_channels_in, n_features_[0], downscale_factor=2, features_last=True), ) - dim = (tuple(range(-n_dim, 0)),) # TODO: allow arbitrary dimensions. + dim_group = (tuple(range(-n_dim - 1, -1)),) encoder_blocks: list[Module] = [] decoder_blocks: list[Module] = [] merge_blocks: list[Module] = [] down_blocks: list[Module] = [] up_blocks: list[Module] = [] for channels, depth, neighborhood, head in zip( - n_features_, depths_, attention_neighborhood_, n_heads_, strict=True + n_features_[:-1], + depths_[:-1], + attention_neighborhood_[:-1], + n_heads_[:-1], + strict=True, ): encoder_blocks.append( SpatialTransformerBlock( - dim_groups=dim, + dim_groups=dim_group, channels=channels, depth=depth, attention_neighborhood=neighborhood, n_heads=head, rope_embed_fraction=1.0, cond_dim=cond_dim, + features_last=True, + norm='rms', ) ) decoder_blocks.append( SpatialTransformerBlock( - dim_groups=dim, + dim_groups=dim_group, channels=channels, depth=depth, attention_neighborhood=neighborhood, n_heads=head, rope_embed_fraction=1.0, cond_dim=cond_dim, + features_last=True, + norm='rms', ) ) merge_blocks.append(Interpolate()) + for channels, channels_next in pairwise(n_features_): + down_blocks.append( + PixelUnshuffleDownsample(n_dim, channels, channels_next, downscale_factor=2, features_last=True) + ) + up_blocks.append(PixelShuffleUpsample(n_dim, channels_next, channels, upscale_factor=2, features_last=True)) last_block = Sequential( - move_channels_last.H, PixelShuffleUpsample(n_dim, n_features_[-1], n_channels_out, upscale_factor=2) + PixelShuffleUpsample(n_dim, n_features_[-1], n_channels_out, upscale_factor=2, features_last=True), + move_channels_last.H, # moves channels back to front ) middle_block = SpatialTransformerBlock( - dim_groups=dim, + dim_groups=dim_group, channels=n_features_[-1], depth=depths_[-1], attention_neighborhood=attention_neighborhood_[-1], n_heads=n_heads_[-1], rope_embed_fraction=1.0, cond_dim=cond_dim, + features_last=True, + norm='rms', ) encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) decoder = UNetDecoder(decoder_blocks, up_blocks, merge_blocks, last_block) diff --git a/tests/nn/test_ape.py b/tests/nn/test_ape.py index c4fbdc634..fb5dbef38 100644 --- a/tests/nn/test_ape.py +++ b/tests/nn/test_ape.py @@ -24,4 +24,4 @@ def test_absolute_position_encodings(device) -> None: y1, y2 = ape(x1), ape(x2) assert y1.shape == x1.shape torch.testing.assert_close(y1 - x1, y2 - x2) - assert x1[:, n_features:] == y1[:, n_features:] # unembedded features + assert (x1[:, n_features:] == y1[:, n_features:]).all() # unembedded features diff --git a/tests/nn/test_pixelshuffle.py b/tests/nn/test_pixelshuffle.py index 9f098a4d3..076cf92f7 100644 --- a/tests/nn/test_pixelshuffle.py +++ b/tests/nn/test_pixelshuffle.py @@ -77,10 +77,10 @@ def test_pixelshuffleupsample_pixelunshuffledownsample(): downsample = PixelUnshuffleDownsample(3, 1, 3**3, downscale_factor=3, residual=False) upsample = PixelShuffleUpsample(3, 3**3, 1, upscale_factor=3, residual=False) # Only if the convs are Identity, the upsample and downsample are inverses. - torch.nn.init.dirac_(downsample.conv.weight) - torch.nn.init.dirac_(upsample.conv.weight) - torch.nn.init.zeros_(downsample.conv.bias) # type: ignore[arg-type] - torch.nn.init.zeros_(upsample.conv.bias) # type: ignore[arg-type] + torch.nn.init.dirac_(downsample.projection.weight) + torch.nn.init.dirac_(upsample.projection.weight) + torch.nn.init.zeros_(downsample.projection.bias) # type: ignore[arg-type] + torch.nn.init.zeros_(upsample.projection.bias) # type: ignore[arg-type] y = downsample(upsample(x)) assert y.shape == (1, 3**3, 3, 4, 5) torch.testing.assert_close(y, x, msg='Upsample and downsample are not inverses.') From aa391106eb7ab7659689f2c2c670c9bf598258bf Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 22 Jul 2025 23:43:18 +0200 Subject: [PATCH 144/205] ignore tensorcode warning in tests --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 50c901d46..3ad77c221 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,6 +120,7 @@ filterwarnings = [ "ignore:'write_like_original':DeprecationWarning:pydicom:", "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 + "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda ] addopts = "-n auto --dist loadfile --maxprocesses=8" markers = ["cuda : Tests only to be run when cuda device is available"] From a3187ae756bdb870e1b244bf97de06875d0c3212 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 23 Jul 2025 00:10:35 +0200 Subject: [PATCH 145/205] filter warning --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 3ad77c221..c003bfa7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,6 +121,7 @@ filterwarnings = [ "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda + "ignore:Online softmax is disabled on the fly since Inductor decides to split the reduction:UserWarningc", # torch.compile ] addopts = "-n auto --dist loadfile --maxprocesses=8" markers = ["cuda : Tests only to be run when cuda device is available"] From a228291627c767e7de6247bbe11726fed7831406 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 23 Jul 2025 00:20:35 +0200 Subject: [PATCH 146/205] docstring --- ...citation.py => 'test_squeezeexcitation.py} | 8 ++- tests/nn/data_consistency/conftest.py | 2 + .../test_analyticcertesiandc.py | 2 + .../test_conjugategradientdc.py | 2 + .../test_gradientdescentdc.py | 2 + tests/nn/nets/test_cnn.py | 4 +- tests/nn/nets/test_dcvae.py | 6 +- tests/nn/nets/test_hourglass.py | 60 +++++++++++++++++++ tests/nn/nets/test_restormer.py | 2 + tests/nn/nets/test_swinir.py | 4 +- tests/nn/nets/test_uformer.py | 4 +- tests/nn/nets/test_unet.py | 2 + tests/nn/test_ape.py | 2 +- tests/nn/test_convert_linear_conv.py | 2 +- tests/nn/test_droppath.py | 8 +-- tests/nn/test_film.py | 6 +- tests/nn/test_fourierfeatures.py | 4 +- tests/nn/test_geglu.py | 4 +- tests/nn/test_groupnorm.py | 8 ++- tests/nn/test_layernorm.py | 34 +++++++---- tests/nn/test_pixelshuffle.py | 16 ++--- tests/nn/test_resblock.py | 12 +++- tests/nn/test_rmsnorm.py | 4 +- tests/nn/test_rope.py | 5 +- tests/nn/test_sequential.py | 8 ++- tests/nn/test_shiftedwindowattention.py | 3 + tests/nn/test_transposedattention.py | 10 +++- 27 files changed, 179 insertions(+), 45 deletions(-) rename tests/nn/{test_sqeezeexcitation.py => 'test_squeezeexcitation.py} (86%) create mode 100644 tests/nn/nets/test_hourglass.py diff --git a/tests/nn/test_sqeezeexcitation.py b/tests/nn/'test_squeezeexcitation.py similarity index 86% rename from tests/nn/test_sqeezeexcitation.py rename to tests/nn/'test_squeezeexcitation.py index 8b2a9720e..879dd71ca 100644 --- a/tests/nn/test_sqeezeexcitation.py +++ b/tests/nn/'test_squeezeexcitation.py @@ -1,5 +1,7 @@ """Tests for SqueezeExcitation module.""" +from collections.abc import Sequence + import pytest from mrpro.nn.attention import SqueezeExcitation from mrpro.utils import RandomGenerator @@ -12,7 +14,11 @@ (3, (1, 64, 16, 16, 16), 16), ], ) -def test_squeeze_excitation(dim, input_shape, squeeze_channels): +def test_squeeze_excitation( + dim: int, + input_shape: Sequence[int], + squeeze_channels: int, +) -> None: """Test SqueezeExcitation output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).requires_grad_(True) diff --git a/tests/nn/data_consistency/conftest.py b/tests/nn/data_consistency/conftest.py index e15087e04..49fca4cf3 100644 --- a/tests/nn/data_consistency/conftest.py +++ b/tests/nn/data_consistency/conftest.py @@ -1,3 +1,5 @@ +"""Test fixtures for data consistency tests.""" + import pytest from mrpro.data.KData import KData from mrpro.data.SpatialDimension import SpatialDimension diff --git a/tests/nn/data_consistency/test_analyticcertesiandc.py b/tests/nn/data_consistency/test_analyticcertesiandc.py index dc200e8b7..2bd601461 100644 --- a/tests/nn/data_consistency/test_analyticcertesiandc.py +++ b/tests/nn/data_consistency/test_analyticcertesiandc.py @@ -1,3 +1,5 @@ +"""Tests for AnalyticCartesianDC module.""" + import torch from mrpro.data.KData import KData from mrpro.nn.data_consistency.AnalyticCartesianDC import AnalyticCartesianDC diff --git a/tests/nn/data_consistency/test_conjugategradientdc.py b/tests/nn/data_consistency/test_conjugategradientdc.py index 2527259f4..39dc11c01 100644 --- a/tests/nn/data_consistency/test_conjugategradientdc.py +++ b/tests/nn/data_consistency/test_conjugategradientdc.py @@ -1,3 +1,5 @@ +"""Tests for ConjugateGradientDC module.""" + import torch from mrpro.data.KData import KData from mrpro.nn.data_consistency.ConjugateGradientDC import ConjugateGradientDC diff --git a/tests/nn/data_consistency/test_gradientdescentdc.py b/tests/nn/data_consistency/test_gradientdescentdc.py index fae284b8e..b3ee440c8 100644 --- a/tests/nn/data_consistency/test_gradientdescentdc.py +++ b/tests/nn/data_consistency/test_gradientdescentdc.py @@ -1,3 +1,5 @@ +"""Tests for GradientDescentDC module.""" + import torch from mrpro.data.KData import KData from mrpro.nn.data_consistency.GradientDescentDC import GradientDescentDC diff --git a/tests/nn/nets/test_cnn.py b/tests/nn/nets/test_cnn.py index db8bf8ffc..12c1dd2f8 100644 --- a/tests/nn/nets/test_cnn.py +++ b/tests/nn/nets/test_cnn.py @@ -1,3 +1,5 @@ +"""Tests for BasicCNN network.""" + from typing import cast import pytest @@ -32,7 +34,7 @@ def test_cnn_forward(torch_compile: bool, device: str) -> None: assert y.shape == (1, 1, 16, 16) -def test_cnn_backward(): +def test_cnn_backward() -> None: cnn = BasicCNN( n_dim=1, n_channels_in=1, diff --git a/tests/nn/nets/test_dcvae.py b/tests/nn/nets/test_dcvae.py index 335595e3d..ff5371b7b 100644 --- a/tests/nn/nets/test_dcvae.py +++ b/tests/nn/nets/test_dcvae.py @@ -1,3 +1,5 @@ +"""Tests for DCVAE network.""" + from typing import cast import pytest @@ -36,7 +38,7 @@ def test_dcvae_forward(torch_compile: bool, device: str) -> None: assert latent.shape == (1, 2 * 4, 2, 2) # 2 because of mean and logvar -def test_dcvae_backward_kl(): +def test_dcvae_backward_kl() -> None: """Test the backward pass of the DCVAE wrt kl.""" dcvae = DCVAE( n_dim=1, @@ -58,7 +60,7 @@ def test_dcvae_backward_kl(): assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' -def test_dcvae_backward_y(): +def test_dcvae_backward_y() -> None: """Test the backward pass of the DCVAE wrt y.""" dcvae = DCVAE( n_dim=1, diff --git a/tests/nn/nets/test_hourglass.py b/tests/nn/nets/test_hourglass.py new file mode 100644 index 000000000..908ef38da --- /dev/null +++ b/tests/nn/nets/test_hourglass.py @@ -0,0 +1,60 @@ +"""Test Hourglass Transformer""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import HourglassTransformer + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_hourglass_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the hourglass.""" + hourglass = HourglassTransformer( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_features=64, + attention_neighborhood=(7, 7, None), + cond_dim=32, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + hourglass = hourglass.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + hourglass = cast(HourglassTransformer, torch.compile(hourglass)) + y = hourglass(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + + +def test_hourglass_backward() -> None: + hourglass = HourglassTransformer( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_features=64, + attention_neighborhood=(7, 7, None), + cond_dim=32, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = hourglass(x, cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in hourglass.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' diff --git a/tests/nn/nets/test_restormer.py b/tests/nn/nets/test_restormer.py index ca5f04785..34ec14014 100644 --- a/tests/nn/nets/test_restormer.py +++ b/tests/nn/nets/test_restormer.py @@ -1,3 +1,5 @@ +"""Tests for Restormer network.""" + from typing import cast import pytest diff --git a/tests/nn/nets/test_swinir.py b/tests/nn/nets/test_swinir.py index 2bef91d13..c8dbed58c 100644 --- a/tests/nn/nets/test_swinir.py +++ b/tests/nn/nets/test_swinir.py @@ -1,3 +1,5 @@ +"""Tests for SwinIR network.""" + from typing import cast import pytest @@ -33,7 +35,7 @@ def test_swinir_forward(torch_compile: bool, device: str) -> None: assert y.shape == (1, 1, 16, 16) -def test_swinir_backward(): +def test_swinir_backward() -> None: swinir = SwinIR( n_dim=1, n_channels_in=1, diff --git a/tests/nn/nets/test_uformer.py b/tests/nn/nets/test_uformer.py index 4a7858f65..d3ed29c9b 100644 --- a/tests/nn/nets/test_uformer.py +++ b/tests/nn/nets/test_uformer.py @@ -1,3 +1,5 @@ +"""Tests for Uformer network.""" + from typing import cast import pytest @@ -30,7 +32,7 @@ def test_uformer_forward(torch_compile: bool, device: str) -> None: assert y.shape == (1, 1, 16, 16) -def test_uformer_backward(): +def test_uformer_backward() -> None: uformer = Uformer( n_dim=1, n_channels_in=1, diff --git a/tests/nn/nets/test_unet.py b/tests/nn/nets/test_unet.py index 4ef8694c7..9400c4363 100644 --- a/tests/nn/nets/test_unet.py +++ b/tests/nn/nets/test_unet.py @@ -1,3 +1,5 @@ +"""Tests for UNet and AttentionGatedUNet networks.""" + from typing import cast import pytest diff --git a/tests/nn/test_ape.py b/tests/nn/test_ape.py index fb5dbef38..9f71444fc 100644 --- a/tests/nn/test_ape.py +++ b/tests/nn/test_ape.py @@ -13,7 +13,7 @@ pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), ], ) -def test_absolute_position_encodings(device) -> None: +def test_absolute_position_encodings(device: str) -> None: """Test absolute position encoding.""" n_features = 32 shape = (1, 2 * n_features, 32, 32) diff --git a/tests/nn/test_convert_linear_conv.py b/tests/nn/test_convert_linear_conv.py index c977f0936..19438b9d9 100644 --- a/tests/nn/test_convert_linear_conv.py +++ b/tests/nn/test_convert_linear_conv.py @@ -118,7 +118,7 @@ def test_conv_to_linear_functional(dim: Literal[1, 2, 3], channels_in: int, chan torch.testing.assert_close(y_conv, y_linear) -def test_conv_to_linear_invalid_kernel(): +def test_conv_to_linear_invalid_kernel() -> None: """Test conv_to_linear with invalid kernel size.""" conv = Conv2d(32, 64, kernel_size=3, bias=True) with pytest.raises(ValueError, match='Kernel size must be 1'): diff --git a/tests/nn/test_droppath.py b/tests/nn/test_droppath.py index ff66c69d5..b4ac7f5d7 100644 --- a/tests/nn/test_droppath.py +++ b/tests/nn/test_droppath.py @@ -12,8 +12,8 @@ pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), ], ) -def test_droppath_no_drop(device): - """Test DropPath.""" +def test_droppath_no_drop(device: str) -> None: + """Test DropPath with zero drop rate (should pass through unchanged).""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((1, 3, 4, 5)).to(device) droppath = DropPath(0).to(device) @@ -21,8 +21,8 @@ def test_droppath_no_drop(device): assert (y == x).all() -def test_droppath_drop_all(): - """Test DropPath.""" +def test_droppath_drop_all() -> None: + """Test DropPath with full drop rate (should output zeros).""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((1, 3, 4, 5)) droppath = DropPath(1.0) diff --git a/tests/nn/test_film.py b/tests/nn/test_film.py index 40ada0a1b..0e564c675 100644 --- a/tests/nn/test_film.py +++ b/tests/nn/test_film.py @@ -33,10 +33,10 @@ def test_film( assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() assert x.grad is not None, 'No gradient computed for input' - assert cond.grad is not None, 'No gradient computed for condedding' + assert cond.grad is not None, 'No gradient computed for conditioning' assert not output.isnan().any(), 'NaN values in output' - assert not cond.isnan().any(), 'NaN values in condedding' + assert not cond.isnan().any(), 'NaN values in conditioning' assert not x.grad.isnan().any(), 'NaN values in input gradients' - assert not cond.grad.isnan().any(), 'NaN values in condedding gradients' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' assert film.project is not None, 'Linear layer is not initialized' assert next(film.project.parameters()).grad is not None, 'No gradient computed for Linear layer' diff --git a/tests/nn/test_fourierfeatures.py b/tests/nn/test_fourierfeatures.py index 30f91d785..9452a369f 100644 --- a/tests/nn/test_fourierfeatures.py +++ b/tests/nn/test_fourierfeatures.py @@ -12,8 +12,8 @@ pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), ], ) -def test_fourierfeatures(device) -> None: - """Test fourier features""" +def test_fourierfeatures(device: str) -> None: + """Test FourierFeatures.""" n_features_in = 1 n_features_out = 16 std = 1.0 diff --git a/tests/nn/test_geglu.py b/tests/nn/test_geglu.py index 9de03103c..061837e51 100644 --- a/tests/nn/test_geglu.py +++ b/tests/nn/test_geglu.py @@ -14,7 +14,7 @@ ], ) def test_geglu(device: str) -> None: - """Test GELU.""" + """Test GEGLU output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((1, 3, 4, 5)).to(device).requires_grad_(True) gelu = GEGLU(3, 4).to(device) @@ -27,7 +27,7 @@ def test_geglu(device: str) -> None: def test_geglu_features_last() -> None: - """Test GELU with features last.""" + """Test GEGLU with features_last=True vs features_last=False.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) gelu_last = GEGLU(3, 4, features_last=True) diff --git a/tests/nn/test_groupnorm.py b/tests/nn/test_groupnorm.py index 945860bca..044de17f7 100644 --- a/tests/nn/test_groupnorm.py +++ b/tests/nn/test_groupnorm.py @@ -21,7 +21,13 @@ (64, 8, (2, 64, 16, 16, 16), False), ], ) -def test_groupnorm(n_channels: int, n_groups: int, input_shape: Sequence[int], device: str, affine: bool) -> None: +def test_groupnorm( + n_channels: int, + n_groups: int | None, + input_shape: Sequence[int], + device: str, + affine: bool, +) -> None: """Test GroupNorm output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) diff --git a/tests/nn/test_layernorm.py b/tests/nn/test_layernorm.py index 85dc136a3..ebc11ccb8 100644 --- a/tests/nn/test_layernorm.py +++ b/tests/nn/test_layernorm.py @@ -1,5 +1,7 @@ """Tests for LayerNorm module.""" +from collections.abc import Sequence + import pytest import torch from mrpro.nn.LayerNorm import LayerNorm @@ -22,7 +24,12 @@ (None, True, (2, 16, 16, 64)), ], ) -def test_layernorm_basic(n_channels, features_last, input_shape, device): +def test_layernorm_basic( + n_channels: int | None, + features_last: bool, + input_shape: Sequence[int], + device: str, +) -> None: """Test LayerNorm basic functionality.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) @@ -49,7 +56,12 @@ def test_layernorm_basic(n_channels, features_last, input_shape, device): (64, 32, (2, 64, 16, 16), (2, 32)), ], ) -def test_layernorm_with_conditioning(n_channels, cond_dim, input_shape, cond_shape): +def test_layernorm_with_conditioning( + n_channels: int, + cond_dim: int, + input_shape: Sequence[int], + cond_shape: Sequence[int], +) -> None: """Test LayerNorm with conditioning.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).requires_grad_(True) @@ -66,8 +78,8 @@ def test_layernorm_with_conditioning(n_channels, cond_dim, input_shape, cond_sha assert norm.cond_proj.weight.grad is not None, 'No gradient computed for cond_proj' -def test_layernorm_features_last(): - """Test LayerNorm with features_last=True.""" +def test_layernorm_features_last() -> None: + """Test LayerNorm with features_last=True vs features_last=False.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) @@ -80,7 +92,7 @@ def test_layernorm_features_last(): torch.testing.assert_close(y_standard, y_last.moveaxis(-1, 1)) -def test_layernorm_no_channels(): +def test_layernorm_no_channels() -> None: """Test LayerNorm without channels (pure normalization).""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((1, 32, 32, 32)).requires_grad_(True) @@ -98,19 +110,19 @@ def test_layernorm_no_channels(): assert torch.allclose(std, torch.ones_like(std), atol=1e-5), 'Std not close to 1' -def test_layernorm_conditioning_without_channels(): +def test_layernorm_conditioning_without_channels() -> None: """Test LayerNorm with conditioning but no channels (should raise error).""" with pytest.raises(ValueError, match='channels must be provided if cond_dim > 0'): LayerNorm(n_channels=None, cond_dim=16) -def test_layernorm_invalid_cond_dim(): +def test_layernorm_invalid_cond_dim() -> None: """Test LayerNorm with invalid cond_dim.""" with pytest.raises(RuntimeError, match='Trying to create tensor with negative dimension'): LayerNorm(n_channels=32, cond_dim=-1) -def test_layernorm_3d_input(): +def test_layernorm_3d_input() -> None: """Test LayerNorm with 3D input.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((2, 64, 128)).requires_grad_(True) @@ -123,7 +135,7 @@ def test_layernorm_3d_input(): assert x.grad is not None, 'No gradient computed for input' -def test_layernorm_5d_input(): +def test_layernorm_5d_input() -> None: """Test LayerNorm with 5D input.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((1, 32, 16, 16, 16)).requires_grad_(True) @@ -136,7 +148,7 @@ def test_layernorm_5d_input(): assert x.grad is not None, 'No gradient computed for input' -def test_layernorm_conditioning_features_last(): +def test_layernorm_conditioning_features_last() -> None: """Test LayerNorm with conditioning and features_last=True.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) @@ -152,7 +164,7 @@ def test_layernorm_conditioning_features_last(): assert cond.grad is not None, 'No gradient computed for conditioning' -def test_layernorm_gradient_flow(): +def test_layernorm_gradient_flow() -> None: """Test that gradients flow properly through LayerNorm.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((1, 32, 32, 32)).requires_grad_(True) diff --git a/tests/nn/test_pixelshuffle.py b/tests/nn/test_pixelshuffle.py index 076cf92f7..0afc8002f 100644 --- a/tests/nn/test_pixelshuffle.py +++ b/tests/nn/test_pixelshuffle.py @@ -5,7 +5,7 @@ from mrpro.utils import RandomGenerator -def test_pixel_shuffle_2d(): +def test_pixel_shuffle_2d() -> None: """Test PixelUnshuffle's fast path for 2D images.""" x = torch.arange(3 * 4 * 8).reshape(1, 3, 4, 8) pixel_unshuffle = PixelUnshuffle(2) @@ -18,7 +18,7 @@ def test_pixel_shuffle_2d(): assert (x == z).all() -def test_pixel_unshuffle_4d(): +def test_pixel_unshuffle_4d() -> None: """Test PixelUnshuffle's general case.""" x = torch.arange(3 * 4 * 8 * 10 * 12).reshape(1, 3, 4, 8, 10, 12) pixel_unshuffle = PixelUnshuffle(2) @@ -31,7 +31,7 @@ def test_pixel_unshuffle_4d(): assert (x == z).all() -def test_pixelunshuffle_features_last(): +def test_pixelunshuffle_features_last() -> None: """Test PixelUnshuffle with features_last.""" x = torch.arange(3 * 4 * 8 * 10 * 12).reshape(1, 3, 4, 8, 10, 12) pixel_unshuffle_last = PixelUnshuffle(2, features_last=True) @@ -41,8 +41,8 @@ def test_pixelunshuffle_features_last(): assert (y_last == y_normal).all() -def test_pixelshuffle_features_last(): - """Test PixelS huffle with features_last.""" +def test_pixelshuffle_features_last() -> None: + """Test PixelShuffle with features_last.""" x = torch.arange(3 * 4 * 8 * 10 * 12).reshape(1, -1, 2, 4, 5, 6) pixel_shuffle_last = PixelShuffle(2, features_last=True) pixel_shuffle = PixelShuffle(2, features_last=False) @@ -51,7 +51,7 @@ def test_pixelshuffle_features_last(): assert (y_last == y_normal).all() -def test_unpixelshuffledownsample_residual(): +def test_unpixelshuffledownsample_residual() -> None: """Test PixelUnshuffleDownsample with residual.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((1, 2, 9, 12, 15)) @@ -60,7 +60,7 @@ def test_unpixelshuffledownsample_residual(): assert y.shape == (1, 27, 3, 4, 5) -def test_pixelshuffleupsample_residual(): +def test_pixelshuffleupsample_residual() -> None: """Test PixelShuffleUpsample with residual.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((1, 2, 3, 4, 5)) @@ -69,7 +69,7 @@ def test_pixelshuffleupsample_residual(): assert y.shape == (1, 1, 9, 12, 15) -def test_pixelshuffleupsample_pixelunshuffledownsample(): +def test_pixelshuffleupsample_pixelunshuffledownsample() -> None: """Test if PixelUnshuffleDownsample is the inverse of PixelShuffleUpsample.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((1, 3**3, 3, 4, 5)) diff --git a/tests/nn/test_resblock.py b/tests/nn/test_resblock.py index 195217638..987cc5f7a 100644 --- a/tests/nn/test_resblock.py +++ b/tests/nn/test_resblock.py @@ -1,5 +1,7 @@ """Tests for ResBlock module.""" +from collections.abc import Sequence + import pytest from mrpro.nn import ResBlock from mrpro.utils import RandomGenerator @@ -19,7 +21,15 @@ (3, 64, 32, 0, (2, 64, 16, 16, 16), None), ], ) -def test_resblock(dim, channels_in, channels_out, cond_dim, input_shape, cond_shape, device): +def test_resblock( + dim: int, + channels_in: int, + channels_out: int, + cond_dim: int, + input_shape: Sequence[int], + cond_shape: Sequence[int] | None, + device: str, +) -> None: """Test ResBlock output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) diff --git a/tests/nn/test_rmsnorm.py b/tests/nn/test_rmsnorm.py index aab133da0..c8ddc0b69 100644 --- a/tests/nn/test_rmsnorm.py +++ b/tests/nn/test_rmsnorm.py @@ -44,8 +44,8 @@ def test_rmsnorm_basic(n_channels: int | None, features_last: bool, input_shape: assert norm.bias.grad is not None, 'No gradient computed for bias' -def test_rmsnorm_features_last(): - """Test RMSNorm with features_last=True.""" +def test_rmsnorm_features_last() -> None: + """Test RMSNorm with features_last=True vs features_last=False.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) diff --git a/tests/nn/test_rope.py b/tests/nn/test_rope.py index b19dc6c01..665c4bed4 100644 --- a/tests/nn/test_rope.py +++ b/tests/nn/test_rope.py @@ -1,3 +1,5 @@ +"""Tests for AxialRoPE module.""" + import pytest import torch from mrpro.nn import AxialRoPE @@ -11,7 +13,8 @@ pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), ], ) -def test_rope(device: torch.device): +def test_rope(device: torch.device) -> None: + """Test AxialRoPE rotation and embedding functionality.""" shape = (10, 10) n_heads = 2 n_channels = 64 diff --git a/tests/nn/test_sequential.py b/tests/nn/test_sequential.py index 9d382a6a0..bdf81bf8d 100644 --- a/tests/nn/test_sequential.py +++ b/tests/nn/test_sequential.py @@ -1,5 +1,7 @@ """Tests for Sequential module.""" +from collections.abc import Sequence + import pytest from mrpro.nn import FiLM, Sequential from mrpro.operators import FastFourierOp, MagnitudeOp @@ -21,7 +23,11 @@ ((2, 32), None), ], ) -def test_sequential(input_shape, cond_dim, device): +def test_sequential( + input_shape: Sequence[int], + cond_dim: Sequence[int] | None, + device: str, +) -> None: """Test Sequential output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) diff --git a/tests/nn/test_shiftedwindowattention.py b/tests/nn/test_shiftedwindowattention.py index c863a680a..f36a36691 100644 --- a/tests/nn/test_shiftedwindowattention.py +++ b/tests/nn/test_shiftedwindowattention.py @@ -1,3 +1,5 @@ +"""Tests for ShiftedWindowAttention module.""" + import pytest from mrpro.nn.attention import ShiftedWindowAttention from mrpro.utils import RandomGenerator @@ -18,6 +20,7 @@ ], ) def test_shifted_window_attention(dim: int, window_size: int, shifted: bool, device: str) -> None: + """Test ShiftedWindowAttention output shape and backpropagation.""" n_batch, n_channels, n_heads = 2, 8, 2 spatial_shape = (window_size * 4,) * dim rng = RandomGenerator(13) diff --git a/tests/nn/test_transposedattention.py b/tests/nn/test_transposedattention.py index b2c27e8cf..afbe53494 100644 --- a/tests/nn/test_transposedattention.py +++ b/tests/nn/test_transposedattention.py @@ -1,5 +1,7 @@ """Tests for TransposedAttention module.""" +from collections.abc import Sequence + import pytest from mrpro.nn.attention import TransposedAttention from mrpro.utils import RandomGenerator @@ -19,7 +21,13 @@ (3, 64, 8, (2, 64, 16, 16, 16)), ], ) -def test_transposed_attention(dim, channels, num_heads, input_shape, device): +def test_transposed_attention( + dim: int, + channels: int, + num_heads: int, + input_shape: Sequence[int], + device: str, +) -> None: """Test TransposedAttention output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) From 74b240cd8ed90bfc9ca93f2a0e44799fe20b4ca0 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 23 Jul 2025 00:21:31 +0200 Subject: [PATCH 147/205] formatting --- tests/nn/data_consistency/test_conjugategradientdc.py | 4 +++- tests/nn/nets/test_restormer.py | 2 +- tests/nn/nets/test_unet.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/nn/data_consistency/test_conjugategradientdc.py b/tests/nn/data_consistency/test_conjugategradientdc.py index 39dc11c01..99de6f4f5 100644 --- a/tests/nn/data_consistency/test_conjugategradientdc.py +++ b/tests/nn/data_consistency/test_conjugategradientdc.py @@ -5,7 +5,9 @@ from mrpro.nn.data_consistency.ConjugateGradientDC import ConjugateGradientDC -def test_conjugate_gradient_dc(image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor): +def test_conjugate_gradient_dc( + image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor +) -> None: image_noisy = image_noisy.clone().requires_grad_(True) dc = ConjugateGradientDC(initial_regularization_weight=1.0) result = dc(image_noisy, kdata_us) diff --git a/tests/nn/nets/test_restormer.py b/tests/nn/nets/test_restormer.py index 34ec14014..370612444 100644 --- a/tests/nn/nets/test_restormer.py +++ b/tests/nn/nets/test_restormer.py @@ -38,7 +38,7 @@ def test_restormer_forward(torch_compile: bool, device: str) -> None: assert y.shape == (1, 1, 16, 16) -def test_restormer_backward(): +def test_restormer_backward() -> None: restormer = Restormer( n_dim=1, n_channels_in=1, diff --git a/tests/nn/nets/test_unet.py b/tests/nn/nets/test_unet.py index 9400c4363..f411e92e3 100644 --- a/tests/nn/nets/test_unet.py +++ b/tests/nn/nets/test_unet.py @@ -39,7 +39,7 @@ def test_unet_forward(torch_compile: bool, device: str) -> None: assert y.shape == (1, 1, 16, 16) -def test_unet_backward(): +def test_unet_backward() -> None: unet = UNet( n_dim=1, n_channels_in=1, From 3ee4c7657304541a20be601ecf9b04f8cd7a8f64 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 23 Jul 2025 00:25:05 +0200 Subject: [PATCH 148/205] formatting --- tests/nn/data_consistency/test_analyticcertesiandc.py | 4 +++- tests/nn/data_consistency/test_gradientdescentdc.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/nn/data_consistency/test_analyticcertesiandc.py b/tests/nn/data_consistency/test_analyticcertesiandc.py index 2bd601461..9020a8326 100644 --- a/tests/nn/data_consistency/test_analyticcertesiandc.py +++ b/tests/nn/data_consistency/test_analyticcertesiandc.py @@ -5,7 +5,9 @@ from mrpro.nn.data_consistency.AnalyticCartesianDC import AnalyticCartesianDC -def test_analytic_cartesian_dc(image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor): +def test_analytic_cartesian_dc( + image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor +) -> None: image_noisy = image_noisy.clone().requires_grad_(True) dc = AnalyticCartesianDC(initial_regularization_weight=1e-6) result = dc(image_noisy, kdata_us) diff --git a/tests/nn/data_consistency/test_gradientdescentdc.py b/tests/nn/data_consistency/test_gradientdescentdc.py index b3ee440c8..00a7648f9 100644 --- a/tests/nn/data_consistency/test_gradientdescentdc.py +++ b/tests/nn/data_consistency/test_gradientdescentdc.py @@ -5,7 +5,9 @@ from mrpro.nn.data_consistency.GradientDescentDC import GradientDescentDC -def test_gradient_descent_dc(image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor): +def test_gradient_descent_dc( + image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor +) -> None: image_noisy = image_noisy.clone().requires_grad_(True) dc = GradientDescentDC(initial_stepsize=1.0) result = dc(image_noisy, kdata_us) From f624cb721fe87615f8c2c8a401ca7988888f368d Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 23 Jul 2025 00:29:11 +0200 Subject: [PATCH 149/205] typo --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c003bfa7f..1b7b0a6fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,7 +121,7 @@ filterwarnings = [ "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda - "ignore:Online softmax is disabled on the fly since Inductor decides to split the reduction:UserWarningc", # torch.compile + "ignore:Online softmax is disabled on the fly since Inductor decides to split the reduction:UserWarning", # torch.compile ] addopts = "-n auto --dist loadfile --maxprocesses=8" markers = ["cuda : Tests only to be run when cuda device is available"] From 4a608567210f28cc263ea8af2e98853e7578f397 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 23 Jul 2025 22:27:57 +0200 Subject: [PATCH 150/205] fix NA --- src/mrpro/nn/AxialRoPE.py | 5 +-- .../nn/attention/NeighborhoodSelfAttention.py | 36 ++++++++++++++----- src/mrpro/nn/attention/__init__.py | 16 +++++---- tests/nn/nets/test_hourglass.py | 2 +- tests/nn/test_resblock.py | 12 +++++-- 5 files changed, 50 insertions(+), 21 deletions(-) diff --git a/src/mrpro/nn/AxialRoPE.py b/src/mrpro/nn/AxialRoPE.py index 03bf55c16..87f276b03 100644 --- a/src/mrpro/nn/AxialRoPE.py +++ b/src/mrpro/nn/AxialRoPE.py @@ -7,9 +7,10 @@ from torch.nn import Module -# pragma: no cover @torch.compile -def get_theta(shape: Sequence[int], n_embedding_channels: int, device: torch.device) -> torch.Tensor: +def get_theta( + shape: Sequence[int], n_embedding_channels: int, device: torch.device +) -> torch.Tensor: # pragma: no cover """Get rotation angles. Parameters diff --git a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py index d389ecd67..e04fe24e6 100644 --- a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py @@ -1,8 +1,8 @@ """Neighborhood Self Attention.""" from collections.abc import Sequence -from functools import cache, reduce -from typing import TypeVar +from functools import reduce +from typing import Any, TypeVar import torch from einops import rearrange @@ -15,16 +15,28 @@ T = TypeVar('T') -# coverage does not pick up the use via flex_attention, as the code gets compiled. -# pragma: no cover -@cache +@torch.compiler.disable(recursive=True) +def uncompiled_flex_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: torch.nn.attention.flex_attention._score_mod_signature | None = None, + block_mask: BlockMask | None = None, + scale: float | None = None, + enable_gqa: bool = False, + kernel_options: dict[str, Any] | None = None, +) -> torch.Tensor: + """Wrap flex_attention to disable compilation.""" + return flex_attention(key, query, value, score_mod, block_mask, scale, enable_gqa, kernel_options=kernel_options) # type: ignore[return-value] # wrong type hints + + def neighborhood_mask( device: str, input_size: torch.Size, kernel_size: int | tuple[int, ...], # tuples instead of Sequence for cache dilation: int | tuple[int, ...] = 1, circular: bool | tuple[bool, ...] = False, -) -> BlockMask: +) -> BlockMask: # pragma: no cover """Create a flex attention block mask for neighborhood attention. This function defines which key/value pairs a query can attend to based @@ -192,16 +204,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: qkv, 'batch ... (qkv heads channels) -> qkv batch heads (...) channels', qkv=3, heads=self.n_head ) query, key = self.rope(query, key) # NO-OP if rope_embed_fraction is 0.0 + query, key, value = query.contiguous(), key.contiguous(), value.contiguous() # the mask depends on the input size. To be more flexible if used within CNNs, we compute it here. # The computation is cached.. + device = str(qkv.device) mask = neighborhood_mask( - device=str(qkv.device), + device=device, input_size=spatial_shape, kernel_size=self.kernel_size, dilation=self.dilation, circular=self.circular, ) - out: torch.Tensor = flex_attention(query.contiguous(), key.contiguous(), value.contiguous(), block_mask=mask) # type: ignore[assignment] # wrong type hints + + if device == 'cpu': + # flex attention cannot be compiled on CPU + # https://github.com/pytorch/pytorch/issues/148752 + out: torch.Tensor = uncompiled_flex_attention(query, key, value, block_mask=mask) + else: + out = flex_attention(query, key, value, block_mask=mask) # type: ignore[assignment] # wrong type hints out = rearrange(out, 'batch head sequence channels -> batch sequence(head channels)') out = self.to_out(out) out = out.unflatten(-2, spatial_shape) diff --git a/src/mrpro/nn/attention/__init__.py b/src/mrpro/nn/attention/__init__.py index 7b3d24115..719ff1409 100644 --- a/src/mrpro/nn/attention/__init__.py +++ b/src/mrpro/nn/attention/__init__.py @@ -4,12 +4,14 @@ from mrpro.nn.attention.ShiftedWindowAttention import ShiftedWindowAttention from mrpro.nn.attention.SqueezeExcitation import SqueezeExcitation from mrpro.nn.attention.TransposedAttention import TransposedAttention +from mrpro.nn.attention.SpatialTransformerBlock import SpatialTransformerBlock __all__ = [ - 'AttentionGate', - 'LinearSelfAttention', - 'NeighborhoodSelfAttention', - 'ShiftedWindowAttention', - 'SqueezeExcitation', - 'TransposedAttention', -] + "AttentionGate", + "LinearSelfAttention", + "NeighborhoodSelfAttention", + "ShiftedWindowAttention", + "SpatialTransformerBlock", + "SqueezeExcitation", + "TransposedAttention" +] \ No newline at end of file diff --git a/tests/nn/nets/test_hourglass.py b/tests/nn/nets/test_hourglass.py index 908ef38da..d6a22d379 100644 --- a/tests/nn/nets/test_hourglass.py +++ b/tests/nn/nets/test_hourglass.py @@ -32,7 +32,7 @@ def test_hourglass_forward(torch_compile: bool, device: str) -> None: x = x.to(device) cond = cond.to(device) if torch_compile: - hourglass = cast(HourglassTransformer, torch.compile(hourglass)) + hourglass = cast(HourglassTransformer, torch.compile(hourglass, dynamic=False)) y = hourglass(x, cond=cond) assert y.shape == (1, 1, 16, 16) diff --git a/tests/nn/test_resblock.py b/tests/nn/test_resblock.py index 987cc5f7a..b5b20c555 100644 --- a/tests/nn/test_resblock.py +++ b/tests/nn/test_resblock.py @@ -1,12 +1,15 @@ """Tests for ResBlock module.""" from collections.abc import Sequence +from typing import cast import pytest +import torch from mrpro.nn import ResBlock from mrpro.utils import RandomGenerator +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) @pytest.mark.parametrize( 'device', [ @@ -29,13 +32,16 @@ def test_resblock( input_shape: Sequence[int], cond_shape: Sequence[int] | None, device: str, + torch_compile: bool, ) -> None: """Test ResBlock output shape and backpropagation.""" rng = RandomGenerator(seed=42) x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) if cond_shape else None - res = ResBlock(n_dim=dim, n_channels_in=channels_in, n_channels_out=channels_out, cond_dim=cond_dim).to(device) - output = res(x, cond=cond) + block = ResBlock(n_dim=dim, n_channels_in=channels_in, n_channels_out=channels_out, cond_dim=cond_dim).to(device) + if torch_compile: + block = cast(ResBlock, torch.compile(block)) + output = block(x, cond=cond) assert output.shape == (input_shape[0], channels_out, *input_shape[2:]), ( f'Output shape {output.shape} != expected {(input_shape[0], channels_out, *input_shape[2:])}' ) @@ -43,7 +49,7 @@ def test_resblock( assert x.grad is not None, 'No gradient computed for input' assert not output.isnan().any(), 'NaN values in output' assert not x.grad.isnan().any(), 'NaN values in input gradients' - assert res.block[2].weight.grad is not None, 'No gradient computed for first Conv' + assert block.block[2].weight.grad is not None, 'No gradient computed for first Conv' if cond is not None: assert cond.grad is not None, 'No gradient computed for conditioning' assert not cond.isnan().any(), 'NaN values in conditioning' From 09606f5bc33663e226b26540d6eb28d6c0dd4e83 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 24 Jul 2025 15:58:23 +0200 Subject: [PATCH 151/205] python 2.3 --- docker/minimal_requirements.txt | 2 +- pyproject.toml | 4 ++-- src/mrpro/nn/attention/NeighborhoodSelfAttention.py | 6 +++++- tests/nn/nets/test_hourglass.py | 3 +++ tests/nn/test_neighborhoodselfattention.py | 3 +++ 5 files changed, 14 insertions(+), 4 deletions(-) diff --git a/docker/minimal_requirements.txt b/docker/minimal_requirements.txt index 2d5b7e24a..2bdffca8c 100644 --- a/docker/minimal_requirements.txt +++ b/docker/minimal_requirements.txt @@ -1,4 +1,4 @@ -torch==2.5.1+cpu +torch==2.3.1+cpu torchvision==0.20.1+cpu numpy==1.23 ismrmrd==1.14.1 diff --git a/pyproject.toml b/pyproject.toml index 1b7b0a6fa..35202f241 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ classifiers = [ ] dependencies = [ "numpy>=1.23, <3.0", - "torch>=2.5.1", + "torch>=2.3.1", "ismrmrd>=1.14.1", "einops>=0.7.0", "pydicom>=3.0.1", @@ -121,7 +121,7 @@ filterwarnings = [ "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda - "ignore:Online softmax is disabled on the fly since Inductor decides to split the reduction:UserWarning", # torch.compile + "ignore:softmax is disabled:UserWarning", # torch.compile ] addopts = "-n auto --dist loadfile --maxprocesses=8" markers = ["cuda : Tests only to be run when cuda device is available"] diff --git a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py index e04fe24e6..0f179f281 100644 --- a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py @@ -1,11 +1,12 @@ """Neighborhood Self Attention.""" from collections.abc import Sequence -from functools import reduce +from functools import cache, reduce from typing import Any, TypeVar import torch from einops import rearrange +from packaging.version import parse as parse_version from torch.nn import Linear, Module from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention @@ -30,6 +31,7 @@ def uncompiled_flex_attention( return flex_attention(key, query, value, score_mod, block_mask, scale, enable_gqa, kernel_options=kernel_options) # type: ignore[return-value] # wrong type hints +@cache def neighborhood_mask( device: str, input_size: torch.Size, @@ -172,6 +174,8 @@ def __init__( Fraction of channels to embed with RoPE. """ + if parse_version(torch.__version__) < parse_version('2.6.0'): + raise NotImplementedError('NeighborhoodSelfAttention requires PyTorch 2.6.0 or higher') super().__init__() self.n_head = n_heads self.kernel_size = kernel_size if isinstance(kernel_size, int) else tuple(kernel_size) diff --git a/tests/nn/nets/test_hourglass.py b/tests/nn/nets/test_hourglass.py index d6a22d379..2fa1cc75d 100644 --- a/tests/nn/nets/test_hourglass.py +++ b/tests/nn/nets/test_hourglass.py @@ -5,8 +5,10 @@ import pytest import torch from mrpro.nn.nets import HourglassTransformer +from tests.nn.conftest import minimal_torch_26 +@minimal_torch_26 @pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) @pytest.mark.parametrize( 'device', @@ -37,6 +39,7 @@ def test_hourglass_forward(torch_compile: bool, device: str) -> None: assert y.shape == (1, 1, 16, 16) +@minimal_torch_26 def test_hourglass_backward() -> None: hourglass = HourglassTransformer( n_dim=1, diff --git a/tests/nn/test_neighborhoodselfattention.py b/tests/nn/test_neighborhoodselfattention.py index 43f866683..100999e82 100644 --- a/tests/nn/test_neighborhoodselfattention.py +++ b/tests/nn/test_neighborhoodselfattention.py @@ -4,8 +4,10 @@ import torch from mrpro.nn.attention.NeighborhoodSelfAttention import NeighborhoodSelfAttention from mrpro.utils import RandomGenerator +from tests.nn.conftest import minimal_torch_26 +@minimal_torch_26 @pytest.mark.parametrize( 'device', [ @@ -61,6 +63,7 @@ def test_neighborhood_self_attention( assert attn.to_out.bias.grad is not None, 'No gradient computed for to_out.bias' +@minimal_torch_26 @pytest.mark.parametrize( ('kernel_size', 'dilation', 'circular', 'rope'), [ From 3f7f36abec70e7aa01c58ca6a9efec652c41a4d3 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Thu, 24 Jul 2025 20:03:45 +0200 Subject: [PATCH 152/205] torch filter --- tests/conftest.py | 6 ++ tests/nn/nets/test_hourglass.py | 2 +- tests/nn/test_spatialtransformerblock.py | 77 ++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 tests/nn/test_spatialtransformerblock.py diff --git a/tests/conftest.py b/tests/conftest.py index 8490674e9..b2aa1cba2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,11 +12,17 @@ from mrpro.data.enums import AcqFlags from mrpro.utils import RandomGenerator from mrpro.utils.reshape import unsqueeze_tensors_left +from packaging.version import parse as parse_version from xsdata.models.datatype import XmlDate, XmlTime from tests.data import IsmrmrdRawTestData from tests.phantoms import EllipsePhantomTestData +minimal_torch_26 = pytest.mark.xfail( + parse_version(torch.__version__) < parse_version('2.6'), + reason='Requires PyTorch >= 2.6', +) + def generate_random_encodingcounter_properties(rng: RandomGenerator) -> dict[str, Any]: return { diff --git a/tests/nn/nets/test_hourglass.py b/tests/nn/nets/test_hourglass.py index 2fa1cc75d..c7f4625d6 100644 --- a/tests/nn/nets/test_hourglass.py +++ b/tests/nn/nets/test_hourglass.py @@ -5,7 +5,7 @@ import pytest import torch from mrpro.nn.nets import HourglassTransformer -from tests.nn.conftest import minimal_torch_26 +from tests.conftest import minimal_torch_26 @minimal_torch_26 diff --git a/tests/nn/test_spatialtransformerblock.py b/tests/nn/test_spatialtransformerblock.py new file mode 100644 index 000000000..373b32db0 --- /dev/null +++ b/tests/nn/test_spatialtransformerblock.py @@ -0,0 +1,77 @@ +"""Test SpatialTransformerBlock""" + +from collections.abc import Sequence +from typing import Literal, cast + +import pytest +import torch +from mrpro.nn.attention import SpatialTransformerBlock +from mrpro.utils import RandomGenerator +from tests.conftest import minimal_torch_26 + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('channels', 'cond_dim', 'attention_neighborhood', 'features_last', 'norm', 'input_shape'), + [ + pytest.param(32, 16, 7, False, 'group', (16, 16), id='2d-cond-group-first-NA'), + pytest.param(32, 16, None, True, 'group', (16, 16), marks=minimal_torch_26, id='2d-cond-group-last-global'), + pytest.param(32, 16, 5, True, 'group', (16, 16), id='2d-cond-group-last-NA'), + pytest.param(64, 0, 7, True, 'rms', (16, 8, 16), marks=minimal_torch_26, id='3d-nocond-rms-last-NA'), + ], +) +def test_spatialtransformerblock( + channels: int, + cond_dim: int, + attention_neighborhood: int | None, + features_last: bool, + norm: Literal['group', 'rms'], + input_shape: Sequence[int], + device: str, + torch_compile: bool, +) -> None: + """Test SpatialTransformerBlock output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + + x = rng.float32_tensor((1, channels, *input_shape)).to(device).requires_grad_(True) + cond = rng.float32_tensor((1, cond_dim)).to(device).requires_grad_(True) if cond_dim else None + + if features_last: + dims = tuple(range(-len(input_shape) - 1, -1)) + else: + dims = tuple(range(-len(input_shape), 0)) + + block = SpatialTransformerBlock( + dim_groups=[dims], + channels=channels, + n_heads=4, + depth=1, + p_dropout=0, + cond_dim=cond_dim, + rope_embed_fraction=0.5, + attention_neighborhood=attention_neighborhood, + features_last=features_last, + norm=norm, + ).to(device) + if torch_compile: + block = cast(SpatialTransformerBlock, torch.compile(block, dynamic=False)) + if features_last: + output = block(x.moveaxis(1, -1), cond=cond).moveaxis(-1, 1) + else: + output = block(x, cond=cond) + assert output.shape == x.shape + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + if cond is not None: + assert cond.grad is not None, 'No gradient computed for conditioning' + assert not cond.isnan().any(), 'NaN values in conditioning' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' From 6bc0bf6a1c81b7fd141bcb32bd9a82ba90398531 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 25 Jul 2025 00:39:42 +0200 Subject: [PATCH 153/205] fix --- tests/nn/test_neighborhoodselfattention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/nn/test_neighborhoodselfattention.py b/tests/nn/test_neighborhoodselfattention.py index 100999e82..677c22b28 100644 --- a/tests/nn/test_neighborhoodselfattention.py +++ b/tests/nn/test_neighborhoodselfattention.py @@ -4,7 +4,7 @@ import torch from mrpro.nn.attention.NeighborhoodSelfAttention import NeighborhoodSelfAttention from mrpro.utils import RandomGenerator -from tests.nn.conftest import minimal_torch_26 +from tests.conftest import minimal_torch_26 @minimal_torch_26 From c466ce3c5d022f36e0f64bb59c16daa3e0eecf76 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 25 Jul 2025 00:47:51 +0200 Subject: [PATCH 154/205] fix tocvhvesion version --- docker/minimal_requirements.txt | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/minimal_requirements.txt b/docker/minimal_requirements.txt index 2bdffca8c..6723b723e 100644 --- a/docker/minimal_requirements.txt +++ b/docker/minimal_requirements.txt @@ -1,5 +1,5 @@ torch==2.3.1+cpu -torchvision==0.20.1+cpu +torchvision==0.18.1+cpu numpy==1.23 ismrmrd==1.14.1 einops==0.7.0 diff --git a/pyproject.toml b/pyproject.toml index 6fb9af3d5..b838802b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,7 @@ dependencies = [ "cufinufft>=2.4.1; platform_system=='Linux'", "scipy>=1.12", "ptwt>=0.1.8, <1.0", - "torchvision>=0.20.1", + "torchvision>=0.18.1", "tqdm>=4.60.0", "typing-extensions>=4.12", "platformdirs>=4.0", From 2480a086178c8e4d0e349264c055c8f1602823a9 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 25 Jul 2025 00:54:21 +0200 Subject: [PATCH 155/205] version filter --- src/mrpro/nn/attention/NeighborhoodSelfAttention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py index 0f179f281..dfe4f0e6d 100644 --- a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py @@ -15,6 +15,9 @@ T = TypeVar('T') +if parse_version(torch.__version__) > parse_version('2.6'): + from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention + @torch.compiler.disable(recursive=True) def uncompiled_flex_attention( From f6ca67011f5f3f7a8424442cd7994967ddddc179 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 25 Jul 2025 00:59:27 +0200 Subject: [PATCH 156/205] fix --- src/mrpro/nn/attention/NeighborhoodSelfAttention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py index dfe4f0e6d..8e8affa23 100644 --- a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py @@ -8,7 +8,6 @@ from einops import rearrange from packaging.version import parse as parse_version from torch.nn import Linear, Module -from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention from mrpro.nn.AxialRoPE import AxialRoPE from mrpro.utils.to_tuple import to_tuple From 9718c3cf8fe26a23d2fcf6a33e1d613bcaa6910d Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 25 Jul 2025 01:04:33 +0200 Subject: [PATCH 157/205] fix --- src/mrpro/nn/attention/NeighborhoodSelfAttention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py index 8e8affa23..d6a631d68 100644 --- a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py @@ -23,7 +23,7 @@ def uncompiled_flex_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - score_mod: torch.nn.attention.flex_attention._score_mod_signature | None = None, + score_mod: Any = None, # noqa: ANN401 block_mask: BlockMask | None = None, scale: float | None = None, enable_gqa: bool = False, From c130c9e4b136e027976f48661aa772bc0271e883 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 25 Jul 2025 01:07:47 +0200 Subject: [PATCH 158/205] fix --- src/mrpro/nn/attention/NeighborhoodSelfAttention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py index d6a631d68..5c98d72e5 100644 --- a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py @@ -24,7 +24,7 @@ def uncompiled_flex_attention( key: torch.Tensor, value: torch.Tensor, score_mod: Any = None, # noqa: ANN401 - block_mask: BlockMask | None = None, + block_mask: Any = None, # noqa: ANN401 scale: float | None = None, enable_gqa: bool = False, kernel_options: dict[str, Any] | None = None, From f88b546a8537124b91db6adf13a0934122f5d09e Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 25 Jul 2025 01:14:04 +0200 Subject: [PATCH 159/205] fix --- src/mrpro/nn/attention/NeighborhoodSelfAttention.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py index 5c98d72e5..01cb664aa 100644 --- a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from functools import cache, reduce -from typing import Any, TypeVar +from typing import Any, TypeAlias, TypeVar import torch from einops import rearrange @@ -16,6 +16,8 @@ if parse_version(torch.__version__) > parse_version('2.6'): from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention +else: + BlockMask: TypeAlias = Any @torch.compiler.disable(recursive=True) @@ -24,7 +26,7 @@ def uncompiled_flex_attention( key: torch.Tensor, value: torch.Tensor, score_mod: Any = None, # noqa: ANN401 - block_mask: Any = None, # noqa: ANN401 + block_mask: BlockMask | None = None, scale: float | None = None, enable_gqa: bool = False, kernel_options: dict[str, Any] | None = None, From 920ed5af0233067cdbf1021e478f4a7e17c81f39 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 25 Jul 2025 01:18:03 +0200 Subject: [PATCH 160/205] fix --- src/mrpro/nn/attention/NeighborhoodSelfAttention.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py index 01cb664aa..93f936c41 100644 --- a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from functools import cache, reduce -from typing import Any, TypeAlias, TypeVar +from typing import Any, TypeVar import torch from einops import rearrange @@ -17,7 +17,9 @@ if parse_version(torch.__version__) > parse_version('2.6'): from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention else: - BlockMask: TypeAlias = Any + + class BlockMask: + """Dummy class for older PyTorch versions.""" @torch.compiler.disable(recursive=True) From ae221966948becff400e14f0c11db9eb791203ed Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 28 Jul 2025 10:13:38 +0200 Subject: [PATCH 161/205] fix? --- pyproject.toml | 1 + src/mrpro/nn/attention/NeighborhoodSelfAttention.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7bd5b19b8..edc434f45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,7 @@ filterwarnings = [ "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda "ignore:softmax is disabled:UserWarning", # torch.compile "ignore:.*In the future, this object will be coerced as if it was first converted using.*:FutureWarning", # numpy 1.2 + "ignore:the load_module() method is deprecated:DeprecationWarning", # torch dynamo ] addopts = "-n auto --dist loadfile --maxprocesses=8" markers = ["cuda : Tests only to be run when cuda device is available"] diff --git a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py index 93f936c41..13de355ff 100644 --- a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py +++ b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from functools import cache, reduce -from typing import Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar import torch from einops import rearrange @@ -14,7 +14,7 @@ T = TypeVar('T') -if parse_version(torch.__version__) > parse_version('2.6'): +if TYPE_CHECKING or parse_version(torch.__version__) > parse_version('2.6'): from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention else: From d927ffa842ac7f8df7ab787857e785719be74307 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 28 Jul 2025 10:37:54 +0200 Subject: [PATCH 162/205] fix? --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index edc434f45..ab199c432 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,7 +123,7 @@ filterwarnings = [ "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda "ignore:softmax is disabled:UserWarning", # torch.compile "ignore:.*In the future, this object will be coerced as if it was first converted using.*:FutureWarning", # numpy 1.2 - "ignore:the load_module() method is deprecated:DeprecationWarning", # torch dynamo + "ignore:load_module:DeprecationWarning", # torch dynamo ] addopts = "-n auto --dist loadfile --maxprocesses=8" markers = ["cuda : Tests only to be run when cuda device is available"] From 8365752d9824bcd49832b4d5942589ec838330aa Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 28 Jul 2025 17:27:19 +0200 Subject: [PATCH 163/205] fix --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ab199c432..673bbec57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,7 +123,7 @@ filterwarnings = [ "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda "ignore:softmax is disabled:UserWarning", # torch.compile "ignore:.*In the future, this object will be coerced as if it was first converted using.*:FutureWarning", # numpy 1.2 - "ignore:load_module:DeprecationWarning", # torch dynamo + "ignore:.*load_module.*:DeprecationWarning", # torch compile in torch<2.6 ] addopts = "-n auto --dist loadfile --maxprocesses=8" markers = ["cuda : Tests only to be run when cuda device is available"] From 3ca441dc38c7c198ce73ed2c1aea06128c7ae2e9 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 28 Jul 2025 17:41:53 +0200 Subject: [PATCH 164/205] fix?? --- tests/nn/test_neighborhoodselfattention.py | 1 + tests/nn/test_resblock.py | 9 ++++++++- tests/nn/test_spatialtransformerblock.py | 1 + 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/nn/test_neighborhoodselfattention.py b/tests/nn/test_neighborhoodselfattention.py index 677c22b28..583e2efc3 100644 --- a/tests/nn/test_neighborhoodselfattention.py +++ b/tests/nn/test_neighborhoodselfattention.py @@ -95,6 +95,7 @@ def test_neighborhood_attention_variants(kernel_size: int, dilation: int, circul assert not output.isnan().any(), 'NaN values in output' +@minimal_torch_26 @pytest.mark.parametrize( ('kernel_size', 'circular', 'input_shape'), [ diff --git a/tests/nn/test_resblock.py b/tests/nn/test_resblock.py index b5b20c555..060712104 100644 --- a/tests/nn/test_resblock.py +++ b/tests/nn/test_resblock.py @@ -7,9 +7,16 @@ import torch from mrpro.nn import ResBlock from mrpro.utils import RandomGenerator +from tests.conftest import minimal_torch_26 -@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'torch_compile', + [ + pytest.param(True, id='compiled', marks=minimal_torch_26), + pytest.param(False, id='eager'), + ], +) @pytest.mark.parametrize( 'device', [ diff --git a/tests/nn/test_spatialtransformerblock.py b/tests/nn/test_spatialtransformerblock.py index 373b32db0..161def976 100644 --- a/tests/nn/test_spatialtransformerblock.py +++ b/tests/nn/test_spatialtransformerblock.py @@ -10,6 +10,7 @@ from tests.conftest import minimal_torch_26 +@minimal_torch_26 @pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) @pytest.mark.parametrize( 'device', From 263fec7315dccc2754948860dc34b46e3a2c08e8 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 28 Jul 2025 17:44:09 +0200 Subject: [PATCH 165/205] fix --- tests/nn/test_resblock.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/nn/test_resblock.py b/tests/nn/test_resblock.py index 060712104..ea4356173 100644 --- a/tests/nn/test_resblock.py +++ b/tests/nn/test_resblock.py @@ -7,16 +7,9 @@ import torch from mrpro.nn import ResBlock from mrpro.utils import RandomGenerator -from tests.conftest import minimal_torch_26 -@pytest.mark.parametrize( - 'torch_compile', - [ - pytest.param(True, id='compiled', marks=minimal_torch_26), - pytest.param(False, id='eager'), - ], -) +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'eager']) @pytest.mark.parametrize( 'device', [ @@ -47,7 +40,7 @@ def test_resblock( cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) if cond_shape else None block = ResBlock(n_dim=dim, n_channels_in=channels_in, n_channels_out=channels_out, cond_dim=cond_dim).to(device) if torch_compile: - block = cast(ResBlock, torch.compile(block)) + block = cast(ResBlock, torch.compile(block, dynamic=False)) output = block(x, cond=cond) assert output.shape == (input_shape[0], channels_out, *input_shape[2:]), ( f'Output shape {output.shape} != expected {(input_shape[0], channels_out, *input_shape[2:])}' From 5cb8f724b2f5518be81e2585bdbfa91e6e7cd4b4 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 28 Jul 2025 23:10:14 +0200 Subject: [PATCH 166/205] Add SeparableResBlock implementation and corresponding tests - Introduced SeparableResBlock class with updated parameter names for clarity. - Updated __init__.py to include SeparableResBlock in the module exports. - Added unit tests for SeparableResBlock to validate output shape and backpropagation. - Created test for AnalyticCartesianDC to ensure data consistency functionality. --- src/mrpro/nn/SeparableResBlock.py | 20 +++---- src/mrpro/nn/__init__.py | 2 + ...esiandc.py => test_analyticcartesiandc.py} | 0 tests/nn/test_separableresblock.py | 58 +++++++++++++++++++ 4 files changed, 70 insertions(+), 10 deletions(-) rename tests/nn/data_consistency/{test_analyticcertesiandc.py => test_analyticcartesiandc.py} (100%) create mode 100644 tests/nn/test_separableresblock.py diff --git a/src/mrpro/nn/SeparableResBlock.py b/src/mrpro/nn/SeparableResBlock.py index 770293884..496004cdd 100644 --- a/src/mrpro/nn/SeparableResBlock.py +++ b/src/mrpro/nn/SeparableResBlock.py @@ -18,8 +18,8 @@ class SeparableResBlock(Module): def __init__( self, dim_groups: Sequence[Sequence[int]], - channels_in: int, - channels_out: int, + n_channels_in: int, + n_channels_out: int, cond_dim: int, ) -> None: """Initialize the SeparableResBlock. @@ -35,9 +35,9 @@ def __init__( ---------- dim_groups Sequence of dimension groups to use in the convolutions. - channels_in + n_channels_in Number of input channels. - channels_out + n_channels_out Number of output channels. cond_dim Number of channels in the conditioning tensor. If 0, no conditioning is applied. @@ -49,17 +49,17 @@ def block(dims: Sequence[int], channels_in: int) -> Module: return Sequential( GroupNorm(channels_in), SiLU(), - PermutedBlock(dims, ConvND(len(dims))(channels_in, channels_out, 3, padding=1)), + PermutedBlock(dims, ConvND(len(dims))(channels_in, n_channels_out, 3, padding=1)), ) - blocks = Sequential(*(block(d, channels_in if i == 0 else channels_out) for i, d in enumerate(dim_groups))) + blocks = Sequential(*(block(d, n_channels_in if i == 0 else n_channels_out) for i, d in enumerate(dim_groups))) if cond_dim > 0: - blocks.append(FiLM(channels_out, cond_dim)) - blocks.extend(block(d, channels_out) for d in dim_groups) + blocks.append(FiLM(n_channels_out, cond_dim)) + blocks.extend(block(d, n_channels_out) for d in dim_groups) self.block = blocks self.skip_connection = None - if channels_in != channels_out: - self.skip_connection = torch.nn.Linear(channels_in, channels_out) + if n_channels_in != n_channels_out: + self.skip_connection = torch.nn.Linear(n_channels_in, n_channels_out) def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the SeparableResBlock. diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index ee5032f51..1466897a0 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -26,6 +26,7 @@ from mrpro.nn.AxialRoPE import AxialRoPE from mrpro.nn.AbsolutePositionEncoding import AbsolutePositionEncoding from mrpro.nn.FourierFeatures import FourierFeatures +from mrpro.nn.SeparableResBlock import SeparableResBlock __all__ = [ "AbsolutePositionEncoding", @@ -47,6 +48,7 @@ "RMSNorm", "ResBlock", "Residual", + "SeparableResBlock", "Sequential", "attention", "data_consistency", diff --git a/tests/nn/data_consistency/test_analyticcertesiandc.py b/tests/nn/data_consistency/test_analyticcartesiandc.py similarity index 100% rename from tests/nn/data_consistency/test_analyticcertesiandc.py rename to tests/nn/data_consistency/test_analyticcartesiandc.py diff --git a/tests/nn/test_separableresblock.py b/tests/nn/test_separableresblock.py new file mode 100644 index 000000000..f9b21ec8d --- /dev/null +++ b/tests/nn/test_separableresblock.py @@ -0,0 +1,58 @@ +"""Tests for SeparableResBlock module.""" + +from collections.abc import Sequence +from typing import cast + +import pytest +import torch +from mrpro.nn import SeparableResBlock +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'eager']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim_groups', 'channels_in', 'channels_out', 'cond_dim', 'input_shape', 'cond_shape'), + [ + (((-1, -2),), 32, 32, 16, (1, 32, 32, 32), (1, 16)), + (((-1, -2), (-3,)), 64, 32, 0, (2, 64, 16, 16, 16), None), # 2D + 1D + ], +) +def test_separable_resblock( + dim_groups: Sequence[Sequence[int]], + channels_in: int, + channels_out: int, + cond_dim: int, + input_shape: Sequence[int], + cond_shape: Sequence[int] | None, + device: str, + torch_compile: bool, +) -> None: + """Test SeparableResBlock output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) if cond_shape else None + block = SeparableResBlock( + dim_groups=dim_groups, n_channels_in=channels_in, n_channels_out=channels_out, cond_dim=cond_dim + ).to(device) + if torch_compile: + block = cast(SeparableResBlock, torch.compile(block, dynamic=False)) + output = block(x, cond=cond) + assert output.shape == (input_shape[0], channels_out, *input_shape[2:]), ( + f'Output shape {output.shape} != expected {(input_shape[0], channels_out, *input_shape[2:])}' + ) + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert block.block[0][2].module.weight.grad is not None, 'No gradient computed for first Conv' + if cond is not None: + assert cond.grad is not None, 'No gradient computed for conditioning' + assert not cond.isnan().any(), 'NaN values in conditioning' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' From 180048b991c7eb49fab3ab89a63dc11dea302ecc Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 28 Jul 2025 23:16:01 +0200 Subject: [PATCH 167/205] fix cuda --- pyproject.toml | 2 +- tests/nn/test_spatialtransformerblock.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 673bbec57..05e7f9c1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,7 +121,7 @@ filterwarnings = [ "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda - "ignore:softmax is disabled:UserWarning", # torch.compile + "ignore:.*softmax is disabled.*:UserWarning", # torch.compile "ignore:.*In the future, this object will be coerced as if it was first converted using.*:FutureWarning", # numpy 1.2 "ignore:.*load_module.*:DeprecationWarning", # torch compile in torch<2.6 ] diff --git a/tests/nn/test_spatialtransformerblock.py b/tests/nn/test_spatialtransformerblock.py index 161def976..2e06b7046 100644 --- a/tests/nn/test_spatialtransformerblock.py +++ b/tests/nn/test_spatialtransformerblock.py @@ -22,9 +22,9 @@ @pytest.mark.parametrize( ('channels', 'cond_dim', 'attention_neighborhood', 'features_last', 'norm', 'input_shape'), [ - pytest.param(32, 16, 7, False, 'group', (16, 16), id='2d-cond-group-first-NA'), + pytest.param(64, 16, 7, False, 'group', (16, 16), id='2d-cond-group-first-NA'), pytest.param(32, 16, None, True, 'group', (16, 16), marks=minimal_torch_26, id='2d-cond-group-last-global'), - pytest.param(32, 16, 5, True, 'group', (16, 16), id='2d-cond-group-last-NA'), + pytest.param(64, 16, 5, True, 'group', (16, 16), id='2d-cond-group-last-NA'), pytest.param(64, 0, 7, True, 'rms', (16, 8, 16), marks=minimal_torch_26, id='3d-nocond-rms-last-NA'), ], ) From 5444f4f8274ee08f993d8388dd9c24fd7b53ab65 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 28 Jul 2025 23:23:09 +0200 Subject: [PATCH 168/205] mypy --- tests/nn/test_separableresblock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/nn/test_separableresblock.py b/tests/nn/test_separableresblock.py index f9b21ec8d..25b6c65e0 100644 --- a/tests/nn/test_separableresblock.py +++ b/tests/nn/test_separableresblock.py @@ -51,7 +51,7 @@ def test_separable_resblock( assert x.grad is not None, 'No gradient computed for input' assert not output.isnan().any(), 'NaN values in output' assert not x.grad.isnan().any(), 'NaN values in input gradients' - assert block.block[0][2].module.weight.grad is not None, 'No gradient computed for first Conv' + assert block.block[0][2].module.weight.grad is not None, 'No gradient computed for first Conv' # type: ignore[union-attr] if cond is not None: assert cond.grad is not None, 'No gradient computed for conditioning' assert not cond.isnan().any(), 'NaN values in conditioning' From a0d0fe80dad9b8046a0abe44755c322640edb161 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 28 Jul 2025 23:30:38 +0200 Subject: [PATCH 169/205] fix? --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 05e7f9c1e..9f52d5654 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,7 +121,7 @@ filterwarnings = [ "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda - "ignore:.*softmax is disabled.*:UserWarning", # torch.compile + "ignore:.*softmax.*:UserWarning", # torch.compile "ignore:.*In the future, this object will be coerced as if it was first converted using.*:FutureWarning", # numpy 1.2 "ignore:.*load_module.*:DeprecationWarning", # torch compile in torch<2.6 ] From a23d5da4c496467de88f405f68b88e08625980e7 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 28 Jul 2025 23:52:22 +0200 Subject: [PATCH 170/205] test --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9f52d5654..5c255e139 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,7 +121,7 @@ filterwarnings = [ "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda - "ignore:.*softmax.*:UserWarning", # torch.compile + "ignore:.*softmax.*", # torch.compile "ignore:.*In the future, this object will be coerced as if it was first converted using.*:FutureWarning", # numpy 1.2 "ignore:.*load_module.*:DeprecationWarning", # torch compile in torch<2.6 ] From c8c91e32f85fba9a658228e6c2a7c4490ddb70b6 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 29 Jul 2025 00:05:04 +0200 Subject: [PATCH 171/205] test --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5c255e139..04160b43c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,7 +116,7 @@ dev = ["mrpro[tests, docs]"] [tool.pytest.ini_options] testpaths = ["tests"] filterwarnings = [ - "error", + #"error", "ignore:'write_like_original':DeprecationWarning:pydicom:", "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 From e8bd9a3888311ddd2aed23be606ae35c95183cc7 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 29 Jul 2025 00:24:43 +0200 Subject: [PATCH 172/205] fix? --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 04160b43c..0a9275e08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,12 +116,12 @@ dev = ["mrpro[tests, docs]"] [tool.pytest.ini_options] testpaths = ["tests"] filterwarnings = [ - #"error", + "error", "ignore:'write_like_original':DeprecationWarning:pydicom:", "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda - "ignore:.*softmax.*", # torch.compile + "ignore:Online softmax:UserWarning", # torch.compile "ignore:.*In the future, this object will be coerced as if it was first converted using.*:FutureWarning", # numpy 1.2 "ignore:.*load_module.*:DeprecationWarning", # torch compile in torch<2.6 ] From d52e688e3ec5b0b3c37a985c7c7a6f91bd44f234 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 29 Jul 2025 10:29:11 +0200 Subject: [PATCH 173/205] try --- pyproject.toml | 2 +- tests/nn/nets/test_uformer.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0a9275e08..5c255e139 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,7 +121,7 @@ filterwarnings = [ "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda - "ignore:Online softmax:UserWarning", # torch.compile + "ignore:.*softmax.*", # torch.compile "ignore:.*In the future, this object will be coerced as if it was first converted using.*:FutureWarning", # numpy 1.2 "ignore:.*load_module.*:DeprecationWarning", # torch compile in torch<2.6 ] diff --git a/tests/nn/nets/test_uformer.py b/tests/nn/nets/test_uformer.py index d3ed29c9b..c26cfc7b9 100644 --- a/tests/nn/nets/test_uformer.py +++ b/tests/nn/nets/test_uformer.py @@ -18,10 +18,10 @@ def test_uformer_forward(torch_compile: bool, device: str) -> None: """Test the forward pass of the uformer.""" uformer = Uformer( - n_dim=2, n_channels_in=1, n_channels_out=1, n_heads=(1, 2, 4), cond_dim=32, n_channels_per_head=8, window_size=2 + n_dim=2, n_channels_in=1, n_channels_out=1, n_heads=(1, 2), cond_dim=32, n_channels_per_head=8, window_size=2 ) - x = torch.zeros(1, 1, 16, 16, device=device) + x = torch.zeros(1, 1, 32, 32, device=device) cond = torch.zeros(1, 32, device=device) uformer = uformer.to(device) x = x.to(device) @@ -29,7 +29,7 @@ def test_uformer_forward(torch_compile: bool, device: str) -> None: if torch_compile: uformer = cast(Uformer, torch.compile(uformer)) y = uformer(x, cond=cond) - assert y.shape == (1, 1, 16, 16) + assert y.shape == (1, 1, 32, 32) def test_uformer_backward() -> None: From cd9541c5a81435af760c57cc9b56cc0a44c4908c Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 29 Jul 2025 10:47:50 +0200 Subject: [PATCH 174/205] ignore warning --- src/mrpro/nn/attention/ShiftedWindowAttention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/mrpro/nn/attention/ShiftedWindowAttention.py b/src/mrpro/nn/attention/ShiftedWindowAttention.py index 5c960dbbb..f6edce929 100644 --- a/src/mrpro/nn/attention/ShiftedWindowAttention.py +++ b/src/mrpro/nn/attention/ShiftedWindowAttention.py @@ -1,5 +1,7 @@ """Shifted Window Attention.""" +import warnings + import torch from einops import rearrange from torch.nn import Linear, Module @@ -100,7 +102,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: qkv=3, ) bias = rearrange(self.relative_position_bias_table[self.rel_position_index], 'wd1 wd2 heads -> 1 heads wd1 wd2') - attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=bias) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message='.*softmax.*') + attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=bias) attention = rearrange(attention, '... head sequence channels->... sequence (head channels)') attention = attention.unflatten(-2, windowed.shape[-self.n_dim - 1 : -1]) # permute (in 3d) batch channels z y x wz wy wx -> batch channels wz z wy y wx x From de08c3c25c77d3e75bd17e49c6d75da1f211194c Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 29 Jul 2025 11:25:15 +0200 Subject: [PATCH 175/205] cleanup --- pyproject.toml | 1 - tests/nn/nets/test_uformer.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5c255e139..c8fffb457 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,7 +121,6 @@ filterwarnings = [ "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda - "ignore:.*softmax.*", # torch.compile "ignore:.*In the future, this object will be coerced as if it was first converted using.*:FutureWarning", # numpy 1.2 "ignore:.*load_module.*:DeprecationWarning", # torch compile in torch<2.6 ] diff --git a/tests/nn/nets/test_uformer.py b/tests/nn/nets/test_uformer.py index c26cfc7b9..20bacf82e 100644 --- a/tests/nn/nets/test_uformer.py +++ b/tests/nn/nets/test_uformer.py @@ -21,7 +21,7 @@ def test_uformer_forward(torch_compile: bool, device: str) -> None: n_dim=2, n_channels_in=1, n_channels_out=1, n_heads=(1, 2), cond_dim=32, n_channels_per_head=8, window_size=2 ) - x = torch.zeros(1, 1, 32, 32, device=device) + x = torch.zeros(1, 1, 16, 16, device=device) cond = torch.zeros(1, 32, device=device) uformer = uformer.to(device) x = x.to(device) @@ -29,7 +29,7 @@ def test_uformer_forward(torch_compile: bool, device: str) -> None: if torch_compile: uformer = cast(Uformer, torch.compile(uformer)) y = uformer(x, cond=cond) - assert y.shape == (1, 1, 32, 32) + assert y.shape == (1, 1, 16, 16) def test_uformer_backward() -> None: From 51a8960dc3e40c1fb1b262568f4e0b036b08e934 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 29 Jul 2025 16:12:08 +0200 Subject: [PATCH 176/205] rename --- src/mrpro/nn/GluMBConvResBlock.py | 10 +-- src/mrpro/nn/PixelShuffle.py | 6 +- src/mrpro/nn/ResBlock.py | 8 +-- src/mrpro/nn/SeparableResBlock.py | 4 +- src/mrpro/nn/__init__.py | 28 ++++---- src/mrpro/nn/attention/AttentionGate.py | 8 +-- .../nn/attention/ShiftedWindowAttention.py | 1 + src/mrpro/nn/attention/SqueezeExcitation.py | 8 +-- src/mrpro/nn/attention/TransposedAttention.py | 8 +-- src/mrpro/nn/convert_linear_conv.py | 4 +- src/mrpro/nn/ndmodules.py | 14 ++-- src/mrpro/nn/nets/BasicCNN.py | 8 +-- src/mrpro/nn/nets/DCVAE.py | 6 +- src/mrpro/nn/nets/Restormer.py | 18 ++--- src/mrpro/nn/nets/SwinIR.py | 16 ++--- src/mrpro/nn/nets/UNet.py | 16 ++--- src/mrpro/nn/nets/Uformer.py | 22 +++--- tests/nn/test_ndmodules.py | 70 +++++++++---------- ...xcitation.py => test_squeezeexcitation.py} | 0 19 files changed, 128 insertions(+), 127 deletions(-) rename tests/nn/{'test_squeezeexcitation.py => test_squeezeexcitation.py} (100%) diff --git a/src/mrpro/nn/GluMBConvResBlock.py b/src/mrpro/nn/GluMBConvResBlock.py index 0455cf118..c17bc019f 100644 --- a/src/mrpro/nn/GluMBConvResBlock.py +++ b/src/mrpro/nn/GluMBConvResBlock.py @@ -5,7 +5,7 @@ from mrpro.nn.CondMixin import CondMixin from mrpro.nn.FiLM import FiLM -from mrpro.nn.ndmodules import ConvND +from mrpro.nn.ndmodules import convND from mrpro.nn.RMSNorm import RMSNorm @@ -56,9 +56,9 @@ def __init__( if stride == 1 and n_channels_in == n_channels_out: self.skip: Module = Identity() else: - self.skip = ConvND(n_dim)(n_channels_in, n_channels_out, kernel_size=1, stride=stride) + self.skip = convND(n_dim)(n_channels_in, n_channels_out, kernel_size=1, stride=stride) self.inverted_conv = Sequential( - ConvND(n_dim)( + convND(n_dim)( n_channels_in, channels_mid * 2, kernel_size=1, @@ -66,7 +66,7 @@ def __init__( SiLU(), ) self.depth_conv = Sequential( - ConvND(n_dim)( + convND(n_dim)( channels_mid * 2, channels_mid * 2, kernel_size=kernel_size, @@ -77,7 +77,7 @@ def __init__( SiLU(), ) self.point_conv = Sequential( - ConvND(n_dim)( + convND(n_dim)( channels_mid, n_channels_out, kernel_size=1, diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py index 9e5da35e8..afcff9335 100644 --- a/src/mrpro/nn/PixelShuffle.py +++ b/src/mrpro/nn/PixelShuffle.py @@ -5,7 +5,7 @@ import torch from torch.nn import Linear, Module -from mrpro.nn.ndmodules import ConvND +from mrpro.nn.ndmodules import convND class PixelUnshuffle(Module): @@ -118,7 +118,7 @@ def __init__( if features_last: self.projection: Module = Linear(n_channels_in, n_channels_out // out_ratio) else: - self.projection = ConvND(n_dim)(n_channels_in, n_channels_out // out_ratio, kernel_size=3, padding='same') + self.projection = convND(n_dim)(n_channels_in, n_channels_out // out_ratio, kernel_size=3, padding='same') self.features_last = features_last self.residual = residual self.pixel_unshuffle = PixelUnshuffle(downscale_factor, features_last) @@ -195,7 +195,7 @@ def __init__( if features_last: self.projection: Module = Linear(n_channels_in, n_channels_out * upscale_factor**n_dim) else: - self.projection = ConvND(n_dim)( + self.projection = convND(n_dim)( n_channels_in, n_channels_out * upscale_factor**n_dim, kernel_size=3, padding='same' ) self.features_last = features_last diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py index f115e205f..32870979f 100644 --- a/src/mrpro/nn/ResBlock.py +++ b/src/mrpro/nn/ResBlock.py @@ -6,7 +6,7 @@ from mrpro.nn.CondMixin import CondMixin from mrpro.nn.FiLM import FiLM from mrpro.nn.GroupNorm import GroupNorm -from mrpro.nn.ndmodules import ConvND +from mrpro.nn.ndmodules import convND from mrpro.nn.Sequential import Sequential @@ -34,10 +34,10 @@ def __init__(self, n_dim: int, n_channels_in: int, n_channels_out: int, cond_dim self.block = Sequential( GroupNorm(n_channels_in), SiLU(), - ConvND(n_dim)(n_channels_in, n_channels_out, kernel_size=3, padding=1), + convND(n_dim)(n_channels_in, n_channels_out, kernel_size=3, padding=1), GroupNorm(n_channels_out), SiLU(), - ConvND(n_dim)(n_channels_out, n_channels_out, kernel_size=3, padding=1), + convND(n_dim)(n_channels_out, n_channels_out, kernel_size=3, padding=1), ) if cond_dim > 0: self.block.insert(-3, FiLM(n_channels_out, cond_dim)) @@ -45,7 +45,7 @@ def __init__(self, n_dim: int, n_channels_in: int, n_channels_out: int, cond_dim if n_channels_out == n_channels_in: self.skip_connection: Module = Identity() else: - self.skip_connection = ConvND(n_dim)(n_channels_in, n_channels_out, kernel_size=1) + self.skip_connection = convND(n_dim)(n_channels_in, n_channels_out, kernel_size=1) def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the ResBlock. diff --git a/src/mrpro/nn/SeparableResBlock.py b/src/mrpro/nn/SeparableResBlock.py index 496004cdd..a12fda16f 100644 --- a/src/mrpro/nn/SeparableResBlock.py +++ b/src/mrpro/nn/SeparableResBlock.py @@ -7,7 +7,7 @@ from mrpro.nn.FiLM import FiLM from mrpro.nn.GroupNorm import GroupNorm -from mrpro.nn.ndmodules import ConvND +from mrpro.nn.ndmodules import convND from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.Sequential import Sequential @@ -49,7 +49,7 @@ def block(dims: Sequence[int], channels_in: int) -> Module: return Sequential( GroupNorm(channels_in), SiLU(), - PermutedBlock(dims, ConvND(len(dims))(channels_in, n_channels_out, 3, padding=1)), + PermutedBlock(dims, convND(len(dims))(channels_in, n_channels_out, 3, padding=1)), ) blocks = Sequential(*(block(d, n_channels_in if i == 0 else n_channels_out) for i, d in enumerate(dim_groups))) diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index 1466897a0..e40808890 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -4,13 +4,13 @@ from mrpro.nn.FiLM import FiLM from mrpro.nn.GroupNorm import GroupNorm from mrpro.nn.ndmodules import ( - AdaptiveAvgPoolND, - AvgPoolND, - BatchNormND, - ConvND, - ConvTransposeND, - InstanceNormND, - MaxPoolND, + adaptiveAvgPoolND, + avgPoolND, + batchNormND, + convND, + convTransposeND, + instanceNormND, + maxPoolND, ) from mrpro.nn.ResBlock import ResBlock from mrpro.nn.Sequential import Sequential @@ -30,20 +30,20 @@ __all__ = [ "AbsolutePositionEncoding", - "AdaptiveAvgPoolND", - "AvgPoolND", + "adaptiveAvgPoolND", + "avgPoolND", "AxialRoPE", - "BatchNormND", + "batchNormND", "ComplexAsChannel", "CondMixin", - "ConvND", - "ConvTransposeND", + "convND", + "convTransposeND", "DropPath", "FiLM", "FourierFeatures", "GroupNorm", - "InstanceNormND", - "MaxPoolND", + "instanceNormND", + "maxPoolND", "PermutedBlock", "RMSNorm", "ResBlock", diff --git a/src/mrpro/nn/attention/AttentionGate.py b/src/mrpro/nn/attention/AttentionGate.py index 682100650..d7fdfeab4 100644 --- a/src/mrpro/nn/attention/AttentionGate.py +++ b/src/mrpro/nn/attention/AttentionGate.py @@ -3,7 +3,7 @@ import torch from torch.nn import Module, ReLU, Sequential, Sigmoid -from mrpro.nn.ndmodules import ConvND +from mrpro.nn.ndmodules import convND class AttentionGate(Module): @@ -36,11 +36,11 @@ def __init__( Whether to concatenate the gated signal with the gate signal in the channel dimension (1) """ super().__init__() - self.project_gate = ConvND(n_dim)(channels_gate, channels_hidden, kernel_size=1) - self.project_x = ConvND(n_dim)(channels_in, channels_hidden, kernel_size=1) + self.project_gate = convND(n_dim)(channels_gate, channels_hidden, kernel_size=1) + self.project_x = convND(n_dim)(channels_in, channels_hidden, kernel_size=1) self.psi = Sequential( ReLU(), - ConvND(n_dim)(channels_hidden, 1, kernel_size=1), + convND(n_dim)(channels_hidden, 1, kernel_size=1), Sigmoid(), ) self.concatenate = concatenate diff --git a/src/mrpro/nn/attention/ShiftedWindowAttention.py b/src/mrpro/nn/attention/ShiftedWindowAttention.py index f6edce929..6c8d82c9d 100644 --- a/src/mrpro/nn/attention/ShiftedWindowAttention.py +++ b/src/mrpro/nn/attention/ShiftedWindowAttention.py @@ -103,6 +103,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) bias = rearrange(self.relative_position_bias_table[self.rel_position_index], 'wd1 wd2 heads -> 1 heads wd1 wd2') with warnings.catch_warnings(): + # Inductor in torch 2.6 warns for small batch*n_patches*n_heads about suboptimal softmax compilation. warnings.filterwarnings('ignore', message='.*softmax.*') attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=bias) attention = rearrange(attention, '... head sequence channels->... sequence (head channels)') diff --git a/src/mrpro/nn/attention/SqueezeExcitation.py b/src/mrpro/nn/attention/SqueezeExcitation.py index bd0fab4e8..5f7802c75 100644 --- a/src/mrpro/nn/attention/SqueezeExcitation.py +++ b/src/mrpro/nn/attention/SqueezeExcitation.py @@ -3,7 +3,7 @@ import torch from torch.nn import Module, ReLU, Sigmoid -from mrpro.nn.ndmodules import AdaptiveAvgPoolND, ConvND +from mrpro.nn.ndmodules import adaptiveAvgPoolND, convND from mrpro.nn.Sequential import Sequential @@ -31,10 +31,10 @@ def __init__(self, n_dim: int, n_channels_input: int, n_channels_squeeze: int) - """ super().__init__() self.scale = Sequential( - AdaptiveAvgPoolND(n_dim)(1), - ConvND(n_dim)(n_channels_input, n_channels_squeeze, kernel_size=1), + adaptiveAvgPoolND(n_dim)(1), + convND(n_dim)(n_channels_input, n_channels_squeeze, kernel_size=1), ReLU(), - ConvND(n_dim)(n_channels_squeeze, n_channels_input, kernel_size=1), + convND(n_dim)(n_channels_squeeze, n_channels_input, kernel_size=1), Sigmoid(), ) diff --git a/src/mrpro/nn/attention/TransposedAttention.py b/src/mrpro/nn/attention/TransposedAttention.py index 1f99c0fe7..88e993c8f 100644 --- a/src/mrpro/nn/attention/TransposedAttention.py +++ b/src/mrpro/nn/attention/TransposedAttention.py @@ -4,7 +4,7 @@ from einops import rearrange from torch.nn import Module, Parameter -from mrpro.nn.ndmodules import ConvND +from mrpro.nn.ndmodules import convND class TransposedAttention(Module): @@ -37,8 +37,8 @@ def __init__(self, n_dim: int, n_channels_in: int, n_channels_out: int, n_heads: self.n_heads = n_heads self.temperature = Parameter(torch.ones(n_heads, 1, 1)) channels_per_head = n_channels_in // n_heads - self.to_qkv = ConvND(n_dim)(n_channels_in, channels_per_head * n_heads * 3, kernel_size=1) - self.qkv_dwconv = ConvND(n_dim)( + self.to_qkv = convND(n_dim)(n_channels_in, channels_per_head * n_heads * 3, kernel_size=1) + self.qkv_dwconv = convND(n_dim)( channels_per_head * n_heads * 3, channels_per_head * n_heads * 3, kernel_size=3, @@ -46,7 +46,7 @@ def __init__(self, n_dim: int, n_channels_in: int, n_channels_out: int, n_heads: padding=1, bias=False, ) - self.to_out = ConvND(n_dim)(channels_per_head * n_heads, n_channels_out, kernel_size=1) + self.to_out = convND(n_dim)(channels_per_head * n_heads, n_channels_out, kernel_size=1) def __call__(self, x: torch.Tensor) -> torch.Tensor: """Apply transposed attention. diff --git a/src/mrpro/nn/convert_linear_conv.py b/src/mrpro/nn/convert_linear_conv.py index beb09d4b0..767a419ff 100644 --- a/src/mrpro/nn/convert_linear_conv.py +++ b/src/mrpro/nn/convert_linear_conv.py @@ -5,7 +5,7 @@ import torch from torch.nn import Conv1d, Conv2d, Conv3d, Linear -from mrpro.nn.ndmodules import ConvND +from mrpro.nn.ndmodules import convND @overload @@ -48,7 +48,7 @@ def linear_to_conv(linear_layer: Linear, n_dim: int) -> Conv1d | Conv2d | Conv3d ------- A Conv layer with equivalent weights and bias. """ - conv = ConvND(n_dim)( + conv = convND(n_dim)( in_channels=linear_layer.in_features, out_channels=linear_layer.out_features, kernel_size=1, diff --git a/src/mrpro/nn/ndmodules.py b/src/mrpro/nn/ndmodules.py index b7626ab5a..3fb2d894c 100644 --- a/src/mrpro/nn/ndmodules.py +++ b/src/mrpro/nn/ndmodules.py @@ -3,7 +3,7 @@ import torch -def ConvND(n_dim: int) -> type[torch.nn.Conv1d] | type[torch.nn.Conv2d] | type[torch.nn.Conv3d]: # noqa: N802 +def convND(n_dim: int) -> type[torch.nn.Conv1d] | type[torch.nn.Conv2d] | type[torch.nn.Conv3d]: # noqa: N802 """Get the `n_dim`-dimensional convolution class. Parameters @@ -26,7 +26,7 @@ def ConvND(n_dim: int) -> type[torch.nn.Conv1d] | type[torch.nn.Conv2d] | type[t raise NotImplementedError(f'ConvND for dim {n_dim} not implemented. Raise an issue if you need this.') -def ConvTransposeND( # noqa: N802 +def convTransposeND( # noqa: N802 n_dim: int, ) -> type[torch.nn.ConvTranspose1d] | type[torch.nn.ConvTranspose2d] | type[torch.nn.ConvTranspose3d]: """Get the `n_dim`-dimensional transposed convolution class. @@ -53,7 +53,7 @@ def ConvTransposeND( # noqa: N802 ) -def MaxPoolND(n_dim: int) -> type[torch.nn.MaxPool1d] | type[torch.nn.MaxPool2d] | type[torch.nn.MaxPool3d]: # noqa: N802 +def maxPoolND(n_dim: int) -> type[torch.nn.MaxPool1d] | type[torch.nn.MaxPool2d] | type[torch.nn.MaxPool3d]: # noqa: N802 """Get the `n_dim`-dimensional max pooling class. Parameters @@ -76,7 +76,7 @@ def MaxPoolND(n_dim: int) -> type[torch.nn.MaxPool1d] | type[torch.nn.MaxPool2d] raise NotImplementedError(f'MaxPoolNd for dim {n_dim} not implemented. Raise an issue if you need this.') -def AvgPoolND(n_dim: int) -> type[torch.nn.AvgPool1d] | type[torch.nn.AvgPool2d] | type[torch.nn.AvgPool3d]: # noqa: N802 +def avgPoolND(n_dim: int) -> type[torch.nn.AvgPool1d] | type[torch.nn.AvgPool2d] | type[torch.nn.AvgPool3d]: # noqa: N802 """Get the `n_dim`-dimensional average pooling class. Parameters @@ -99,7 +99,7 @@ def AvgPoolND(n_dim: int) -> type[torch.nn.AvgPool1d] | type[torch.nn.AvgPool2d] raise NotImplementedError(f'AvgPoolNd for dim {n_dim} not implemented. Raise an issue if you need this.') -def AdaptiveAvgPoolND( # noqa: N802 +def adaptiveAvgPoolND( # noqa: N802 n_dim: int, ) -> type[torch.nn.AdaptiveAvgPool1d] | type[torch.nn.AdaptiveAvgPool2d] | type[torch.nn.AdaptiveAvgPool3d]: """Get the `n_dim`-dimensional adaptive average pooling class. @@ -126,7 +126,7 @@ def AdaptiveAvgPoolND( # noqa: N802 ) -def InstanceNormND( # noqa: N802 +def instanceNormND( # noqa: N802 n_dim: int, ) -> type[torch.nn.InstanceNorm1d] | type[torch.nn.InstanceNorm2d] | type[torch.nn.InstanceNorm3d]: """Get the `n_dim`-dimensional instance normalization class. @@ -153,7 +153,7 @@ def InstanceNormND( # noqa: N802 ) -def BatchNormND( # noqa: N802 +def batchNormND( # noqa: N802 n_dim: int, ) -> type[torch.nn.BatchNorm1d] | type[torch.nn.BatchNorm2d] | type[torch.nn.BatchNorm3d]: """Get the `n_dim`-dimensional batch normalization class. diff --git a/src/mrpro/nn/nets/BasicCNN.py b/src/mrpro/nn/nets/BasicCNN.py index b39741f85..f120720e0 100644 --- a/src/mrpro/nn/nets/BasicCNN.py +++ b/src/mrpro/nn/nets/BasicCNN.py @@ -9,7 +9,7 @@ from mrpro.nn.FiLM import FiLM from mrpro.nn.GroupNorm import GroupNorm -from mrpro.nn.ndmodules import BatchNormND, ConvND +from mrpro.nn.ndmodules import batchNormND, convND from mrpro.nn.Sequential import Sequential @@ -59,11 +59,11 @@ def __init__( super().__init__() use_film = cond_dim > 0 - self.append(ConvND(n_dim)(n_channels_in, n_features[0], kernel_size=3, padding='same')) + self.append(convND(n_dim)(n_channels_in, n_features[0], kernel_size=3, padding='same')) for c_in, c_out in pairwise((*n_features, n_channels_out)): if norm.lower() == 'batch': - self.append(BatchNormND(n_dim)(c_in, affine=not use_film)) + self.append(batchNormND(n_dim)(c_in, affine=not use_film)) elif norm.lower() == 'group': self.append(GroupNorm(c_in, affine=not use_film)) elif norm.lower() == 'instance': @@ -85,7 +85,7 @@ def __init__( else: raise ValueError(f'Invalid activation type: {activation}') - self.append(ConvND(n_dim)(c_in, c_out, kernel_size=3, padding='same')) + self.append(convND(n_dim)(c_in, c_out, kernel_size=3, padding='same')) def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: # type: ignore[override] """Apply the basic CNN to the input tensor. diff --git a/src/mrpro/nn/nets/DCVAE.py b/src/mrpro/nn/nets/DCVAE.py index 5b3da8a01..5faba554b 100644 --- a/src/mrpro/nn/nets/DCVAE.py +++ b/src/mrpro/nn/nets/DCVAE.py @@ -9,7 +9,7 @@ from mrpro.nn.attention.LinearSelfAttention import LinearSelfAttention from mrpro.nn.attention.MultiHeadAttention import MultiHeadAttention from mrpro.nn.GluMBConvResBlock import GluMBConvResBlock -from mrpro.nn.ndmodules import ConvND +from mrpro.nn.ndmodules import convND from mrpro.nn.nets.VAE import VAE from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample from mrpro.nn.Residual import Residual @@ -44,9 +44,9 @@ def __init__( """ super().__init__( Sequential( - ConvND(n_dim)(n_channels, n_channels, kernel_size=3, padding=1), + convND(n_dim)(n_channels, n_channels, kernel_size=3, padding=1), SiLU(True), - ConvND(n_dim)(n_channels, n_channels, kernel_size=3, padding=1, bias=False), + convND(n_dim)(n_channels, n_channels, kernel_size=3, padding=1, bias=False), RMSNorm(n_channels), ) ) diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py index 9f85c48ba..ae43eaf86 100644 --- a/src/mrpro/nn/nets/Restormer.py +++ b/src/mrpro/nn/nets/Restormer.py @@ -10,7 +10,7 @@ from mrpro.nn.CondMixin import CondMixin from mrpro.nn.FiLM import FiLM from mrpro.nn.join import Concat -from mrpro.nn.ndmodules import ConvND, InstanceNormND +from mrpro.nn.ndmodules import convND, instanceNormND from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample from mrpro.nn.Sequential import Sequential @@ -37,8 +37,8 @@ def __init__(self, n_dim: int, n_channels: int, mlp_ratio: float): super().__init__() hidden_features = int(n_channels * mlp_ratio) - self.project_in = ConvND(n_dim)(n_channels, hidden_features * 2, kernel_size=1) - self.depthwise_conv = ConvND(n_dim)( + self.project_in = convND(n_dim)(n_channels, hidden_features * 2, kernel_size=1) + self.depthwise_conv = convND(n_dim)( hidden_features * 2, hidden_features * 2, kernel_size=3, @@ -46,7 +46,7 @@ def __init__(self, n_dim: int, n_channels: int, mlp_ratio: float): padding=1, groups=hidden_features * 2, ) - self.project_out = ConvND(n_dim)(hidden_features, n_channels, kernel_size=1) + self.project_out = convND(n_dim)(hidden_features, n_channels, kernel_size=1) def __call__(self, x: torch.Tensor) -> torch.Tensor: """Apply the gated depthwise feed forward network. @@ -87,9 +87,9 @@ def __init__(self, n_dim: int, n_channels: int, n_heads: int, mlp_ratio: float, Dimension of conditioning input. If 0, no conditioning is applied. """ super().__init__() - self.norm1 = Sequential(InstanceNormND(n_dim)(n_channels)) + self.norm1 = Sequential(instanceNormND(n_dim)(n_channels)) self.attn = TransposedAttention(n_dim, n_channels, n_channels, n_heads) - self.norm2 = Sequential(InstanceNormND(n_dim)(n_channels)) + self.norm2 = Sequential(instanceNormND(n_dim)(n_channels)) self.ffn = GDFN(n_dim, n_channels, mlp_ratio) if cond_dim > 0: self.norm2.append(FiLM(channels=n_channels, cond_dim=cond_dim)) @@ -176,7 +176,7 @@ def blocks(n_heads: int, n_blocks: int): layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, cond_dim=cond_dim)) return layers - first_block = ConvND(n_dim)(n_channels_in, n_channels_per_head, kernel_size=3, stride=1, padding=1, bias=False) + first_block = convND(n_dim)(n_channels_in, n_channels_per_head, kernel_size=3, stride=1, padding=1, bias=False) encoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)] down_blocks = [ PixelUnshuffleDownsample(n_dim, n_channels_per_head * head_current, n_channels_per_head * head_next) @@ -197,14 +197,14 @@ def blocks(n_heads: int, n_blocks: int): concat_blocks = [ Sequential( Concat(), - ConvND(n_dim)(2 * n_channels_per_head * head, n_channels_per_head * head, kernel_size=1), + convND(n_dim)(2 * n_channels_per_head * head, n_channels_per_head * head, kernel_size=1), ) for head in n_heads[-2::-1] ] decoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)][::-1] last_block = Sequential( *(RestormerBlock(n_dim, n_channels_per_head, n_heads[0], mlp_ratio) for _ in range(n_refinement_blocks)), - ConvND(n_dim)(n_channels_per_head, n_channels_out, kernel_size=3, stride=1, padding=1), + convND(n_dim)(n_channels_per_head, n_channels_out, kernel_size=3, stride=1, padding=1), ) decoder = UNetDecoder( blocks=decoder_blocks, diff --git a/src/mrpro/nn/nets/SwinIR.py b/src/mrpro/nn/nets/SwinIR.py index af944a753..4cb62ab08 100644 --- a/src/mrpro/nn/nets/SwinIR.py +++ b/src/mrpro/nn/nets/SwinIR.py @@ -6,7 +6,7 @@ from mrpro.nn.attention.ShiftedWindowAttention import ShiftedWindowAttention from mrpro.nn.DropPath import DropPath from mrpro.nn.FiLM import FiLM -from mrpro.nn.ndmodules import ConvND, InstanceNormND +from mrpro.nn.ndmodules import convND, instanceNormND from mrpro.nn.Sequential import Sequential @@ -46,15 +46,15 @@ def __init__( Droppath probability for MLP """ super().__init__() - self.norm1 = InstanceNormND(n_dim)(n_channels) + self.norm1 = instanceNormND(n_dim)(n_channels) self.attn = ShiftedWindowAttention(n_dim, n_channels, n_channels, n_heads, window_size) - self.norm2 = Sequential(InstanceNormND(n_dim)(n_channels)) + self.norm2 = Sequential(instanceNormND(n_dim)(n_channels)) if emb_dim > 0: self.norm2.append(FiLM(channels=n_channels, cond_dim=emb_dim)) self.mlp = Sequential( - ConvND(n_dim)(n_channels, n_channels * mlp_ratio, 1), + convND(n_dim)(n_channels, n_channels * mlp_ratio, 1), GELU('tanh'), - ConvND(n_dim)(n_channels * mlp_ratio, n_channels, 1), + convND(n_dim)(n_channels * mlp_ratio, n_channels, 1), DropPath(p_droppath), ) @@ -130,7 +130,7 @@ def __init__( for _ in range(depth) ) ) - self.conv = ConvND(n_dim)(n_channels, n_channels, 3, padding=1) + self.conv = convND(n_dim)(n_channels, n_channels, 3, padding=1) def __call__(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the residual Swin Transformer block. @@ -208,7 +208,7 @@ def __init__( The ratio for hidden dimension expansion in MLP. """ super().__init__() - self.first = ConvND(n_dim)(n_channels_in, n_channels_per_head * n_heads, kernel_size=3, padding=1) + self.first = convND(n_dim)(n_channels_in, n_channels_per_head * n_heads, kernel_size=3, padding=1) self.blocks = Sequential( *( ResidualSwinTransformerBlock( @@ -224,7 +224,7 @@ def __init__( for _ in range(n_blocks) ) ) - self.last = ConvND(n_dim)(n_channels_per_head * n_heads, n_channels_out, kernel_size=3, padding=1) + self.last = convND(n_dim)(n_channels_per_head * n_heads, n_channels_out, kernel_size=3, padding=1) def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply SwinIR. diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py index 9a2e6db8f..5c3572075 100644 --- a/src/mrpro/nn/nets/UNet.py +++ b/src/mrpro/nn/nets/UNet.py @@ -12,7 +12,7 @@ from mrpro.nn.CondMixin import call_with_cond from mrpro.nn.FiLM import FiLM from mrpro.nn.join import Concat -from mrpro.nn.ndmodules import ConvND, MaxPoolND +from mrpro.nn.ndmodules import convND, maxPoolND from mrpro.nn.ResBlock import ResBlock from mrpro.nn.Sequential import Sequential from mrpro.nn.Upsample import Upsample @@ -276,7 +276,7 @@ def blocks(channels_in: int, channels_out: int, attention: bool) -> Module: for i_level, (n_feat, n_feat_next) in enumerate(pairwise(n_features)): encoder_blocks.append(blocks(n_feat, n_feat, i_level in attention_depths)) - down_blocks.append(ConvND(n_dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) + down_blocks.append(convND(n_dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) decoder_blocks.append(blocks(n_feat_next + n_feat, n_feat, i_level in attention_depths)) up_blocks.append(Upsample(tuple(range(-n_dim, 0)), scale_factor=2)) @@ -286,13 +286,13 @@ def blocks(channels_in: int, channels_out: int, attention: bool) -> Module: ) if depth - 1 in attention_depths: middle_block.insert(1, attention_block(n_feat_next)) - first_block = ConvND(n_dim)(n_channels_in, n_features[0], 3, padding=1) + first_block = convND(n_dim)(n_channels_in, n_features[0], 3, padding=1) encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) decoder_blocks, up_blocks = decoder_blocks[::-1], up_blocks[::-1] last_block = Sequential( SiLU(), - ConvND(n_dim)(n_features[0], n_channels_out, 3, padding=1), + convND(n_dim)(n_features[0], n_channels_out, 3, padding=1), ) concat_blocks = [Concat() for _ in range(len(decoder_blocks))] decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) @@ -332,9 +332,9 @@ def __init__( def block(channels_in: int, channels_out: int) -> Module: block = Sequential( - ConvND(n_dim)(channels_in, channels_out, 3, padding=1), + convND(n_dim)(channels_in, channels_out, 3, padding=1), ReLU(True), - ConvND(n_dim)(channels_out, channels_out, 3, padding=1), + convND(n_dim)(channels_out, channels_out, 3, padding=1), ReLU(True), ) if cond_dim > 0: @@ -346,7 +346,7 @@ def block(channels_in: int, channels_out: int) -> Module: n_feat_old = n_channels_in for n_feat in n_features[:-1]: encoder_blocks.append(block(n_feat_old, n_feat)) - down_blocks.append(MaxPoolND(n_dim)(2)) + down_blocks.append(maxPoolND(n_dim)(2)) n_feat_old = n_feat middle_block = block(n_features[-2], n_features[-1]) encoder = UNetEncoder(Identity(), encoder_blocks, down_blocks, middle_block) @@ -358,7 +358,7 @@ def block(channels_in: int, channels_out: int) -> Module: concat_blocks.append(AttentionGate(n_dim, n_feat, n_feat_skip, n_feat_skip, concatenate=True)) decoder_blocks.append(block(n_feat + n_feat_skip, n_feat_skip)) up_blocks.append(Upsample(range(-n_dim, 0), scale_factor=2)) - last_block = ConvND(n_dim)(n_features[0], n_channels_out, 1) + last_block = convND(n_dim)(n_features[0], n_channels_out, 1) decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py index 960d14e17..02f1d1cc1 100644 --- a/src/mrpro/nn/nets/Uformer.py +++ b/src/mrpro/nn/nets/Uformer.py @@ -11,7 +11,7 @@ from mrpro.nn.DropPath import DropPath from mrpro.nn.FiLM import FiLM from mrpro.nn.join import Concat -from mrpro.nn.ndmodules import ConvND, ConvTransposeND, InstanceNormND +from mrpro.nn.ndmodules import convND, convTransposeND, instanceNormND from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder from mrpro.nn.Sequential import Sequential @@ -57,7 +57,7 @@ def __init__( super().__init__() channels = n_channels_per_head * n_heads hidden_dim = int(channels * mlp_ratio) - self.norm1 = InstanceNormND(n_dim)(channels) + self.norm1 = instanceNormND(n_dim)(channels) self.attn = ShiftedWindowAttention( n_dim=n_dim, n_channels_in=channels, @@ -66,13 +66,13 @@ def __init__( window_size=window_size, shifted=shifted, ) - self.norm2 = InstanceNormND(n_dim)(channels) + self.norm2 = instanceNormND(n_dim)(channels) self.ff = Sequential( - ConvND(n_dim)(channels, hidden_dim, 1), + convND(n_dim)(channels, hidden_dim, 1), GELU(), - ConvND(n_dim)(hidden_dim, hidden_dim, kernel_size=3, groups=hidden_dim, stride=1, padding=1), + convND(n_dim)(hidden_dim, hidden_dim, kernel_size=3, groups=hidden_dim, stride=1, padding=1), GELU(), - ConvND(n_dim)(hidden_dim, channels, 1), + convND(n_dim)(hidden_dim, channels, 1), ) if cond_dim > 0: self.ff.append(FiLM(channels, cond_dim)) @@ -178,7 +178,7 @@ def blocks(n_heads: int, p_droppath: float = 0.0): ) first_block = torch.nn.Sequential( - ConvND(n_dim)(n_channels_in, n_channels_per_head * n_heads[0], kernel_size=3, stride=1, padding='same'), + convND(n_dim)(n_channels_in, n_channels_per_head * n_heads[0], kernel_size=3, stride=1, padding='same'), LeakyReLU(), ) drop_path_rates = torch.linspace(0, max_droppath_rate, len(n_heads)).tolist() @@ -187,7 +187,7 @@ def blocks(n_heads: int, p_droppath: float = 0.0): for n_head, p_droppath_input in zip(n_heads[:-1], drop_path_rates[:-1], strict=True) ] down_blocks = [ - ConvND(n_dim)( + convND(n_dim)( n_channels_per_head * n_head_current, n_channels_per_head * n_head_next, kernel_size=4, @@ -207,17 +207,17 @@ def blocks(n_heads: int, p_droppath: float = 0.0): decoder_blocks = [blocks(n_heads=2 * n_head, p_droppath=max_droppath_rate) for n_head in reversed(n_heads[:-1])] concat_blocks = [Concat() for _ in range(len(decoder_blocks))] up_blocks = [ - ConvTransposeND(n_dim)( + convTransposeND(n_dim)( n_channels_per_head * n_heads[-1], n_channels_per_head * n_heads[-2], kernel_size=2, stride=2 ) ] for n_head_current, n_head_next in pairwise(reversed(n_heads[:-1])): up_blocks.append( - ConvTransposeND(n_dim)( + convTransposeND(n_dim)( 2 * n_channels_per_head * n_head_current, n_channels_per_head * n_head_next, kernel_size=2, stride=2 ) ) - last_block = ConvND(n_dim)( + last_block = convND(n_dim)( 2 * n_channels_per_head * n_heads[0], n_channels_out, kernel_size=3, stride=1, padding='same' ) decoder = UNetDecoder( diff --git a/tests/nn/test_ndmodules.py b/tests/nn/test_ndmodules.py index b34999170..a0a77a98d 100644 --- a/tests/nn/test_ndmodules.py +++ b/tests/nn/test_ndmodules.py @@ -3,74 +3,74 @@ import pytest import torch from mrpro.nn.ndmodules import ( - AdaptiveAvgPoolND, - AvgPoolND, - BatchNormND, - ConvND, - ConvTransposeND, - InstanceNormND, - MaxPoolND, + adaptiveAvgPoolND, + avgPoolND, + batchNormND, + convND, + convTransposeND, + instanceNormND, + maxPoolND, ) def test_convnd() -> None: """Test ConvND.""" - assert ConvND(1) is torch.nn.Conv1d - assert ConvND(2) is torch.nn.Conv2d - assert ConvND(3) is torch.nn.Conv3d + assert convND(1) is torch.nn.Conv1d + assert convND(2) is torch.nn.Conv2d + assert convND(3) is torch.nn.Conv3d with pytest.raises(NotImplementedError): - ConvND(4) + convND(4) def test_convtransposend() -> None: """Test ConvTransposeND.""" - assert ConvTransposeND(1) is torch.nn.ConvTranspose1d - assert ConvTransposeND(2) is torch.nn.ConvTranspose2d - assert ConvTransposeND(3) is torch.nn.ConvTranspose3d + assert convTransposeND(1) is torch.nn.ConvTranspose1d + assert convTransposeND(2) is torch.nn.ConvTranspose2d + assert convTransposeND(3) is torch.nn.ConvTranspose3d with pytest.raises(NotImplementedError): - ConvTransposeND(4) + convTransposeND(4) def test_maxpoolnd() -> None: """Test MaxPoolND.""" - assert MaxPoolND(1) is torch.nn.MaxPool1d - assert MaxPoolND(2) is torch.nn.MaxPool2d - assert MaxPoolND(3) is torch.nn.MaxPool3d + assert maxPoolND(1) is torch.nn.MaxPool1d + assert maxPoolND(2) is torch.nn.MaxPool2d + assert maxPoolND(3) is torch.nn.MaxPool3d with pytest.raises(NotImplementedError): - MaxPoolND(4) + maxPoolND(4) def test_avgpoolnd() -> None: """Test AvgPoolND.""" - assert AvgPoolND(1) is torch.nn.AvgPool1d - assert AvgPoolND(2) is torch.nn.AvgPool2d - assert AvgPoolND(3) is torch.nn.AvgPool3d + assert avgPoolND(1) is torch.nn.AvgPool1d + assert avgPoolND(2) is torch.nn.AvgPool2d + assert avgPoolND(3) is torch.nn.AvgPool3d with pytest.raises(NotImplementedError): - AvgPoolND(4) + avgPoolND(4) def test_adaptiveavgpoolnd() -> None: """Test AdaptiveAvgPoolND.""" - assert AdaptiveAvgPoolND(1) is torch.nn.AdaptiveAvgPool1d - assert AdaptiveAvgPoolND(2) is torch.nn.AdaptiveAvgPool2d - assert AdaptiveAvgPoolND(3) is torch.nn.AdaptiveAvgPool3d + assert adaptiveAvgPoolND(1) is torch.nn.AdaptiveAvgPool1d + assert adaptiveAvgPoolND(2) is torch.nn.AdaptiveAvgPool2d + assert adaptiveAvgPoolND(3) is torch.nn.AdaptiveAvgPool3d with pytest.raises(NotImplementedError): - AdaptiveAvgPoolND(4) + adaptiveAvgPoolND(4) def test_instancenormnd() -> None: """Test InstanceNormND.""" - assert InstanceNormND(1) is torch.nn.InstanceNorm1d - assert InstanceNormND(2) is torch.nn.InstanceNorm2d - assert InstanceNormND(3) is torch.nn.InstanceNorm3d + assert instanceNormND(1) is torch.nn.InstanceNorm1d + assert instanceNormND(2) is torch.nn.InstanceNorm2d + assert instanceNormND(3) is torch.nn.InstanceNorm3d with pytest.raises(NotImplementedError): - InstanceNormND(4) + instanceNormND(4) def test_batchnormnd() -> None: """Test BatchNormND.""" - assert BatchNormND(1) is torch.nn.BatchNorm1d - assert BatchNormND(2) is torch.nn.BatchNorm2d - assert BatchNormND(3) is torch.nn.BatchNorm3d + assert batchNormND(1) is torch.nn.BatchNorm1d + assert batchNormND(2) is torch.nn.BatchNorm2d + assert batchNormND(3) is torch.nn.BatchNorm3d with pytest.raises(NotImplementedError): - BatchNormND(4) + batchNormND(4) diff --git a/tests/nn/'test_squeezeexcitation.py b/tests/nn/test_squeezeexcitation.py similarity index 100% rename from tests/nn/'test_squeezeexcitation.py rename to tests/nn/test_squeezeexcitation.py From d9928bf6622dbccf5124f9bae5f178c404380978 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 30 Jul 2025 00:21:03 +0200 Subject: [PATCH 177/205] text --- examples/scripts/apply_pinqi.py | 104 ++++++++++++++++++++++++++------ 1 file changed, 85 insertions(+), 19 deletions(-) diff --git a/examples/scripts/apply_pinqi.py b/examples/scripts/apply_pinqi.py index daad3137c..881dc1dfe 100644 --- a/examples/scripts/apply_pinqi.py +++ b/examples/scripts/apply_pinqi.py @@ -1,3 +1,27 @@ +# %% [markdown] +# # End-to-end physics informed network for quantitative MRI (PINQI) +# A recent DL approach, PINQI, approaches learned quantitative MRI by half quadratic splitting to alternate between two +# subproblems. The first is a linear image reconstruction task +# $$ +# \underset{\mathbf{x}}{\min} \frac{1}{2} \| \mathbf{A} \mathbf{x} - \mathbf{y} \|_2^2 + \frac{\lambda_\mathbf{x}}{2} \left\| \mathbf{x} - \mathbf{x}_{\text{reg}} \right\|_2^2 + \frac{\lambda_{\mathbf{q}}}{2} \left\| \mathbf{q}(\mathbf{p}) - \mathbf{x} \right\|_2^2 +# $$ +# with $\mathbf{x}$ being intermediary qualitative images, $\lambda_{\mathbf{x}}$ and $\lambda_{\mathbf{q}}$ being +# regularization strengths and $\mathbf{x}_{\text{reg}}$ denoting an image prior for regularization. +# The second, non-linear, subproblem is finding the quantitative parameters by solving +# $$ +# \underset{\mathbf{p}}{\min} \frac{\lambda_{\mathbf{q}}}{2}\left \| \mathbf{q}(\vec{p}) - \mathbf{x} \right\|_2^2 + \frac{\lambda_{\mathbf{p}}}{2} \left\| \mathbf{p} - \mathbf{p}_{\text{reg}} \right\|_2^2. +# $$ +# Here, $\mathbf{p}_{\text{reg}}$ is a prior on the parameter maps and $\lambda_{\mathbf{p}}$ the associated weight for regularization. +# In PINQI, a solution is found by iterating between both subproblems. In each iteration $k=1,\ldots,T$, the image and parameter priors are updated by +# U-Nets. The network parameters and the regularization strengths are trained end-to-end. +# Here, we apply a trained PINQI model to a validation set. We first define the dataset, then define the PINQI model, before loading the model weights +# and applying it to the dataset. + +# %% [markdown] +# ## Dataset +# We base the dataset on the BrainWeb phantom (`mrpro.phantoms.brainweb.BrainwebSlices`) and simulate Cartesian random undersampling in phase +# encode direction. + # %% from collections.abc import Sequence from copy import deepcopy @@ -8,9 +32,8 @@ import mrpro import torch -# mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True) - +# %% class BatchType(TypedDict): """Typehint for a batch of data.""" @@ -106,6 +129,12 @@ def __getitem__(self, index: int): return {'kdata': kdata, 'csm': csm, **phantom} +# %% [markdown] +# ## PINQI +# Next, We define the PINQI model. Here we can make use of the diffferntiable optimization operators in MRpro. + + +# %% class PINQI(torch.nn.Module): """PINQI model.""" @@ -157,8 +186,13 @@ def objective_factory( objective_factory, lambda _l, _i, *parameter_reg: parameter_reg, ) + # This can be done once, as the signal model is the same for all samples. def get_linear_solver(self, gram: mrpro.operators.LinearOperator): + """Set up the linear solver.""" + # This needs to be done for each sample, as the undersampling pattern and csm are different for each sample, + # thus the gram operator of the acquisition operator is different for each sample. + def operator_factory( lambda_image: torch.Tensor, lambda_q: torch.Tensor, @@ -181,6 +215,7 @@ def rhs_factory( ) def get_parameter_reg(self, image: torch.Tensor, iteration: int = 0) -> tuple[torch.Tensor, ...]: + """Get the parameter regularization.""" image = einops.rearrange( torch.view_as_real(image), 'batch t 1 1 y x complex-> batch (t complex) y x', @@ -200,6 +235,7 @@ def get_parameter_reg(self, image: torch.Tensor, iteration: int = 0) -> tuple[to return tuple(result) def get_image_reg(self, image: torch.Tensor, iteration: int = 0) -> torch.Tensor: + """Get the image regularization.""" batch = image.shape[0] image = einops.rearrange( torch.view_as_real(image), @@ -211,24 +247,43 @@ def get_image_reg(self, image: torch.Tensor, iteration: int = 0) -> torch.Tensor return torch.view_as_complex(image.contiguous()) def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): + """Estimate the quantitative parameters. + + Parameters + ---------- + kdata + The k-space data. + csm + The coil sensitivity maps. + + Returns + ------- + images + The qualitative images. + parameters + The quantitative parameters. + """ csm_op = csm.as_operator() fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) acquisition_op = fourier_op @ csm_op gram = acquisition_op.gram (zero_filled_image,) = acquisition_op.H(kdata.data) - images = list(mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2)) - parameters = [self.get_parameter_reg(images[-1], 0)] + images = mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2) + parameters = self.get_parameter_reg(images, 0) linear_solver = self.get_linear_solver(gram) for i, (lambda_image, lambda_q, lambda_parameter) in enumerate(self.softplus(self.lambdas_raw)): - image_reg = self.get_image_reg(images[-1], i + 1) - (signal,) = self.signalmodel(*parameters[-1]) - images.extend(linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image)) - parameters_reg = self.get_parameter_reg(images[-1], i + 1) - parameters.append(self.nonlinear_solver(lambda_parameter, images[-1], *parameters_reg)) + # linear subproblem 1 + image_reg = self.get_image_reg(images, i) + (signal,) = self.signalmodel(*parameters) + images = linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image) + # nonlinear subproblem 2 + parameters_reg = self.get_parameter_reg(images, i + 1) + parameters = self.nonlinear_solver(lambda_parameter, images, *parameters_reg) if self.constraints_op is not None: - parameters = [self.constraints_op(*p) for p in parameters] - return images, parameters + # map the parameters into the constrained space + parameters = self.constraints_op(*parameters) + return parameters # %% @@ -283,7 +338,7 @@ def baseline_solution( list(range(500)), ) # %% -checkpoint = torch.load('last.ckpt', map_location='cpu') +checkpoint = torch.load('./examples/scripts/last.ckpt', map_location='cpu') hyper_parameters = checkpoint['hyper_parameters'] @@ -332,12 +387,20 @@ def baseline_solution( print(f'SSIM: {ssim_t1.item():.4f}, NRMSE: {nrmse_t1.item():.4f}') -fig, ax = plt.subplots(1, 5, gridspec_kw={'width_ratios': [1, 1, 1, 0.01, 0.075], 'wspace': 0.0}, figsize=(5, 2)) +fig, ax = plt.subplots( + 1, + 5, + gridspec_kw={ + 'width_ratios': [1, 1, 1, 0.28, 0.075], + 'wspace': -0.25, + }, + figsize=(6.5, 2.5), +) baseline_t1 = baseline_t1.squeeze() baseline_t1[~batch['mask']] = torch.nan ax[0].imshow(baseline_t1, vmin=0, vmax=2, cmap=cmap) ax[0].axis('off') -ax[0].set_title('SENSE + Regression') +ax[0].set_title('SENSE + NLS') ax[0].text( 0.5, -0.00, @@ -346,6 +409,7 @@ def baseline_solution( horizontalalignment='center', verticalalignment='top', transform=ax[0].transAxes, + size=11, ) predicted_t1 = predicted_t1.squeeze() predicted_t1[~batch['mask']] = torch.nan @@ -360,22 +424,24 @@ def baseline_solution( horizontalalignment='center', verticalalignment='top', transform=ax[1].transAxes, - size=10, + size=11, ) target_t1 = batch['t1'].squeeze() target_t1[~batch['mask']] = torch.nan im = ax[2].imshow(target_t1, vmin=0, vmax=2, cmap=cmap) ax[2].axis('off') -ax[2].set_title('Ground Truth') -fig.tight_layout() +ax[2].set_title( + 'Ground Truth', +) ax[-2].axis('off') +fig.tight_layout() plt.colorbar(im, cax=ax[-1], label='$T_1$ (s)') -fig.savefig('/home/zimmer08/code/mrpro/examples/scripts/pinqi_t1_2.pdf', bbox_inches='tight') +fig.savefig('/home/zimmer08/code/mrpro/examples/scripts/pinqi_t1_3.pdf', bbox_inches='tight', pad_inches=0) # %% - +1 # %% # %% From 4b8a52b8e5fde86d3c65ddf863edf00b50ae1500 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 8 Feb 2026 22:57:24 +0100 Subject: [PATCH 178/205] add core nn foundations, layers, and resize blocks ghstack-source-id: f91e13b1b63eab7f12924afeefed72743836d857 ghstack-comment-id: 3865650347 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/953 --- pyproject.toml | 10 +- src/mrpro/__init__.py | 6 +- src/mrpro/nn/ComplexAsChannel.py | 59 ++++++ src/mrpro/nn/CondMixin.py | 22 +++ src/mrpro/nn/DropPath.py | 55 ++++++ src/mrpro/nn/FiLM.py | 54 +++++ src/mrpro/nn/FourierFeatures.py | 50 +++++ src/mrpro/nn/GEGLU.py | 56 ++++++ src/mrpro/nn/GroupNorm.py | 71 +++++++ src/mrpro/nn/LayerNorm.py | 83 ++++++++ src/mrpro/nn/PermutedBlock.py | 75 +++++++ src/mrpro/nn/PixelShuffle.py | 285 +++++++++++++++++++++++++++ src/mrpro/nn/RMSNorm.py | 73 +++++++ src/mrpro/nn/Residual.py | 45 +++++ src/mrpro/nn/Sequential.py | 60 ++++++ src/mrpro/nn/Upsample.py | 66 +++++++ src/mrpro/nn/__init__.py | 45 +++++ src/mrpro/nn/convert_linear_conv.py | 100 ++++++++++ src/mrpro/nn/join.py | 176 +++++++++++++++++ src/mrpro/nn/ndmodules.py | 178 +++++++++++++++++ src/mrpro/utils/__init__.py | 5 +- src/mrpro/utils/to_tuple.py | 36 ++++ tests/nn/test_complexaschannel.py | 30 +++ tests/nn/test_convert_linear_conv.py | 150 ++++++++++++++ tests/nn/test_droppath.py | 30 +++ tests/nn/test_film.py | 42 ++++ tests/nn/test_fourierfeatures.py | 24 +++ tests/nn/test_geglu.py | 38 ++++ tests/nn/test_groupnorm.py | 45 +++++ tests/nn/test_join.py | 160 +++++++++++++++ tests/nn/test_layernorm.py | 187 ++++++++++++++++++ tests/nn/test_ndmodules.py | 76 +++++++ tests/nn/test_pixelshuffle.py | 92 +++++++++ tests/nn/test_rmsnorm.py | 58 ++++++ tests/nn/test_sequential.py | 50 +++++ 35 files changed, 2586 insertions(+), 6 deletions(-) create mode 100644 src/mrpro/nn/ComplexAsChannel.py create mode 100644 src/mrpro/nn/CondMixin.py create mode 100644 src/mrpro/nn/DropPath.py create mode 100644 src/mrpro/nn/FiLM.py create mode 100644 src/mrpro/nn/FourierFeatures.py create mode 100644 src/mrpro/nn/GEGLU.py create mode 100644 src/mrpro/nn/GroupNorm.py create mode 100644 src/mrpro/nn/LayerNorm.py create mode 100644 src/mrpro/nn/PermutedBlock.py create mode 100644 src/mrpro/nn/PixelShuffle.py create mode 100644 src/mrpro/nn/RMSNorm.py create mode 100644 src/mrpro/nn/Residual.py create mode 100644 src/mrpro/nn/Sequential.py create mode 100644 src/mrpro/nn/Upsample.py create mode 100644 src/mrpro/nn/__init__.py create mode 100644 src/mrpro/nn/convert_linear_conv.py create mode 100644 src/mrpro/nn/join.py create mode 100644 src/mrpro/nn/ndmodules.py create mode 100644 src/mrpro/utils/to_tuple.py create mode 100644 tests/nn/test_complexaschannel.py create mode 100644 tests/nn/test_convert_linear_conv.py create mode 100644 tests/nn/test_droppath.py create mode 100644 tests/nn/test_film.py create mode 100644 tests/nn/test_fourierfeatures.py create mode 100644 tests/nn/test_geglu.py create mode 100644 tests/nn/test_groupnorm.py create mode 100644 tests/nn/test_join.py create mode 100644 tests/nn/test_layernorm.py create mode 100644 tests/nn/test_ndmodules.py create mode 100644 tests/nn/test_pixelshuffle.py create mode 100644 tests/nn/test_rmsnorm.py create mode 100644 tests/nn/test_sequential.py diff --git a/pyproject.toml b/pyproject.toml index 67aeb1928..9863b3856 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,6 +98,7 @@ docs = [ "sphinx-autodoc-typehints>=3, <3.1", "sphinx-copybutton>=0.5, <0.6", "sphinx-last-updated-by-git>=0.3, <0.4", + "snowballstemmer>=2.2, <3.0", ] notebooks = [ "zenodo_get>=2.0", @@ -118,10 +119,12 @@ testpaths = ["tests"] filterwarnings = [ "error", "ignore:'write_like_original':DeprecationWarning:pydicom:", - "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd - "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 + "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd + "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 + "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda "ignore:.*In the future, this object will be coerced as if it was first converted using.*:FutureWarning", # numpy 1.2 - "ignore:`torch.jit.script` is deprecated:DeprecationWarning", # torch 2.10 + "ignore:.*load_module.*:DeprecationWarning", # torch compile in torch<2.6 + "ignore:`torch.jit.script` is deprecated:DeprecationWarning", # torch 2.10 ] addopts = "-n auto --dist loadfile --maxprocesses=8" markers = ["cuda : Tests only to be run when cuda device is available"] @@ -230,6 +233,7 @@ iy = "iy" arange = "arange" # torch.arange Ba = "Ba" wht = "wht" # Brainweb tissue class +ND = "ND" # Short for N-dimensional [tool.typos.files] extend-exclude = [ diff --git a/src/mrpro/__init__.py b/src/mrpro/__init__.py index 729ae188c..bbd401f1f 100644 --- a/src/mrpro/__init__.py +++ b/src/mrpro/__init__.py @@ -1,10 +1,12 @@ from mrpro._version import __version__ -from mrpro import algorithms, operators, data, phantoms, utils +from mrpro import algorithms, operators, data, phantoms, utils, nn + __all__ = [ "__version__", "algorithms", "data", + "nn", "operators", "phantoms", "utils" -] +] \ No newline at end of file diff --git a/src/mrpro/nn/ComplexAsChannel.py b/src/mrpro/nn/ComplexAsChannel.py new file mode 100644 index 000000000..22e13458e --- /dev/null +++ b/src/mrpro/nn/ComplexAsChannel.py @@ -0,0 +1,59 @@ +"""ComplexAsChannel: handling complex-valued tensors as channels.""" + +import torch +from einops import rearrange +from torch.nn import Module + +from mrpro.nn.CondMixin import CondMixin, call_with_cond + + +class ComplexAsChannel(CondMixin, Module): + """Wrap module to treat complex numbers as a channel dimension.""" + + def __init__(self, module: Module, convert_back: bool = True): + """Initialize the ComplexAsChannel module. + + Wraps a module to treat complex numbers as a channel dimension. + For each complex tensor in the input, real and imaginary parts are concatenated along the channel dimension + before being passed to the wrapped module. + + + Parameters + ---------- + module + The module to wrap. Should output a single real tensor. + convert_back + If True, the output is converted back to a complex tensor. + The output should have a number of channels that is a multiple of 2. + """ + super().__init__() + self.module = module + self.convert_back = convert_back + + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor (if used by the wrapped module) + """ + return super().__call__(*x, cond=cond) + + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module.""" + x_real = [ + rearrange(torch.view_as_real(c), 'batch channel ... complex -> batch (channel complex) ...') + if c.is_complex() + else c + for c in x + ] + + y = call_with_cond(self.module, *x_real, cond=cond) + + if self.convert_back: + y = rearrange(y, 'b (channel complex) ... -> b channel ... complex', complex=2).contiguous() + y = torch.view_as_complex(y) + return y diff --git a/src/mrpro/nn/CondMixin.py b/src/mrpro/nn/CondMixin.py new file mode 100644 index 000000000..6a902c413 --- /dev/null +++ b/src/mrpro/nn/CondMixin.py @@ -0,0 +1,22 @@ +"""Base class for modules using a conditioning.""" + +import torch +from torch.nn import Module + + +def call_with_cond(module: Module, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Call a module with conditioning if it is a CondMixin.""" + if isinstance(module, CondMixin): + return module(*x, cond=cond) + return module(*x) + + +class CondMixin(Module): + """Mixin for modules using a conditioning. + + Used to determine if a module uses a conditioning within a Sequential container. + """ + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module to the input.""" + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/DropPath.py b/src/mrpro/nn/DropPath.py new file mode 100644 index 000000000..7262fd86c --- /dev/null +++ b/src/mrpro/nn/DropPath.py @@ -0,0 +1,55 @@ +"""DropPath (stochastic depth).""" + +import torch +from torch.nn import Module + + +class DropPath(Module): + """Drop path or stochastic depth. + + Drops full samples from batch with probability `droprate`. + Should be used in the main path of a Resblock. + + References + ---------- + .. [HUANG16] Huang, G., Sun, Y., Liu, Z., Sedra, D., & Weinberger, K. Q. Deep networks with stochastic depth. + ECCV 2016. https://link.springer.com/chapter/10.1007/978-3-319-46493-0_39 + """ + + def __init__(self, droprate: float = 0.0, scale_by_keep: bool = False): + """Initialize the DropPath module. + + Parameters + ---------- + droprate + Drop probability + scale_by_keep + If True, the kept samples are scaled by :math:`1/(1-droprate)` + """ + super().__init__() + self.droprate = droprate + self.scale_by_keep = scale_by_keep + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply DropPath. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Tensor with batch samples randomly dropped + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply DropPath.""" + if self.droprate == 0 or not self.training: + return x + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + mask = ((1 - self.droprate) + torch.rand(shape, dtype=x.dtype, device=x.device)).floor_() + if self.scale_by_keep: + mask = mask.div_(1 - self.droprate) + return x * mask diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py new file mode 100644 index 000000000..92780aae3 --- /dev/null +++ b/src/mrpro/nn/FiLM.py @@ -0,0 +1,54 @@ +"""Feature-wise Linear Modulation.""" + +import torch +from torch.nn import Linear, Module + +from mrpro.nn.CondMixin import CondMixin +from mrpro.utils.reshape import unsqueeze_tensors_right + + +class FiLM(CondMixin, Module): + """Feature-wise Linear Modulation. + + Feature-wise Linear Modulation from [FiLM]_ to condition a network on a conditioning tensor. + + + References + ---------- + ..[FiLM] Perez, L., Strub, F., de Vries, H., Dumoulin, V., & Courville, A. "FiLM : Visual reasoning with a general + conditioning layer." AAAI (2018). https://arxiv.org/abs/1709.07871 + """ + + def __init__(self, channels: int, cond_dim: int) -> None: + """Initialize FiLM. + + Parameters + ---------- + channels + The number of channels in the input tensor. + cond_dim + The dimension of the conditioning tensor. + """ + super().__init__() + self.project = Linear(cond_dim, 2 * channels) if cond_dim > 0 else None + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply FiLM. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply FiLM.""" + if cond is None or self.project is None: + return x + scale, shift = self.project(cond).chunk(2, dim=1) + + scale, shift = unsqueeze_tensors_right(scale, shift, ndim=x.ndim) + return x * (1 + scale) + shift diff --git a/src/mrpro/nn/FourierFeatures.py b/src/mrpro/nn/FourierFeatures.py new file mode 100644 index 000000000..847ae3c60 --- /dev/null +++ b/src/mrpro/nn/FourierFeatures.py @@ -0,0 +1,50 @@ +"""Random Fourier feature embedding.""" + +import torch +from torch.nn import Module + + +class FourierFeatures(Module): + """Fourier feature encoding layer. + + Projects input features into a higher dimensional space using random Fourier features. + Used in INRs and to embed the time or other continuous variables. + """ + + weight: torch.Tensor + + def __init__(self, n_features_in: int, n_features_out: int, std: float = 1.0): + """Initialize Fourier feature encoding layer. + + Parameters + ---------- + n_features_in + Number of input features + n_features_out + Number of output features (must be even) + std + Standard deviation for random initialization + """ + if n_features_out % 2 != 0: + raise ValueError('n_features_out must be even.') + super().__init__() + self.register_buffer('weight', torch.randn([n_features_out // 2, n_features_in]) * std) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply Fourier feature encoding. + + Parameters + ---------- + x + Input tensor of shape (..., in_features) + + Returns + ------- + Encoded features of shape (..., out_features) + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply Fourier feature encoding.""" + f = 2 * torch.pi * x @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) diff --git a/src/mrpro/nn/GEGLU.py b/src/mrpro/nn/GEGLU.py new file mode 100644 index 000000000..6151503d2 --- /dev/null +++ b/src/mrpro/nn/GEGLU.py @@ -0,0 +1,56 @@ +"""Gated linear unit activation function.""" + +import torch +from torch.nn import Linear, Module + + +class GEGLU(Module): + r"""Gated linear unit activation function. + + References + ---------- + ..[GLU] Shazeer, N. (2020). GLU variants improve transformer. https://arxiv.org/abs/2002.05202 + """ + + def __init__(self, n_channels_in: int, n_channels_out: int | None = None, features_last: bool = False): + """Initialize the GEGLU activation function. + + Parameters + ---------- + n_channels_in + The number of input features/channels. + n_channels_out + The number of output features/channels. If None, the number of + output features is the same as the number of input features. + features_last + If True, the channel dimension is the last dimension, else in the second dimension. + """ + super().__init__() + out_channels_ = n_channels_in if n_channels_out is None else n_channels_out + self.proj = Linear(n_channels_in, out_channels_ * 2) # gate and output stacked + self.features_last = features_last + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the GEGLU activation.""" + if not self.features_last: + x = x.moveaxis(1, -1) + h, gate = self.proj(x).chunk(2, dim=-1) + gate = torch.nn.functional.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + out = h * gate + if not self.features_last: + out = out.moveaxis(-1, 1) + return out + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the GEGLU activation. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Activated tensor + """ + return super().__call__(x) diff --git a/src/mrpro/nn/GroupNorm.py b/src/mrpro/nn/GroupNorm.py new file mode 100644 index 000000000..e0bbb0a4b --- /dev/null +++ b/src/mrpro/nn/GroupNorm.py @@ -0,0 +1,71 @@ +"""GroupNorm with 32-bit precision.""" + +import torch + + +class GroupNorm(torch.nn.GroupNorm): + """A 32-bit GroupNorm with (optional) automatic group size selection. + + Casts to float32 before calling the parent class to avoid instabilities in mixed precision training. + + If `n_groups` is not provided, the number of groups is selected automatically as follows: + + - start from `1` group, + - try powers of two (`2, 4, 8, ...`), + - keep the largest candidate that divides `n_channels`, + - enforce at most `32` groups and at least `4` channels per group. + + This yields a stable default that stays close to common GroupNorm choices while + adapting to small channel counts. + """ + + features_last: bool + + def __init__(self, n_channels: int, n_groups: int | None = None, affine: bool = False, features_last: bool = False): + """Initialize GroupNorm. + + Parameters + ---------- + n_channels + The number of channels in the input tensor. + n_groups + The number of groups to use. If None, the number of groups is determined automatically as + the largest power of 2 that divides `n_channels`, is less than or equal to 32, + and leaves at least 4 channels per group. + affine + Whether to use learnable affine parameters. + features_last + Whether the features are last in the input tensor, as common in transformer models, + or in the second dimension, as common in CNNs. + """ + if n_groups is None: + groups_, candidate = 1, 2 + while (candidate <= min(32, n_channels // 4)) and (n_channels % candidate == 0): + groups_, candidate = candidate, groups_ * 2 + else: + groups_ = n_groups + self.features_last: bool = features_last + super().__init__(groups_, n_channels, affine=affine) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply GroupNorm32. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x.float()).type(x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply GroupNorm.""" + if self.features_last: + x = x.moveaxis(-1, 1) + result = super().forward(x.float()).type(x.dtype) + if self.features_last: + result = result.moveaxis(1, -1) + return result diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py new file mode 100644 index 000000000..7c35eee96 --- /dev/null +++ b/src/mrpro/nn/LayerNorm.py @@ -0,0 +1,83 @@ +"""Layer normalization.""" + +import torch +from torch.nn import Linear, Module, Parameter + +from mrpro.nn.CondMixin import CondMixin +from mrpro.utils.reshape import unsqueeze_at, unsqueeze_right + + +class LayerNorm(CondMixin, Module): + """Layer normalization.""" + + def __init__(self, n_channels: int | None, features_last: bool = False, cond_dim: int = 0) -> None: + """Initialize the layer normalization. + + Parameters + ---------- + n_channels + Number of channels in the input tensor. If `None`, the layer normalization does not do an elementwise + affine transformation. + features_last + If `True`, the channel dimension is the last dimension. + cond_dim + Number of channels in the conditioning tensor. If `0`, no adaptive scaling is applied. + """ + super().__init__() + if n_channels is None and cond_dim == 0: + self.weight: Parameter | None = None + self.bias: Parameter | None = None + self.cond_proj: Linear | None = None + elif n_channels is None and cond_dim > 0: + raise ValueError('channels must be provided if cond_dim > 0') + elif n_channels is not None and cond_dim == 0: + self.weight = Parameter(torch.ones(n_channels)) + self.bias = Parameter(torch.zeros(n_channels)) + self.cond_proj = None + elif n_channels is not None: + self.weight = None + self.bias = None + self.cond_proj = Linear(cond_dim, 2 * n_channels) + else: + raise ValueError('cond_dim must be zero or positive.') + + self.features_last = features_last + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply layer normalization to the input tensor. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor. If `None`, no conditioning is applied. + + Returns + ------- + Normalized output tensor + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply layer normalization to the input tensor.""" + dims = tuple(range(1, x.ndim)) + mean = x.mean(dim=dims, keepdim=True) + std = x.std(dim=dims, keepdim=True, unbiased=False) + x = (x - mean) / (std + 1e-5) + + if self.weight is not None and self.bias is not None: + if self.features_last: + x = x * self.weight + self.bias + else: + x = x * unsqueeze_right(self.weight, x.ndim - 2) + unsqueeze_right(self.bias, x.ndim - 2) + + if self.cond_proj is not None and cond is not None: + scale, shift = self.cond_proj(cond).chunk(2, dim=-1) + scale = 1 + scale + if self.features_last: + x = x * unsqueeze_at(scale, 1, x.ndim - 2) + unsqueeze_at(shift, 1, x.ndim - 2) + else: + x = x * unsqueeze_right(scale, x.ndim - 2) + unsqueeze_right(shift, x.ndim - 2) + + return x diff --git a/src/mrpro/nn/PermutedBlock.py b/src/mrpro/nn/PermutedBlock.py new file mode 100644 index 000000000..935d114dc --- /dev/null +++ b/src/mrpro/nn/PermutedBlock.py @@ -0,0 +1,75 @@ +"""Block that applies a submodule along selected spatial dimensions.""" + +from collections.abc import Sequence + +import torch +from torch import nn + +from mrpro.nn.CondMixin import CondMixin, call_with_cond + + +class PermutedBlock(CondMixin, nn.Module): + """Apply a submodule along selected spatial dimensions.""" + + apply_along_dim: tuple[int, ...] + module: nn.Module + + def __init__(self, apply_along_dim: Sequence[int], module: nn.Module, features_last: bool = False): + """Initialize the PermutedBlock. + + Parameters + ---------- + apply_along_dim + Spatial dimension indices to use when applying the module. + These will be moved to the last dimensions. + module + Module to apply on the selected dims. + features_last + If True, the features dimension is assumed to be the last dimension, as common in transformer models. + """ + super().__init__() + self.apply_along_dim = tuple(sorted(apply_along_dim)) + self.module = module + self.features_last = features_last + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module along the selected dimensions. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor, passed to the module if it supports conditioning + (that is, if it is a subclass of `~mrpro.nn.CondMixin`) + + Returns + ------- + Output tensor. + """ + return self.forward(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module along the selected dimensions.""" + keep = tuple(d % x.ndim for d in self.apply_along_dim) + if 0 in keep: + raise ValueError('Batch dimension should not be in apply_along_dim.') + if self.features_last: + if x.ndim - 1 in keep: + raise ValueError('Features dimension should not be in apply_along_dim.') + batch_dim = tuple(d for d in range(1, x.ndim - 1) if d not in keep) + permute = (0, *batch_dim, *keep, x.ndim - 1) + else: + if 1 in keep: + raise ValueError('Features dimension should not be in apply_along_dim.') + batch_dim = tuple(d for d in range(2, x.ndim) if d not in keep) + permute = (0, *batch_dim, 1, *keep) + h = x.permute(permute) + batch_shape = h.shape[: 1 + len(batch_dim)] + h = h.flatten(0, len(batch_dim)) + h = call_with_cond(self.module, h, cond=cond) + h = h.unflatten(0, batch_shape) + permute_back = [0] * x.ndim + for i, p in enumerate(permute): + permute_back[p] = i + return h.permute(tuple(permute_back)) diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py new file mode 100644 index 000000000..afcff9335 --- /dev/null +++ b/src/mrpro/nn/PixelShuffle.py @@ -0,0 +1,285 @@ +"""ND-version of PixelShuffle and PixelUnshuffle.""" + +from math import ceil + +import torch +from torch.nn import Linear, Module + +from mrpro.nn.ndmodules import convND + + +class PixelUnshuffle(Module): + """ND-version of PixelUnshuffle downscaling.""" + + def __init__(self, downscale_factor: int, features_last: bool = False): + """Initialize PixelUnshuffle. + + Reduces spatial dimensions and increases the channel number by reshaping. + The first dimension is considered a batch dimension, the second dimension + the channel dimension, and the remaining dimensions the spatial dimensions that are downscaled. + + See `mrpro.nn.PixelShuffle` for the inverse operation. + + Parameters + ---------- + downscale_factor + The factor by which to downscale the input tensor. + features_last + Whether the features/channels dimension is the last dimension as common in transformer models or the + second dimension as common in CNN models. + """ + super().__init__() + self.downscale_factor = downscale_factor + self.features_last = features_last + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Downscale the input. + + Parameters + ---------- + x + Tensor of shape `batch, channels, *spatial_dims` or `batch, *spatial_dims, channels` (if `features_last`). + + Returns + ------- + Tensor of shape `batch, channels * downscale_factor**dim, *spatial_dims/downscale_factor` or + `batch, *spatial_dims/downscale_factor, channels * downscale_factor**dim` (if `features_last`). + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Downscale the input.""" + n_dim = x.ndim - 2 + if n_dim == 2 and not self.features_last: # fast path for 2D images + return torch.nn.functional.pixel_unshuffle(x, self.downscale_factor) + + new_shape = list(x.shape[:1]) if self.features_last else list(x.shape[:2]) + source_positions = [] + for i, old in enumerate(x.shape[1:-1] if self.features_last else x.shape[2:]): + if old % self.downscale_factor: + raise ValueError('Spatial size must be divisible by downscale_factor.') + new_shape.append(old // self.downscale_factor) + new_shape.append(self.downscale_factor) + source_positions.append(2 + 2 * i) + if self.features_last: + new_shape.append(x.shape[-1]) + x = x.view(new_shape) + x = x.moveaxis(source_positions, tuple(range(-n_dim, 0))) + if self.features_last: + x = x.flatten(-n_dim - 1) + else: + x = x.flatten(1, -n_dim - 1) + + return x + + +class PixelUnshuffleDownsample(Module): + """PixelUnshuffle Downsampling. + + PixelUnshuffle followed by a convolution. Optionally uses a residual connection [DCAE]_ + + References + ---------- + .. [DCAE] Chen et al. Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models. ICLR 2025 + https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + downscale_factor: int = 2, + residual: bool = False, + features_last: bool = False, + ): + """Initialize a PixelUnshuffleDownsample layer. + + Parameters + ---------- + n_dim + Dimension of the input tensor. + n_channels_in + Number of channels in the input tensor. + n_channels_out + Number of channels in the output tensor. + downscale_factor + Factor by which to downscale the input tensor. + residual + Whether to use a residual connection as proposed in [DCAE]_. + features_last + Whether the features are last in the input tensor, as common in transformer models, + or in the second dimension, as common in CNNs. + """ + super().__init__() + out_ratio = downscale_factor**n_dim + if n_channels_out % out_ratio != 0: + raise ValueError(f'channels_out must be divisible by downscale_factor**{n_dim}.') + if features_last: + self.projection: Module = Linear(n_channels_in, n_channels_out // out_ratio) + else: + self.projection = convND(n_dim)(n_channels_in, n_channels_out // out_ratio, kernel_size=3, padding='same') + self.features_last = features_last + self.residual = residual + self.pixel_unshuffle = PixelUnshuffle(downscale_factor, features_last) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply downsampling. + + Parameters + ---------- + x + Tensor of shape `batch, channels_in, *spatial_dims` + + Returns + ------- + Tensor of shape `batch, channels_out, *spatial_dims/downscale_factor` + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply downsampling.""" + h = self.projection(x) + h = self.pixel_unshuffle(h) + + if self.residual: + x = self.pixel_unshuffle(x) + if self.features_last: + n = (x.shape[-1] // h.shape[-1]) * h.shape[-1] + h = h + x[..., :n].unflatten(-1, (h.shape[-1], -1)).mean(-1) + else: + n = (x.shape[1] // h.shape[1]) * h.shape[1] + h = h + x[:, :n].unflatten(1, (h.shape[1], -1)).mean(2) + return h + + +class PixelShuffleUpsample(Module): + """PixelShuffle Upsampling. + + Convolution followed by PixelShuffle. Optionally uses a residual connection [DCAE]_ + + References + ---------- + .. [DCAE] Chen et al. Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models. ICLR 2025 + https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + upscale_factor: int = 2, + residual: bool = False, + features_last: bool = False, + ): + """Initialize a PixelShuffleUpsample layer. + + Parameters + ---------- + n_dim + Dimension of the input tensor. + n_channels_in + Number of channels in the input tensor. + n_channels_out + Number of channels in the output tensor. + upscale_factor + Factor by which to upscale the input tensor. + residual + Whether to use a residual connection as proposed in [DCAE]_. + features_last + Whether the features are last in the input tensor, as common in transformer models, + or in the second dimension, as common in CNNs. + """ + super().__init__() + if features_last: + self.projection: Module = Linear(n_channels_in, n_channels_out * upscale_factor**n_dim) + else: + self.projection = convND(n_dim)( + n_channels_in, n_channels_out * upscale_factor**n_dim, kernel_size=3, padding='same' + ) + self.features_last = features_last + self.pixel_shuffle = PixelShuffle(upscale_factor, features_last) + self.residual = residual + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply upsampling. + + Parameters + ---------- + x + Tensor of shape `batch, channels_in, *spatial_dims` + + Returns + ------- + Tensor of shape `batch, channels_out, *spatial_dims * upscale_factor` + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply upsampling.""" + h = self.projection(x) + if self.residual: + if self.features_last: + h = h + x.repeat_interleave(ceil(h.shape[-1] / x.shape[-1]), dim=-1)[..., : h.shape[-1]] + else: + h = h + x.repeat_interleave(ceil(h.shape[1] / x.shape[1]), dim=1)[:, : h.shape[1]] + out = self.pixel_shuffle(h) + return out + + +class PixelShuffle(Module): + """ND-version of PixelShuffle upscaling.""" + + def __init__(self, upscale_factor: int, features_last: bool = False): + """Initialize PixelShuffle. + + Upscales spatial dimensions and decreases the channel number by reshaping. + The first dimension is considered a batch dimension, the second dimension + the channel dimension, and the remaining dimensions the spatial dimensions that are upscaled. + + See `mrpro.nn.PixelUnshuffle` for the inverse operation. + + Parameters + ---------- + upscale_factor + The factor by which to upscale the spatial dimensions. + features_last + Whether the features/channels dimension is the last dimension as common in transformer models or the + second dimension as common in CNN models. + """ + super().__init__() + self.upscale_factor = upscale_factor + self.features_last = features_last + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Upscale the input. + + Parameters + ---------- + x + Tensor of shape `batch, channels, *spatial_dims` or `batch, *spatial_dims, channels` (if `features_last`). + + Returns + ------- + Tensor of shape `batch, channels / upscale_factor**n_dim, *spatial_dims * upscale_factor` or + `batch, *spatial_dims * upscale_factor, channels / upscale_factor**n_dim` (if `features_last`). + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Upscale the input.""" + n_dim = x.ndim - 2 + if n_dim == 2 and not self.features_last: # fast path for 2D + return torch.nn.functional.pixel_shuffle(x, self.upscale_factor) + + if self.features_last: + new_shape = (x.shape[0], *(old * self.upscale_factor for old in x.shape[-n_dim - 1 : -1]), -1) + x = x.unflatten(-1, (-1, *(self.upscale_factor,) * n_dim)) + x = x.moveaxis(tuple(range(-n_dim, 0)), tuple(range(-2 * n_dim, 0, 2))) + else: + new_shape = (x.shape[0], -1, *(old * self.upscale_factor for old in x.shape[-n_dim:])) + x = x.unflatten(1, (-1, *(self.upscale_factor,) * n_dim)) + x = x.moveaxis(tuple(range(2, 2 + n_dim)), tuple(range(-2 * n_dim + 1, 0, 2))) + x = x.reshape(new_shape) + return x diff --git a/src/mrpro/nn/RMSNorm.py b/src/mrpro/nn/RMSNorm.py new file mode 100644 index 000000000..b97641545 --- /dev/null +++ b/src/mrpro/nn/RMSNorm.py @@ -0,0 +1,73 @@ +"""RMSNorm over the channel dimension.""" + +import torch +from torch.nn import Module, Parameter + + +class RMSNorm(Module): + """RMSNorm over the channel dimension. + + As used in the DCAE [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_channels: int | None = None, + eps: float = 1e-8, + features_last: bool = False, + ): + """Initialize RMSNorm. + + Includes a learnable weight and bias if n_channels is provided. + + Parameters + ---------- + n_channels + Number of channels. If `None`, no learnable weight and bias are included. + eps + Epsilon value to avoid division by zero. + features_last + If True, the channel dimension is the last dimension. + """ + super().__init__() + if n_channels is not None: + self.weight: Parameter | None = Parameter(torch.zeros(n_channels)) + self.bias: Parameter | None = Parameter(torch.zeros(n_channels)) + else: + self.weight = None + self.bias = None + self.eps = eps + self.channel_dim = -1 if features_last else 1 + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply RMSNorm over the channel dimension. + + Parameters + ---------- + x + Input tensor. + + Returns + ------- + Normalized tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply RMSNorm over the channel dimension.""" + x32 = x.to(torch.float32) # normalization in float32 to stabilize mixed precision training + mean_square = x32.square().mean(dim=self.channel_dim, keepdim=True) + scale = (mean_square + self.eps).rsqrt() + x32 = x32 * scale + if self.weight is not None and self.bias is not None: + shape = [1] * x.ndim + shape[self.channel_dim] = -1 + weight = (self.weight.to(x32.dtype) + 1).view(shape) + bias = self.bias.view(shape) + x32 = x32 * weight + bias + return x32.to(x.dtype) diff --git a/src/mrpro/nn/Residual.py b/src/mrpro/nn/Residual.py new file mode 100644 index 000000000..e524fe169 --- /dev/null +++ b/src/mrpro/nn/Residual.py @@ -0,0 +1,45 @@ +"""Residual connection.""" + +import torch +from torch.nn import Identity, Module + +from mrpro.nn.CondMixin import CondMixin, call_with_cond + + +class Residual(CondMixin, Module): + """Residual connection.""" + + def __init__(self, module: Module, skip: Module | None = None): + """Initialize the residual connection. + + Parameters + ---------- + module + The main path of the residual connection. + skip + The skip path of the residual connection. If None, the identity function is used. + """ + super().__init__() + self.module = module + self.skip = Identity() if skip is None else skip + + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module. + + Parameters + ---------- + x + The input tensor. + cond + The optional conditioning tensor. If the modules are an instance of `CondMixin`, + the conditioning is passed to the modules. + + Returns + ------- + The output tensor. + """ + return super().__call__(*x, cond=cond) + + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module.""" + return call_with_cond(self.module, *x, cond=cond) + call_with_cond(self.skip, *x, cond=cond) diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py new file mode 100644 index 000000000..fed22c3a2 --- /dev/null +++ b/src/mrpro/nn/Sequential.py @@ -0,0 +1,60 @@ +"""Sequential container with support for conditioning and Operators.""" + +from collections import OrderedDict +from typing import cast + +import torch + +from mrpro.nn.CondMixin import CondMixin +from mrpro.operators import Operator + + +class Sequential(CondMixin, torch.nn.Sequential): + """Sequential container with support for conditioning and Operators. + + Allows multiple input tensors and a single output tensor of the sequential block. + + """ + + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply all modules in series to the input. + + Parameters + ---------- + x + The input tensor. + cond + The (optional) conditioning tensor. + + Returns + ------- + The output tensor. + """ + return torch.nn.Sequential.__call__(self, *x, cond=cond) + + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply all modules in series to the input.""" + for module in self: + if isinstance(module, Operator): + x = cast(tuple[torch.Tensor, ...], module(*x)) # always tuple + else: + ret: torch.Tensor | tuple[torch.Tensor, ...] + if isinstance(module, CondMixin): + ret = module(*x, cond=cond) + else: + ret = module(*x) + if isinstance(ret, tuple): + x = ret + else: + x = (ret,) + return x[0] + + def __getitem__(self, idx: slice | int) -> 'Sequential': + """Get a slice or item from the Sequential container. + + Subclasses will decompose to `Sequential` on indexing. + """ + if isinstance(idx, slice): + return Sequential(OrderedDict(list(self._modules.items())[idx])) + else: + return cast(Sequential, self._get_item_by_idx(self._modules.values(), idx)) diff --git a/src/mrpro/nn/Upsample.py b/src/mrpro/nn/Upsample.py new file mode 100644 index 000000000..acced8d48 --- /dev/null +++ b/src/mrpro/nn/Upsample.py @@ -0,0 +1,66 @@ +"""Upsampling by interpolation.""" + +from collections.abc import Sequence +from typing import Literal + +import torch +from torch.nn import Module, Sequential + +from mrpro.nn.PermutedBlock import PermutedBlock + + +class Upsample(Module): + """Upsampling by interpolation.""" + + def __init__( + self, dim: Sequence[int], scale_factor: int = 2, mode: Literal['nearest', 'linear', 'cubic'] = 'linear' + ): + """Initialize the upsampling layer. + + Parameters + ---------- + dim + Dimensions which to upsample + scale_factor + Factor by which to upsample + mode + Interpolation mode. See `torch.nn.functional.interpolate` for details. + """ + super().__init__() + self.scale_factor = scale_factor + if mode == 'nearest': + dims = [d.tolist() for d in torch.tensor(dim).split(3)] + modes = ['nearest'] * len(dim) + elif mode == 'linear': + dims = [d.tolist() for d in torch.tensor(dim).split(3)] + modes = [{1: 'linear', 2: 'bilinear', 3: 'trilinear'}[len(d)] for d in dims] + elif mode == 'cubic': + if not len(dim) == 2: + raise ValueError('Cubic interpolation is only supported for 2D images.') + dims = [tuple(dim)] + modes = ['bicubic'] + + self.blocks = Sequential( + *[ + PermutedBlock(d, torch.nn.Upsample(scale_factor=len(d) * (scale_factor,), mode=m)) + for d, m in zip(dims, modes, strict=False) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Upsample the input tensor.""" + return self.blocks(x) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Upsample the input tensor. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Upsampled tensor + """ + return super().__call__(x) diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py new file mode 100644 index 000000000..d6541f5c8 --- /dev/null +++ b/src/mrpro/nn/__init__.py @@ -0,0 +1,45 @@ +"""Neural network modules and utilities.""" + +from mrpro.nn.ComplexAsChannel import ComplexAsChannel +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.DropPath import DropPath +from mrpro.nn.FiLM import FiLM +from mrpro.nn.FourierFeatures import FourierFeatures +from mrpro.nn.GEGLU import GEGLU +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.RMSNorm import RMSNorm +from mrpro.nn.Residual import Residual +from mrpro.nn.Sequential import Sequential +from mrpro.nn.ndmodules import ( + adaptiveAvgPoolND, + avgPoolND, + batchNormND, + convND, + convTransposeND, + instanceNormND, + maxPoolND, +) + +__all__ = [ + 'ComplexAsChannel', + 'CondMixin', + 'DropPath', + 'FiLM', + 'FourierFeatures', + 'GEGLU', + 'GroupNorm', + 'LayerNorm', + 'PermutedBlock', + 'RMSNorm', + 'Residual', + 'Sequential', + 'adaptiveAvgPoolND', + 'avgPoolND', + 'batchNormND', + 'convND', + 'convTransposeND', + 'instanceNormND', + 'maxPoolND', +] diff --git a/src/mrpro/nn/convert_linear_conv.py b/src/mrpro/nn/convert_linear_conv.py new file mode 100644 index 000000000..767a419ff --- /dev/null +++ b/src/mrpro/nn/convert_linear_conv.py @@ -0,0 +1,100 @@ +"""Convert Linear layers to kernel size 1 ConvNd layers and vice versa.""" + +from typing import Literal, overload + +import torch +from torch.nn import Conv1d, Conv2d, Conv3d, Linear + +from mrpro.nn.ndmodules import convND + + +@overload +def linear_to_conv(linear_layer: Linear, n_dim: Literal[1]) -> Conv1d: ... + + +@overload +def linear_to_conv(linear_layer: Linear, n_dim: Literal[2]) -> Conv2d: ... + + +@overload +def linear_to_conv(linear_layer: Linear, n_dim: Literal[3]) -> Conv3d: ... + + +@overload +def linear_to_conv(linear_layer: Linear, n_dim: int) -> Conv1d | Conv2d | Conv3d: ... + + +def linear_to_conv(linear_layer: Linear, n_dim: int) -> Conv1d | Conv2d | Conv3d: + """Convert a Linear layer to a ConvNd layer with kernel size 1. + + Rearranging the spatial dimensions to the batch dimension, + applying the linear layer and rearranging the spatial dimensions back + is equivalent to applying a kernel size 1 ConvNd layer. + + This function will create the Conv1d, Conv2d, or Conv3d with the correct weights and bias. + + See :func:`conv_to_linear` for the reverse operation. + + + + Parameters + ---------- + linear_layer + The linear layer to convert. + n_dim + The convolution dimension (1, 2, or 3). + + Returns + ------- + A Conv layer with equivalent weights and bias. + """ + conv = convND(n_dim)( + in_channels=linear_layer.in_features, + out_channels=linear_layer.out_features, + kernel_size=1, + bias=linear_layer.bias is not None, + device=linear_layer.weight.device, + dtype=linear_layer.weight.dtype, + ) + + with torch.no_grad(): + conv.weight.copy_(linear_layer.weight.view_as(conv.weight)) + if conv.bias is not None and linear_layer.bias is not None: + conv.bias.copy_(linear_layer.bias) + + return conv + + +def conv_to_linear(conv_layer: Conv1d | Conv2d | Conv3d) -> Linear: + """ + Convert a Conv1d, Conv2d, or Conv3d layer with kernel size 1 to a Linear layer. + + Applying a kernel size 1 ConvNd layer is equivalent to applying a Linear layer to each voxel. + This function will create the Linear layer with the correct weights and bias. + + See :func:`linear_to_conv` for the reverse operation. + + Parameters + ---------- + conv_layer : nn.Module + The convolutional layer to convert. Must have kernel size 1. + + Returns + ------- + A linear layer with equivalent weights and bias. + """ + if not all(k == 1 for k in conv_layer.kernel_size): + raise ValueError('Kernel size must be 1 for conversion.') + linear = Linear( + conv_layer.in_channels, + conv_layer.out_channels, + bias=conv_layer.bias is not None, + device=conv_layer.weight.device, + dtype=conv_layer.weight.dtype, + ) + with torch.no_grad(): + linear.weight.copy_(conv_layer.weight.view_as(linear.weight)) + if linear.bias is not None and conv_layer.bias is not None: + linear.bias.copy_(conv_layer.bias) + + return linear diff --git a/src/mrpro/nn/join.py b/src/mrpro/nn/join.py new file mode 100644 index 000000000..d98aeb7b2 --- /dev/null +++ b/src/mrpro/nn/join.py @@ -0,0 +1,176 @@ +"""Modules for concatenating or adding tensors.""" + +from collections.abc import Sequence +from typing import Literal + +import torch +from torch.nn import Module + +from mrpro.utils.interpolate import interpolate +from mrpro.utils.pad_or_crop import pad_or_crop + + +def _fix_shapes( + xs: Sequence[torch.Tensor], + mode: str, + dim: Sequence[int], +) -> tuple[torch.Tensor, ...]: + """Fix shapes of input tensors by padding or cropping.""" + if mode == 'fail': + return tuple(xs) + + shapes = [[x.shape[d] for d in dim] for x in xs] + if mode == 'crop': # smallest as target + target = tuple(min(s) for s in zip(*shapes, strict=True)) + else: # largest as target + target = tuple(max(s) for s in zip(*shapes, strict=True)) + if mode == 'linear' or mode == 'nearest': + return tuple(interpolate(x, target, dim=dim, mode=mode) for x in xs) # type: ignore[arg-type] + if mode == 'zero' or mode == 'crop': + return tuple(pad_or_crop(x, target, dim=dim, mode='constant', value=0.0) for x in xs) + else: + return tuple(pad_or_crop(x, target, dim=dim, mode=mode) for x in xs) # type: ignore[arg-type] + + +class Concat(Module): + """Concatenate tensors along the channel dimension.""" + + def __init__( + self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'] = 'fail', dim: int = 1 + ) -> None: + """Initialize Concat. + + Parameters + ---------- + mode + How to handle mismatched dimensions: + - 'fail': do not align, raise error if shapes mismatch + - 'crop': center-crop to smallest spatial size + - 'zero': zero-pad to largest spatial size + - 'replicate': pad by edge value replication + - 'circular': circular padding + - 'linear': linear interpolation to largest spatial size + - 'nearest': nearest neighbor interpolation to largest spatial size + dim + Dimension along which to concatenate. + """ + super().__init__() + modes = {'fail', 'crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'} + if mode not in modes: + raise ValueError(f'mode must be one of {modes}') + self.mode = mode + self.dim = dim + + def forward(self, *xs: torch.Tensor) -> torch.Tensor: + """Concatenate input tensors.""" + xs = _fix_shapes(xs, self.mode, dim=[i for i in range(max(x.ndim for x in xs)) if i != self.dim]) + return torch.cat(xs, dim=self.dim) + + def __call__(self, *xs: torch.Tensor) -> torch.Tensor: + """ + Concatenate input tensors. + + Parameters + ---------- + xs + Input tensors + + Returns + ------- + Concatenated tensor + """ + return super().__call__(*xs) + + +class Add(Module): + """Add tensors.""" + + def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'] = 'fail') -> None: + """Initialize Add. + + Parameters + ---------- + mode + How to handle mismatched dimensions: + - 'fail': do not align, raise error if shapes mismatch + - 'crop': center-crop to smallest spatial size + - 'zero': zero-pad to largest spatial size + - 'replicate': pad by edge value replication + - 'circular': circular padding + """ + super().__init__() + modes = {'fail', 'crop', 'zero', 'replicate', 'circular'} + if mode not in modes: + raise ValueError(f'mode must be one of {modes}') + self.mode = mode + + def forward(self, *xs: torch.Tensor) -> torch.Tensor: + """Add input tensors.""" + xs = _fix_shapes(xs, self.mode, dim=range(max(x.ndim for x in xs))) + return sum(xs, start=torch.tensor(0.0)) + + def __call__(self, *xs: torch.Tensor) -> torch.Tensor: + """ + Add input tensors. + + Parameters + ---------- + xs + Input tensors + + Returns + ------- + Summed tensor + """ + return super().__call__(*xs) + + +class Interpolate(Module): + """Linear interpolate between two tensors. + + As suggestions for the Hourglass Transformer [CR]_ + + References + ---------- + .. [CK] Crowson, Katherine, et al. "Scalable high-resolution pixel-space image synthesis with + hourglass diffusion transformers." ICML 2024, https://arxiv.org/abs/2401.11605 + """ + + def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'] = 'fail') -> None: + """Initialize learned linear interpolation. + + Parameters + ---------- + mode + How to handle mismatched dimensions: + - 'fail': do not align, raise error if shapes mismatch + - 'crop': center-crop to smallest spatial size + - 'zero': zero-pad to largest spatial size + - 'replicate': pad by edge value replication + - 'circular': circular padding + """ + super().__init__() + modes = {'fail', 'crop', 'zero', 'replicate', 'circular'} + if mode not in modes: + raise ValueError(f'mode must be one of {modes}') + self.mode = mode + self.weight = torch.nn.Parameter(torch.tensor(0.5)) + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """Linear interpolate between two tensors.""" + x1, x2 = _fix_shapes((x1, x2), self.mode, dim=range(max(x.ndim for x in (x1, x2)))) + return x1 * self.weight + x2 * (1 - self.weight) + + def __call__(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """Linear interpolate between two tensors. + + Parameters + ---------- + x1, x2 + Input tensors + + Returns + ------- + Interpolated tensor + """ + return super().__call__(x1, x2) diff --git a/src/mrpro/nn/ndmodules.py b/src/mrpro/nn/ndmodules.py new file mode 100644 index 000000000..3fb2d894c --- /dev/null +++ b/src/mrpro/nn/ndmodules.py @@ -0,0 +1,178 @@ +"""Helper functions to get the correct N-dimensional module.""" + +import torch + + +def convND(n_dim: int) -> type[torch.nn.Conv1d] | type[torch.nn.Conv2d] | type[torch.nn.Conv3d]: # noqa: N802 + """Get the `n_dim`-dimensional convolution class. + + Parameters + ---------- + n_dim + The dimension of the convolution. + + Returns + ------- + The convolution class. + """ + match n_dim: + case 1: + return torch.nn.Conv1d + case 2: + return torch.nn.Conv2d + case 3: + return torch.nn.Conv3d + case _: + raise NotImplementedError(f'ConvND for dim {n_dim} not implemented. Raise an issue if you need this.') + + +def convTransposeND( # noqa: N802 + n_dim: int, +) -> type[torch.nn.ConvTranspose1d] | type[torch.nn.ConvTranspose2d] | type[torch.nn.ConvTranspose3d]: + """Get the `n_dim`-dimensional transposed convolution class. + + Parameters + ---------- + n_dim + The dimension of the transposed convolution. + + Returns + ------- + The transposed convolution class. + """ + match n_dim: + case 1: + return torch.nn.ConvTranspose1d + case 2: + return torch.nn.ConvTranspose2d + case 3: + return torch.nn.ConvTranspose3d + case _: + raise NotImplementedError( + f'ConvTransposeND for dim {n_dim} not implemented. Raise an issue if you need this.' + ) + + +def maxPoolND(n_dim: int) -> type[torch.nn.MaxPool1d] | type[torch.nn.MaxPool2d] | type[torch.nn.MaxPool3d]: # noqa: N802 + """Get the `n_dim`-dimensional max pooling class. + + Parameters + ---------- + n_dim + The dimension of the max pooling. + + Returns + ------- + The max pooling class. + """ + match n_dim: + case 1: + return torch.nn.MaxPool1d + case 2: + return torch.nn.MaxPool2d + case 3: + return torch.nn.MaxPool3d + case _: + raise NotImplementedError(f'MaxPoolNd for dim {n_dim} not implemented. Raise an issue if you need this.') + + +def avgPoolND(n_dim: int) -> type[torch.nn.AvgPool1d] | type[torch.nn.AvgPool2d] | type[torch.nn.AvgPool3d]: # noqa: N802 + """Get the `n_dim`-dimensional average pooling class. + + Parameters + ---------- + n_dim + The dimension of the average pooling. + + Returns + ------- + The average pooling class. + """ + match n_dim: + case 1: + return torch.nn.AvgPool1d + case 2: + return torch.nn.AvgPool2d + case 3: + return torch.nn.AvgPool3d + case _: + raise NotImplementedError(f'AvgPoolNd for dim {n_dim} not implemented. Raise an issue if you need this.') + + +def adaptiveAvgPoolND( # noqa: N802 + n_dim: int, +) -> type[torch.nn.AdaptiveAvgPool1d] | type[torch.nn.AdaptiveAvgPool2d] | type[torch.nn.AdaptiveAvgPool3d]: + """Get the `n_dim`-dimensional adaptive average pooling class. + + Parameters + ---------- + n_dim + The dimension of the adaptive average pooling. + + Returns + ------- + The adaptive average pooling class. + """ + match n_dim: + case 1: + return torch.nn.AdaptiveAvgPool1d + case 2: + return torch.nn.AdaptiveAvgPool2d + case 3: + return torch.nn.AdaptiveAvgPool3d + case _: + raise NotImplementedError( + f'AdaptiveAvgPoolNd for dim {n_dim} not implemented. Raise an issue if you need this.' + ) + + +def instanceNormND( # noqa: N802 + n_dim: int, +) -> type[torch.nn.InstanceNorm1d] | type[torch.nn.InstanceNorm2d] | type[torch.nn.InstanceNorm3d]: + """Get the `n_dim`-dimensional instance normalization class. + + Parameters + ---------- + n_dim + The dimension of the instance normalization. + + Returns + ------- + The instance normalization class. + """ + match n_dim: + case 1: + return torch.nn.InstanceNorm1d + case 2: + return torch.nn.InstanceNorm2d + case 3: + return torch.nn.InstanceNorm3d + case _: + raise NotImplementedError( + f'InstanceNormNd for dim {n_dim} not implemented. Raise an issue if you need this.' + ) + + +def batchNormND( # noqa: N802 + n_dim: int, +) -> type[torch.nn.BatchNorm1d] | type[torch.nn.BatchNorm2d] | type[torch.nn.BatchNorm3d]: + """Get the `n_dim`-dimensional batch normalization class. + + Parameters + ---------- + n_dim + The dimension of the batch normalization. + + Returns + ------- + The batch normalization class. + """ + match n_dim: + case 1: + return torch.nn.BatchNorm1d + case 2: + return torch.nn.BatchNorm2d + case 3: + return torch.nn.BatchNorm3d + case _: + raise NotImplementedError(f'BatchNormNd for dim {n_dim} not implemented. Raise an issue if you need this.') diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index 345883e12..2d4eceb2f 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -15,8 +15,10 @@ from mrpro.utils.TensorAttributeMixin import TensorAttributeMixin from mrpro.utils.interpolate import interpolate, apply_lowres from mrpro.utils.RandomGenerator import RandomGenerator - +from mrpro.utils.to_tuple import to_tuple +from mrpro.utils.ema import EMADict __all__ = [ + "EMADict", "Indexer", "RandomGenerator", "TensorAttributeMixin", @@ -38,6 +40,7 @@ "split_idx", "summarize_object", "summarize_values", + "to_tuple", "typing", "unit_conversion", "unsqueeze_at", diff --git a/src/mrpro/utils/to_tuple.py b/src/mrpro/utils/to_tuple.py new file mode 100644 index 000000000..657d7bf56 --- /dev/null +++ b/src/mrpro/utils/to_tuple.py @@ -0,0 +1,36 @@ +"""Standardize an argument to a fixed-length tuple.""" + +from collections.abc import Sequence +from typing import TypeVar + +T = TypeVar('T') + + +def to_tuple(length: int, arg: Sequence[T] | T) -> tuple[T, ...]: + """Standardize an argument to a fixed-length tuple. + + If the argument is a sequence, it checks if its length matches the + specified dimension. If it's a single value, it replicates it `dim` times. + + Parameters + ---------- + length + The expected length of the sequence. + arg + The argument to check. Can be a single value of type T or a + sequence of T. + + Returns + ------- + A tuple of length `dim` containing elements of type T. + + Raises + ------ + ValueError + If `arg` is a sequence and its length does not match `length`. + """ + if isinstance(arg, Sequence): + if not len(arg) == length: + raise ValueError(f'The arguments must be either single values or have length {length}. Got {arg}.') + return tuple(arg) + return (arg,) * length diff --git a/tests/nn/test_complexaschannel.py b/tests/nn/test_complexaschannel.py new file mode 100644 index 000000000..37889f654 --- /dev/null +++ b/tests/nn/test_complexaschannel.py @@ -0,0 +1,30 @@ +"""Tests for ComplexAsChannel module.""" + +import pytest +from mrpro.nn.ComplexAsChannel import ComplexAsChannel +from mrpro.utils import RandomGenerator +from torch.nn import Linear + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_complexaschannel(device: str) -> None: + """Test ComplexAsChannel output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + input_shape = (1, 32) + x = rng.complex64_tensor(input_shape).to(device).requires_grad_(True) + module = ComplexAsChannel(Linear(input_shape[1] * 2, input_shape[1] * 2)).to(device) + output = module(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + assert output.is_complex(), 'Output is not complex' + output.sum().abs().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert module.module.weight.grad is not None, 'No gradient computed for weight' + assert module.module.bias.grad is not None, 'No gradient computed for bias' diff --git a/tests/nn/test_convert_linear_conv.py b/tests/nn/test_convert_linear_conv.py new file mode 100644 index 000000000..19438b9d9 --- /dev/null +++ b/tests/nn/test_convert_linear_conv.py @@ -0,0 +1,150 @@ +"""Tests for converting between Linear and Conv layers.""" + +from typing import Literal + +import pytest +import torch +from mrpro.nn.convert_linear_conv import conv_to_linear, linear_to_conv +from mrpro.utils import RandomGenerator +from torch.nn import Conv1d, Conv2d, Conv3d, Linear + +DEVICES = pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +SHAPES = pytest.mark.parametrize( + ('dim', 'channels_in', 'channels_out', 'bias'), + [ + (1, 32, 64, True), + (2, 16, 32, True), + (3, 8, 16, True), + (3, 1, 1, False), + ], + ids=['1d', '2d', '3d', '3d_no_bias'], +) + + +@SHAPES +@DEVICES +def test_linear_to_conv(device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test converting Linear to Conv layer.""" + rng = RandomGenerator(seed=42) + linear = Linear(channels_in, channels_out, bias=bias).to(device) + linear.weight.data = rng.rand_like(linear.weight) + if bias: + linear.bias.data = rng.rand_like(linear.bias) + + conv = linear_to_conv(linear, dim) + assert isinstance(conv, (Conv1d, Conv2d, Conv3d)[dim - 1]) + + assert conv.in_channels == channels_in + assert conv.out_channels == channels_out + assert conv.kernel_size == (1,) * dim + assert conv.bias is not None if bias else conv.bias is None + + assert conv.weight.device.type == device + if conv.bias is not None: + assert conv.bias.device.type == device + + +@SHAPES +def test_linear_to_conv_functional(dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test functional equivalence of Linear to Conv conversion.""" + rng = RandomGenerator(seed=42) + linear = Linear(channels_in, channels_out, bias=bias) + linear.weight.data = rng.rand_like(linear.weight) + if bias: + linear.bias.data = rng.rand_like(linear.bias) + + conv = linear_to_conv(linear, dim) + spatial_shape = (4,) * dim + x = rng.randn_tensor((2, channels_in, *spatial_shape), torch.float32) + + y_conv = conv(x) + y_conv = y_conv.moveaxis(1, -1).flatten(0, -2) + + x_reshaped = x.moveaxis(1, -1).flatten(0, -2) + y_linear = linear(x_reshaped) + + torch.testing.assert_close(y_conv, y_linear) + + +@SHAPES +@DEVICES +def test_conv_to_linear(device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test converting Conv layer to Linear.""" + rng = RandomGenerator(seed=42) + conv_class = (Conv1d, Conv2d, Conv3d)[dim - 1] + conv = conv_class(channels_in, channels_out, kernel_size=1, bias=bias).to(device) + conv.weight.data = rng.rand_like(conv.weight) + if conv.bias is not None: + conv.bias.data = rng.rand_like(conv.bias) + + linear = conv_to_linear(conv) + + assert isinstance(linear, Linear) + assert linear.in_features == channels_in + assert linear.out_features == channels_out + assert linear.bias is not None if bias else linear.bias is None + + assert linear.weight.device.type == device + if bias: + assert linear.bias.device.type == device + + +@SHAPES +def test_conv_to_linear_functional(dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test functional equivalence of Conv to Linear conversion.""" + rng = RandomGenerator(seed=42) + conv_class = (Conv1d, Conv2d, Conv3d)[dim - 1] + conv = conv_class(channels_in, channels_out, kernel_size=1, bias=bias) + conv.weight.data = rng.rand_like(conv.weight) + if conv.bias is not None: + conv.bias.data = rng.rand_like(conv.bias) + + linear = conv_to_linear(conv) + spatial_shape = (4,) * dim + + x = rng.randn_tensor((2, channels_in, *spatial_shape), torch.float32) + y_conv = conv(x) + y_conv = y_conv.moveaxis(1, -1).flatten(0, -2) + + x_reshaped = x.moveaxis(1, -1).flatten(0, -2) + y_linear = linear(x_reshaped) + + torch.testing.assert_close(y_conv, y_linear) + + +def test_conv_to_linear_invalid_kernel() -> None: + """Test conv_to_linear with invalid kernel size.""" + conv = Conv2d(32, 64, kernel_size=3, bias=True) + with pytest.raises(ValueError, match='Kernel size must be 1'): + conv_to_linear(conv) + + +@SHAPES +@DEVICES +def test_round_trip_conversion( + device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool +) -> None: + """Test round-trip conversion between Linear and Conv layers.""" + rng = RandomGenerator(seed=42) + + linear1 = Linear(channels_in, channels_out, bias=bias).to(device) + linear1.weight.data = rng.rand_like(linear1.weight) + if bias: + linear1.bias.data = rng.rand_like(linear1.bias) + + conv = linear_to_conv(linear1, dim) + linear2 = conv_to_linear(conv) + + assert linear2.in_features == channels_in + assert linear2.out_features == channels_out + assert linear2.bias is not None if bias else linear2.bias is None + + torch.testing.assert_close(linear2.weight, linear1.weight) + if bias: + torch.testing.assert_close(linear2.bias, linear1.bias) diff --git a/tests/nn/test_droppath.py b/tests/nn/test_droppath.py new file mode 100644 index 000000000..b4ac7f5d7 --- /dev/null +++ b/tests/nn/test_droppath.py @@ -0,0 +1,30 @@ +"""Test DropPath.""" + +import pytest +from mrpro.nn.DropPath import DropPath +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_droppath_no_drop(device: str) -> None: + """Test DropPath with zero drop rate (should pass through unchanged).""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).to(device) + droppath = DropPath(0).to(device) + y = droppath(x) + assert (y == x).all() + + +def test_droppath_drop_all() -> None: + """Test DropPath with full drop rate (should output zeros).""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)) + droppath = DropPath(1.0) + y = droppath(x) + assert (y == 0).all() diff --git a/tests/nn/test_film.py b/tests/nn/test_film.py new file mode 100644 index 000000000..0e564c675 --- /dev/null +++ b/tests/nn/test_film.py @@ -0,0 +1,42 @@ +"""Tests for FiLM module.""" + +from collections.abc import Sequence + +import pytest +from mrpro.nn.FiLM import FiLM +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels', 'n_channels_cond', 'input_shape', 'cond_shape'), + [ + (64, 32, (1, 64, 32, 32), (1, 32)), + (32, 16, (2, 32, 16, 16), (2, 16)), + ], +) +def test_film( + n_channels: int, n_channels_cond: int, input_shape: Sequence[int], cond_shape: Sequence[int], device: str +) -> None: + """Test FiLM output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) + film = FiLM(channels=n_channels, cond_dim=n_channels_cond).to(device) + output = film(x, cond=cond) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert cond.grad is not None, 'No gradient computed for conditioning' + assert not output.isnan().any(), 'NaN values in output' + assert not cond.isnan().any(), 'NaN values in conditioning' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' + assert film.project is not None, 'Linear layer is not initialized' + assert next(film.project.parameters()).grad is not None, 'No gradient computed for Linear layer' diff --git a/tests/nn/test_fourierfeatures.py b/tests/nn/test_fourierfeatures.py new file mode 100644 index 000000000..9452a369f --- /dev/null +++ b/tests/nn/test_fourierfeatures.py @@ -0,0 +1,24 @@ +"""Test for random fourier features""" + +import pytest +from mrpro.nn import FourierFeatures +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_fourierfeatures(device: str) -> None: + """Test FourierFeatures.""" + n_features_in = 1 + n_features_out = 16 + std = 1.0 + rng = RandomGenerator(444) + x = rng.float32_tensor((1, n_features_in)).to(device) + ff = FourierFeatures(n_features_in, n_features_out, std).to(device) + y = ff(x) + assert y.shape == (1, n_features_out) diff --git a/tests/nn/test_geglu.py b/tests/nn/test_geglu.py new file mode 100644 index 000000000..061837e51 --- /dev/null +++ b/tests/nn/test_geglu.py @@ -0,0 +1,38 @@ +"""Test GEGLU.""" + +import pytest +import torch +from mrpro.nn.GEGLU import GEGLU +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_geglu(device: str) -> None: + """Test GEGLU output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).to(device).requires_grad_(True) + gelu = GEGLU(3, 4).to(device) + y = gelu(x) + assert y.shape == (1, 4, 4, 5) + + y.sum().backward() + assert x.grad is not None + assert gelu.proj.weight.grad is not None + + +def test_geglu_features_last() -> None: + """Test GEGLU with features_last=True vs features_last=False.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + gelu_last = GEGLU(3, 4, features_last=True) + gelu = GEGLU(3, 4, features_last=False) + gelu.proj = gelu_last.proj # need to set the same weights + y_last = gelu_last(x.moveaxis(1, -1)) + y = gelu(x) + torch.testing.assert_close(y, y_last.moveaxis(-1, 1)) diff --git a/tests/nn/test_groupnorm.py b/tests/nn/test_groupnorm.py new file mode 100644 index 000000000..044de17f7 --- /dev/null +++ b/tests/nn/test_groupnorm.py @@ -0,0 +1,45 @@ +"""Tests for GroupNorm module.""" + +from collections.abc import Sequence + +import pytest +from mrpro.nn import GroupNorm +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels', 'n_groups', 'input_shape', 'affine'), + [ + (32, None, (1, 32, 32, 32), True), + (64, 8, (2, 64, 16, 16, 16), False), + ], +) +def test_groupnorm( + n_channels: int, + n_groups: int | None, + input_shape: Sequence[int], + device: str, + affine: bool, +) -> None: + """Test GroupNorm output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + norm = GroupNorm(n_channels=n_channels, n_groups=n_groups, affine=affine).to(device) + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + if affine: + assert norm.weight is not None, 'Weight should not be None when affine is True' + assert norm.weight.grad is not None, 'No gradient computed for weight' + assert norm.bias is not None, 'Bias should not be None when affine is True' + assert norm.bias.grad is not None, 'No gradient computed for bias' diff --git a/tests/nn/test_join.py b/tests/nn/test_join.py new file mode 100644 index 000000000..f86647ac4 --- /dev/null +++ b/tests/nn/test_join.py @@ -0,0 +1,160 @@ +"""Tests for join modules.""" + +from typing import Literal + +import pytest +import torch +from mrpro.nn.join import Add, Concat +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('mode', 'input_shapes', 'expected_shape'), + [ + ('crop', [(1, 3, 32, 32), (1, 5, 30, 30)], (1, 8, 30, 30)), + ('zero', [(1, 3, 32, 32), (1, 5, 34, 34)], (1, 8, 34, 34)), + ('linear', [(1, 3, 32, 32), (1, 5, 34, 34)], (1, 8, 34, 34)), + ('nearest', [(1, 3, 32, 32), (1, 5, 34, 34)], (1, 8, 34, 34)), + ], +) +def test_concat_basic( + mode: Literal['crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'], + input_shapes: list[tuple[int, ...]], + expected_shape: tuple[int, ...], + device: str, +) -> None: + """Test Concat basic functionality.""" + rng = RandomGenerator(seed=42) + xs = [rng.float32_tensor(shape).to(device).requires_grad_(True) for shape in input_shapes] + concat = Concat(mode=mode).to(device) + + output = concat(*xs) + assert output.shape == expected_shape + assert not output.isnan().any(), 'NaN values in output' + + output.sum().backward() + for x in xs: + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('mode', 'input_shapes', 'expected_shape'), + [ + ('crop', [(1, 3, 32, 32), (1, 3, 30, 30)], (1, 3, 30, 30)), + ('zero', [(1, 3, 32, 32), (1, 3, 34, 34)], (1, 3, 34, 34)), + ('replicate', [(1, 1, 1, 2), (1, 1, 1, 3)], (1, 1, 1, 3)), + ('circular', [(1, 1, 1, 2), (1, 1, 1, 4)], (1, 1, 1, 4)), + ], +) +def test_add_basic( + mode: Literal['crop', 'zero', 'replicate', 'circular'], + input_shapes: list[tuple[int, ...]], + expected_shape: tuple[int, ...], + device: str, +) -> None: + """Test Add basic functionality.""" + rng = RandomGenerator(seed=42) + xs = [rng.float32_tensor(shape).to(device).requires_grad_(True) for shape in input_shapes] + add = Add(mode=mode).to(device) + + output = add(*xs) + assert output.shape == expected_shape + assert not output.isnan().any(), 'NaN values in output' + + output.sum().backward() + for x in xs: + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + +@pytest.mark.parametrize( + ('dim', 'input_shapes', 'expected_shape'), + [ + (0, [(1, 3, 32, 32), (1, 3, 32, 32)], (2, 3, 32, 32)), + (1, [(1, 3, 32, 32), (1, 5, 32, 32)], (1, 8, 32, 32)), + (2, [(1, 3, 32, 32), (1, 3, 32, 32)], (1, 3, 64, 32)), + ], +) +def test_concat_dimensions(dim: int, input_shapes: list[tuple[int, ...]], expected_shape: tuple[int, ...]) -> None: + """Test Concat with different concatenation dimensions.""" + rng = RandomGenerator(seed=42) + xs = [rng.float32_tensor(shape).requires_grad_(True) for shape in input_shapes] + concat = Concat(mode='fail', dim=dim) + output = concat(*xs) + assert output.shape == expected_shape + + +def test_concat_values() -> None: + """Test that Concat preserves input values correctly.""" + x1 = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]).requires_grad_(True) + x2 = torch.tensor([[[[5.0, 6.0], [7.0, 8.0]]]]).requires_grad_(True) + + concat = Concat(mode='fail') + output = concat(x1, x2) + + expected = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]) + torch.testing.assert_close(output, expected) + + +def test_add_values() -> None: + """Test that Add correctly sums input values.""" + x1 = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]).requires_grad_(True) + x2 = torch.tensor([[[[5.0, 6.0], [7.0, 8.0]]]]).requires_grad_(True) + + add = Add(mode='fail') + output = add(x1, x2) + + expected = torch.tensor([[[[6.0, 8.0], [10.0, 12.0]]]]) + torch.testing.assert_close(output, expected) + + +def test_concat_mode_fail() -> None: + """Test Concat with mode='fail'.""" + rng = RandomGenerator(seed=42) + + x1 = rng.float32_tensor((1, 3, 32, 32)) + x2 = rng.float32_tensor((1, 5, 32, 32)) + concat = Concat(mode='fail') + output = concat(x1, x2) + assert output.shape == (1, 8, 32, 32) + + x3 = rng.float32_tensor((1, 3, 30, 30)) + with pytest.raises(RuntimeError): + concat(x1, x3) + + +def test_add_mode_fail() -> None: + """Test Add with mode='fail'.""" + rng = RandomGenerator(seed=42) + + x1 = rng.float32_tensor((1, 3, 32, 32)) + x2 = rng.float32_tensor((1, 3, 32, 32)) + add = Add(mode='fail') + output = add(x1, x2) + assert output.shape == (1, 3, 32, 32) + + x3 = rng.float32_tensor((1, 3, 30, 30)) + with pytest.raises(RuntimeError): + add(x1, x3) + + +@pytest.mark.parametrize('module_class', [Concat, Add]) +def test_invalid_mode(module_class: type) -> None: + """Test modules with invalid mode.""" + with pytest.raises(ValueError, match='mode must be one of'): + module_class(mode='invalid_mode') diff --git a/tests/nn/test_layernorm.py b/tests/nn/test_layernorm.py new file mode 100644 index 000000000..ebc11ccb8 --- /dev/null +++ b/tests/nn/test_layernorm.py @@ -0,0 +1,187 @@ +"""Tests for LayerNorm module.""" + +from collections.abc import Sequence + +import pytest +import torch +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels', 'features_last', 'input_shape'), + [ + (32, False, (1, 32, 32, 32)), + (64, True, (2, 16, 16, 64)), + (None, False, (1, 32, 32, 32)), + (None, True, (2, 16, 16, 64)), + ], +) +def test_layernorm_basic( + n_channels: int | None, + features_last: bool, + input_shape: Sequence[int], + device: str, +) -> None: + """Test LayerNorm basic functionality.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + norm = LayerNorm(n_channels=n_channels, features_last=features_last).to(device) + output = norm(x) + + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + if n_channels is not None: + assert norm.weight is not None, 'Weight should not be None when n_channels is provided' + assert norm.bias is not None, 'Bias should not be None when n_channels is provided' + assert norm.weight.grad is not None, 'No gradient computed for weight' + assert norm.bias.grad is not None, 'No gradient computed for bias' + + +@pytest.mark.parametrize( + ('n_channels', 'cond_dim', 'input_shape', 'cond_shape'), + [ + (32, 16, (1, 32, 32, 32), (1, 16)), + (64, 32, (2, 64, 16, 16), (2, 32)), + ], +) +def test_layernorm_with_conditioning( + n_channels: int, + cond_dim: int, + input_shape: Sequence[int], + cond_shape: Sequence[int], +) -> None: + """Test LayerNorm with conditioning.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).requires_grad_(True) + norm = LayerNorm(n_channels=n_channels, cond_dim=cond_dim) + + output = norm(x, cond=cond) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert cond.grad is not None, 'No gradient computed for conditioning' + assert norm.cond_proj is not None, 'cond_proj should not be None when cond_dim > 0' + assert norm.cond_proj.weight.grad is not None, 'No gradient computed for cond_proj' + + +def test_layernorm_features_last() -> None: + """Test LayerNorm with features_last=True vs features_last=False.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + + norm_standard = LayerNorm(n_channels=3, features_last=False) + y_standard = norm_standard(x) + + norm_last = LayerNorm(n_channels=3, features_last=True) + y_last = norm_last(x.moveaxis(1, -1)) + + torch.testing.assert_close(y_standard, y_last.moveaxis(-1, 1)) + + +def test_layernorm_no_channels() -> None: + """Test LayerNorm without channels (pure normalization).""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 32, 32)).requires_grad_(True) + norm = LayerNorm(n_channels=None) + + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + # Check that normalization is applied (mean close to 0, std close to 1) + dims = tuple(range(1, x.ndim)) + mean = output.mean(dim=dims) + std = output.std(dim=dims) + + assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-6), 'Mean not close to 0' + assert torch.allclose(std, torch.ones_like(std), atol=1e-5), 'Std not close to 1' + + +def test_layernorm_conditioning_without_channels() -> None: + """Test LayerNorm with conditioning but no channels (should raise error).""" + with pytest.raises(ValueError, match='channels must be provided if cond_dim > 0'): + LayerNorm(n_channels=None, cond_dim=16) + + +def test_layernorm_invalid_cond_dim() -> None: + """Test LayerNorm with invalid cond_dim.""" + with pytest.raises(RuntimeError, match='Trying to create tensor with negative dimension'): + LayerNorm(n_channels=32, cond_dim=-1) + + +def test_layernorm_3d_input() -> None: + """Test LayerNorm with 3D input.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((2, 64, 128)).requires_grad_(True) + norm = LayerNorm(n_channels=64) + + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + + +def test_layernorm_5d_input() -> None: + """Test LayerNorm with 5D input.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 16, 16, 16)).requires_grad_(True) + norm = LayerNorm(n_channels=32) + + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + + +def test_layernorm_conditioning_features_last() -> None: + """Test LayerNorm with conditioning and features_last=True.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + cond = rng.float32_tensor((1, 8)).requires_grad_(True) + + norm = LayerNorm(n_channels=3, features_last=True, cond_dim=8) + output = norm(x.moveaxis(1, -1), cond=cond) + + assert output.shape == x.moveaxis(1, -1).shape, f'Output shape {output.shape} != expected shape' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert cond.grad is not None, 'No gradient computed for conditioning' + + +def test_layernorm_gradient_flow() -> None: + """Test that gradients flow properly through LayerNorm.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 32, 32)).requires_grad_(True) + norm = LayerNorm(n_channels=32) + + output = norm(x) + loss = output.sum() + loss.backward() + + # Check that gradients are computed for all learnable parameters + assert x.grad is not None, 'Input gradients not computed' + assert norm.weight is not None, 'Weight should not be None when n_channels is provided' + assert norm.bias is not None, 'Bias should not be None when n_channels is provided' + assert norm.weight.grad is not None, 'Weight gradients not computed' + assert norm.bias.grad is not None, 'Bias gradients not computed' + + # Check that gradients are finite + assert torch.isfinite(x.grad).all(), 'Input gradients contain non-finite values' + assert torch.isfinite(norm.weight.grad).all(), 'Weight gradients contain non-finite values' + assert torch.isfinite(norm.bias.grad).all(), 'Bias gradients contain non-finite values' diff --git a/tests/nn/test_ndmodules.py b/tests/nn/test_ndmodules.py new file mode 100644 index 000000000..a0a77a98d --- /dev/null +++ b/tests/nn/test_ndmodules.py @@ -0,0 +1,76 @@ +"""Tests for the ndmodules module.""" + +import pytest +import torch +from mrpro.nn.ndmodules import ( + adaptiveAvgPoolND, + avgPoolND, + batchNormND, + convND, + convTransposeND, + instanceNormND, + maxPoolND, +) + + +def test_convnd() -> None: + """Test ConvND.""" + assert convND(1) is torch.nn.Conv1d + assert convND(2) is torch.nn.Conv2d + assert convND(3) is torch.nn.Conv3d + with pytest.raises(NotImplementedError): + convND(4) + + +def test_convtransposend() -> None: + """Test ConvTransposeND.""" + assert convTransposeND(1) is torch.nn.ConvTranspose1d + assert convTransposeND(2) is torch.nn.ConvTranspose2d + assert convTransposeND(3) is torch.nn.ConvTranspose3d + with pytest.raises(NotImplementedError): + convTransposeND(4) + + +def test_maxpoolnd() -> None: + """Test MaxPoolND.""" + assert maxPoolND(1) is torch.nn.MaxPool1d + assert maxPoolND(2) is torch.nn.MaxPool2d + assert maxPoolND(3) is torch.nn.MaxPool3d + with pytest.raises(NotImplementedError): + maxPoolND(4) + + +def test_avgpoolnd() -> None: + """Test AvgPoolND.""" + assert avgPoolND(1) is torch.nn.AvgPool1d + assert avgPoolND(2) is torch.nn.AvgPool2d + assert avgPoolND(3) is torch.nn.AvgPool3d + with pytest.raises(NotImplementedError): + avgPoolND(4) + + +def test_adaptiveavgpoolnd() -> None: + """Test AdaptiveAvgPoolND.""" + assert adaptiveAvgPoolND(1) is torch.nn.AdaptiveAvgPool1d + assert adaptiveAvgPoolND(2) is torch.nn.AdaptiveAvgPool2d + assert adaptiveAvgPoolND(3) is torch.nn.AdaptiveAvgPool3d + with pytest.raises(NotImplementedError): + adaptiveAvgPoolND(4) + + +def test_instancenormnd() -> None: + """Test InstanceNormND.""" + assert instanceNormND(1) is torch.nn.InstanceNorm1d + assert instanceNormND(2) is torch.nn.InstanceNorm2d + assert instanceNormND(3) is torch.nn.InstanceNorm3d + with pytest.raises(NotImplementedError): + instanceNormND(4) + + +def test_batchnormnd() -> None: + """Test BatchNormND.""" + assert batchNormND(1) is torch.nn.BatchNorm1d + assert batchNormND(2) is torch.nn.BatchNorm2d + assert batchNormND(3) is torch.nn.BatchNorm3d + with pytest.raises(NotImplementedError): + batchNormND(4) diff --git a/tests/nn/test_pixelshuffle.py b/tests/nn/test_pixelshuffle.py new file mode 100644 index 000000000..8d5917a83 --- /dev/null +++ b/tests/nn/test_pixelshuffle.py @@ -0,0 +1,92 @@ +"""Test PixelShuffle and PixelUnshuffle.""" + +from typing import cast + +import torch +from mrpro.nn.PixelShuffle import PixelShuffle, PixelShuffleUpsample, PixelUnshuffle, PixelUnshuffleDownsample +from mrpro.utils import RandomGenerator + + +def test_pixel_shuffle_2d() -> None: + """Test PixelUnshuffle's fast path for 2D images.""" + x = torch.arange(3 * 4 * 8).reshape(1, 3, 4, 8) + pixel_unshuffle = PixelUnshuffle(2) + y = pixel_unshuffle(x) + assert y.shape == (1, 3 * 4, 4 // 2, 8 // 2) + + pixel_shuffle = PixelShuffle(2) + z = pixel_shuffle(y) + assert z.shape == (1, 3, 4, 8) + assert (x == z).all() + + +def test_pixel_unshuffle_4d() -> None: + """Test PixelUnshuffle's general case.""" + x = torch.arange(3 * 4 * 8 * 10 * 12).reshape(1, 3, 4, 8, 10, 12) + pixel_unshuffle = PixelUnshuffle(2) + y = pixel_unshuffle(x) + assert y.shape == (1, 3 * 16, 4 // 2, 8 // 2, 10 // 2, 12 // 2) + + pixel_shuffle = PixelShuffle(2) + z = pixel_shuffle(y) + assert z.shape == (1, 3, 4, 8, 10, 12) + assert (x == z).all() + + +def test_pixelunshuffle_features_last() -> None: + """Test PixelUnshuffle with features_last.""" + x = torch.arange(3 * 4 * 8 * 10 * 12).reshape(1, 3, 4, 8, 10, 12) + pixel_unshuffle_last = PixelUnshuffle(2, features_last=True) + pixel_unshuffle = PixelUnshuffle(2, features_last=False) + y_last = pixel_unshuffle_last(x.moveaxis(1, -1)).moveaxis(-1, 1) + y_normal = pixel_unshuffle(x) + assert (y_last == y_normal).all() + + +def test_pixelshuffle_features_last() -> None: + """Test PixelShuffle with features_last.""" + x = torch.arange(3 * 4 * 8 * 10 * 12).reshape(1, -1, 2, 4, 5, 6) + pixel_shuffle_last = PixelShuffle(2, features_last=True) + pixel_shuffle = PixelShuffle(2, features_last=False) + y_last = pixel_shuffle_last(x.moveaxis(1, -1)).moveaxis(-1, 1) + y_normal = pixel_shuffle(x) + assert (y_last == y_normal).all() + + +def test_unpixelshuffledownsample_residual() -> None: + """Test PixelUnshuffleDownsample with residual.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 2, 9, 12, 15)) + downsample = PixelUnshuffleDownsample(3, 2, 27, downscale_factor=3, residual=True) + y = downsample(x) + assert y.shape == (1, 27, 3, 4, 5) + + +def test_pixelshuffleupsample_residual() -> None: + """Test PixelShuffleUpsample with residual.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 2, 3, 4, 5)) + upsample = PixelShuffleUpsample(3, 2, 1, upscale_factor=3, residual=True) + y = upsample(x) + assert y.shape == (1, 1, 9, 12, 15) + + +def test_pixelshuffleupsample_pixelunshuffledownsample() -> None: + """Test if PixelUnshuffleDownsample is the inverse of PixelShuffleUpsample.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3**3, 3, 4, 5)) + # Only without residual, the upsample and downsample are inverses. + downsample = PixelUnshuffleDownsample(3, 1, 3**3, downscale_factor=3, residual=False) + upsample = PixelShuffleUpsample(3, 3**3, 1, upscale_factor=3, residual=False) + # Only if the convs are Identity, the upsample and downsample are inverses. + torch.nn.init.dirac_(cast(torch.Tensor, downsample.projection.weight)) + torch.nn.init.dirac_(cast(torch.Tensor, upsample.projection.weight)) + downsample_bias = cast(torch.Tensor | None, downsample.projection.bias) + upsample_bias = cast(torch.Tensor | None, upsample.projection.bias) + if downsample_bias is not None: + torch.nn.init.zeros_(downsample_bias) + if upsample_bias is not None: + torch.nn.init.zeros_(upsample_bias) + y = downsample(upsample(x)) + assert y.shape == (1, 3**3, 3, 4, 5) + torch.testing.assert_close(y, x, msg='Upsample and downsample are not inverses.') diff --git a/tests/nn/test_rmsnorm.py b/tests/nn/test_rmsnorm.py new file mode 100644 index 000000000..c8ddc0b69 --- /dev/null +++ b/tests/nn/test_rmsnorm.py @@ -0,0 +1,58 @@ +"""Tests for RMSNorm module.""" + +from collections.abc import Sequence + +import pytest +import torch +from mrpro.nn import RMSNorm +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels', 'features_last', 'input_shape'), + [ + (32, False, (1, 32, 32, 32)), + (64, True, (2, 16, 16, 64)), + (None, False, (1, 32, 32, 32)), + (None, True, (2, 16, 16, 64)), + ], +) +def test_rmsnorm_basic(n_channels: int | None, features_last: bool, input_shape: Sequence[int], device: str) -> None: + """Test RMSNorm basic functionality.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + norm = RMSNorm(n_channels=n_channels, features_last=features_last).to(device) + output = norm(x) + + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + if n_channels is not None: + assert norm.weight is not None, 'Weight should not be None when n_channels is provided' + assert norm.bias is not None, 'Bias should not be None when n_channels is provided' + assert norm.weight.grad is not None, 'No gradient computed for weight' + assert norm.bias.grad is not None, 'No gradient computed for bias' + + +def test_rmsnorm_features_last() -> None: + """Test RMSNorm with features_last=True vs features_last=False.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + + norm_standard = RMSNorm(n_channels=3, features_last=False) + y_standard = norm_standard(x) + + norm_last = RMSNorm(n_channels=3, features_last=True) + y_last = norm_last(x.moveaxis(1, -1)) + + torch.testing.assert_close(y_standard, y_last.moveaxis(-1, 1)) diff --git a/tests/nn/test_sequential.py b/tests/nn/test_sequential.py new file mode 100644 index 000000000..bdf81bf8d --- /dev/null +++ b/tests/nn/test_sequential.py @@ -0,0 +1,50 @@ +"""Tests for Sequential module.""" + +from collections.abc import Sequence + +import pytest +from mrpro.nn import FiLM, Sequential +from mrpro.operators import FastFourierOp, MagnitudeOp +from mrpro.utils import RandomGenerator +from torch.nn import Linear + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('input_shape', 'cond_dim'), + [ + ((1, 32), (1, 16)), + ((2, 32), None), + ], +) +def test_sequential( + input_shape: Sequence[int], + cond_dim: Sequence[int] | None, + device: str, +) -> None: + """Test Sequential output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_dim).to(device).requires_grad_(True) if cond_dim else None + seq = Sequential( + Linear(input_shape[1], 64), + FastFourierOp(dim=(-1,)), + FiLM(channels=64, cond_dim=16), + MagnitudeOp(), + ).to(device) + output = seq(x, cond=cond) + assert output.shape == (input_shape[0], 64) + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + if cond is not None: + assert cond.grad is not None, 'No gradient computed for cond' + assert not cond.grad.isnan().any(), 'NaN values in cond gradients' + assert seq[0].weight.grad is not None, 'No gradient computed for Linear' From 1e186e5d123e4fa487b577e183935d19139d9987 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 8 Feb 2026 22:57:26 +0100 Subject: [PATCH 179/205] add data consistency modules and tests ghstack-source-id: 667260860fb33de9158d56620360ca10e3697941 ghstack-comment-id: 3865650618 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/954 --- docker/minimal_requirements.txt | 2 +- src/mrpro/nn/__init__.py | 2 + .../data_consistency/AnalyticCartesianDC.py | 101 +++++++++++++++ .../data_consistency/ConjugateGradientDC.py | 117 ++++++++++++++++++ .../nn/data_consistency/GradientDescentDC.py | 99 +++++++++++++++ src/mrpro/nn/data_consistency/__init__.py | 5 + tests/nn/data_consistency/conftest.py | 46 +++++++ .../test_analyticcartesiandc.py | 21 ++++ .../test_conjugategradientdc.py | 21 ++++ .../test_gradientdescentdc.py | 21 ++++ 10 files changed, 434 insertions(+), 1 deletion(-) create mode 100644 src/mrpro/nn/data_consistency/AnalyticCartesianDC.py create mode 100644 src/mrpro/nn/data_consistency/ConjugateGradientDC.py create mode 100644 src/mrpro/nn/data_consistency/GradientDescentDC.py create mode 100644 src/mrpro/nn/data_consistency/__init__.py create mode 100644 tests/nn/data_consistency/conftest.py create mode 100644 tests/nn/data_consistency/test_analyticcartesiandc.py create mode 100644 tests/nn/data_consistency/test_conjugategradientdc.py create mode 100644 tests/nn/data_consistency/test_gradientdescentdc.py diff --git a/docker/minimal_requirements.txt b/docker/minimal_requirements.txt index c9be333c9..6723b723e 100644 --- a/docker/minimal_requirements.txt +++ b/docker/minimal_requirements.txt @@ -6,7 +6,7 @@ einops==0.7.0 pydicom==3.0.1 pypulseq==1.4.2 pytorch-finufft==0.1.0 -cufinufft==2.3.1 +cufinufft==2.4.1 scipy==1.12 ptwt==0.1.8 tqdm==4.60.0 diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index d6541f5c8..f988855e7 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -12,6 +12,7 @@ from mrpro.nn.RMSNorm import RMSNorm from mrpro.nn.Residual import Residual from mrpro.nn.Sequential import Sequential +from mrpro.nn import data_consistency from mrpro.nn.ndmodules import ( adaptiveAvgPoolND, avgPoolND, @@ -40,6 +41,7 @@ 'batchNormND', 'convND', 'convTransposeND', + 'data_consistency', 'instanceNormND', 'maxPoolND', ] diff --git a/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py new file mode 100644 index 000000000..5b2848fca --- /dev/null +++ b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py @@ -0,0 +1,101 @@ +"""Analytic Cartesian data consistency.""" + +from typing import overload + +import torch +from torch.nn import Module, Parameter + +from mrpro.data.KData import KData +from mrpro.operators.FourierOp import FourierOp +from mrpro.operators.IdentityOp import IdentityOp + + +class AnalyticCartesianDC(Module): + r"""Analytic Cartesian data consistency. + + Solves the following problem: + :math:`\min_x \|Ax - k\|_2^2 + \lambda \|x-p\|_2^2` + where :math:`A` is the acquisition operator and :math:`k` is the data, :math:`\lambda` is the regularization + parameter and :math:`p` is the regularization image/prior analytically. :math:`A^H A` has to be diagonal. This is a + special case for a Cartesian acquisition without coil sensitivity weighting. This can be used for either single-coil + data or to apply data consistency to each coil image [NOSENSE]_. + + References + ---------- + .. [NOSENSE] Zimmermann, FF, and Kofler, Andreas. "NoSENSE: Learned unrolled cardiac MRI reconstruction without + explicit sensitivity maps." STACOM@MICCAI 2023. https://arxiv.org/abs/2309.15608 + + Parameters + ---------- + initial_regularization_weight + Initial regularization weight. The regularization weight is a trainable parameter. + + + """ + + def __init__(self, initial_regularization_weight: torch.Tensor | float): + """Initialize the data consistency. + + Parameters + ---------- + initial_regularization_weight + Initial regularization weight. The regularization weight is a trainable parameter. + Must be a positive scalar. + """ + super().__init__() + weight = torch.as_tensor(initial_regularization_weight) + if weight.ndim != 0: + raise ValueError('Regularization weight must be a scalar') + if weight.item() <= 0: + raise ValueError('Regularization weight must be positive') + self.log_weight = Parameter(weight.log()) + + @overload + def __call__(self, image: torch.Tensor, data: KData, fourier_op: FourierOp | None = None) -> torch.Tensor: ... + + @overload + def __call__(self, image: torch.Tensor, data: torch.Tensor, fourier_op: FourierOp) -> torch.Tensor: ... + + def __call__( + self, image: torch.Tensor, data: KData | torch.Tensor, fourier_op: FourierOp | None = None + ) -> torch.Tensor: + """Apply the data consistency. + + Parameters + ---------- + image + Current image estimate, i.e. the regularized image. + data + k-space data. + fourier_op + Fourier operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` object, + the Fourier operator is automatically created from the data. + + Returns + ------- + Updated image estimate. + """ + return super().__call__(image, data, fourier_op) + + def forward( + self, image: torch.Tensor, data: KData | torch.Tensor, fourier_op: FourierOp | None = None + ) -> torch.Tensor: + """Apply the data consistency.""" + if fourier_op is None: + if isinstance(data, KData): + fourier_op = FourierOp.from_kdata(data) + else: + raise ValueError('Either a KData or a FourierOp is required') + + if not isinstance(fourier_op, FourierOp) or fourier_op._nufft_dims or fourier_op._fast_fourier_op is None: + raise ValueError('Only Cartesian acquisitions are supported') + + data_ = data.data if isinstance(data, KData) else data + fft_op = fourier_op._fast_fourier_op + sampling_op = fourier_op._cart_sampling_op if fourier_op._cart_sampling_op is not None else IdentityOp() + (zero_filled,) = sampling_op.adjoint(data_) + (k_pred,) = fft_op(image) + regularization_weight = self.log_weight.exp() + (k,) = sampling_op.gram((zero_filled - k_pred) / (1 + regularization_weight)) + (delta,) = fft_op.H(k) + return image + delta diff --git a/src/mrpro/nn/data_consistency/ConjugateGradientDC.py b/src/mrpro/nn/data_consistency/ConjugateGradientDC.py new file mode 100644 index 000000000..f81b24f98 --- /dev/null +++ b/src/mrpro/nn/data_consistency/ConjugateGradientDC.py @@ -0,0 +1,117 @@ +"""Conjugate gradient data consistency.""" + +from typing import overload + +import torch +from torch.nn import Module, Parameter + +from mrpro.data.CsmData import CsmData +from mrpro.data.KData import KData +from mrpro.operators.ConjugateGradientOp import ConjugateGradientOp +from mrpro.operators.FourierOp import FourierOp +from mrpro.operators.LinearOperator import LinearOperator +from mrpro.operators.SensitivityOp import SensitivityOp + + +class ConjugateGradientDC(Module): + """Conjugate gradient data consistency.""" + + def __init__(self, initial_regularization_weight: torch.Tensor | float): + """Initialize the conjugate gradient data consistency. + + Parameters + ---------- + initial_regularization_weight + Initial regularization weight. The regularization weight is a trainable parameter. + Must be a positive scalar. + """ + super().__init__() + weight = torch.as_tensor(initial_regularization_weight) + if weight.ndim != 0: + raise ValueError('Regularization weight must be a scalar') + if weight.item() <= 0: + raise ValueError('Regularization weight must be positive') + self.log_weight = Parameter(weight.log()) + + @overload + def __call__( + self, + image: torch.Tensor, + data: KData, + fourier_op: FourierOp | None = None, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: ... + + @overload + def __call__( + self, + image: torch.Tensor, + data: torch.Tensor, + fourier_op: FourierOp, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: ... + + def __call__( + self, + image: torch.Tensor, + data: KData | torch.Tensor, + fourier_op: LinearOperator | None = None, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: + """Apply the data consistency. + + Parameters + ---------- + image + Current image estimate. + data + k-space data. + fourier_op + Fourier operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` object, + the Fourier operator is automatically created from the data. + This operator can already include the coil sensitivity weighting, if gradients wrt the coil sensitivity maps + NOT required. Otherwise, they should be given as an additional argument. + csm + Coil sensitivity maps. If None, no coil sensitivity weighting is applied. + + Returns + ------- + Updated image estimate. + """ + return super().__call__(image, data, fourier_op, csm) + + def forward( + self, + image: torch.Tensor, + data: torch.Tensor | KData, + fourier_op: FourierOp | None = None, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: + """Apply the data consistency.""" + if fourier_op is None: + if isinstance(data, KData): + fourier_op = FourierOp.from_kdata(data) + else: + raise ValueError('Either a KData or a FourierOp is required') + + data_ = data.data if isinstance(data, KData) else data + + if csm is None: + csm = torch.tensor(()) + elif isinstance(csm, CsmData): + csm = csm.data + + def operator_factory(csm: torch.Tensor, weight: torch.Tensor, *_): + op = fourier_op.gram + if csm.numel(): + csm_op = SensitivityOp(csm) + op = csm_op.H @ op @ csm_op + op = op + weight + return op + + def rhs_factory(_csm: torch.Tensor, weight: torch.Tensor, zero_filled: torch.Tensor, image: torch.Tensor): + return (zero_filled + weight * image,) + + cg_op = ConjugateGradientOp(operator_factory=operator_factory, rhs_factory=rhs_factory) + (result,) = cg_op(csm, self.log_weight.exp(), fourier_op.adjoint(data_)[0], image) + return result diff --git a/src/mrpro/nn/data_consistency/GradientDescentDC.py b/src/mrpro/nn/data_consistency/GradientDescentDC.py new file mode 100644 index 000000000..579e5181a --- /dev/null +++ b/src/mrpro/nn/data_consistency/GradientDescentDC.py @@ -0,0 +1,99 @@ +"""Gradient descent data consistency.""" + +from typing import overload + +import torch +from torch.nn import Module, Parameter + +from mrpro.data.KData import KData +from mrpro.operators.FourierOp import FourierOp +from mrpro.operators.LinearOperator import LinearOperator + + +class GradientDescentDC(Module): + r"""Gradient descent data consistency. + + Performs gradient descent steps on + :math:`\|Ax - k\|_2^2` where :math:`A` is the acquisition operator and :math:`k` is the data. + + Parameters + ---------- + initial_stepsize + Initial stepsize. The stepsize is a trainable parameter. + Must be a positive scalar. + n_steps + Number of gradient descent steps. + + Returns + ------- + The updated image. + """ + + def __init__(self, initial_stepsize: float | torch.Tensor, n_steps: int = 1) -> None: + """Initialize the gradient descent data consistency. + + Parameters + ---------- + initial_stepsize + Initial stepsize. The stepsize is a trainable parameter. + Must be a positive scalar. + n_steps + Number of gradient descent steps. + """ + super().__init__() + stepsize = torch.as_tensor(initial_stepsize) + if stepsize.ndim != 0: + raise ValueError('Stepsize must be a scalar') + if stepsize.item() <= 0: + raise ValueError('Stepsize must be positive') + self.log_stepsize = Parameter(stepsize.log()) + self.n_steps = n_steps + + @overload + def __call__( + self, image: torch.Tensor, data: KData, acquisition_operator: LinearOperator | None = None + ) -> torch.Tensor: ... + + @overload + def __call__( + self, image: torch.Tensor, data: torch.Tensor, acquisition_operator: LinearOperator + ) -> torch.Tensor: ... + + def __call__( + self, image: torch.Tensor, data: KData | torch.Tensor, acquisition_operator: LinearOperator | None = None + ) -> torch.Tensor: + """Apply the data consistency. + + Parameters + ---------- + image + Current image estimate. + data + k-space data. + acquisition_operator + Acquisition operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` + object, the Fourier operator is automatically created from the data. + + Returns + ------- + Updated image estimate. + """ + return super().__call__(image, data, acquisition_operator) + + def forward( + self, image: torch.Tensor, data: KData | torch.Tensor, acquisition_operator: LinearOperator | None = None + ) -> torch.Tensor: + """Apply the data consistency.""" + if acquisition_operator is None: + if isinstance(data, KData): + acquisition_operator = FourierOp.from_kdata(data) + else: + raise ValueError('Either a KData or an acquisition operator is required') + + data_ = data.data if isinstance(data, KData) else data + stepsize = self.log_stepsize.exp() + x = image + for _ in range(self.n_steps): + residual = acquisition_operator(x)[0] - data_ + x = x - stepsize * acquisition_operator.adjoint(residual)[0] + return x diff --git a/src/mrpro/nn/data_consistency/__init__.py b/src/mrpro/nn/data_consistency/__init__.py new file mode 100644 index 000000000..cea955c85 --- /dev/null +++ b/src/mrpro/nn/data_consistency/__init__.py @@ -0,0 +1,5 @@ +from mrpro.nn.data_consistency.AnalyticCartesianDC import AnalyticCartesianDC +from mrpro.nn.data_consistency.GradientDescentDC import GradientDescentDC +from mrpro.nn.data_consistency.ConjugateGradientDC import ConjugateGradientDC + +__all__ = ["AnalyticCartesianDC", "ConjugateGradientDC", "GradientDescentDC"] \ No newline at end of file diff --git a/tests/nn/data_consistency/conftest.py b/tests/nn/data_consistency/conftest.py new file mode 100644 index 000000000..49fca4cf3 --- /dev/null +++ b/tests/nn/data_consistency/conftest.py @@ -0,0 +1,46 @@ +"""Test fixtures for data consistency tests.""" + +import pytest +from mrpro.data.KData import KData +from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.data.traj_calculators.KTrajectoryCartesian import KTrajectoryCartesian +from mrpro.operators.FourierOp import FourierOp +from mrpro.phantoms.EllipsePhantom import EllipsePhantom +from mrpro.utils import RandomGenerator + + +@pytest.fixture +def kdata(): + matrix = SpatialDimension(x=128, y=128, z=1) + kdata = EllipsePhantom().kdata(KTrajectoryCartesian.fullysampled(matrix), matrix) + return kdata + + +@pytest.fixture +def kdata_noisy(kdata: KData): + kdata_noisy = kdata.clone() + kdata_noisy.data += 0.1 * RandomGenerator(123).randn_like(kdata_noisy.data) + return kdata_noisy + + +@pytest.fixture +def kdata_us(kdata: KData): + return kdata[..., ::2, :].clone() + + +@pytest.fixture +def image_noisy(kdata_noisy: KData): + fourier_op = FourierOp.from_kdata(kdata_noisy) + return fourier_op.adjoint(kdata_noisy.data)[0] + + +@pytest.fixture +def image(kdata: KData): + fourier_op = FourierOp.from_kdata(kdata) + return fourier_op.adjoint(kdata.data)[0] + + +@pytest.fixture +def image_us(kdata_us: KData): + fourier_op = FourierOp.from_kdata(kdata_us) + return fourier_op.adjoint(kdata_us.data)[0] diff --git a/tests/nn/data_consistency/test_analyticcartesiandc.py b/tests/nn/data_consistency/test_analyticcartesiandc.py new file mode 100644 index 000000000..9020a8326 --- /dev/null +++ b/tests/nn/data_consistency/test_analyticcartesiandc.py @@ -0,0 +1,21 @@ +"""Tests for AnalyticCartesianDC module.""" + +import torch +from mrpro.data.KData import KData +from mrpro.nn.data_consistency.AnalyticCartesianDC import AnalyticCartesianDC + + +def test_analytic_cartesian_dc( + image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor +) -> None: + image_noisy = image_noisy.clone().requires_grad_(True) + dc = AnalyticCartesianDC(initial_regularization_weight=1e-6) + result = dc(image_noisy, kdata_us) + loss = (result - image).abs().mean() + assert loss < (image_noisy - image).abs().mean() + assert loss < (image_us - image).abs().mean() + loss.backward() + assert image_noisy.grad is not None + assert dc.log_weight.grad is not None + assert not dc.log_weight.grad.isnan() + assert not image_noisy.grad.isnan().any() diff --git a/tests/nn/data_consistency/test_conjugategradientdc.py b/tests/nn/data_consistency/test_conjugategradientdc.py new file mode 100644 index 000000000..99de6f4f5 --- /dev/null +++ b/tests/nn/data_consistency/test_conjugategradientdc.py @@ -0,0 +1,21 @@ +"""Tests for ConjugateGradientDC module.""" + +import torch +from mrpro.data.KData import KData +from mrpro.nn.data_consistency.ConjugateGradientDC import ConjugateGradientDC + + +def test_conjugate_gradient_dc( + image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor +) -> None: + image_noisy = image_noisy.clone().requires_grad_(True) + dc = ConjugateGradientDC(initial_regularization_weight=1.0) + result = dc(image_noisy, kdata_us) + loss = (result - image).abs().mean() + assert loss < (image_noisy - image).abs().mean() + assert loss < (image_us - image).abs().mean() + loss.backward() + assert image_noisy.grad is not None + assert dc.log_weight.grad is not None + assert not dc.log_weight.grad.isnan() + assert not image_noisy.grad.isnan().any() diff --git a/tests/nn/data_consistency/test_gradientdescentdc.py b/tests/nn/data_consistency/test_gradientdescentdc.py new file mode 100644 index 000000000..00a7648f9 --- /dev/null +++ b/tests/nn/data_consistency/test_gradientdescentdc.py @@ -0,0 +1,21 @@ +"""Tests for GradientDescentDC module.""" + +import torch +from mrpro.data.KData import KData +from mrpro.nn.data_consistency.GradientDescentDC import GradientDescentDC + + +def test_gradient_descent_dc( + image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor +) -> None: + image_noisy = image_noisy.clone().requires_grad_(True) + dc = GradientDescentDC(initial_stepsize=1.0) + result = dc(image_noisy, kdata_us) + loss = (result - image).abs().mean() + assert loss < (image_noisy - image).abs().mean() + assert loss < (image_us - image).abs().mean() + loss.backward() + assert image_noisy.grad is not None + assert dc.log_stepsize.grad is not None + assert not dc.log_stepsize.grad.isnan() + assert not image_noisy.grad.isnan().any() From 9376390b82c49980855c238bc3bf581c1cec614f Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 8 Feb 2026 22:57:27 +0100 Subject: [PATCH 180/205] add positional encodings and attention modules ghstack-source-id: cc10414c642313d4b8a7356e20d139d60c57f5be ghstack-comment-id: 3865650808 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/955 --- pyproject.toml | 2 + src/mrpro/nn/AbsolutePositionEncoding.py | 74 ++++++ src/mrpro/nn/AxialRoPE.py | 106 ++++++++ src/mrpro/nn/__init__.py | 6 + src/mrpro/nn/attention/AttentionGate.py | 74 ++++++ src/mrpro/nn/attention/LinearSelfAttention.py | 98 ++++++++ src/mrpro/nn/attention/MultiHeadAttention.py | 105 ++++++++ .../nn/attention/NeighborhoodSelfAttention.py | 237 ++++++++++++++++++ .../nn/attention/ShiftedWindowAttention.py | 131 ++++++++++ .../nn/attention/SpatialTransformerBlock.py | 216 ++++++++++++++++ src/mrpro/nn/attention/SqueezeExcitation.py | 57 +++++ src/mrpro/nn/attention/TransposedAttention.py | 76 ++++++ src/mrpro/nn/attention/__init__.py | 17 ++ tests/conftest.py | 6 + tests/nn/test_ape.py | 27 ++ tests/nn/test_attentiongate.py | 51 ++++ tests/nn/test_linearselfattention.py | 58 +++++ tests/nn/test_neighborhoodselfattention.py | 149 +++++++++++ tests/nn/test_rope.py | 36 +++ tests/nn/test_shiftedwindowattention.py | 61 +++++ tests/nn/test_spatialtransformerblock.py | 142 +++++++++++ tests/nn/test_squeezeexcitation.py | 32 +++ tests/nn/test_transposedattention.py | 44 ++++ 23 files changed, 1805 insertions(+) create mode 100644 src/mrpro/nn/AbsolutePositionEncoding.py create mode 100644 src/mrpro/nn/AxialRoPE.py create mode 100644 src/mrpro/nn/attention/AttentionGate.py create mode 100644 src/mrpro/nn/attention/LinearSelfAttention.py create mode 100644 src/mrpro/nn/attention/MultiHeadAttention.py create mode 100644 src/mrpro/nn/attention/NeighborhoodSelfAttention.py create mode 100644 src/mrpro/nn/attention/ShiftedWindowAttention.py create mode 100644 src/mrpro/nn/attention/SpatialTransformerBlock.py create mode 100644 src/mrpro/nn/attention/SqueezeExcitation.py create mode 100644 src/mrpro/nn/attention/TransposedAttention.py create mode 100644 src/mrpro/nn/attention/__init__.py create mode 100644 tests/nn/test_ape.py create mode 100644 tests/nn/test_attentiongate.py create mode 100644 tests/nn/test_linearselfattention.py create mode 100644 tests/nn/test_neighborhoodselfattention.py create mode 100644 tests/nn/test_rope.py create mode 100644 tests/nn/test_shiftedwindowattention.py create mode 100644 tests/nn/test_spatialtransformerblock.py create mode 100644 tests/nn/test_squeezeexcitation.py create mode 100644 tests/nn/test_transposedattention.py diff --git a/pyproject.toml b/pyproject.toml index 9863b3856..a94d9b3c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,7 +124,9 @@ filterwarnings = [ "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda "ignore:.*In the future, this object will be coerced as if it was first converted using.*:FutureWarning", # numpy 1.2 "ignore:.*load_module.*:DeprecationWarning", # torch compile in torch<2.6 + "ignore:The \\.grad attribute of a Tensor that is not a leaf Tensor is being accessed:UserWarning", # torch dynamo bug in flex attention "ignore:`torch.jit.script` is deprecated:DeprecationWarning", # torch 2.10 + "ignore:`torch.jit.script_method` is deprecated", # torch 2.10 ] addopts = "-n auto --dist loadfile --maxprocesses=8" markers = ["cuda : Tests only to be run when cuda device is available"] diff --git a/src/mrpro/nn/AbsolutePositionEncoding.py b/src/mrpro/nn/AbsolutePositionEncoding.py new file mode 100644 index 000000000..f6093d6ba --- /dev/null +++ b/src/mrpro/nn/AbsolutePositionEncoding.py @@ -0,0 +1,74 @@ +"""Absolute position encoding (APE).""" + +from itertools import combinations +from math import ceil + +import torch +from torch.nn import Module + +from mrpro.utils.reshape import unsqueeze_right + + +class AbsolutePositionEncoding(Module): + """Absolute position encoding layer. + + Encodes absolute positions in a grid. Has no learnable parameters. + """ + + encoding: torch.Tensor + + def __init__(self, n_dim: int, n_features: int, include_radii: bool = True, base_resolution: int = 128): + """Initialize absolute position encoding layer. + + Parameters + ---------- + n_dim + Dimensions of the input space (1, 2, or 3) + n_features + Number of features to encode. The input to the forward pass needs to have at least + this many features/channels. + include_radii + Whether to include radius features + base_resolution + Base resolution for position encoding. + Encodings are generated at this resolution and interpolated to the input shape in the forward pass. + """ + super().__init__() + + coords = [unsqueeze_right(torch.linspace(-1, 1, base_resolution), i) for i in range(n_dim)] + if include_radii: + for n in range(2, n_dim + 1): + for combination in combinations(coords, n): + coords.append((2 * sum([c**2 for c in combination])) ** 0.5 - 1) + n_freqs = ceil(n_features / len(coords) / 2) + freqs = unsqueeze_right((base_resolution) ** torch.linspace(0, 1, n_freqs), n_dim) + encoding = [] + for coord in coords: + encoding.append(torch.sin(coord * freqs).broadcast_to(1, -1, *((base_resolution,) * n_dim))) + encoding.append(torch.cos(coord * freqs).broadcast_to(1, -1, *((base_resolution,) * n_dim))) + self.register_buffer('encoding', torch.cat(encoding, dim=1)[:, :n_features]) + self.interpolation_mode = ['linear', 'bilinear', 'trilinear'][n_dim - 1] + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply absolute position encoding to a tensor. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Encoded tensor with absolute position information + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply absolute position encoding to a tensor.""" + features = self.encoding.shape[1] + if features > x.shape[1]: + raise ValueError(f'x has {x.shape[1]} features, but {features} are required') + + x_enc, x_unenc = x.split([features, x.shape[1] - features], dim=1) + encoding = torch.nn.functional.interpolate(self.encoding, size=x_unenc.shape[2:], mode=self.interpolation_mode) + return torch.cat((x_enc + encoding, x_unenc), dim=1) diff --git a/src/mrpro/nn/AxialRoPE.py b/src/mrpro/nn/AxialRoPE.py new file mode 100644 index 000000000..7d76a86d1 --- /dev/null +++ b/src/mrpro/nn/AxialRoPE.py @@ -0,0 +1,106 @@ +"""Rotary Position Embedding (RoPE).""" + +from collections.abc import Sequence + +import torch +from einops import rearrange +from torch.nn import Module + + +@torch.compile +def get_theta( + shape: Sequence[int], n_embedding_channels: int, device: torch.device +) -> torch.Tensor: # pragma: no cover + """Get rotation angles. + + Parameters + ---------- + shape + Spatial shape of the input tensor to use for the position embedding, + i.e. the shape excluding batch and channel dimensions. + n_embedding_channels + Number of embedding channels per head + device + Device to create the rotation angles on + + Returns + ------- + Rotation angles + """ + position = torch.stack( + torch.meshgrid([torch.arange(s, device=device) - s // 2 for s in shape], indexing='ij'), dim=-1 + ) + log_min = torch.log(torch.tensor(torch.pi)) + log_max = torch.log(torch.tensor(10000.0)) + freqs = torch.exp(torch.linspace(log_min, log_max, n_embedding_channels // (2 * position.shape[-1]), device=device)) + return rearrange(freqs * position[..., None], '... dim freqs ->... (dim freqs)') + + +class AxialRoPE(Module): + """Axial Rotary Position Embedding. + + Applies rotary position embeddings along each axis independently. + """ + + embed_fraction: float + freqs: torch.Tensor # explicit annotation kept for static type checking + + def __init__( + self, + embed_fraction: float = 1.0, + ): + """Initialize AxialRoPE. + + Parameters + ---------- + embed_fraction + Fraction of channels used for embedding + """ + super().__init__() + self.embed_fraction: float = float(embed_fraction) + if embed_fraction < 0 or embed_fraction > 1: + raise ValueError('embed_fraction must be between 0 and 1') + + def forward(self, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Apply rotary embeddings to input tensors. + + Parameters + ---------- + *tensors + Tensors to apply rotary embeddings to. + Shape must be `(batch, heads, *spatial_dims, channels)`. + """ + if self.embed_fraction == 0.0: + return tensors + + shape = tensors[0].shape + if not all(t.shape == shape for t in tensors): + raise ValueError('All tensors must have the same shape') + device = tensors[0].device + if not all(t.device == device for t in tensors): + raise ValueError('All tensors must be on the same device') + + shape, n_channels_per_head = shape[2:-1], shape[-1] + n_embedding_channels = int(n_channels_per_head * self.embed_fraction) + theta = get_theta(shape, n_embedding_channels, device) + return tuple(self.apply_rotary_emb(t, theta) for t in tensors) + + @staticmethod + def apply_rotary_emb(x: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + """Add rotary embedding to the input tensor. + + Parameters + ---------- + x + Input tensor to modify + theta + Rotation angles + """ + n_emb = theta.shape[-1] * 2 + if n_emb > x.shape[-1]: + raise ValueError(f'Embedding dimension {n_emb} is larger than input dimension {x.shape[-1]}') + (x1, x2), x_unembed = x[..., :n_emb].chunk(2, dim=-1), x[..., n_emb:] + result = torch.cat( + [x1 * theta.cos() - x2 * theta.sin(), x2 * theta.cos() + x1 * theta.sin(), x_unembed], dim=-1 + ) + return result diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index f988855e7..ffb98843d 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -1,6 +1,8 @@ """Neural network modules and utilities.""" from mrpro.nn.ComplexAsChannel import ComplexAsChannel +from mrpro.nn.AbsolutePositionEncoding import AbsolutePositionEncoding +from mrpro.nn.AxialRoPE import AxialRoPE from mrpro.nn.CondMixin import CondMixin from mrpro.nn.DropPath import DropPath from mrpro.nn.FiLM import FiLM @@ -12,6 +14,7 @@ from mrpro.nn.RMSNorm import RMSNorm from mrpro.nn.Residual import Residual from mrpro.nn.Sequential import Sequential +from mrpro.nn import attention from mrpro.nn import data_consistency from mrpro.nn.ndmodules import ( adaptiveAvgPoolND, @@ -24,6 +27,8 @@ ) __all__ = [ + 'AbsolutePositionEncoding', + 'AxialRoPE', 'ComplexAsChannel', 'CondMixin', 'DropPath', @@ -37,6 +42,7 @@ 'Residual', 'Sequential', 'adaptiveAvgPoolND', + 'attention', 'avgPoolND', 'batchNormND', 'convND', diff --git a/src/mrpro/nn/attention/AttentionGate.py b/src/mrpro/nn/attention/AttentionGate.py new file mode 100644 index 000000000..d7fdfeab4 --- /dev/null +++ b/src/mrpro/nn/attention/AttentionGate.py @@ -0,0 +1,74 @@ +"""Attention gate from Attention UNet.""" + +import torch +from torch.nn import Module, ReLU, Sequential, Sigmoid + +from mrpro.nn.ndmodules import convND + + +class AttentionGate(Module): + """Attention gate from Attention UNet. + + The attention mechanism from the attention UNet [OKT18]_. + + References + ---------- + ..[OKT18] Oktay, Ozan, et al. "Attention U-net: Learning where to look for the pancreas." MIDL (2018). + https://arxiv.org/abs/1804.03999 + """ + + def __init__( + self, n_dim: int, channels_gate: int, channels_in: int, channels_hidden: int, concatenate: bool = False + ): + """Initialize the attention gate. + + Parameters + ---------- + n_dim + The dimension, i.e. 1, 2 or 3. + channels_gate + The number of channels in the gate tensor. + channels_in + The number of channels in the input tensor. + channels_hidden + The number of internal, hidden channels. + concatenate + Whether to concatenate the gated signal with the gate signal in the channel dimension (1) + """ + super().__init__() + self.project_gate = convND(n_dim)(channels_gate, channels_hidden, kernel_size=1) + self.project_x = convND(n_dim)(channels_in, channels_hidden, kernel_size=1) + self.psi = Sequential( + ReLU(), + convND(n_dim)(channels_hidden, 1, kernel_size=1), + Sigmoid(), + ) + self.concatenate = concatenate + + def __call__(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + """Apply the attention gate. + + Parameters + ---------- + x + The input tensor. + gate + The gate tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, gate) + + def forward(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + """Apply the attention gate.""" + projected_gate = self.project_gate(gate) + projected_x = self.project_x(x) + projected_gate = torch.nn.functional.interpolate(projected_gate, size=x.shape[2:], mode='nearest') + alpha = self.psi(projected_gate + projected_x) + x = x * alpha + if self.concatenate: + gate = torch.nn.functional.interpolate(gate, size=x.shape[2:], mode='nearest') + x = torch.cat([x, gate], dim=1) + return x diff --git a/src/mrpro/nn/attention/LinearSelfAttention.py b/src/mrpro/nn/attention/LinearSelfAttention.py new file mode 100644 index 000000000..2bab08930 --- /dev/null +++ b/src/mrpro/nn/attention/LinearSelfAttention.py @@ -0,0 +1,98 @@ +"""Linear self-attention.""" + +import torch +from einops import rearrange +from torch import Tensor +from torch.nn import Linear, Module, ReLU + + +class LinearSelfAttention(Module): + """Linear multi-head self-attention via kernel trick. + + Uses a ReLU kernel to compute attention in O(N) [KAT20]_ time and space. + + + References + ---------- + .. [KAT20] Katharopoulos, Angelos, et al. Transformers are RNNs: Fast autoregressive transformers with linear + attention. ICML 2020. https://arxiv.org/abs/2006.16236 + """ + + def __init__( + self, + n_channels_in: int, + n_channels_out: int, + n_heads: int, + eps: float = 1e-6, + features_last: bool = False, + ): + """Initialize linear self-attention layer. + + Parameters + ---------- + n_channels_in + Input channel dimension. + n_channels_out + Output channel dimension. + n_heads + Number of attention heads. + eps + Small epsilon for numerical stability in normalization. + features_last + Whether the channel dimension is the last dimension, as common in transformer models, + or the second dimension, as common in image models. + """ + super().__init__() + self.features_last = features_last + self.eps = eps + self.n_heads = n_heads + channels_per_head = n_channels_in // n_heads + self.to_qkv = Linear(n_channels_in, 3 * channels_per_head * n_heads) + self.kernel_function = ReLU() + self.to_out = Linear(channels_per_head * n_heads, n_channels_out) + + def __call__(self, x: Tensor) -> Tensor: + """Apply linear self-attention. + + Parameters + ---------- + x + Tensor of shape `batch, channels, *spatial_dims` or (`batch, *spatial_dims, channels` if `features_last`) + + Returns + ------- + Tensor after attention, same shape as input. + """ + return super().__call__(x) + + def forward(self, x: Tensor) -> Tensor: + """Apply linear self-attention.""" + orig_dtype = x.dtype + if x.dtype == torch.float16: + x = x.float() + if not self.features_last: + x = x.moveaxis(1, -1) + spatial_shape = x.shape[1:-1] + + qkv = self.to_qkv(x) + query, key, value = rearrange( + qkv, 'batch ... (qkv head channels) -> qkv batch head (...) channels', qkv=3, head=self.n_heads + ) + + query = self.kernel_function(query) + key = self.kernel_function(key) + + # trick to avoid second attention calculation: add normalization slot + value = torch.nn.functional.pad(value, (0, 0, 0, 1), mode='constant', value=1.0) + + value_key = value @ key.transpose(-1, -2) + value_key_query = value_key @ query + normalization = value_key_query[..., -1:, :] + self.eps + attn = value_key_query[..., :-1, :] / normalization + attn = attn.moveaxis(1, -1).flatten(-2) # join heads and channels + out = self.to_out(attn) + out = out.to(orig_dtype) + out = out.unflatten(-2, spatial_shape) + if not self.features_last: + out = out.moveaxis(-1, 1) + return out diff --git a/src/mrpro/nn/attention/MultiHeadAttention.py b/src/mrpro/nn/attention/MultiHeadAttention.py new file mode 100644 index 000000000..212bc68eb --- /dev/null +++ b/src/mrpro/nn/attention/MultiHeadAttention.py @@ -0,0 +1,105 @@ +"""Multi-head Attention.""" + +import torch +from einops import rearrange +from torch.nn import Linear, Module + +from mrpro.nn.AxialRoPE import AxialRoPE + + +class MultiHeadAttention(Module): + """Multi-head Attention. + + Implements multihead scaled dot-product attention and supports "image-like" inputs, + i.e. `batch, channels, *spatial_dims` as well as "transformer-like" inputs, `batch, sequence, features`. + """ + + def __init__( + self, + n_channels_in: int, + n_channels_out: int, + n_heads: int, + features_last: bool = False, + p_dropout: float = 0.0, + n_channels_cross: int | None = None, + rope_embed_fraction: float = 0.0, + ): + """Initialize the Multi-head Attention. + + Parameters + ---------- + n_channels_in + Number of input channels. + n_channels_out + Number of output channels. + n_heads + number of attention heads + features_last + Whether the features dimension is the last dimension, as common in transformer models, + or the second dimension, as common in image models. + p_dropout + Dropout probability. + n_channels_cross + Number of channels for cross-attention. If `None`, use `n_channels_in`. + rope_embed_fraction + Fraction of channels to embed with RoPE. + """ + super().__init__() + channels_per_head_q = n_channels_in // n_heads + channels_per_head_kv = n_channels_cross // n_heads if n_channels_cross is not None else n_channels_in // n_heads + self.to_q = Linear(n_channels_in, channels_per_head_q * n_heads) + self.to_kv = Linear(n_channels_in, channels_per_head_kv * n_heads * 2) + self.p_dropout = p_dropout + self.features_last = features_last + self.to_out = Linear(n_channels_in, n_channels_out) + self.n_heads = n_heads + self.rope = AxialRoPE(rope_embed_fraction) + + def __call__(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: + """Apply multi-head attention. + + Parameters + ---------- + x + The input tensor. + cross_attention + The key and value tensors for cross-attention. If `None`, self-attention is applied. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cross_attention) + + def _reshape(self, x: torch.Tensor) -> torch.Tensor: + if not self.features_last: + x = x.moveaxis(1, -1) + return x.flatten(1, -2) + + def forward(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: + """Apply multi-head attention.""" + if cross_attention is None: + cross_attention = x + if not self.features_last: + x = x.moveaxis(1, -1) + cross_attention = cross_attention.moveaxis(1, -1) + + query = rearrange(self.to_q(x), 'batch ... (heads channels) -> batch heads ... channels ', heads=self.n_heads) + key, value = rearrange( + self.to_kv(cross_attention), + 'batch ... (kv heads channels) -> kv batch heads ... channels ', + heads=self.n_heads, + kv=2, + ) + query, key = self.rope(query, key) # NO-OP if rope_embed_fraction is 0.0 + query, key, value = query.flatten(2, -2), key.flatten(2, -2), value.flatten(2, -2) + y = torch.nn.functional.scaled_dot_product_attention( + query, key, value, dropout_p=self.p_dropout, is_causal=False + ) + y = rearrange(y, '... heads L channels -> ... L (heads channels)') + out = self.to_out(y).reshape(x.shape) + + if not self.features_last: + out = out.moveaxis(-1, 1) + + return out diff --git a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py new file mode 100644 index 000000000..27916ee6f --- /dev/null +++ b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py @@ -0,0 +1,237 @@ +"""Neighborhood Self Attention.""" + +from collections.abc import Sequence +from functools import cache, reduce +from typing import TYPE_CHECKING, TypeVar, cast + +import torch +from einops import rearrange +from packaging.version import parse as parse_version +from torch.nn import Linear, Module + +from mrpro.nn.AxialRoPE import AxialRoPE +from mrpro.utils.to_tuple import to_tuple + +T = TypeVar('T') + +if TYPE_CHECKING or parse_version(torch.__version__) >= parse_version('2.6'): + from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention +else: + + class BlockMask: + """Dummy class for older PyTorch versions.""" + + +_compiled_flex_attention = torch.compile( + lambda q, k, v, mask: flex_attention(q, k, v, block_mask=mask), + dynamic=False, +) + + +@torch.compiler.disable +@cache +def neighborhood_mask( + device: str, + input_size: torch.Size, + kernel_size: int | tuple[int, ...], + dilation: int | tuple[int, ...] = 1, + circular: bool | tuple[bool, ...] = False, +) -> BlockMask: # pragma: no cover + """Create a flex attention block mask for neighborhood attention. + + This function defines which key/value pairs a query can attend to based + on a local neighborhood. The neighborhood is defined by `kernel_size` + and `dilation` and can be circular (wrapping around edges). + + Parameters + ---------- + input_size + The dimensions of the input tensor (e.g., (H, W) for 2D). + kernel_size + The size of the attention neighborhood window. Can be a single + integer for a symmetric window or a sequence of integers for + each dimension. + dilation + The dilation factor for the neighborhood + Can be a single integer for a symmetric window or a sequence + of integers for each dimension. + circular + Whether the neighborhood wraps around the edges (circular padding). + Can be a single boolean or a sequence of booleans. + device + The device to create the mask on. + + Returns + ------- + A mask object suitable for `flex_attention` that defines the + allowed attention connections. + """ + kernel_size_tuple, dilation_tuple, circular_tuple = ( + to_tuple(len(input_size), x) for x in (kernel_size, dilation, circular) + ) + + def unravel_index(idx: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Convert a flat 1D index into multi-dimensional coordinates.""" + idx = idx.clone() + coords = [] + for dim in reversed(input_size): + coords.append(idx % dim) + idx = torch.div(idx, dim, rounding_mode='floor').long() + coords.reverse() + return tuple(coords) + + def mask( + _batch: torch.Tensor, + _head: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + ) -> torch.Tensor: + """Determine if a query can attend to a key/value pair.""" + q_coord = unravel_index(q_idx) + kv_coord = unravel_index(kv_idx) + + masks = [] + for input_, kernel_, dilation_, circular_, q_, kv_ in zip( + input_size, + kernel_size_tuple, + dilation_tuple, + circular_tuple, + q_coord, + kv_coord, + strict=False, + ): + masks.append((q_ % dilation_) == (kv_ % dilation_)) + kernel_dilation = kernel_ * dilation_ + window_left = kernel_dilation // 2 + window_right = (kernel_dilation // 2) + ((kernel_dilation % 2) - 1) + if circular_: + left = (q_ - kv_ + input_) % input_ + right = (kv_ - q_ + input_) % input_ + masks.append((left <= window_left) | (right <= window_right)) + else: + center = q_.clamp(window_left, input_ - 1 - window_right) + left = center - kv_ + right = kv_ - center + masks.append(((left >= 0) & (left <= window_left)) | ((right >= 0) & (right <= window_right))) + return reduce(lambda x, y: x & y, masks) + + qkv_len = input_size.numel() + return create_block_mask(mask, B=None, H=None, Q_LEN=qkv_len, KV_LEN=qkv_len, device=torch.device(device)) + + +class NeighborhoodSelfAttention(Module): + """Attention where each query attends to a neighborhood of the key and value. + + Neighborhood attention is a type of attention where each query attends to a neighborhood of the key and value. + It is a more efficient alternative to regular attention, especially for large input sizes [NAT]_. + + This implementation uses `~torch.nn.attention.flex_attention`. For a more efficient implementation, + see also [NATTEN]_. + + + References + ---------- + .. [NAT] Hassani, A. et al. "Neighborhood Attention Transformer" CVPR, 2023, https://arxiv.org/abs/2204.07143 + .. [NATTEN] https://github.com/SHI-Labs/NATTEN/ + """ + + n_head: int + kernel_size: int | tuple[int, ...] + dilation: int | tuple[int, ...] + circular: bool | tuple[bool, ...] + features_last: bool + + def __init__( + self, + n_channels_in: int, + n_channels_out: int, + n_heads: int, + kernel_size: int | Sequence[int], + dilation: int | Sequence[int] = 1, + circular: bool | Sequence[bool] = False, + features_last: bool = False, + rope_embed_fraction: float = 1.0, + ) -> None: + """Initialize a neighborhood attention module. + + The parameters `kernel_size`, `dilation`, and `circular` can either be sequences, interpreted as per-dimension + values, or scalars, interpreted as the same value for all dimensions. + + Parameters + ---------- + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. + n_heads + The number of attention heads. + kernel_size + The size of the attention neighborhood window. + dilation + The dilation factor for the neighborhood. + circular + Whether the neighborhood wraps around the edges (circular padding) + features_last + Whether the channels are in the last dimension of the tensor, as common in visíon transformers. + Otherwise, assume the channels are in the second dimension, as common in CNN models. + rope_embed_fraction + Fraction of channels to embed with RoPE. + + """ + if parse_version(torch.__version__) < parse_version('2.6.0'): + raise NotImplementedError('NeighborhoodSelfAttention requires PyTorch 2.6.0 or higher') + super().__init__() + self.n_head = n_heads + self.kernel_size = kernel_size if isinstance(kernel_size, int) else tuple(kernel_size) + self.dilation = dilation if isinstance(dilation, int) else tuple(dilation) + self.circular = circular if isinstance(circular, bool) else tuple(circular) + self.features_last = features_last + channels_per_head = n_channels_in // n_heads + self.to_qkv = Linear(n_channels_in, 3 * channels_per_head * n_heads) + self.to_out = Linear(channels_per_head * n_heads, n_channels_out) + self.rope = AxialRoPE(rope_embed_fraction) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply neighborhood attention to the input tensor. + + Parameters + ---------- + x + The input tensor, with shape `(batch, channels, *spatial_dims)` + or `(batch, *spatial_dims, channels)` (if `features_last`). + + Returns + ------- + The output tensor after attention, with the same shape as the input tensor. + """ + if not self.features_last: + x = x.moveaxis(1, -1) + spatial_shape = x.shape[1:-1] + qkv = self.to_qkv(x) + query, key, value = rearrange( + qkv, + 'batch ... (qkv heads channels) -> qkv batch heads (...) channels', + qkv=3, + heads=self.n_head, + ) + query, key = self.rope(query, key) + query, key, value = query.contiguous(), key.contiguous(), value.contiguous() + device = str(qkv.device) + mask = neighborhood_mask( + device=device, + input_size=spatial_shape, + kernel_size=self.kernel_size, + dilation=self.dilation, + circular=self.circular, + ) + mask = torch.compiler.assume_constant_result(mask) + if torch.compiler.is_compiling(): + out = cast(torch.Tensor, flex_attention(query, key, value, block_mask=mask)) + else: + out = cast(torch.Tensor, _compiled_flex_attention(query, key, value, mask)) + out = rearrange(out, 'batch head sequence channels -> batch sequence (head channels)') + out = self.to_out(out) + out = out.unflatten(-2, spatial_shape) + if not self.features_last: + out = out.moveaxis(-1, 1) + return out diff --git a/src/mrpro/nn/attention/ShiftedWindowAttention.py b/src/mrpro/nn/attention/ShiftedWindowAttention.py new file mode 100644 index 000000000..0935ff63d --- /dev/null +++ b/src/mrpro/nn/attention/ShiftedWindowAttention.py @@ -0,0 +1,131 @@ +"""Shifted Window Attention.""" + +import warnings + +import torch +from einops import rearrange +from torch.nn import Linear, Module + +from mrpro.utils.reshape import ravel_multi_index +from mrpro.utils.sliding_window import sliding_window + + +class ShiftedWindowAttention(Module): + """Shifted Window Attention. + + (Shifted) Window Attention calculates attention over windows of the input. + It was introduced in Swin Transformer [SWIN]_ and is used in Uformer. + + References + ---------- + .. [SWIN] Liu, Ze, et al. "Swin transformer: Hierarchical vision transformer using shifted windows." ICCV 2021. + """ + + rel_position_index: torch.Tensor + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + n_heads: int, + window_size: int = 7, + shifted: bool = True, + features_last: bool = False, + ): + """Initialize the ShiftedWindowAttention module. + + Parameters + ---------- + n_dim + The dimension of the input. + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. + n_heads + The number of attention heads. The number if channels per head is ``channels // n_heads``. + window_size + The size of the window. + shifted + Whether to shift the window. + features_last + Whether the features are last in the input tensor or in the second dimension. + """ + super().__init__() + self.n_heads = n_heads + self.window_size = window_size + self.shifted = shifted + self.features_last = features_last + channels_per_head = n_channels_in // n_heads + self.to_qkv = Linear(channels_per_head * n_heads, 3 * channels_per_head * n_heads) + self.to_out = Linear(channels_per_head * n_heads, n_channels_out) + self.n_dim = n_dim + coords_1d = torch.arange(window_size) + coords_nd = torch.stack(torch.meshgrid(*([coords_1d] * n_dim), indexing='ij'), 0).flatten(1) + rel_coords = coords_nd[:, :, None] - coords_nd[:, None, :] # (dim, window_size**dim, window_size**dim) + rel_coords += window_size - 1 # shift to >=0 + rel_position_index = ravel_multi_index(tuple(rel_coords), (2 * window_size - 1,) * n_dim) + self.register_buffer('rel_position_index', rel_position_index) + + self.relative_position_bias_table = torch.nn.Parameter(torch.empty((2 * window_size - 1) ** n_dim, n_heads)) + torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02, a=-0.04, b=0.04) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the ShiftedWindowAttention. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the ShiftedWindowAttention.""" + if not self.features_last: + x = x.moveaxis(1, -1) # now it is features last + if self.shifted: + x = torch.roll(x, (-(self.window_size // 2),) * self.n_dim, dims=tuple(range(-self.n_dim - 1, -1))) + + padding = [] + for s in x.shape[-self.n_dim - 1 : -1]: + target = ((s + self.window_size - 1) // self.window_size) * self.window_size + padding.extend([target - s, 0]) + x_padded = torch.nn.functional.pad(x, (0, 0, *padding[::-1]), mode='circular') if any(padding) else x + + qkv = self.to_qkv(x_padded) + windowed = sliding_window( + qkv, window_shape=self.window_size, stride=self.window_size, dim=range(-self.n_dim - 1, -1) + ) + q, k, v = rearrange( + windowed.flatten(-self.n_dim - 1, -2), + '... sequence (qkv heads channels)->qkv ... heads sequence channels', + heads=self.n_heads, + qkv=3, + ) + bias = rearrange(self.relative_position_bias_table[self.rel_position_index], 'wd1 wd2 heads -> 1 heads wd1 wd2') + with warnings.catch_warnings(): + # Inductor in torch 2.6 warns for small batch*n_patches*n_heads about suboptimal softmax compilation. + warnings.filterwarnings('ignore', message='.*softmax.*') + attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=bias) + attention = rearrange(attention, '... head sequence channels->... sequence (head channels)') + attention = attention.unflatten(-2, windowed.shape[-self.n_dim - 1 : -1]) + # permute (in 3d) batch channels z y x wz wy wx -> batch channels wz z wy y wx x + attention = attention.moveaxis(list(range(self.n_dim)), list(range(2, 2 + 2 * self.n_dim, 2))) + attention = attention.reshape(x_padded.shape) + if any(padding): + crop_idx = (Ellipsis, *[slice(0, s) for s in x.shape[-self.n_dim - 1 : -1]], slice(None)) + attention = attention[crop_idx] + if self.shifted: + attention = torch.roll( + attention, (self.window_size // 2,) * self.n_dim, dims=tuple(range(-self.n_dim - 1, -1)) + ) + out = self.to_out(attention) + if not self.features_last: + out = out.moveaxis(-1, 1) + return out diff --git a/src/mrpro/nn/attention/SpatialTransformerBlock.py b/src/mrpro/nn/attention/SpatialTransformerBlock.py new file mode 100644 index 000000000..18817b26f --- /dev/null +++ b/src/mrpro/nn/attention/SpatialTransformerBlock.py @@ -0,0 +1,216 @@ +"""Spatial transformer block.""" + +from collections.abc import Sequence +from typing import Literal + +import torch +from torch.nn import Dropout, Linear, Module + +from mrpro.nn.attention.MultiHeadAttention import MultiHeadAttention +from mrpro.nn.attention.NeighborhoodSelfAttention import NeighborhoodSelfAttention +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.GEGLU import GEGLU +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.RMSNorm import RMSNorm +from mrpro.nn.Sequential import Sequential + + +def zero_init(m: Module) -> Module: + """Initialize module weights and bias to zero.""" + if hasattr(m, 'weight') and isinstance(m.weight, torch.Tensor): + torch.nn.init.zeros_(m.weight) + if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor): + torch.nn.init.zeros_(m.bias) + return m + + +class BasicTransformerBlock(CondMixin, Module): + """Basic vision transformer block.""" + + def __init__( + self, + channels: int, + n_heads: int, + p_dropout: float = 0.0, + cond_dim: int = 0, + mlp_ratio: float = 4, + features_last: bool = False, + rope_embed_fraction: float = 0.0, + attention_neighborhood: int | None = None, + ): + """Initialize the basic transformer block. + + Parameters + ---------- + channels + Number of channels in the input and output. + n_heads + Number of attention heads. + p_dropout + Dropout probability. + cond_dim + Number of channels in the conditioning tensor. + mlp_ratio + Ratio of the hidden dimension to the input dimension. + features_last + Whether the features are last in the input tensor. + rope_embed_fraction + Fraction of channels to embed with RoPE. + attention_neighborhood + If not None, use neighborhood self attention with the given neighborhood size instead + of global self attention. + """ + super().__init__() + self.features_last = features_last + + if attention_neighborhood is None: + attention: Module = MultiHeadAttention( + n_channels_in=channels, + n_channels_out=channels, + n_heads=n_heads, + p_dropout=p_dropout, + features_last=True, + rope_embed_fraction=rope_embed_fraction, + ) + else: + if p_dropout > 0: + raise ValueError('p_dropout > 0 is not supported for neighborhood self attention') + attention = NeighborhoodSelfAttention( + n_channels_in=channels, + n_channels_out=channels, + n_heads=n_heads, + features_last=True, + kernel_size=attention_neighborhood, + circular=True, + rope_embed_fraction=rope_embed_fraction, + ) + self.selfattention = Sequential(LayerNorm(channels, features_last=True), attention) + hidden_dim = int(channels * mlp_ratio) + self.ff = Sequential( + LayerNorm(channels, features_last=True, cond_dim=cond_dim), + GEGLU(channels, hidden_dim, features_last=True), + Dropout(p_dropout), + Linear(hidden_dim, channels), + ) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the basic transformer block. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. If None, no conditioning is applied. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the basic transformer block.""" + if not self.features_last: + x = x.moveaxis(1, -1).contiguous() + x = self.selfattention(x) + x + x = self.ff(x, cond=cond) + x + if not self.features_last: + x = x.moveaxis(-1, 1).contiguous() + return x + + +class SpatialTransformerBlock(CondMixin, Module): + """Spatial transformer block.""" + + def __init__( + self, + dim_groups: Sequence[tuple[int, ...]], + channels: int, + n_heads: int, + depth: int = 1, + p_dropout: float = 0.0, + cond_dim: int = 0, + rope_embed_fraction: float = 0.0, + attention_neighborhood: int | None = None, + features_last: bool = False, + norm: Literal['group', 'rms'] = 'group', + ): + """Initialize the spatial transformer block. + + Parameters + ---------- + dim_groups + Groups of spatial dimensions for separate attention mechanisms. + channels + Number of channels in the input and output. + n_heads + Number of attention heads for each group. + depth + Number of transformer blocks for each group. + p_dropout + Dropout probability. + cond_dim + Dimension of the conditioning tensor. + rope_embed_fraction + Fraction of channels to embed with RoPE. + attention_neighborhood + If not None, use NeighborhoodSelfAttention with the given neighborhood size instead of MultiHeadAttention. + features_last + Whether the features are last in the input tensor, as common in transformer models. + norm + Whether to use GroupNorm or RMSNorm. + """ + super().__init__() + hidden_dim = n_heads * (channels // n_heads) + match norm: + case 'group': + self.norm: Module = GroupNorm(channels, features_last=features_last) + case 'rms': + self.norm = RMSNorm(channels, features_last=features_last) + case _: + raise ValueError(f'Invalid norm: {norm}') + self.features_last = features_last + self.proj_in = Linear(channels, hidden_dim) + self.transformer_blocks = Sequential() + for group in (g for _ in range(depth) for g in dim_groups): + if not self.features_last: + group = tuple(g - 1 if g < 0 else g for g in group) + block = BasicTransformerBlock( + hidden_dim, + n_heads, + p_dropout=p_dropout, + cond_dim=cond_dim, + features_last=True, + rope_embed_fraction=rope_embed_fraction, + attention_neighborhood=attention_neighborhood, + ) + self.transformer_blocks.append(PermutedBlock(group, block, features_last=True)) + self.proj_out = Linear(hidden_dim, channels) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the spatial transformer block.""" + skip = x + h = self.norm(x) + if not self.features_last: + h = h.movedim(1, -1) + h = self.proj_in(h) + h = self.transformer_blocks(h, cond=cond) + h = self.proj_out(h) + if not self.features_last: + h = h.movedim(-1, 1) + return skip + h + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the spatial transformer block. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. If None, no conditioning is applied. + + Returns + ------- + Output tensor. + """ + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/attention/SqueezeExcitation.py b/src/mrpro/nn/attention/SqueezeExcitation.py new file mode 100644 index 000000000..5f7802c75 --- /dev/null +++ b/src/mrpro/nn/attention/SqueezeExcitation.py @@ -0,0 +1,57 @@ +"""Squeeze-and-Excitation block.""" + +import torch +from torch.nn import Module, ReLU, Sigmoid + +from mrpro.nn.ndmodules import adaptiveAvgPoolND, convND +from mrpro.nn.Sequential import Sequential + + +class SqueezeExcitation(Module): + """Squeeze-and-Excitation block. + + Sequeeze-and-Excitation block from [SE]_. + + References + ---------- + ..[SE] Hu, Jie, Li Shen, and Gang Sun. "Squeeze-and-excitation networks." CVPR 2018, https://arxiv.org/abs/1709.01507 + """ + + def __init__(self, n_dim: int, n_channels_input: int, n_channels_squeeze: int) -> None: + """Initialize SqueezeExcitation. + + Parameters + ---------- + n_dim + The dimension of the input tensor. + n_channels_input + The number of channels in the input tensor. + n_channels_squeeze + The number of channels in the squeeze tensor. + """ + super().__init__() + self.scale = Sequential( + adaptiveAvgPoolND(n_dim)(1), + convND(n_dim)(n_channels_input, n_channels_squeeze, kernel_size=1), + ReLU(), + convND(n_dim)(n_channels_squeeze, n_channels_input, kernel_size=1), + Sigmoid(), + ) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply SqueezeExcitation. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply SqueezeExcitation.""" + return x * self.scale(x) diff --git a/src/mrpro/nn/attention/TransposedAttention.py b/src/mrpro/nn/attention/TransposedAttention.py new file mode 100644 index 000000000..88e993c8f --- /dev/null +++ b/src/mrpro/nn/attention/TransposedAttention.py @@ -0,0 +1,76 @@ +"""Transposed Attention from Restormer.""" + +import torch +from einops import rearrange +from torch.nn import Module, Parameter + +from mrpro.nn.ndmodules import convND + + +class TransposedAttention(Module): + """Transposed Self Attention from Restormer. + + Implements the transposed self-attention, i.e. channel-wise multihead self-attention, + layer from Restormer [ZAM22]_. + + References + ---------- + .. [ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for high-resolution image restoration." + CVPR 2022, https://arxiv.org/pdf/2111.09881.pdf + """ + + def __init__(self, n_dim: int, n_channels_in: int, n_channels_out: int, n_heads: int): + """Initialize a TransposedAttention layer. + + Parameters + ---------- + n_dim + input dimension + n_channels_in + Number of channels in the input tensor. + n_channels_out + Number of channels in the output tensor. + n_heads + Number of attention heads. + """ + super().__init__() + self.n_heads = n_heads + self.temperature = Parameter(torch.ones(n_heads, 1, 1)) + channels_per_head = n_channels_in // n_heads + self.to_qkv = convND(n_dim)(n_channels_in, channels_per_head * n_heads * 3, kernel_size=1) + self.qkv_dwconv = convND(n_dim)( + channels_per_head * n_heads * 3, + channels_per_head * n_heads * 3, + kernel_size=3, + groups=n_channels_in * 3, + padding=1, + bias=False, + ) + self.to_out = convND(n_dim)(channels_per_head * n_heads, n_channels_out, kernel_size=1) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply transposed attention. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply transposed attention.""" + qkv = self.qkv_dwconv(self.to_qkv(x)) + q, k, v = rearrange(qkv, 'b (qkv heads channels) ... -> qkv b heads (...) channels', heads=self.n_heads, qkv=3) + q = torch.nn.functional.normalize(q, dim=-1) * self.temperature + k = torch.nn.functional.normalize(k, dim=-1) + attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=1.0) + out = rearrange(attention, '... heads points channels -> ... (heads channels) points').unflatten( + -1, x.shape[2:] + ) + out = self.to_out(out) + return out diff --git a/src/mrpro/nn/attention/__init__.py b/src/mrpro/nn/attention/__init__.py new file mode 100644 index 000000000..719ff1409 --- /dev/null +++ b/src/mrpro/nn/attention/__init__.py @@ -0,0 +1,17 @@ +from mrpro.nn.attention.AttentionGate import AttentionGate +from mrpro.nn.attention.LinearSelfAttention import LinearSelfAttention +from mrpro.nn.attention.NeighborhoodSelfAttention import NeighborhoodSelfAttention +from mrpro.nn.attention.ShiftedWindowAttention import ShiftedWindowAttention +from mrpro.nn.attention.SqueezeExcitation import SqueezeExcitation +from mrpro.nn.attention.TransposedAttention import TransposedAttention +from mrpro.nn.attention.SpatialTransformerBlock import SpatialTransformerBlock + +__all__ = [ + "AttentionGate", + "LinearSelfAttention", + "NeighborhoodSelfAttention", + "ShiftedWindowAttention", + "SpatialTransformerBlock", + "SqueezeExcitation", + "TransposedAttention" +] \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 8490674e9..b2aa1cba2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,11 +12,17 @@ from mrpro.data.enums import AcqFlags from mrpro.utils import RandomGenerator from mrpro.utils.reshape import unsqueeze_tensors_left +from packaging.version import parse as parse_version from xsdata.models.datatype import XmlDate, XmlTime from tests.data import IsmrmrdRawTestData from tests.phantoms import EllipsePhantomTestData +minimal_torch_26 = pytest.mark.xfail( + parse_version(torch.__version__) < parse_version('2.6'), + reason='Requires PyTorch >= 2.6', +) + def generate_random_encodingcounter_properties(rng: RandomGenerator) -> dict[str, Any]: return { diff --git a/tests/nn/test_ape.py b/tests/nn/test_ape.py new file mode 100644 index 000000000..9f71444fc --- /dev/null +++ b/tests/nn/test_ape.py @@ -0,0 +1,27 @@ +"""Tests for absolute position encoding""" + +import pytest +import torch +from mrpro.nn import AbsolutePositionEncoding +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_absolute_position_encodings(device: str) -> None: + """Test absolute position encoding.""" + n_features = 32 + shape = (1, 2 * n_features, 32, 32) + ape = AbsolutePositionEncoding(2, n_features, True, 128).to(device) + rng = RandomGenerator(444) + x1 = rng.float32_tensor(shape).to(device) + x2 = rng.float32_tensor(shape).to(device) + y1, y2 = ape(x1), ape(x2) + assert y1.shape == x1.shape + torch.testing.assert_close(y1 - x1, y2 - x2) + assert (x1[:, n_features:] == y1[:, n_features:]).all() # unembedded features diff --git a/tests/nn/test_attentiongate.py b/tests/nn/test_attentiongate.py new file mode 100644 index 000000000..10d30cb07 --- /dev/null +++ b/tests/nn/test_attentiongate.py @@ -0,0 +1,51 @@ +"""Tests for AttentionGate module.""" + +from collections.abc import Sequence + +import pytest +from mrpro.nn.attention import AttentionGate +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_dim', 'n_channels_gate', 'n_channels_in', 'n_channels_hidden', 'input_shape', 'gate_shape'), + [ + (2, 32, 32, 16, (1, 32, 32, 32), (1, 32, 16, 16)), + (3, 32, 4, 8, (2, 4, 16, 16, 16), (2, 32, 16, 16, 16)), + ], +) +def test_attention_gate( + n_dim: int, + n_channels_gate: int, + n_channels_in: int, + n_channels_hidden: int, + input_shape: Sequence[int], + gate_shape: Sequence[int], + device: str, +) -> None: + """Test AttentionGate output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + gate = rng.float32_tensor(gate_shape).to(device).requires_grad_(True) + attn = AttentionGate( + n_dim=n_dim, channels_gate=n_channels_gate, channels_in=n_channels_in, channels_hidden=n_channels_hidden + ).to(device) + output = attn(x, gate) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert gate.grad is not None, 'No gradient computed for gate' + assert not output.isnan().any(), 'NaN values in output' + assert not gate.isnan().any(), 'NaN values in gate' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert not gate.grad.isnan().any(), 'NaN values in gate gradients' + assert attn.project_gate.weight.grad is not None, 'No gradient computed for project_gate' + assert attn.project_x.weight.grad is not None, 'No gradient computed for project_x' + assert attn.psi[1].weight.grad is not None, 'No gradient computed for psi' diff --git a/tests/nn/test_linearselfattention.py b/tests/nn/test_linearselfattention.py new file mode 100644 index 000000000..11d17b301 --- /dev/null +++ b/tests/nn/test_linearselfattention.py @@ -0,0 +1,58 @@ +"""Tests for LinearSelfAttention module.""" + +import pytest +from mrpro.nn.attention import LinearSelfAttention +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels_in', 'n_channels_out', 'n_heads', 'input_shape', 'features_last'), + [ + (32, 32, 4, (1, 32, 32, 32), False), + (64, 64, 8, (2, 64, 16, 16), False), + (16, 16, 2, (1, 16, 16, 16), True), + ], +) +def test_linear_self_attention( + n_channels_in: int, + n_channels_out: int, + n_heads: int, + input_shape: tuple[int, ...], + features_last: bool, + device: str, +) -> None: + """Test LinearSelfAttention output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + + attn = LinearSelfAttention( + n_channels_in=n_channels_in, + n_channels_out=n_channels_out, + n_heads=n_heads, + features_last=features_last, + ).to(device) + + if features_last: + output = attn(x.moveaxis(1, -1)).moveaxis(-1, 1) + else: + output = attn(x) + + expected_shape = (x.shape[0], n_channels_out, *x.shape[2:]) + assert output.shape == expected_shape, f'Output shape {output.shape} != expected shape {expected_shape}' + assert not output.isnan().any(), 'NaN values in output' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + assert attn.to_qkv.weight.grad is not None, 'No gradient computed for to_qkv.weight' + assert attn.to_qkv.bias.grad is not None, 'No gradient computed for to_qkv.bias' + assert attn.to_out.weight.grad is not None, 'No gradient computed for to_out.weight' + assert attn.to_out.bias.grad is not None, 'No gradient computed for to_out.bias' diff --git a/tests/nn/test_neighborhoodselfattention.py b/tests/nn/test_neighborhoodselfattention.py new file mode 100644 index 000000000..3b51e963d --- /dev/null +++ b/tests/nn/test_neighborhoodselfattention.py @@ -0,0 +1,149 @@ +"""Tests for NeighborhoodSelfAttention module.""" + +import pytest +import torch +from mrpro.nn.attention.NeighborhoodSelfAttention import NeighborhoodSelfAttention +from mrpro.utils import RandomGenerator +from tests.conftest import minimal_torch_26 + + +@minimal_torch_26 +@pytest.mark.parametrize( + 'device', + [ + pytest.param( + 'cpu', + id='cpu', + marks=pytest.mark.skip( + reason='Flex Attention backward not supported on CPU. https://github.com/pytorch/pytorch/issues/148752' + ), + ), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels_in', 'n_channels_out', 'n_heads', 'kernel_size', 'input_shape', 'features_last'), + [ + (2, 3, 1, 2, (1, 2, 16, 16), False), + (3, 2, 2, 4, (1, 3, 8, 8, 8, 8), True), + ], + ids=['2d_kernel2', '4d_features-last_kernel4'], +) +def test_neighborhood_self_attention_backward( + n_channels_in: int, + n_channels_out: int, + n_heads: int, + kernel_size: int, + input_shape: tuple[int, ...], + features_last: bool, + device: str, +) -> None: + """Test NeighborhoodSelfAttention output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + + attention = NeighborhoodSelfAttention( + n_channels_in=n_channels_in, + n_channels_out=n_channels_out, + n_heads=n_heads, + kernel_size=kernel_size, + features_last=features_last, + ).to(device) + + if features_last: + output = attention(x.moveaxis(1, -1)).moveaxis(-1, 1) + else: + output = attention(x) + + expected_shape = (input_shape[0], n_channels_out, *input_shape[2:]) + assert output.shape == expected_shape + assert not output.isnan().any(), 'NaN values in output' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + assert attention.to_qkv.weight.grad is not None, 'No gradient computed for to_qkv.weight' + assert attention.to_qkv.bias.grad is not None, 'No gradient computed for to_qkv.bias' + assert attention.to_out.weight.grad is not None, 'No gradient computed for to_out.weight' + assert attention.to_out.bias.grad is not None, 'No gradient computed for to_out.bias' + + +@minimal_torch_26 +@pytest.mark.cuda +@pytest.mark.parametrize( + ('kernel_size', 'dilation', 'circular', 'rope'), + [ + (3, 1, False, True), + (5, 2, True, False), + (7, 1, False, True), + ], +) +def test_neighborhood_attention_variants(kernel_size: int, dilation: int, circular: bool, rope: bool) -> None: + """Test NeighborhoodSelfAttention with different neighborhood configurations.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 16, 16)).cuda() + + attention = NeighborhoodSelfAttention( + n_channels_in=32, + n_channels_out=32, + n_heads=4, + kernel_size=kernel_size, + dilation=dilation, + circular=circular, + rope_embed_fraction=1.0 if rope else 0.0, + ) + output = attention(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + +@minimal_torch_26 +@pytest.mark.parametrize( + ('kernel_size', 'circular', 'input_shape'), + [ + (11, False, (1, 8, 32, 32)), + (3, True, (1, 8, 64, 64)), + ], + ids=['regular', 'circular'], +) +@torch.no_grad() +def test_neighborhood_constraint(kernel_size: int, circular: bool, input_shape: tuple[int, int, int, int]) -> None: + """Test that neighborhood attention only affects pixels within the kernel window.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape) + attention = NeighborhoodSelfAttention( + n_channels_in=8, + n_channels_out=8, + n_heads=2, + kernel_size=kernel_size, + dilation=1, + circular=circular, + ) + output_original = attention(x) + x_modified = x.clone() + test_point = (input_shape[-2] - 2, input_shape[-1] - 2) + x_modified[..., test_point[0], test_point[1]] += 1.0 + output_modified = attention(x_modified) + + diff = output_modified - output_original + changed_pixels = torch.abs(diff).sum(dim=(0, 1)) > 1e-6 + + half_kernel = kernel_size // 2 + h, w = input_shape[2], input_shape[3] + + i_coords, j_coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') + + if circular: + h_dist = torch.minimum((i_coords - test_point[0]) % h, (test_point[0] - i_coords) % h) + w_dist = torch.minimum((j_coords - test_point[1]) % w, (test_point[1] - j_coords) % w) + in_neighborhood = (h_dist <= half_kernel) & (w_dist <= half_kernel) + else: + h_min, h_max = max(0, test_point[0] - half_kernel), min(h, test_point[0] + half_kernel + 1) + w_min, w_max = max(0, test_point[1] - half_kernel), min(w, test_point[1] + half_kernel + 1) + in_neighborhood = (i_coords >= h_min) & (i_coords < h_max) & (j_coords >= w_min) & (j_coords < w_max) + + neighborhood_changed = changed_pixels[in_neighborhood].all() + outside_changed = changed_pixels[~in_neighborhood].any() + + assert neighborhood_changed, 'Not all pixels in the neighborhood changed, which indicates a problem' + assert not outside_changed, 'Pixels outside the neighborhood changed, which violates the constraint' diff --git a/tests/nn/test_rope.py b/tests/nn/test_rope.py new file mode 100644 index 000000000..665c4bed4 --- /dev/null +++ b/tests/nn/test_rope.py @@ -0,0 +1,36 @@ +"""Tests for AxialRoPE module.""" + +import pytest +import torch +from mrpro.nn import AxialRoPE +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_rope(device: torch.device) -> None: + """Test AxialRoPE rotation and embedding functionality.""" + shape = (10, 10) + n_heads = 2 + n_channels = 64 + n_embed = int(0.5 * n_channels) + q, k = RandomGenerator(seed=42).float32_tensor((2, 1, n_heads, *shape, n_channels), low=0.5).to(device) + + rope = AxialRoPE(embed_fraction=0.5) + (q_rope, k_rope) = rope(q, k) + + assert q_rope.shape == q.shape + assert k_rope.shape == k.shape + + # non embedded channels should be the same + torch.testing.assert_close(q[..., n_embed:], q_rope[..., n_embed:]) + torch.testing.assert_close(k[..., n_embed:], k_rope[..., n_embed:]) + + # other should change + assert not torch.isclose(q_rope[..., :n_embed], q[..., :n_embed]).all() + assert not torch.isclose(k_rope[..., :n_embed], k[..., :n_embed]).all() diff --git a/tests/nn/test_shiftedwindowattention.py b/tests/nn/test_shiftedwindowattention.py new file mode 100644 index 000000000..5d416efee --- /dev/null +++ b/tests/nn/test_shiftedwindowattention.py @@ -0,0 +1,61 @@ +"""Tests for ShiftedWindowAttention module.""" + +import pytest +from mrpro.nn.attention import ShiftedWindowAttention +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_dim', 'window_size', 'shifted'), + [ + (2, 8, False), + (4, 4, True), + ], +) +def test_shifted_window_attention(n_dim: int, window_size: int, shifted: bool, device: str) -> None: + """Test ShiftedWindowAttention output shape and backpropagation.""" + n_batch, n_channels, n_heads = 2, 8, 2 + spatial_shape = (window_size * 4,) * n_dim + rng = RandomGenerator(13) + x = rng.float32_tensor((n_batch, n_channels, *spatial_shape)).to(device).requires_grad_(True) + swin = ShiftedWindowAttention( + n_dim=n_dim, + n_channels_in=n_channels, + n_channels_out=n_channels, + n_heads=n_heads, + window_size=window_size, + shifted=shifted, + ).to(device) + out = swin(x) + assert out.shape == x.shape, f'Output shape {out.shape} != input shape {x.shape}' + assert not out.isnan().any(), 'NaN values in output' + out.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert swin.to_qkv.weight.grad is not None, 'No gradient computed for to_qkv.weight' + assert swin.relative_position_bias_table.grad is not None, 'No gradient computed for relative_position_bias_table' + + +@pytest.mark.parametrize('shifted', [True, False], ids=['shifted', 'non-shifted']) +def test_shifted_window_attention_size_mismatch(shifted: bool): + n_batch, n_channels, n_heads, n_dim, window_size = 3, 4, 2, 2, 7 + spatial_shape = (window_size * 4 + 1,) * n_dim + rng = RandomGenerator(13) + x = rng.float32_tensor((n_batch, n_channels, *spatial_shape)) + swin = ShiftedWindowAttention( + n_dim=n_dim, + n_channels_in=n_channels, + n_channels_out=n_channels, + n_heads=n_heads, + window_size=window_size, + shifted=shifted, + ) + out = swin(x) + assert out.shape == x.shape, f'Output shape {out.shape} != input shape {x.shape}' diff --git a/tests/nn/test_spatialtransformerblock.py b/tests/nn/test_spatialtransformerblock.py new file mode 100644 index 000000000..4c78ceecd --- /dev/null +++ b/tests/nn/test_spatialtransformerblock.py @@ -0,0 +1,142 @@ +"""Test SpatialTransformerBlock""" + +from collections.abc import Sequence +from typing import Literal, cast + +import pytest +import torch +from mrpro.nn.attention import SpatialTransformerBlock +from mrpro.utils import RandomGenerator +from tests.conftest import minimal_torch_26 + + +@minimal_torch_26 +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + ('channels', 'cond_dim', 'attention_neighborhood', 'features_last', 'norm', 'input_shape'), + [ + pytest.param(32, 16, None, True, 'group', (16, 16), id='2d-cond-group-last-global'), + pytest.param(64, 16, 7, False, 'group', (16, 16), marks=minimal_torch_26, id='2d-cond-group-first-NA'), + pytest.param(64, 16, 5, True, 'group', (16, 16), marks=minimal_torch_26, id='2d-cond-group-last-NA'), + pytest.param(64, 0, 7, True, 'rms', (16, 8, 16), marks=minimal_torch_26, id='3d-nocond-rms-last-NA'), + ], +) +def test_spatialtransformerblock_backward( + channels: int, + cond_dim: int, + attention_neighborhood: int | None, + features_last: bool, + norm: Literal['group', 'rms'], + input_shape: Sequence[int], + device: str, + torch_compile: bool, +) -> None: + """Test SpatialTransformerBlock output shape and backpropagation.""" + if device == 'cpu' and attention_neighborhood is not None: + pytest.skip( + 'CompiledFlex Attention backward not supported on CPU. https://github.com/pytorch/pytorch/issues/148752' + ) + rng = RandomGenerator(seed=42) + + x = rng.float32_tensor((1, channels, *input_shape)).to(device).requires_grad_(True) + cond = rng.float32_tensor((1, cond_dim)).to(device).requires_grad_(True) if cond_dim else None + + if features_last: + dims = tuple(range(-len(input_shape) - 1, -1)) + else: + dims = tuple(range(-len(input_shape), 0)) + + block = SpatialTransformerBlock( + dim_groups=[dims], + channels=channels, + n_heads=4, + depth=1, + p_dropout=0, + cond_dim=cond_dim, + rope_embed_fraction=0.5, + attention_neighborhood=attention_neighborhood, + features_last=features_last, + norm=norm, + ).to(device) + if torch_compile: + block = cast(SpatialTransformerBlock, torch.compile(block, dynamic=False)) + if features_last: + output = block(x.moveaxis(1, -1), cond=cond).moveaxis(-1, 1) + else: + output = block(x, cond=cond) + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + if cond is not None: + assert cond.grad is not None, 'No gradient computed for conditioning' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' + + +@minimal_torch_26 +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize('torch_compile', [False, True]) +@pytest.mark.parametrize( + ('channels', 'cond_dim', 'attention_neighborhood', 'features_last', 'norm', 'input_shape'), + [ + pytest.param(32, 16, None, True, 'group', (16, 16), id='2d-cond-group-last-global'), + pytest.param(64, 16, 5, True, 'group', (16, 16), marks=minimal_torch_26, id='2d-cond-group-last-NA'), + pytest.param(64, 16, 7, False, 'group', (16, 16), marks=minimal_torch_26, id='2d-cond-group-first-NA'), + pytest.param(64, 0, 7, True, 'rms', (16, 8, 16), marks=minimal_torch_26, id='3d-nocond-rms-last-NA'), + ], +) +def test_spatialtransformerblock_forward( + channels: int, + cond_dim: int, + attention_neighborhood: int | None, + features_last: bool, + norm: Literal['group', 'rms'], + input_shape: Sequence[int], + device: str, + torch_compile: bool, +) -> None: + """Test SpatialTransformerBlock output shape and backpropagation.""" + + rng = RandomGenerator(seed=42) + + x = rng.float32_tensor((1, channels, *input_shape)).to(device).requires_grad_(True) + cond = rng.float32_tensor((1, cond_dim)).to(device).requires_grad_(True) if cond_dim else None + + if features_last: + dims = tuple(range(-len(input_shape) - 1, -1)) + else: + dims = tuple(range(-len(input_shape), 0)) + + block = SpatialTransformerBlock( + dim_groups=[dims], + channels=channels, + n_heads=4, + depth=1, + p_dropout=0, + cond_dim=cond_dim, + rope_embed_fraction=0.5, + attention_neighborhood=attention_neighborhood, + features_last=features_last, + norm=norm, + ).to(device) + if torch_compile: + block = cast(SpatialTransformerBlock, torch.compile(block, dynamic=False)) + with torch.no_grad(): + if features_last: + output = block(x.moveaxis(1, -1), cond=cond).moveaxis(-1, 1) + else: + output = block(x, cond=cond) + assert output.shape == x.shape + assert not output.isnan().any(), 'NaN values in output' diff --git a/tests/nn/test_squeezeexcitation.py b/tests/nn/test_squeezeexcitation.py new file mode 100644 index 000000000..879dd71ca --- /dev/null +++ b/tests/nn/test_squeezeexcitation.py @@ -0,0 +1,32 @@ +"""Tests for SqueezeExcitation module.""" + +from collections.abc import Sequence + +import pytest +from mrpro.nn.attention import SqueezeExcitation +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + ('dim', 'input_shape', 'squeeze_channels'), + [ + (2, (1, 64, 32, 32), 16), + (3, (1, 64, 16, 16, 16), 16), + ], +) +def test_squeeze_excitation( + dim: int, + input_shape: Sequence[int], + squeeze_channels: int, +) -> None: + """Test SqueezeExcitation output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).requires_grad_(True) + se = SqueezeExcitation(n_dim=dim, n_channels_input=input_shape[1], n_channels_squeeze=squeeze_channels) + output = se(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert se.scale[1].weight.grad is not None, 'No gradient computed for Conv' diff --git a/tests/nn/test_transposedattention.py b/tests/nn/test_transposedattention.py new file mode 100644 index 000000000..afbe53494 --- /dev/null +++ b/tests/nn/test_transposedattention.py @@ -0,0 +1,44 @@ +"""Tests for TransposedAttention module.""" + +from collections.abc import Sequence + +import pytest +from mrpro.nn.attention import TransposedAttention +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'channels', 'num_heads', 'input_shape'), + [ + (2, 32, 4, (1, 32, 32, 32)), + (3, 64, 8, (2, 64, 16, 16, 16)), + ], +) +def test_transposed_attention( + dim: int, + channels: int, + num_heads: int, + input_shape: Sequence[int], + device: str, +) -> None: + """Test TransposedAttention output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + attn = TransposedAttention(n_dim=dim, n_channels_in=channels, n_channels_out=channels, n_heads=num_heads).to(device) + output = attn(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert attn.to_qkv.weight.grad is not None, 'No gradient computed for qkv' + assert attn.qkv_dwconv.weight.grad is not None, 'No gradient computed for qkv_dwconv' + assert attn.to_out.weight.grad is not None, 'No gradient computed for project_out' + assert attn.temperature.grad is not None, 'No gradient computed for temperature' From 1513fb3c19d683a6352d86504225847946b664ca Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 8 Feb 2026 22:57:29 +0100 Subject: [PATCH 181/205] add unet, basic cnn, and residual blocks ghstack-source-id: 7bc5fb55c3853583a882536e0faac0fc3baa3645 ghstack-comment-id: 3865651070 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/956 --- src/mrpro/nn/ResBlock.py | 70 ++++++ src/mrpro/nn/SeparableResBlock.py | 89 +++++++ src/mrpro/nn/__init__.py | 6 + src/mrpro/nn/nets/BasicCNN.py | 105 +++++++++ src/mrpro/nn/nets/UNet.py | 364 +++++++++++++++++++++++++++++ src/mrpro/nn/nets/__init__.py | 8 + tests/nn/nets/test_cnn.py | 58 +++++ tests/nn/nets/test_unet.py | 116 +++++++++ tests/nn/test_resblock.py | 56 +++++ tests/nn/test_separableresblock.py | 58 +++++ 10 files changed, 930 insertions(+) create mode 100644 src/mrpro/nn/ResBlock.py create mode 100644 src/mrpro/nn/SeparableResBlock.py create mode 100644 src/mrpro/nn/nets/BasicCNN.py create mode 100644 src/mrpro/nn/nets/UNet.py create mode 100644 src/mrpro/nn/nets/__init__.py create mode 100644 tests/nn/nets/test_cnn.py create mode 100644 tests/nn/nets/test_unet.py create mode 100644 tests/nn/test_resblock.py create mode 100644 tests/nn/test_separableresblock.py diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py new file mode 100644 index 000000000..32870979f --- /dev/null +++ b/src/mrpro/nn/ResBlock.py @@ -0,0 +1,70 @@ +"""Residual convolution block with two convolutions.""" + +import torch +from torch.nn import Identity, Module, SiLU + +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import convND +from mrpro.nn.Sequential import Sequential + + +class ResBlock(CondMixin, Module): + """Residual convolution block with two convolutions.""" + + def __init__(self, n_dim: int, n_channels_in: int, n_channels_out: int, cond_dim: int) -> None: + """Initialize the ResBlock. + + Parameters + ---------- + n_dim + The number of dimensions, i.e. 1, 2 or 3. + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. + cond_dim + The number of features in the conditioning tensor used in a FiLM. + If set to 0 no FiLM is used. + + """ + super().__init__() + self.rezero = torch.nn.Parameter(torch.tensor(0.1)) + self.block = Sequential( + GroupNorm(n_channels_in), + SiLU(), + convND(n_dim)(n_channels_in, n_channels_out, kernel_size=3, padding=1), + GroupNorm(n_channels_out), + SiLU(), + convND(n_dim)(n_channels_out, n_channels_out, kernel_size=3, padding=1), + ) + if cond_dim > 0: + self.block.insert(-3, FiLM(n_channels_out, cond_dim)) + + if n_channels_out == n_channels_in: + self.skip_connection: Module = Identity() + else: + self.skip_connection = convND(n_dim)(n_channels_in, n_channels_out, kernel_size=1) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the ResBlock. + + Parameters + ---------- + x + The input tensor. + cond + A conditioning tensor to be used for FiLM. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the ResBlock.""" + h = self.block(x, cond=cond) + x = self.skip_connection(x) + self.rezero * h + return x diff --git a/src/mrpro/nn/SeparableResBlock.py b/src/mrpro/nn/SeparableResBlock.py new file mode 100644 index 000000000..a12fda16f --- /dev/null +++ b/src/mrpro/nn/SeparableResBlock.py @@ -0,0 +1,89 @@ +"""Residual block with separable convolutions.""" + +from collections.abc import Sequence + +import torch +from torch.nn import Module, SiLU + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import convND +from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.Sequential import Sequential + + +class SeparableResBlock(Module): + """Residual block with separable convolutions.""" + + def __init__( + self, + dim_groups: Sequence[Sequence[int]], + n_channels_in: int, + n_channels_out: int, + cond_dim: int, + ) -> None: + """Initialize the SeparableResBlock. + + Applies convolutions as separable convolutions with SilU activation and group normalization. + For example, if ``dim_groups = ((-1,-2), (-3))`` then one 2D convolution is applied to the last two dimensions, + and one 1D convolution is applied to the last dimension. + The order within the block is Norm->Activation->Conv. + The whole sequence for all dimension groups is performed twice, with optional FiLM conditioning in between. + So for two `dim_groups`, a total of 4 convolutions are applied. + + Parameters + ---------- + dim_groups + Sequence of dimension groups to use in the convolutions. + n_channels_in + Number of input channels. + n_channels_out + Number of output channels. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + """ + super().__init__() + self.rezero = torch.nn.Parameter(torch.tensor(0.1)) + + def block(dims: Sequence[int], channels_in: int) -> Module: + return Sequential( + GroupNorm(channels_in), + SiLU(), + PermutedBlock(dims, convND(len(dims))(channels_in, n_channels_out, 3, padding=1)), + ) + + blocks = Sequential(*(block(d, n_channels_in if i == 0 else n_channels_out) for i, d in enumerate(dim_groups))) + if cond_dim > 0: + blocks.append(FiLM(n_channels_out, cond_dim)) + blocks.extend(block(d, n_channels_out) for d in dim_groups) + self.block = blocks + self.skip_connection = None + if n_channels_in != n_channels_out: + self.skip_connection = torch.nn.Linear(n_channels_in, n_channels_out) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. + + Returns + ------- + Output tensor with the same number and order of dimensions as the input. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock.""" + h = self.block(x, cond=cond) + if self.skip_connection is None: + skip = x + else: + skip = torch.moveaxis(x, 1, -1) + skip = self.skip_connection(skip) + skip = torch.moveaxis(skip, -1, 1) + return skip + self.rezero * h diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index ffb98843d..94f28625e 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -12,10 +12,13 @@ from mrpro.nn.LayerNorm import LayerNorm from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.RMSNorm import RMSNorm +from mrpro.nn.ResBlock import ResBlock from mrpro.nn.Residual import Residual +from mrpro.nn.SeparableResBlock import SeparableResBlock from mrpro.nn.Sequential import Sequential from mrpro.nn import attention from mrpro.nn import data_consistency +from mrpro.nn import nets from mrpro.nn.ndmodules import ( adaptiveAvgPoolND, avgPoolND, @@ -39,7 +42,9 @@ 'LayerNorm', 'PermutedBlock', 'RMSNorm', + 'ResBlock', 'Residual', + 'SeparableResBlock', 'Sequential', 'adaptiveAvgPoolND', 'attention', @@ -50,4 +55,5 @@ 'data_consistency', 'instanceNormND', 'maxPoolND', + 'nets', ] diff --git a/src/mrpro/nn/nets/BasicCNN.py b/src/mrpro/nn/nets/BasicCNN.py new file mode 100644 index 000000000..5bf911eef --- /dev/null +++ b/src/mrpro/nn/nets/BasicCNN.py @@ -0,0 +1,105 @@ +"""Basic CNN.""" + +from collections.abc import Sequence +from itertools import pairwise +from typing import Literal + +import torch +from torch.nn import LeakyReLU, ReLU, SiLU + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import batchNormND, convND +from mrpro.nn.Sequential import Sequential + + +class BasicCNN(Sequential): + """Basic CNN. + + A series of convolutions (window 3, stride 1, padding 1), normalization and activation. + Allows to use FiLM conditioning. + Order is Conv -> Norm (optional) -> FiLM (optional) -> Activation. + + If you need more flexibility, use `~mrpro.nn.Sequential` directly. + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + norm: Literal['batch', 'group', 'instance', 'none', 'layer'] = 'none', + activation: Literal['relu', 'silu', 'leaky_relu'] = 'relu', + n_features: Sequence[int] = (64, 64, 64), + cond_dim: int = 0, + ): + """Initialize a basic CNN. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of input channels. + n_channels_out + The number of output channels. + norm + The type of normalization to use. If 'batch', use batch normalization. If 'group', use group normalization, + if 'instance', use instance normalization, and if `layer`, use layer normalization. + If 'none', use no normalization. + activation + The type of activation to use. If 'relu', use ReLU. If 'silu', use SiLU. If 'leaky_relu', use LeakyReLU. + n_features + The number of features in the hidden layers. The length of this sequence determines the number of hidden + layers. The total number of convolutions is `len(n_features) + 1`. + cond_dim + The dimension of the condition tensor. If 0, no FiLM conditioning is applied. + Otherwise, between convolutions, after normalization, FiLM conditioning is applied. + """ + super().__init__() + use_film = cond_dim > 0 + + self.append(convND(n_dim)(n_channels_in, n_features[0], kernel_size=3, padding='same')) + + for c_in, c_out in pairwise((*n_features, n_channels_out)): + if norm.lower() == 'batch': + self.append(batchNormND(n_dim)(c_in, affine=not use_film)) + elif norm.lower() == 'group': + self.append(GroupNorm(c_in, affine=not use_film)) + elif norm.lower() == 'instance': + self.append(GroupNorm(c_in, n_groups=c_in, affine=not use_film)) # is instance norm + elif norm.lower() == 'layer': + self.append(GroupNorm(c_in, n_groups=1, affine=not use_film)) # is layer norm + elif norm.lower() != 'none': + raise ValueError(f'Invalid normalization type: {norm}') + + if use_film: + self.append(FiLM(c_in, cond_dim)) + + if activation.lower() == 'relu': + self.append(ReLU(True)) + elif activation.lower() == 'silu': + self.append(SiLU(inplace=True)) + elif activation.lower() == 'leaky_relu': + self.append(LeakyReLU(inplace=True)) + else: + raise ValueError(f'Invalid activation type: {activation}') + + self.append(convND(n_dim)(c_in, c_out, kernel_size=3, padding='same')) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: # type: ignore[override] + """Apply the basic CNN to the input tensor. + + Parameters + ---------- + x + The input tensor. Should be of shape `(batch_size, channels_in, *spatial dimensions)` + with `spatial dimensions` being of length `dim`. + cond + The condition tensor. If None, no FiLM conditioning is applied. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py new file mode 100644 index 000000000..7b92b1c73 --- /dev/null +++ b/src/mrpro/nn/nets/UNet.py @@ -0,0 +1,364 @@ +"""UNet variants.""" + +from collections.abc import Sequence +from functools import partial +from itertools import pairwise + +import torch +from torch.nn import Identity, Module, ModuleList, ReLU, SiLU + +from mrpro.nn.attention.AttentionGate import AttentionGate +from mrpro.nn.attention.SpatialTransformerBlock import SpatialTransformerBlock +from mrpro.nn.CondMixin import call_with_cond +from mrpro.nn.FiLM import FiLM +from mrpro.nn.join import Concat +from mrpro.nn.ndmodules import convND, maxPoolND +from mrpro.nn.ResBlock import ResBlock +from mrpro.nn.Sequential import Sequential +from mrpro.nn.Upsample import Upsample + + +class UNetEncoder(Module): + """Encoder.""" + + def __init__( + self, + first_block: Module, + blocks: Sequence[Module], + down_blocks: Sequence[Module], + middle_block: Module, + ) -> None: + """Initialize the UNetEncoder.""" + super().__init__() + self.first = first_block + """The first block. Should expand from the number of input channels.""" + + self.blocks = ModuleList(blocks) + """The encoder blocks. Order is highest resolution to lowest resolution.""" + + self.down_blocks = ModuleList(down_blocks) + """The downsampling blocks""" + + self.middle_block = middle_block + """Also called bottleneck block""" + + def __len__(self): + """Get the number of resolutions levels.""" + return len(self.down_blocks) + 1 + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple[torch.Tensor, ...]: + """Apply to Network.""" + call = partial(call_with_cond, cond=cond) + + x = call(self.first, x) + + xs = [] + for block, down in zip(self.blocks, self.down_blocks, strict=True): + x = call(block, x) + xs.append(x) + x = call(down, x) + + x = call(self.middle_block, x) + + return (*xs, x) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple[torch.Tensor, ...]: + """Apply to Network. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor. + + Returns + ------- + The tensors at the different resolutions, highest resolution first. + """ + return super().__call__(x, cond=cond) + + +class UNetDecoder(Module): + """Decoder.""" + + def __init__( + self, + blocks: Sequence[Module], + up_blocks: Sequence[Module], + concat_blocks: Sequence[Module], + last_block: Module, + ) -> None: + """Initialize the UNetDecoder.""" + super().__init__() + self.blocks = ModuleList(blocks) + """The decoder blocks. Order is lowest resolution to highest resolution.""" + + self.up_blocks = ModuleList(up_blocks) + """The upsampling blocks""" + + self.concat_blocks = ModuleList(concat_blocks) + """Joins the skip connections with the upsampled features from a lower resolution level""" + + self.last_block = last_block + """The last block. Should reduce to the number of output channels.""" + + def __len__(self): + """Get the number of resolutions levels.""" + return len(self.up_blocks) + 1 + + def forward(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network.""" + call = partial(call_with_cond, cond=cond) + + x = hs[-1] # lowest resolution, from middle block + for block, up, concat, h in zip(self.blocks, self.up_blocks, self.concat_blocks, hs[-2::-1], strict=True): + x = call(up, x) + x = concat(h, x) + x = call(block, x) + x = call(self.last_block, x) + return x + + def __call__(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network. + + Parameters + ---------- + hs + The tensors at the different resolutions, highest resolution first. + cond + The conditioning tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(hs, cond=cond) + + +class UNetBase(Module): + """Base class for U-shaped networks.""" + + def __init__(self, encoder: UNetEncoder, decoder: UNetDecoder, skip_blocks: Sequence[Module] | None = None) -> None: + """Initialize the UNetBase.""" + super().__init__() + self.encoder = encoder + """The encoder.""" + + self.decoder = decoder + """The decoder.""" + + self.skip_blocks = ModuleList() + """Modifications of the skip connections.""" + + if len(decoder) != len(encoder): + raise ValueError( + 'The number of resolutions in the encoder and decoder must be the same, ' + f'got {len(decoder)} and {len(encoder)}' + ) + + if skip_blocks is None: + self.skip_blocks.extend(Identity() for _ in range(len(decoder))) + elif len(skip_blocks) != len(decoder): + raise ValueError( + f'The number of skip blocks must be the same as the number of resolutions, ' + f'got {len(skip_blocks)} and {len(encoder)}' + ) + else: + self.skip_blocks.extend(skip_blocks) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network.""" + xs = self.encoder(x, cond=cond) + xs = tuple( + call_with_cond(self.skip_blocks[i], x, cond=cond) if i < len(self.skip_blocks) else x + for i, x in enumerate(xs) + ) + x = self.decoder(xs, cond=cond) + return x + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) + + +class UNet(UNetBase): + """UNet. + + U-shaped convolutional network with optional patch attention. + Inspired by [NOSENSE_] and the OpenAi DDPM UNet/Latent Diffusion UNet [LDM]_. + significant differences to the vanilla UNet [UNET]_ include: + - Spatial transformer blocks + - Convolutional downsampling, nearest neighbor upsampling + - Residual convolution blocks with pre-act group normalization and SiLU activation + + + References + ---------- + .. [UNET] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image + segmentation MICCAI 2015. https://arxiv.org/abs/1505.04597 + .. [LDM] https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py + .. [NOSENSE] Zimmermann, FF, and Kofler, Andreas. "NoSENSE: Learned unrolled cardiac MRI reconstruction without + explicit sensitivity maps." STACOM 2023. https://github.com/fzimmermann89/CMRxRecon/blob/master/src/cmrxrecon/nets/unet.py + + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + attention_depths: Sequence[int] = (-1,), + n_features: Sequence[int] = (64, 128, 192, 256), + n_heads: int = 8, + cond_dim: int = 0, + encoder_blocks_per_scale: int = 2, + ) -> None: + """Initialize the UNet. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. + attention_depths + The depths at which to apply attention. + n_features + Number of features at each resolution level. The length determines the number of resolution levels. + n_heads + Number of attention heads. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + encoder_blocks_per_scale + Number of encoder blocks per resolution level. The number of decoder blocks is one more. + """ + depth = len(n_features) + if not all(-depth <= d < depth for d in attention_depths): + raise ValueError( + f'attention_depths must be in the range [-depth, depth], got {attention_depths=} for {depth=}' + ) + attention_depths = tuple(d % depth for d in attention_depths) + if len(attention_depths) != len(set(attention_depths)): + raise ValueError(f'attention_depths must be unique, got {attention_depths=}') + + def attention_block(channels: int) -> Module: + dim_groups = (tuple(range(-n_dim, 0)),) + return SpatialTransformerBlock(dim_groups, channels, n_heads, cond_dim=cond_dim) + + def blocks(channels_in: int, channels_out: int, attention: bool) -> Module: + blocks = Sequential() + for _ in range(encoder_blocks_per_scale): + blocks.append(ResBlock(n_dim, channels_in, channels_out, cond_dim)) + if attention: + blocks.append(attention_block(channels_out)) + channels_in = channels_out + return blocks + + encoder_blocks: list[Module] = [] + down_blocks: list[Module] = [] + decoder_blocks: list[Module] = [] + up_blocks: list[Module] = [] + + for i_level, (n_feat, n_feat_next) in enumerate(pairwise(n_features)): + encoder_blocks.append(blocks(n_feat, n_feat, i_level in attention_depths)) + down_blocks.append(convND(n_dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) + decoder_blocks.append(blocks(n_feat_next + n_feat, n_feat, i_level in attention_depths)) + up_blocks.append(Upsample(tuple(range(-n_dim, 0)), scale_factor=2)) + + middle_block = Sequential( + ResBlock(n_dim, n_feat_next, n_feat_next, cond_dim), + ResBlock(n_dim, n_feat_next, n_feat_next, cond_dim), + ) + if depth - 1 in attention_depths: + middle_block.insert(1, attention_block(n_feat_next)) + first_block = convND(n_dim)(n_channels_in, n_features[0], 3, padding=1) + encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) + + decoder_blocks, up_blocks = decoder_blocks[::-1], up_blocks[::-1] + last_block = Sequential( + SiLU(), + convND(n_dim)(n_features[0], n_channels_out, 3, padding=1), + ) + concat_blocks = [Concat() for _ in range(len(decoder_blocks))] + decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + + super().__init__(encoder, decoder) + + +class AttentionGatedUNet(UNetBase): + """UNet with attention gates. + + Basic UNet with attention gating of the skip signals by the lower resolution features [OKT18]_. + + References + ---------- + .. [OKT18] Oktay, Ozan, et al. "Attention U-net: Learning where to look for the pancreas." MIDL (2018). + https://arxiv.org/abs/1804.03999 + """ + + def __init__( + self, n_dim: int, n_channels_in: int, n_channels_out: int, n_features: Sequence[int], cond_dim: int = 0 + ): + """Initialize the AttentionGatedUNet. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. + n_features + Number of features at each resolution level. The length determines the number of resolution levels. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + """ + + def block(channels_in: int, channels_out: int) -> Module: + block = Sequential( + convND(n_dim)(channels_in, channels_out, 3, padding=1), + ReLU(True), + convND(n_dim)(channels_out, channels_out, 3, padding=1), + ReLU(True), + ) + if cond_dim > 0: + block.insert(2, FiLM(channels_out, cond_dim)) + return block + + encoder_blocks: list[Module] = [] + down_blocks: list[Module] = [] + n_feat_old = n_channels_in + for n_feat in n_features[:-1]: + encoder_blocks.append(block(n_feat_old, n_feat)) + down_blocks.append(maxPoolND(n_dim)(2)) + n_feat_old = n_feat + middle_block = block(n_features[-2], n_features[-1]) + encoder = UNetEncoder(Identity(), encoder_blocks, down_blocks, middle_block) + + concat_blocks = [] + decoder_blocks: list[Module] = [] + up_blocks: list[Module] = [] + for n_feat, n_feat_skip in pairwise(n_features[::-1]): + concat_blocks.append(AttentionGate(n_dim, n_feat, n_feat_skip, n_feat_skip, concatenate=True)) + decoder_blocks.append(block(n_feat + n_feat_skip, n_feat_skip)) + up_blocks.append(Upsample(range(-n_dim, 0), scale_factor=2)) + last_block = convND(n_dim)(n_features[0], n_channels_out, 1) + decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + + super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py new file mode 100644 index 000000000..06271d970 --- /dev/null +++ b/src/mrpro/nn/nets/__init__.py @@ -0,0 +1,8 @@ +from mrpro.nn.nets.BasicCNN import BasicCNN +from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet + +__all__ = [ + 'AttentionGatedUNet', + 'BasicCNN', + 'UNet', +] diff --git a/tests/nn/nets/test_cnn.py b/tests/nn/nets/test_cnn.py new file mode 100644 index 000000000..5dbbd5ad6 --- /dev/null +++ b/tests/nn/nets/test_cnn.py @@ -0,0 +1,58 @@ +"""Tests for BasicCNN network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import BasicCNN + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_cnn_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the cnn.""" + cnn = BasicCNN( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + norm='layer', + n_features=(8, 8), + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cnn = cnn.to(device) + x = x.to(device) + if torch_compile: + cnn = cast(BasicCNN, torch.compile(cnn)) + y = cnn(x) + assert y.shape == (1, 1, 16, 16) + + +def test_cnn_backward() -> None: + cnn = BasicCNN( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + norm='instance', + activation='silu', + n_features=(8, 8), + cond_dim=32, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = cnn(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in cnn.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' diff --git a/tests/nn/nets/test_unet.py b/tests/nn/nets/test_unet.py new file mode 100644 index 000000000..fdf2f5250 --- /dev/null +++ b/tests/nn/nets/test_unet.py @@ -0,0 +1,116 @@ +"""Tests for UNet and AttentionGatedUNet networks.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import AttentionGatedUNet, UNet + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_unet_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the UNet.""" + unet = UNet( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + attention_depths=(-1,), + n_features=(4, 6, 8), + n_heads=2, + cond_dim=32, + encoder_blocks_per_scale=1, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + unet = unet.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + unet = cast(UNet, torch.compile(unet)) + y = unet(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + + +def test_unet_backward() -> None: + unet = UNet( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + attention_depths=(-1,), + n_features=(4, 6, 8), + n_heads=2, + cond_dim=32, + encoder_blocks_per_scale=1, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = unet(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in unet.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_gated_unet_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the AttentionGatedUNet.""" + unet = AttentionGatedUNet( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_features=(4, 6, 8), + cond_dim=32, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + unet = unet.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + unet = cast(AttentionGatedUNet, torch.compile(unet)) + y = unet(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + + +def test_gated_unet_backward() -> None: + """Test the backward pass of the AttentionGatedUNet.""" + unet = AttentionGatedUNet( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_features=(4, 6, 8), + cond_dim=32, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = unet(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in unet.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' diff --git a/tests/nn/test_resblock.py b/tests/nn/test_resblock.py new file mode 100644 index 000000000..ea4356173 --- /dev/null +++ b/tests/nn/test_resblock.py @@ -0,0 +1,56 @@ +"""Tests for ResBlock module.""" + +from collections.abc import Sequence +from typing import cast + +import pytest +import torch +from mrpro.nn import ResBlock +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'eager']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'channels_in', 'channels_out', 'cond_dim', 'input_shape', 'cond_shape'), + [ + (2, 32, 32, 16, (1, 32, 32, 32), (1, 16)), + (3, 64, 32, 0, (2, 64, 16, 16, 16), None), + ], +) +def test_resblock( + dim: int, + channels_in: int, + channels_out: int, + cond_dim: int, + input_shape: Sequence[int], + cond_shape: Sequence[int] | None, + device: str, + torch_compile: bool, +) -> None: + """Test ResBlock output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) if cond_shape else None + block = ResBlock(n_dim=dim, n_channels_in=channels_in, n_channels_out=channels_out, cond_dim=cond_dim).to(device) + if torch_compile: + block = cast(ResBlock, torch.compile(block, dynamic=False)) + output = block(x, cond=cond) + assert output.shape == (input_shape[0], channels_out, *input_shape[2:]), ( + f'Output shape {output.shape} != expected {(input_shape[0], channels_out, *input_shape[2:])}' + ) + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert block.block[2].weight.grad is not None, 'No gradient computed for first Conv' + if cond is not None: + assert cond.grad is not None, 'No gradient computed for conditioning' + assert not cond.isnan().any(), 'NaN values in conditioning' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' diff --git a/tests/nn/test_separableresblock.py b/tests/nn/test_separableresblock.py new file mode 100644 index 000000000..25b6c65e0 --- /dev/null +++ b/tests/nn/test_separableresblock.py @@ -0,0 +1,58 @@ +"""Tests for SeparableResBlock module.""" + +from collections.abc import Sequence +from typing import cast + +import pytest +import torch +from mrpro.nn import SeparableResBlock +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'eager']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim_groups', 'channels_in', 'channels_out', 'cond_dim', 'input_shape', 'cond_shape'), + [ + (((-1, -2),), 32, 32, 16, (1, 32, 32, 32), (1, 16)), + (((-1, -2), (-3,)), 64, 32, 0, (2, 64, 16, 16, 16), None), # 2D + 1D + ], +) +def test_separable_resblock( + dim_groups: Sequence[Sequence[int]], + channels_in: int, + channels_out: int, + cond_dim: int, + input_shape: Sequence[int], + cond_shape: Sequence[int] | None, + device: str, + torch_compile: bool, +) -> None: + """Test SeparableResBlock output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) if cond_shape else None + block = SeparableResBlock( + dim_groups=dim_groups, n_channels_in=channels_in, n_channels_out=channels_out, cond_dim=cond_dim + ).to(device) + if torch_compile: + block = cast(SeparableResBlock, torch.compile(block, dynamic=False)) + output = block(x, cond=cond) + assert output.shape == (input_shape[0], channels_out, *input_shape[2:]), ( + f'Output shape {output.shape} != expected {(input_shape[0], channels_out, *input_shape[2:])}' + ) + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert block.block[0][2].module.weight.grad is not None, 'No gradient computed for first Conv' # type: ignore[union-attr] + if cond is not None: + assert cond.grad is not None, 'No gradient computed for conditioning' + assert not cond.isnan().any(), 'NaN values in conditioning' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' From 100057b02a117dc76899e5a265426768b4fcaf43 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 8 Feb 2026 22:57:30 +0100 Subject: [PATCH 182/205] add restormer architecture and tests ghstack-source-id: 9fa6f1dde7cd0771d4637360acc4315918a3040c ghstack-comment-id: 3865651248 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/957 --- src/mrpro/nn/nets/Restormer.py | 223 ++++++++++++++++++++++++++++++++ src/mrpro/nn/nets/__init__.py | 2 + tests/nn/nets/test_restormer.py | 62 +++++++++ 3 files changed, 287 insertions(+) create mode 100644 src/mrpro/nn/nets/Restormer.py create mode 100644 tests/nn/nets/test_restormer.py diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py new file mode 100644 index 000000000..357bebf51 --- /dev/null +++ b/src/mrpro/nn/nets/Restormer.py @@ -0,0 +1,223 @@ +"""Restormer implementation.""" + +from collections.abc import Sequence +from itertools import pairwise + +import torch +from torch.nn import Module + +from mrpro.nn.attention.TransposedAttention import TransposedAttention +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.FiLM import FiLM +from mrpro.nn.join import Concat +from mrpro.nn.ndmodules import convND, instanceNormND +from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder +from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample +from mrpro.nn.Sequential import Sequential + + +class GDFN(Module): + """Gated depthwise feed forward network. + + Feed-forward block used in Restormer [ZAM22]_. It first expands channels, + applies a depthwise convolution, then uses a gated interaction between two + channel splits before projecting back to the input width. + + References + ---------- + .. [ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for + high-resolution image restoration." CVPR 2022. + """ + + def __init__(self, n_dim: int, n_channels: int, mlp_ratio: float): + """Initialize GDFN. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + mlp_ratio + Ratio for hidden dimension expansion + """ + super().__init__() + + hidden_features = int(n_channels * mlp_ratio) + self.project_in = convND(n_dim)(n_channels, hidden_features * 2, kernel_size=1) + self.depthwise_conv = convND(n_dim)( + hidden_features * 2, + hidden_features * 2, + kernel_size=3, + stride=1, + padding=1, + groups=hidden_features * 2, + ) + self.project_out = convND(n_dim)(hidden_features, n_channels, kernel_size=1) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the gated depthwise feed forward network. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Output tensor + """ + x = self.project_in(x) + x1, x2 = self.depthwise_conv(x).chunk(2, dim=1) + x = x1 * torch.sigmoid(x2) + x = self.project_out(x) + return x + + +class RestormerBlock(CondMixin, Module): + """Transformer block with transposed attention and gated depthwise feed forward network.""" + + def __init__(self, n_dim: int, n_channels: int, n_heads: int, mlp_ratio: float, cond_dim: int = 0): + """Initialize RestormerBlock. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + n_heads + Number of attention heads + mlp_ratio + Ratio for hidden dimension expansion + cond_dim + Dimension of conditioning input. If 0, no conditioning is applied. + """ + super().__init__() + self.norm1 = Sequential(instanceNormND(n_dim)(n_channels)) + self.attn = TransposedAttention(n_dim, n_channels, n_channels, n_heads) + self.norm2 = Sequential(instanceNormND(n_dim)(n_channels)) + self.ffn = GDFN(n_dim, n_channels, mlp_ratio) + if cond_dim > 0: + self.norm2.append(FiLM(channels=n_channels, cond_dim=cond_dim)) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply Restormer block. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor. If None, no conditioning is applied. + + Returns + ------- + Output tensor + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Forward pass for RestormerBlock.""" + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x, cond=cond)) + return x + + +class Restormer(UNetBase): + """Restormer architecture. + + Implements the Restormer [ZAM22]_ network, which is a U-shaped transformer + with channel wise attention and depthwise convolutions in the feed forward network. + + References + ---------- + .. [ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for high-resolution image restoration." + CVPR 2022, https://arxiv.org/pdf/2111.09881.pdf + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + n_blocks: Sequence[int] = (4, 6, 6, 8), + n_refinement_blocks: int = 4, + n_heads: Sequence[int] = (1, 2, 4, 8), + n_channels_per_head: int = 48, + mlp_ratio: float = 2.66, + cond_dim: int = 0, + ): + """Initialize Restormer. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of input channels. + n_channels_out + The number of output channels. + n_blocks + Number of blocks in each stage + n_refinement_blocks + Number of refinement blocks + n_heads + Number of attention heads in each stage + n_channels_per_head + Number of channels per attention head + mlp_ratio + Ratio for hidden dimension expansion + cond_dim + Dimension of conditioning input. If 0, no conditioning is applied. + """ + if len(n_blocks) != len(n_heads): + raise ValueError('n_blocks and n_heads must have the same length.') + + def blocks(n_heads: int, n_blocks: int): + layers = Sequential( + *(RestormerBlock(n_dim, n_channels_per_head * n_heads, n_heads, mlp_ratio) for _ in range(n_blocks)) + ) + + if cond_dim > 0 and n_blocks > 1: + layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, cond_dim=cond_dim)) + return layers + + first_block = convND(n_dim)(n_channels_in, n_channels_per_head, kernel_size=3, stride=1, padding=1, bias=False) + encoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)] + down_blocks = [ + PixelUnshuffleDownsample(n_dim, n_channels_per_head * head_current, n_channels_per_head * head_next) + for head_current, head_next in pairwise(n_heads) + ] + middle_block = blocks(n_heads[-1], n_blocks[-1]) + encoder = UNetEncoder( + first_block=first_block, + blocks=encoder_blocks, + down_blocks=down_blocks, + middle_block=middle_block, + ) + + up_blocks = [ + PixelShuffleUpsample(n_dim, n_channels_per_head * head_next, n_channels_per_head * head_current) + for head_current, head_next in pairwise(n_heads) + ][::-1] + concat_blocks = [ + Sequential( + Concat(), + convND(n_dim)(2 * n_channels_per_head * head, n_channels_per_head * head, kernel_size=1), + ) + for head in n_heads[-2::-1] + ] + decoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)][::-1] + last_block = Sequential( + *(RestormerBlock(n_dim, n_channels_per_head, n_heads[0], mlp_ratio) for _ in range(n_refinement_blocks)), + convND(n_dim)(n_channels_per_head, n_channels_out, kernel_size=3, stride=1, padding=1), + ) + decoder = UNetDecoder( + blocks=decoder_blocks, + up_blocks=up_blocks, + concat_blocks=concat_blocks, + last_block=last_block, + ) + + super().__init__(encoder=encoder, decoder=decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 06271d970..850c95ca6 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -1,8 +1,10 @@ from mrpro.nn.nets.BasicCNN import BasicCNN +from mrpro.nn.nets.Restormer import Restormer from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet __all__ = [ 'AttentionGatedUNet', 'BasicCNN', + 'Restormer', 'UNet', ] diff --git a/tests/nn/nets/test_restormer.py b/tests/nn/nets/test_restormer.py new file mode 100644 index 000000000..68c84a689 --- /dev/null +++ b/tests/nn/nets/test_restormer.py @@ -0,0 +1,62 @@ +"""Tests for Restormer network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import Restormer + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_restormer_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the restormer.""" + restormer = Restormer( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_heads=(1, 2, 4), + n_blocks=(2, 1, 1), + cond_dim=32, + n_channels_per_head=2, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + restormer = restormer.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + restormer = cast(Restormer, torch.compile(restormer)) + y = restormer(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + + +def test_restormer_backward() -> None: + restormer = Restormer( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_heads=(1, 2), + n_blocks=(2, 2), + cond_dim=32, + n_channels_per_head=4, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = restormer(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in restormer.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From 497e58f7c9838f04fe22a2fa28e8f8b54d3c9e85 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 8 Feb 2026 22:57:31 +0100 Subject: [PATCH 183/205] add swinir architecture and tests ghstack-source-id: dad29a088353d8f916000d8c550dd832160f8587 ghstack-comment-id: 3865651450 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/958 --- src/mrpro/nn/nets/SwinIR.py | 247 ++++++++++++++++++++++++++++++++++ src/mrpro/nn/nets/__init__.py | 2 + tests/nn/nets/test_swinir.py | 56 ++++++++ 3 files changed, 305 insertions(+) create mode 100644 src/mrpro/nn/nets/SwinIR.py create mode 100644 tests/nn/nets/test_swinir.py diff --git a/src/mrpro/nn/nets/SwinIR.py b/src/mrpro/nn/nets/SwinIR.py new file mode 100644 index 000000000..bc064848d --- /dev/null +++ b/src/mrpro/nn/nets/SwinIR.py @@ -0,0 +1,247 @@ +"""SwinIR implementation.""" + +import torch +from torch.nn import GELU, Module + +from mrpro.nn.attention.ShiftedWindowAttention import ShiftedWindowAttention +from mrpro.nn.DropPath import DropPath +from mrpro.nn.FiLM import FiLM +from mrpro.nn.ndmodules import convND, instanceNormND +from mrpro.nn.Sequential import Sequential + + +class SwinTransformerLayer(Module): + """Swin Transformer layer. + + Implements a single layer of the Swin Transformer architecture. + """ + + def __init__( + self, + n_dim: int, + n_channels: int, + n_heads: int, + window_size: int, + mlp_ratio: int = 4, + emb_dim: int = 0, + p_droppath: float = 0.0, + ): + """Initialize SwinTransformerLayer. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + n_heads + Number of attention heads + window_size + Size of the attention window + mlp_ratio + Ratio for hidden dimension expansion in MLP + emb_dim + Dimension of conditioning input. If 0, no FiLM conditioning is used. + p_droppath + Droppath probability for MLP + """ + super().__init__() + self.norm1 = instanceNormND(n_dim)(n_channels) + self.attn = ShiftedWindowAttention(n_dim, n_channels, n_channels, n_heads, window_size) + self.norm2 = Sequential(instanceNormND(n_dim)(n_channels)) + if emb_dim > 0: + self.norm2.append(FiLM(channels=n_channels, cond_dim=emb_dim)) + self.mlp = Sequential( + convND(n_dim)(n_channels, n_channels * mlp_ratio, 1), + GELU('tanh'), + convND(n_dim)(n_channels * mlp_ratio, n_channels, 1), + DropPath(p_droppath), + ) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the Swin Transformer layer. + + Parameters + ---------- + x + Input tensor + cond + Conditioning input + + Returns + ------- + torch.Tensor + Output tensor + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the Swin Transformer layer.""" + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x, cond=cond)) + return x + + +class ResidualSwinTransformerBlock(Module): + """Residual Swin Transformer block (RSTB). + + Combines a Swin Transformer layer with a residual connection, + as used in the SwinIR architecture. + """ + + def __init__( + self, + n_dim: int, + n_channels: int, + n_heads: int, + window_size: int, + depth: int, + emb_dim: int = 0, + p_droppath: float = 0.0, + mlp_ratio: int = 4, + ): + """Initialize ResidualSwinTransformerBlock. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + n_heads + Number of attention heads + window_size + Size of the attention window + depth + Number of Swin Transformer layers + emb_dim + Dimension of conditioning input. If 0, no FiLM conditioning is used. + p_droppath + Droppath probability for MLP. + mlp_ratio + Ratio for hidden dimension expansion in MLP + """ + super().__init__() + self.layers = Sequential( + *( + SwinTransformerLayer( + n_dim, n_channels, n_heads, window_size, emb_dim=emb_dim, p_droppath=p_droppath, mlp_ratio=mlp_ratio + ) + for _ in range(depth) + ) + ) + self.conv = convND(n_dim)(n_channels, n_channels, 3, padding=1) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the residual Swin Transformer block. + + Parameters + ---------- + x + Input tensor + cond + Conditioning input. If None, no FiLM conditioning is used. + + Returns + ------- + torch.Tensor + Output tensor + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the residual Swin Transformer block.""" + return x + self.conv(self.layers(x, cond=cond)) + + +class SwinIR(Module): + """SwinIR architecture. + + Implements the SwinIR [LZL21]_ network, which is a Swin Transformer based + image restoration network. + + References + ---------- + .. [LZL21] Liang, Jie, et al. "SwinIR: Image restoration using swin transformer." + ICCVW 2021, https://arxiv.org/pdf/2108.10257.pdf + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + n_channels_per_head: int = 16, + n_heads: int = 6, + window_size: int = 64, + n_blocks: int = 6, + n_attn_per_block: int = 6, + emb_dim: int = 0, + p_droppath: float = 0.0, + mlp_ratio: int = 4, + ): + """Initialize SwinIR. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of input channels. + n_channels_out + The number of output channels. + n_channels_per_head + The number of channels per attention head. + n_heads + The number of attention heads. + window_size + The size of the attention window. Inputs sizes must be divisible by this value. + n_blocks + The number of residual blocks. + n_attn_per_block + The number of attention layers per block. + emb_dim + The dimension of the conditioning input. If 0, no FiLM conditioning is used. + p_droppath + The droppath probability for MLP. + mlp_ratio + The ratio for hidden dimension expansion in MLP. + """ + super().__init__() + self.first = convND(n_dim)(n_channels_in, n_channels_per_head * n_heads, kernel_size=3, padding=1) + self.blocks = Sequential( + *( + ResidualSwinTransformerBlock( + n_dim, + n_channels_per_head * n_heads, + n_heads, + window_size, + n_attn_per_block, + emb_dim, + p_droppath, + mlp_ratio, + ) + for _ in range(n_blocks) + ) + ) + self.last = convND(n_dim)(n_channels_per_head * n_heads, n_channels_out, kernel_size=3, padding=1) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply SwinIR. + + Parameters + ---------- + x + Input tensor + cond + Conditioning input. If None, no FiLM conditioning is used. + + Returns + ------- + torch.Tensor + Output tensor + """ + x = self.first(x) + x = self.blocks(x, cond=cond) + x = self.last(x) + return x diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 850c95ca6..01840cfba 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -1,10 +1,12 @@ from mrpro.nn.nets.BasicCNN import BasicCNN from mrpro.nn.nets.Restormer import Restormer +from mrpro.nn.nets.SwinIR import SwinIR from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet __all__ = [ 'AttentionGatedUNet', 'BasicCNN', 'Restormer', + 'SwinIR', 'UNet', ] diff --git a/tests/nn/nets/test_swinir.py b/tests/nn/nets/test_swinir.py new file mode 100644 index 000000000..c8dbed58c --- /dev/null +++ b/tests/nn/nets/test_swinir.py @@ -0,0 +1,56 @@ +"""Tests for SwinIR network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import SwinIR + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_swinir_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the UNet.""" + swinir = SwinIR( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_heads=2, + n_channels_per_head=4, + n_blocks=2, + window_size=4, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + swinir = swinir.to(device) + if torch_compile: + swinir = cast(SwinIR, torch.compile(swinir)) + y = swinir(x) + assert y.shape == (1, 1, 16, 16) + + +def test_swinir_backward() -> None: + swinir = SwinIR( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_heads=2, + n_channels_per_head=4, + n_blocks=2, + window_size=4, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + y = swinir(x) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + for name, parameter in swinir.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From 920d1f6fc2d36c0f6d33b72916c65d0d9d033058 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 8 Feb 2026 22:57:33 +0100 Subject: [PATCH 184/205] add uformer architecture and tests ghstack-source-id: ed8ad497ac670136c60fb97295213102d6552198 ghstack-comment-id: 3865651637 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/959 --- src/mrpro/nn/nets/Uformer.py | 230 ++++++++++++++++++++++++++++++++++ src/mrpro/nn/nets/__init__.py | 2 + tests/nn/nets/test_uformer.py | 56 +++++++++ 3 files changed, 288 insertions(+) create mode 100644 src/mrpro/nn/nets/Uformer.py create mode 100644 tests/nn/nets/test_uformer.py diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py new file mode 100644 index 000000000..02f1d1cc1 --- /dev/null +++ b/src/mrpro/nn/nets/Uformer.py @@ -0,0 +1,230 @@ +"""Uformer: U-Net with window attention.""" + +from collections.abc import Sequence +from itertools import pairwise + +import torch +from torch.nn import GELU, LeakyReLU, Module + +from mrpro.nn.attention.ShiftedWindowAttention import ShiftedWindowAttention +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.DropPath import DropPath +from mrpro.nn.FiLM import FiLM +from mrpro.nn.join import Concat +from mrpro.nn.ndmodules import convND, convTransposeND, instanceNormND +from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder +from mrpro.nn.Sequential import Sequential + + +class LeWinTransformerBlock(CondMixin, Module): + """Locally-enhanced windowed attention transformer block. + + Part of the Uformer architecture. + """ + + def __init__( + self, + n_dim: int, + n_channels_per_head: int, + n_heads: int, + window_size: int = 8, + shifted: bool = False, + mlp_ratio: float = 4.0, + p_droppath: float = 0.0, + cond_dim: int = 0, + ) -> None: + """Initialize the LeWinTransformerBlock module. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_per_head + The number of features per head. + n_heads + Number of attention heads + window_size + Size of the attention window + shifted + Whether to use shifted variant of the attention + mlp_ratio + Ratio of the hidden dimension to the input dimension + p_droppath + Dropout probability for the drop path. + cond_dim + Dimension of a conditioning tensor. If `0`, no FiLM layers are added. + """ + super().__init__() + channels = n_channels_per_head * n_heads + hidden_dim = int(channels * mlp_ratio) + self.norm1 = instanceNormND(n_dim)(channels) + self.attn = ShiftedWindowAttention( + n_dim=n_dim, + n_channels_in=channels, + n_channels_out=channels, + n_heads=n_heads, + window_size=window_size, + shifted=shifted, + ) + self.norm2 = instanceNormND(n_dim)(channels) + self.ff = Sequential( + convND(n_dim)(channels, hidden_dim, 1), + GELU(), + convND(n_dim)(hidden_dim, hidden_dim, kernel_size=3, groups=hidden_dim, stride=1, padding=1), + GELU(), + convND(n_dim)(hidden_dim, channels, 1), + ) + if cond_dim > 0: + self.ff.append(FiLM(channels, cond_dim)) + self.modulator = torch.nn.Parameter(torch.empty(channels, *((window_size,) * n_dim))) + torch.nn.init.trunc_normal_(self.modulator) + self.drop_path = DropPath(droprate=p_droppath) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the transformer block. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor + + Returns + ------- + Output tensor + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the transformer block.""" + modulator = self.modulator.tile([t // s for t, s in zip(x.shape[1:], self.modulator.shape, strict=False)]) + x_mod = self.norm1(x) + modulator + x_attn = self.attn(x_mod) + x_ff = self.ff(self.norm2(x_attn), cond=cond) + return x + self.drop_path(x_ff) + + +class Uformer(UNetBase): + """Uformer: U-Net with window attention. + + Implements the Uformer network proposed in [WANG21]_ + It is SWin-Transformer/U-Net hybrid consisting of (shifted) windows attention transformer layers at different + resolution levels, extended by FiLM layers for conditioning. + + References + ---------- + .. [WANG21] Wang, Z., Cun, X., Bao, J., Zhou, W., Liu, J., & Li, H. Uformer: A general u-shaped transformer for + image restoration. CVPR 2022. https://doi.org/10.48550/arXiv.2106.03106 + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + n_channels_per_head: int = 32, + n_heads: Sequence[int] = (1, 2, 4, 8), + n_blocks: int = 2, + cond_dim: int = 0, + window_size: int = 8, + mlp_ratio: float = 4.0, + max_droppath_rate: float = 0.1, + ): + """Initialize the Uformer module. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of input channels. + n_channels_out + The number of output channels. + n_channels_per_head + The number of features per head. The number of features at a resolution level is given by + `n_channels_per_head * n_heads`. + n_heads + Number of attention heads at each resolution level. + n_blocks + The number of transformer blocks at each resolution level in the input and output path + cond_dim + Dimension of a conditioning tensor. If `0`, no FiLM layers are added. + window_size + The size of the attention windows in the (shifted) window attention layers. + mlp_ratio + Ratio of the hidden dimension to the input dimension in the feed-forward blocks + max_droppath_rate + Maximum drop path rate. As in the original implementation, the drop path rate in the input path + is linearly increased from `0` to `max_droppath_rate` with decreasing resolution. The rate in output + blocks is fixed to `max_droppath_rate`. + """ + + def blocks(n_heads: int, p_droppath: float = 0.0): + return Sequential( + *( + LeWinTransformerBlock( + n_dim=n_dim, + n_heads=n_heads, + n_channels_per_head=n_channels_per_head, + window_size=window_size, + mlp_ratio=mlp_ratio, + shifted=bool(i % 2), + p_droppath=p_droppath, + cond_dim=cond_dim, + ) + for i in range(n_blocks) + ) + ) + + first_block = torch.nn.Sequential( + convND(n_dim)(n_channels_in, n_channels_per_head * n_heads[0], kernel_size=3, stride=1, padding='same'), + LeakyReLU(), + ) + drop_path_rates = torch.linspace(0, max_droppath_rate, len(n_heads)).tolist() + encoder_blocks = [ + blocks(n_heads=n_head, p_droppath=p_droppath_input) + for n_head, p_droppath_input in zip(n_heads[:-1], drop_path_rates[:-1], strict=True) + ] + down_blocks = [ + convND(n_dim)( + n_channels_per_head * n_head_current, + n_channels_per_head * n_head_next, + kernel_size=4, + stride=2, + padding=1, + ) + for n_head_current, n_head_next in pairwise(n_heads) + ] + middle_block = blocks(n_heads=n_heads[-1], p_droppath=max_droppath_rate) + encoder = UNetEncoder( + first_block=first_block, + blocks=encoder_blocks, + down_blocks=down_blocks, + middle_block=middle_block, + ) + + decoder_blocks = [blocks(n_heads=2 * n_head, p_droppath=max_droppath_rate) for n_head in reversed(n_heads[:-1])] + concat_blocks = [Concat() for _ in range(len(decoder_blocks))] + up_blocks = [ + convTransposeND(n_dim)( + n_channels_per_head * n_heads[-1], n_channels_per_head * n_heads[-2], kernel_size=2, stride=2 + ) + ] + for n_head_current, n_head_next in pairwise(reversed(n_heads[:-1])): + up_blocks.append( + convTransposeND(n_dim)( + 2 * n_channels_per_head * n_head_current, n_channels_per_head * n_head_next, kernel_size=2, stride=2 + ) + ) + last_block = convND(n_dim)( + 2 * n_channels_per_head * n_heads[0], n_channels_out, kernel_size=3, stride=1, padding='same' + ) + decoder = UNetDecoder( + blocks=decoder_blocks, + concat_blocks=concat_blocks, + up_blocks=up_blocks, + last_block=last_block, + ) + + super().__init__(encoder=encoder, decoder=decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 01840cfba..facfe4a6a 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -2,6 +2,7 @@ from mrpro.nn.nets.Restormer import Restormer from mrpro.nn.nets.SwinIR import SwinIR from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet +from mrpro.nn.nets.Uformer import Uformer __all__ = [ 'AttentionGatedUNet', @@ -9,4 +10,5 @@ 'Restormer', 'SwinIR', 'UNet', + 'Uformer', ] diff --git a/tests/nn/nets/test_uformer.py b/tests/nn/nets/test_uformer.py new file mode 100644 index 000000000..f4315702e --- /dev/null +++ b/tests/nn/nets/test_uformer.py @@ -0,0 +1,56 @@ +"""Tests for Uformer network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import Uformer + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_uformer_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the uformer.""" + uformer = Uformer( + n_dim=2, n_channels_in=1, n_channels_out=1, n_heads=(1, 2), cond_dim=32, n_channels_per_head=8, window_size=2 + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + uformer = uformer.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + uformer = cast(Uformer, torch.compile(uformer)) + y = uformer(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + + +def test_uformer_backward() -> None: + uformer = Uformer( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_heads=(1, 2, 4), + cond_dim=32, + n_channels_per_head=8, + window_size=2, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = uformer(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in uformer.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From f9baf2a6150be6ab069409133bc9abb4f10f0455 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 8 Feb 2026 22:57:34 +0100 Subject: [PATCH 185/205] add hourglass transformer architecture and tests ghstack-source-id: 9347f2c619724e79e9ad51e5d24b0a41b17c4986 ghstack-comment-id: 3865651799 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/960 --- src/mrpro/nn/nets/HourglassTransformer.py | 147 ++++++++++++++++++++++ src/mrpro/nn/nets/__init__.py | 2 + tests/nn/nets/test_hourglass.py | 65 ++++++++++ 3 files changed, 214 insertions(+) create mode 100644 src/mrpro/nn/nets/HourglassTransformer.py create mode 100644 tests/nn/nets/test_hourglass.py diff --git a/src/mrpro/nn/nets/HourglassTransformer.py b/src/mrpro/nn/nets/HourglassTransformer.py new file mode 100644 index 000000000..ee4d9d44c --- /dev/null +++ b/src/mrpro/nn/nets/HourglassTransformer.py @@ -0,0 +1,147 @@ +"""Hourglass Transformer.""" + +from collections.abc import Sequence +from itertools import pairwise + +from torch.nn import Module + +from mrpro.nn.attention.SpatialTransformerBlock import SpatialTransformerBlock +from mrpro.nn.join import Interpolate +from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder +from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample +from mrpro.nn.Sequential import Sequential +from mrpro.operators.RearrangeOp import RearrangeOp +from mrpro.utils.to_tuple import to_tuple + + +class HourglassTransformer(UNetBase): + """Hourglass Transformer. + + A U-shaped transformer [CK]_ with neighborhood self-attention [NAT]_. + + References + ---------- + .. [CK] Crowson, Katherine, et al. "Scalable high-resolution pixel-space image synthesis with + hourglass diffusion transformers." ICML 2024, https://arxiv.org/abs/2401.11605 + .. [NAT] Hassani, A. et al. "Neighborhood Attention Transformer" CVPR, 2023, https://arxiv.org/abs/2204.07143 + + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + n_features: Sequence[int] | int, + depths: Sequence[int] | int = 3, + attention_neighborhood: Sequence[None | int] | int | None = 11, + n_heads: int | Sequence[int] = 4, + cond_dim: int = 0, + ): + """Initialize the Hourglass Transformer. + + Parameters + ---------- + n_dim + Number of (spatial)dimensions of the input data. + n_channels_in + Number of channels in the input data. + n_channels_out + Number of channels in the output data. + n_features + Number of features in each stage. + depths + Number of layers in each stage. + attention_neighborhood + Neighborhood size for the neighborhood self-attention. If None, use global attention + for that stage. + n_heads + Number of heads in each stage. + cond_dim + Number of dimensions of the conditioning tensor. + """ + n_layers_ = [ + len(x) + for x in (n_features, depths, attention_neighborhood, n_heads) + if (x is not None and not isinstance(x, int)) + ] + n_layers = n_layers_[0] + + if any(x != n_layers_[0] for x in n_layers_): + raise ValueError('All arguments must have the same length or be scalars') + + n_features_ = to_tuple(n_layers, n_features) + depths_ = to_tuple(n_layers, depths) + attention_neighborhood_ = to_tuple(n_layers, attention_neighborhood) + n_heads_ = to_tuple(n_layers, n_heads) + + move_channels_last = RearrangeOp('batch channels ... -> batch ... channels') + first_block = Sequential( + move_channels_last, + PixelUnshuffleDownsample(n_dim, n_channels_in, n_features_[0], downscale_factor=2, features_last=True), + ) + dim_group = (tuple(range(-n_dim - 1, -1)),) + encoder_blocks: list[Module] = [] + decoder_blocks: list[Module] = [] + merge_blocks: list[Module] = [] + down_blocks: list[Module] = [] + up_blocks: list[Module] = [] + for channels, depth, neighborhood, head in zip( + n_features_[:-1], + depths_[:-1], + attention_neighborhood_[:-1], + n_heads_[:-1], + strict=True, + ): + encoder_blocks.append( + SpatialTransformerBlock( + dim_groups=dim_group, + channels=channels, + depth=depth, + attention_neighborhood=neighborhood, + n_heads=head, + rope_embed_fraction=1.0, + cond_dim=cond_dim, + features_last=True, + norm='rms', + ) + ) + decoder_blocks.append( + SpatialTransformerBlock( + dim_groups=dim_group, + channels=channels, + depth=depth, + attention_neighborhood=neighborhood, + n_heads=head, + rope_embed_fraction=1.0, + cond_dim=cond_dim, + features_last=True, + norm='rms', + ) + ) + merge_blocks.append(Interpolate()) + for channels, channels_next in pairwise(n_features_): + down_blocks.append( + PixelUnshuffleDownsample(n_dim, channels, channels_next, downscale_factor=2, features_last=True) + ) + up_blocks.append(PixelShuffleUpsample(n_dim, channels_next, channels, upscale_factor=2, features_last=True)) + + last_block = Sequential( + PixelShuffleUpsample(n_dim, n_features_[-1], n_channels_out, upscale_factor=2, features_last=True), + move_channels_last.H, # moves channels back to front + ) + middle_block = SpatialTransformerBlock( + dim_groups=dim_group, + channels=n_features_[-1], + depth=depths_[-1], + attention_neighborhood=attention_neighborhood_[-1], + n_heads=n_heads_[-1], + rope_embed_fraction=1.0, + cond_dim=cond_dim, + features_last=True, + norm='rms', + ) + encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) + decoder = UNetDecoder(decoder_blocks, up_blocks, merge_blocks, last_block) + + super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index facfe4a6a..b5a6a76af 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -1,4 +1,5 @@ from mrpro.nn.nets.BasicCNN import BasicCNN +from mrpro.nn.nets.HourglassTransformer import HourglassTransformer from mrpro.nn.nets.Restormer import Restormer from mrpro.nn.nets.SwinIR import SwinIR from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet @@ -7,6 +8,7 @@ __all__ = [ 'AttentionGatedUNet', 'BasicCNN', + 'HourglassTransformer', 'Restormer', 'SwinIR', 'UNet', diff --git a/tests/nn/nets/test_hourglass.py b/tests/nn/nets/test_hourglass.py new file mode 100644 index 000000000..717e7e203 --- /dev/null +++ b/tests/nn/nets/test_hourglass.py @@ -0,0 +1,65 @@ +"""Test Hourglass Transformer""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import HourglassTransformer +from tests.conftest import minimal_torch_26 + + +@minimal_torch_26 +@torch.no_grad() +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_hourglass_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the hourglass.""" + hourglass = HourglassTransformer( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_features=64, + attention_neighborhood=(7, 7, None), + cond_dim=32, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + hourglass = hourglass.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + hourglass = cast(HourglassTransformer, torch.compile(hourglass, dynamic=False)) + y = hourglass(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + + +@minimal_torch_26 +@pytest.mark.cuda +def test_hourglass_backward() -> None: + hourglass = HourglassTransformer( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_features=64, + attention_neighborhood=(7, 7, None), + cond_dim=32, + ).cuda() + + x = torch.zeros(1, 1, 16, requires_grad=True).cuda() + cond = torch.zeros(1, 32, requires_grad=True).cuda() + y = hourglass(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in hourglass.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From 5aedd68e53d430adce12bdb188e8f1a777511f67 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 8 Feb 2026 22:57:36 +0100 Subject: [PATCH 186/205] add vae and dcvae architectures with mbconv block ghstack-source-id: ed0f93cbf218727ec024f11d56e68966e415cf12 ghstack-comment-id: 3865652021 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/961 --- src/mrpro/nn/GluMBConvResBlock.py | 118 ++++++++++++ src/mrpro/nn/nets/DCVAE.py | 302 ++++++++++++++++++++++++++++++ src/mrpro/nn/nets/VAE.py | 64 +++++++ src/mrpro/nn/nets/__init__.py | 4 + tests/nn/nets/test_dcvae.py | 82 ++++++++ 5 files changed, 570 insertions(+) create mode 100644 src/mrpro/nn/GluMBConvResBlock.py create mode 100644 src/mrpro/nn/nets/DCVAE.py create mode 100644 src/mrpro/nn/nets/VAE.py create mode 100644 tests/nn/nets/test_dcvae.py diff --git a/src/mrpro/nn/GluMBConvResBlock.py b/src/mrpro/nn/GluMBConvResBlock.py new file mode 100644 index 000000000..c17bc019f --- /dev/null +++ b/src/mrpro/nn/GluMBConvResBlock.py @@ -0,0 +1,118 @@ +"""Gateded MBConv Residual Block.""" + +import torch +from torch.nn import Identity, Module, Sequential, SiLU + +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.FiLM import FiLM +from mrpro.nn.ndmodules import convND +from mrpro.nn.RMSNorm import RMSNorm + + +class GluMBConvResBlock(CondMixin, Module): + """Gated MBConv residual block. + + Gated variant [DCAE]_ of the MBConv block [EffNet]_ with a residual connection and (optional) conditioning. + + References + ---------- + .. [DCAE] Chen et al. Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models. ICLR 2025 + https://arxiv.org/abs/2410.10733 + .. [EffNet] Tan et al. EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. ICML 2019 + https://arxiv.org/abs/1905.11946 + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + expand_ratio: int = 6, + stride: int = 1, + kernel_size: int = 3, + cond_dim: int = 0, + ): + """Initialize MBConv block. + + Parameters + ---------- + n_dim + Number of spatial dimensions. + n_channels_in + Number of input channels. + n_channels_out + Number of output channels. + expand_ratio + Expansion ratio inside the block. + stride + Stride of the depthwise convolution. + kernel_size + Kernel size of the depthwise convolution. + cond_dim + Dimension of the conditioning tensor used in a FiLM. If 0, no FiLM is used. + """ + super().__init__() + channels_mid = n_channels_in * expand_ratio + if stride == 1 and n_channels_in == n_channels_out: + self.skip: Module = Identity() + else: + self.skip = convND(n_dim)(n_channels_in, n_channels_out, kernel_size=1, stride=stride) + self.inverted_conv = Sequential( + convND(n_dim)( + n_channels_in, + channels_mid * 2, + kernel_size=1, + ), + SiLU(), + ) + self.depth_conv = Sequential( + convND(n_dim)( + channels_mid * 2, + channels_mid * 2, + kernel_size=kernel_size, + stride=stride, + padding='same', + groups=channels_mid * 2, + ), + SiLU(), + ) + self.point_conv = Sequential( + convND(n_dim)( + channels_mid, + n_channels_out, + kernel_size=1, + ), + RMSNorm(n_channels_out), + SiLU(), + ) + if cond_dim > 0: + self.film: FiLM | None = FiLM(channels_mid, cond_dim) + else: + self.film = None + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply MBConv block. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. If `None`, no conditioning is applied. + + Returns + ------- + Output tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply MBConv block.""" + h = self.inverted_conv(x) + h = self.depth_conv(h) + h, gate = torch.chunk(h, 2, dim=1) + h = h * torch.nn.functional.silu(gate) + if self.film is not None: + h = self.film(h, cond=cond) + h = self.point_conv(h) + return self.skip(x) + h diff --git a/src/mrpro/nn/nets/DCVAE.py b/src/mrpro/nn/nets/DCVAE.py new file mode 100644 index 000000000..5faba554b --- /dev/null +++ b/src/mrpro/nn/nets/DCVAE.py @@ -0,0 +1,302 @@ +"""Deep Compression Autoencoder.""" + +from collections.abc import Sequence +from typing import Literal + +import torch +from torch.nn import Module, ReLU, SiLU + +from mrpro.nn.attention.LinearSelfAttention import LinearSelfAttention +from mrpro.nn.attention.MultiHeadAttention import MultiHeadAttention +from mrpro.nn.GluMBConvResBlock import GluMBConvResBlock +from mrpro.nn.ndmodules import convND +from mrpro.nn.nets.VAE import VAE +from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample +from mrpro.nn.Residual import Residual +from mrpro.nn.RMSNorm import RMSNorm +from mrpro.nn.Sequential import Sequential + + +class CNNBlock(Residual): + """Block with two convolutions and normalization. + + As used in the DCAE [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int, + n_channels: int, + ): + """Initialize the CNNBlock. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + """ + super().__init__( + Sequential( + convND(n_dim)(n_channels, n_channels, kernel_size=3, padding=1), + SiLU(True), + convND(n_dim)(n_channels, n_channels, kernel_size=3, padding=1, bias=False), + RMSNorm(n_channels), + ) + ) + + +class EfficientViTBlock(Module): + """Efficient Vision Transformer block with optional linear attention. + + As used in the DCAE [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int, + n_channels: int, + n_heads: int, + expand_ratio: int = 4, + linear_attn: bool = False, + ): + """Initialize the EfficientViTBlock. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + n_heads + The number of attention heads. + expand_ratio + The expansion ratio of the GluMBConvResBlock. + linear_attn + Whether to use linear attention instead of softmax attention with quadratic complexity. + """ + super().__init__() + if linear_attn: + attention: Module = LinearSelfAttention(n_channels, n_channels, n_heads) + else: + attention = MultiHeadAttention(n_channels, n_channels, n_heads, features_last=False) + self.context_module = Residual(Sequential(attention, RMSNorm(n_channels))) + self.local_module = GluMBConvResBlock( + n_dim=n_dim, + n_channels_in=n_channels, + n_channels_out=n_channels, + expand_ratio=expand_ratio, + ) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the EfficientViTBlock. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Output tensor + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for EfficientViTBlock.""" + x = self.context_module(x) + x = self.local_module(x) + return x + + +class Encoder(Sequential): + """Encoder for DCAE. + + As used in the DC-Autoencoder [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int = 2, + n_channels_in: int = 3, + n_channels_out: int = 32, + block_types: Sequence[Literal['CNN', 'LinearViT', 'ViT']] = ('CNN', 'CNN', 'LinearViT', 'LinearViT', 'ViT'), + widths: Sequence[int] = (256, 512, 512, 1024, 1024), + depths: Sequence[int] = (4, 6, 2, 2, 2), + ): + """Initialize the Encoder. + + The length of the `block_types`, `widths`, and `depths` must be the same and determine + the number of stages in the encoder. Between the stages, downsampling is performed. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of channels in the input tensor, i.e. the latent space + n_channels_out + The number of channels in the output tensor, i.e. the original space + block_types + The types of blocks to use in the decoder. + widths + The widths of the blocks in the decoder, i.e. the number of channels in the blocks + depths + The depths of the blocks in the decoder, i.e. the number blocks in the stage + """ + super().__init__() + self.append(PixelUnshuffleDownsample(n_dim, n_channels_in, widths[0], downscale_factor=2, residual=False)) + if len(block_types) != len(widths) or len(block_types) != len(depths): + raise ValueError('block_types, widths, and depths must have the same length') + for block_type, width, next_width, depth in zip(block_types, widths, (*widths[1:], None), depths, strict=False): + match block_type: + case 'CNN': + stage: list[Module] = [CNNBlock(n_dim, width) for _ in range(depth)] + case 'LinearViT': + stage = [ + EfficientViTBlock(n_dim, width, max(1, width // 32), linear_attn=True) for _ in range(depth) + ] + case 'ViT': + stage = [EfficientViTBlock(n_dim, width, max(1, width // 32)) for _ in range(depth)] + case _: + raise ValueError(f'Block type {block_type} not supported') + self.append(Sequential(*stage)) + if next_width: + self.append(PixelUnshuffleDownsample(n_dim, width, next_width, downscale_factor=2, residual=True)) + self.append( + Sequential( + RMSNorm(widths[-1]), + ReLU(), + PixelUnshuffleDownsample(n_dim, widths[-1], n_channels_out, downscale_factor=1, residual=True), + ) + ) + + +class Decoder(Sequential): + """Decoder for DCAE. + + As used in the DC-Autoencoder [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int = 2, + n_channels_in: int = 32, + n_channels_out: int = 3, + block_types: Sequence[Literal['ViT', 'LinearViT', 'CNN']] = ('ViT', 'LinearViT', 'LinearViT', 'CNN', 'CNN'), + widths: Sequence[int] = (1024, 1024, 512, 512, 256), + depths: Sequence[int] = (2, 2, 2, 6, 4), + ): + """Initialize the Decoder. + + The length of the `block_types`, `widths`, and `depths` must be the same and determine + the number of stages in the decoder. Between the stages, upsampling is performed. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of channels in the input tensor, i.e. the latent space + n_channels_out + The number of channels in the output tensor, i.e. the original space + block_types + The types of blocks to use in the decoder. + widths + The widths of the blocks in the decoder, i.e. the number of channels in the blocks + depths + The depths of the blocks in the decoder, i.e. the number blocks in the stage + """ + super().__init__() + if not (len(block_types) == len(widths) == len(depths)): + raise ValueError('block_types, widths, and depths must have the same length') + self.append(PixelShuffleUpsample(n_dim, n_channels_in, widths[0], upscale_factor=1, residual=True)) + + for block_type, width, next_width, depth in zip(block_types, widths, (*widths[1:], None), depths, strict=False): + match block_type: + case 'CNN': + stage: list[Module] = [CNNBlock(n_dim, width) for _ in range(depth)] + case 'LinearViT': + stage = [ + EfficientViTBlock(n_dim, width, n_heads=max(1, width // 32), linear_attn=True) + for _ in range(depth) + ] + case 'ViT': + stage = [ + EfficientViTBlock(n_dim, width, n_heads=max(1, width // 32), linear_attn=False) + for _ in range(depth) + ] + case _: + raise ValueError(f'Block type {block_type} not supported') + self.append(Sequential(*stage)) + if next_width: + self.append(PixelShuffleUpsample(n_dim, width, next_width, upscale_factor=2, residual=True)) + + self.append( + Sequential( + RMSNorm(widths[-1]), + ReLU(), + PixelShuffleUpsample(n_dim, widths[-1], n_channels_out, upscale_factor=2), + ) + ) + + +class DCVAE(VAE): + """Variational Autoencoder based on DCAE. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int, + n_channels: int, + latent_dim: int = 32, + block_types: Sequence[Literal['CNN', 'LinearViT', 'ViT']] = ('CNN', 'CNN', 'LinearViT', 'LinearViT', 'ViT'), + widths: Sequence[int] = (256, 512, 512, 1024, 1024), + depths: Sequence[int] = (4, 6, 2, 2, 2), + ): + """Initialize the DCVAE. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + latent_dim + The number of channels in the latent space. + block_types + The types of blocks to use in the encoder and decoder. + widths + The widths of the blocks in the encoder and decoder. + depths + The depths of the blocks in the encoder and decoder. + """ + encoder = Encoder(n_dim, n_channels, latent_dim * 2, block_types, widths, depths) + decoder = Decoder(n_dim, latent_dim, n_channels, block_types[::-1], widths[::-1], depths[::-1]) + super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/VAE.py b/src/mrpro/nn/nets/VAE.py new file mode 100644 index 000000000..e0b1bfc58 --- /dev/null +++ b/src/mrpro/nn/nets/VAE.py @@ -0,0 +1,64 @@ +"""Variational Autoencoder with a Gaussian latent space.""" + +import torch +from torch.nn import Module + + +class VAE(Module): + """Basic Variational Autoencoder. + + Consists of an encoder to transform the input into a latent space and a decoder to transform the latent space back + into the original space. The encoder should return twice the number of channels as the decoder needs to reconstruct + the input: half of the channels are the mean and the other half the log variance of the latent space. + The reparameterization trick is used to sample from the latent space. + The forward pass returns the reconstructed image and the KL divergence between the latent space and the standard + normal distribution. + """ + + def __init__(self, encoder: Module, decoder: Module): + """Initialize the VAE. + + Parameters + ---------- + encoder + Encoder module. Should return double the number of channels of the latent space. + decoder + Decoder module + """ + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass of the VAE. + + Calculates the reconstruction as well as the KL divergence between the latent space and the + standard normal distribution. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + tuple of the reconstructed image and + the KL divergence between the latent space and the standard normal distribution. + """ + return self.forward(x) + + def mode(self, x: torch.Tensor) -> torch.Tensor: + """Mode of the VAE.""" + z = self.encoder(x) + mean, _ = z.chunk(2, dim=1) + return self.decoder(mean) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass of the VAE.""" + z = self.encoder(x) + mean, logvar = z.chunk(2, dim=1) + std = torch.exp(0.5 * logvar) + sample = mean + torch.randn_like(std) * std + reconstruction = self.decoder(sample) + kl = -0.5 * torch.sum(1 + logvar - mean.square() - std.square()) + return reconstruction, kl diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index b5a6a76af..87d9075f7 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -1,16 +1,20 @@ from mrpro.nn.nets.BasicCNN import BasicCNN +from mrpro.nn.nets.DCVAE import DCVAE from mrpro.nn.nets.HourglassTransformer import HourglassTransformer from mrpro.nn.nets.Restormer import Restormer from mrpro.nn.nets.SwinIR import SwinIR from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet from mrpro.nn.nets.Uformer import Uformer +from mrpro.nn.nets.VAE import VAE __all__ = [ 'AttentionGatedUNet', 'BasicCNN', + 'DCVAE', 'HourglassTransformer', 'Restormer', 'SwinIR', 'UNet', 'Uformer', + 'VAE', ] diff --git a/tests/nn/nets/test_dcvae.py b/tests/nn/nets/test_dcvae.py new file mode 100644 index 000000000..ff5371b7b --- /dev/null +++ b/tests/nn/nets/test_dcvae.py @@ -0,0 +1,82 @@ +"""Tests for DCVAE network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import DCVAE + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_dcvae_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the DCVAE.""" + dcvae = DCVAE( + n_dim=2, + n_channels=1, + latent_dim=4, + block_types=('CNN', 'LinearViT', 'ViT'), + widths=(32, 64, 32), + depths=(1, 2, 2), + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + dcvae = dcvae.to(device) + x = x.to(device) + if torch_compile: + dcvae = cast(DCVAE, torch.compile(dcvae)) + y, kl = dcvae(x) + assert y.shape == (1, 1, 16, 16) + assert kl.shape == () + latent = dcvae.encoder(x) + assert latent.shape == (1, 2 * 4, 2, 2) # 2 because of mean and logvar + + +def test_dcvae_backward_kl() -> None: + """Test the backward pass of the DCVAE wrt kl.""" + dcvae = DCVAE( + n_dim=1, + n_channels=1, + latent_dim=4, + block_types=('CNN', 'LinearViT', 'ViT'), + widths=(8, 12, 16), + depths=(2, 2, 3), + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + + _, kl = dcvae(x) + kl.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + for name, parameter in dcvae.encoder.named_parameters(): # only the encoder parameters can influence kl + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' + + +def test_dcvae_backward_y() -> None: + """Test the backward pass of the DCVAE wrt y.""" + dcvae = DCVAE( + n_dim=1, + n_channels=1, + latent_dim=4, + block_types=('CNN', 'LinearViT', 'ViT'), + widths=(8, 12, 16), + depths=(2, 2, 3), + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + + y, _ = dcvae(x) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + for name, parameter in dcvae.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From 515e1ecbd70dfe81672d6a78cc3d3f7f3b90d475 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 9 Feb 2026 13:37:05 +0100 Subject: [PATCH 187/205] fix precommit --- examples/notebooks/apply_pinqi.ipynb | 609 +++++++++++++++++++ examples/notebooks/modl.ipynb | 266 +++++++++ examples/notebooks/train_pinqi.ipynb | 840 +++++++++++++++++++++++++++ examples/scripts/apply_pinqi.py | 41 +- examples/scripts/modl.py | 54 +- examples/scripts/train_pinqi.py | 39 +- 6 files changed, 1808 insertions(+), 41 deletions(-) create mode 100644 examples/notebooks/apply_pinqi.ipynb create mode 100644 examples/notebooks/modl.ipynb create mode 100644 examples/notebooks/train_pinqi.ipynb diff --git a/examples/notebooks/apply_pinqi.ipynb b/examples/notebooks/apply_pinqi.ipynb new file mode 100644 index 000000000..0e8ed1b52 --- /dev/null +++ b/examples/notebooks/apply_pinqi.ipynb @@ -0,0 +1,609 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "82f66a37", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PTB-MR/mrpro/blob/main/examples/notebooks/apply_pinqi.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19a028c0", + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "import importlib\n", + "\n", + "if not importlib.util.find_spec('mrpro'):\n", + " %pip install mrpro[notebooks]" + ] + }, + { + "cell_type": "markdown", + "id": "08651577", + "metadata": {}, + "source": [ + "# End-to-end physics informed network for quantitative MRI (PINQI)\n", + "A recent DL approach, PINQI, approaches learned quantitative MRI by half quadratic splitting to alternate between two\n", + "subproblems. The first is a linear image reconstruction task\n", + "$$\n", + "\\underset{\\mathbf{x}}{\\min} \\frac{1}{2} \\| \\mathbf{A} \\mathbf{x} - \\mathbf{y} \\|_2^2\n", + "+ \\frac{\\lambda_\\mathbf{x}}{2} \\left\\| \\mathbf{x} - \\mathbf{x}_{\\text{reg}} \\right\\|_2^2\n", + "+ \\frac{\\lambda_{\\mathbf{q}}}{2} \\left\\| \\mathbf{q}(\\mathbf{p}) - \\mathbf{x} \\right\\|_2^2\n", + "$$\n", + "with $\\mathbf{x}$ being intermediary qualitative images, $\\lambda_{\\mathbf{x}}$ and $\\lambda_{\\mathbf{q}}$ being\n", + "regularization strengths and $\\mathbf{x}_{\\text{reg}}$ denoting an image prior for regularization.\n", + "The second, non-linear, subproblem is finding the quantitative parameters by solving\n", + "$$\n", + "\\underset{\\mathbf{p}}{\\min} \\frac{\\lambda_{\\mathbf{q}}}{2}\\left \\| \\mathbf{q}(\\vec{p}) - \\mathbf{x} \\right\\|_2^2\n", + "+ \\frac{\\lambda_{\\mathbf{p}}}{2} \\left\\| \\mathbf{p} - \\mathbf{p}_{\\text{reg}} \\right\\|_2^2.\n", + "$$\n", + "Here, $\\mathbf{p}_{\\text{reg}}$ is a prior on the parameter maps and $\\lambda_{\\mathbf{p}}$ the associated weight for\n", + "regularization.\n", + "In PINQI, a solution is found by iterating between both subproblems. In each iteration $k=1,\\ldots,T$,\n", + "the image and parameter priors are updated by U-Nets. The network parameters and the regularization strengths\n", + "are trained end-to-end.\n", + "Here, we apply a trained PINQI model to a validation set. We first define the dataset, then define the PINQI model,\n", + "before loading the model weights and applying it to the dataset." + ] + }, + { + "cell_type": "markdown", + "id": "275fcb68", + "metadata": {}, + "source": [ + "## Dataset\n", + "We base the dataset on the BrainWeb phantom (`mrpro.phantoms.brainweb.BrainwebSlices`) and simulate Cartesian random\n", + "undersampling in phase encode direction." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d60b908c", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "from collections.abc import Sequence\n", + "from copy import deepcopy\n", + "from pathlib import Path\n", + "from typing import Literal, TypedDict\n", + "\n", + "import einops\n", + "import mrpro\n", + "import torch\n", + "\n", + "# mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2927cf00", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "class BatchType(TypedDict):\n", + " \"\"\"Typehint for a batch of data.\"\"\"\n", + "\n", + " kdata: mrpro.data.KData\n", + " csm: mrpro.data.CsmData\n", + " m0: torch.Tensor\n", + " t1: torch.Tensor\n", + " mask: torch.Tensor\n", + "\n", + "\n", + "class Dataset(torch.utils.data.Dataset[BatchType]):\n", + " \"\"\"A brainweb based cartesian qMRI dataset.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " folder: Path,\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " n_images: int,\n", + " size: int,\n", + " acceleration: int,\n", + " n_coils: int,\n", + " max_noise: float,\n", + " orientation: Sequence[Literal['axial', 'coronal', 'sagittal']],\n", + " random: bool = True,\n", + " ):\n", + " \"\"\"Initialize the dataset.\"\"\"\n", + " if random:\n", + " augment = mrpro.phantoms.brainweb.augment(size=size)\n", + " else:\n", + " augment = mrpro.phantoms.brainweb.augment(\n", + " size=size,\n", + " max_random_shear=0,\n", + " max_random_rotation=0,\n", + " max_random_scaling_factor=0,\n", + " p_horizontal_flip=0,\n", + " p_vertical_flip=1.0,\n", + " )\n", + " self.phantom = mrpro.phantoms.brainweb.BrainwebSlices(\n", + " folder=folder,\n", + " what=('m0', 't1', 'mask'),\n", + " seed='index' if not random else 'random',\n", + " slice_preparation=augment,\n", + " orientation=orientation,\n", + " )\n", + " self.signalmodel = deepcopy(signalmodel)\n", + " self.encoding_matrix = mrpro.data.SpatialDimension(1, size, size)\n", + " self.fov = mrpro.data.SpatialDimension(0.01, 0.25, 0.25)\n", + " self.acceleration = acceleration\n", + " self.n_coils = n_coils\n", + " self._random = random\n", + " self.max_noise = max_noise\n", + " self._n_images = n_images\n", + "\n", + " def __len__(self) -> int:\n", + " \"\"\"Get the length of the dataset.\"\"\"\n", + " return len(self.phantom)\n", + "\n", + " def __getitem__(self, index: int):\n", + " \"\"\"Get an item from the dataset.\"\"\"\n", + " phantom = self.phantom[index]\n", + " (images,) = self.signalmodel(phantom['m0'], phantom['t1'])\n", + " seed = int(torch.randint(0, 1000000, (1,))) if self._random else index\n", + "\n", + " traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density(\n", + " encoding_matrix=self.encoding_matrix,\n", + " seed=seed,\n", + " acceleration=self.acceleration,\n", + " fwhm_ratio=1.5,\n", + " n_center=10,\n", + " n_other=(self._n_images,),\n", + " )\n", + " header = mrpro.data.KHeader(\n", + " encoding_matrix=self.encoding_matrix,\n", + " recon_matrix=self.encoding_matrix,\n", + " recon_fov=self.fov,\n", + " encoding_fov=self.fov,\n", + " )\n", + "\n", + " if isinstance(self.signalmodel, mrpro.operators.models.SaturationRecovery):\n", + " header.ti = self.signalmodel.saturation_time.tolist()\n", + " elif isinstance(self.signalmodel, mrpro.operators.models.InversionRecovery):\n", + " header.ti = self.signalmodel.ti.tolist()\n", + "\n", + " fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj)\n", + " csm = mrpro.data.CsmData(\n", + " mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix),\n", + " header,\n", + " )\n", + " images = einops.rearrange(images, 't y x -> t 1 1 y x')\n", + " (data,) = (fourier_op @ csm.as_operator())(images)\n", + " data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std()\n", + " kdata = mrpro.data.KData(header, data, traj)\n", + " return {'kdata': kdata, 'csm': csm, **phantom}" + ] + }, + { + "cell_type": "markdown", + "id": "8b611057", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "## PINQI\n", + "Next, We define the PINQI model. Here we can make use of the diffferntiable optimization operators in MRpro." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2777b221", + "metadata": {}, + "outputs": [], + "source": [ + "class PINQI(torch.nn.Module):\n", + " \"\"\"PINQI model.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp,\n", + " parameter_is_complex: Sequence[bool],\n", + " n_images: int,\n", + " n_iterations: int,\n", + " n_features_parameter_net: Sequence[int],\n", + " n_features_image_net: Sequence[int],\n", + " ):\n", + " \"\"\"Initialize the PINQI model.\"\"\"\n", + " super().__init__()\n", + " self.signalmodel = mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ signalmodel @ constraints_op\n", + " self.constraints_op = constraints_op\n", + " self._n_images = n_images\n", + " self._parameter_is_complex = parameter_is_complex\n", + " real_parameters = sum(1 for c in parameter_is_complex if c) + len(parameter_is_complex)\n", + " self.parameter_net = mrpro.nn.nets.UNet(\n", + " dim=2,\n", + " channels_in=n_images * 2,\n", + " channels_out=real_parameters,\n", + " attention_depths=(-1, -2),\n", + " n_features=n_features_parameter_net,\n", + " cond_dim=128,\n", + " )\n", + "\n", + " self.image_net = mrpro.nn.nets.UNet(\n", + " 2,\n", + " channels_in=2,\n", + " channels_out=2,\n", + " attention_depths=(),\n", + " n_features=n_features_image_net,\n", + " cond_dim=128,\n", + " )\n", + " self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3))\n", + " self.softplus = torch.nn.Softplus(beta=5)\n", + " self.iteration_embedding = torch.nn.Embedding(n_iterations + 1, 128)\n", + "\n", + " def objective_factory(\n", + " lambda_parameters: torch.Tensor,\n", + " image: torch.Tensor,\n", + " *parameter_reg: torch.Tensor,\n", + " ) -> torch.operators.Operator:\n", + " dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel\n", + " reg = mrpro.operators.ProximableFunctionalSeparableSum(\n", + " *[mrpro.operators.functionals.L2NormSquared(r) for r in parameter_reg]\n", + " )\n", + " return dc + lambda_parameters * reg\n", + "\n", + " self.nonlinear_solver = mrpro.operators.OptimizerOp(\n", + " objective_factory,\n", + " lambda _l, _i, *parameter_reg: parameter_reg,\n", + " )\n", + " # This can be done once, as the signal model is the same for all samples.\n", + "\n", + " def get_linear_solver(self, gram: mrpro.operators.LinearOperator) -> mrpro.operators.ConjugateGradientOp:\n", + " \"\"\"Set up the linear solver.\"\"\"\n", + " # This needs to be done for each sample, as the undersampling pattern and csm are different for each sample,\n", + " # thus the gram operator of the acquisition operator is different for each sample.\n", + "\n", + " def operator_factory(\n", + " lambda_image: torch.Tensor,\n", + " lambda_q: torch.Tensor,\n", + " *_,\n", + " ):\n", + " return gram + lambda_image + lambda_q\n", + "\n", + " def rhs_factory(\n", + " lambda_image: torch.Tensor,\n", + " lambda_q: torch.Tensor,\n", + " image_reg: torch.Tensor,\n", + " signal: torch.Tensor,\n", + " zero_filled_image: torch.Tensor,\n", + " ):\n", + " return (zero_filled_image + lambda_image * image_reg + lambda_q * signal,)\n", + "\n", + " return mrpro.operators.ConjugateGradientOp(\n", + " operator_factory=operator_factory,\n", + " rhs_factory=rhs_factory,\n", + " )\n", + "\n", + " def get_parameter_reg(self, image: torch.Tensor, iteration: int = 0) -> tuple[torch.Tensor, ...]:\n", + " \"\"\"Get the parameter regularization.\"\"\"\n", + " image = einops.rearrange(\n", + " torch.view_as_real(image),\n", + " 'batch t 1 1 y x complex-> batch (t complex) y x',\n", + " )\n", + " cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None]\n", + " parameters = self.parameter_net(image.contiguous(), cond=cond)\n", + " parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x')\n", + " i = 0\n", + " result = []\n", + " for is_complex in self._parameter_is_complex:\n", + " if is_complex:\n", + " result.append(torch.complex(parameters[i], parameters[i + 1]))\n", + " i += 2\n", + " else:\n", + " result.append(parameters[i])\n", + " i += 1\n", + " return tuple(result)\n", + "\n", + " def get_image_reg(self, image: torch.Tensor, iteration: int = 0) -> torch.Tensor:\n", + " \"\"\"Get the image regularization.\"\"\"\n", + " batch = image.shape[0]\n", + " image = einops.rearrange(\n", + " torch.view_as_real(image),\n", + " 'batch t 1 1 y x complex-> (batch t) complex y x',\n", + " )\n", + " cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None]\n", + " image = image + self.image_net(image.contiguous(), cond=cond)\n", + " image = einops.rearrange(image, '(batch t) complex y x-> batch t 1 1 y x complex', batch=batch)\n", + " return torch.view_as_complex(image.contiguous())\n", + "\n", + " def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> tuple[torch.Tensor, ...]:\n", + " \"\"\"Estimate the quantitative parameters.\n", + "\n", + " Parameters\n", + " ----------\n", + " kdata\n", + " The k-space data.\n", + " csm\n", + " The coil sensitivity maps.\n", + "\n", + " Returns\n", + " -------\n", + " images\n", + " The qualitative images.\n", + " parameters\n", + " The quantitative parameters.\n", + " \"\"\"\n", + " csm_op = csm.as_operator()\n", + " fourier_op = mrpro.operators.FourierOp.from_kdata(kdata)\n", + " acquisition_op = fourier_op @ csm_op\n", + " gram = acquisition_op.gram\n", + " (zero_filled_image,) = acquisition_op.H(kdata.data)\n", + " images = mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2)\n", + " parameters = self.get_parameter_reg(images, 0)\n", + " linear_solver = self.get_linear_solver(gram)\n", + "\n", + " for i, (lambda_image, lambda_q, lambda_parameter) in enumerate(self.softplus(self.lambdas_raw)):\n", + " # linear subproblem 1\n", + " image_reg = self.get_image_reg(images, i)\n", + " (signal,) = self.signalmodel(*parameters)\n", + " images = linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image)\n", + " # nonlinear subproblem 2\n", + " parameters_reg = self.get_parameter_reg(images, i + 1)\n", + " parameters = self.nonlinear_solver(lambda_parameter, images, *parameters_reg)\n", + " if self.constraints_op is not None:\n", + " # map the parameters into the constrained space\n", + " parameters = self.constraints_op(*parameters)\n", + " return parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08494939", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# As a baseline methods for comparison, we use a simple non-learned approach. We reconstruct the qualitative images at\n", + "# different saturation times using iterative SENSE. We then perform a constrained non-linear least squares regression\n", + "# using L-BFGS to obtain the parameter maps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "249b9f7f", + "metadata": {}, + "outputs": [], + "source": [ + "def baseline_solution(\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp,\n", + " parameter_is_complex: Sequence[bool],\n", + " kdata: mrpro.data.KData,\n", + " csm: mrpro.data.CsmData,\n", + ") -> tuple[torch.Tensor, ...]:\n", + " \"\"\"Compute a baseline solution using SENSE + Regression.\"\"\"\n", + " sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(kdata, csm=csm)\n", + " images = sense(kdata)\n", + " objective = mrpro.operators.functionals.L2NormSquared(images.data) @ signalmodel @ constraints_op\n", + " initial_values = tuple(\n", + " torch.zeros(\n", + " images.shape[1:],\n", + " device=images.device,\n", + " dtype=torch.complex64 if is_complex else torch.float32,\n", + " )\n", + " for is_complex in parameter_is_complex\n", + " )\n", + " solution = constraints_op(*mrpro.algorithms.optimizers.lbfgs(objective, initial_values))\n", + " return solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e88d174", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "data_folder = Path('/home/zimmer08/.cache/mrpro/brainweb')\n", + "\n", + "signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 8.0))\n", + "constraints_op = mrpro.operators.ConstraintsOp(\n", + " bounds=(\n", + " (-2, 2), # M0 in [-2, 2]\n", + " (0.01, 6.0), # T1 is constrained between 10 ms and 6 s\n", + " )\n", + ")\n", + "n_images = len(signalmodel.saturation_time)\n", + "parameter_is_complex = [True, False]\n", + "\n", + "\n", + "dataset = torch.utils.data.Subset(\n", + " Dataset(\n", + " folder=data_folder,\n", + " signalmodel=signalmodel,\n", + " n_images=n_images,\n", + " size=192,\n", + " acceleration=8,\n", + " n_coils=8,\n", + " max_noise=0.05,\n", + " orientation=('axial',),\n", + " random=False,\n", + " ),\n", + " list(range(500)),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e804c074", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "checkpoint = torch.load('./examples/scripts/last.ckpt', map_location='cpu')\n", + "hyper_parameters = checkpoint['hyper_parameters']\n", + "\n", + "\n", + "pinqi = PINQI(\n", + " signalmodel=signalmodel,\n", + " constraints_op=constraints_op,\n", + " parameter_is_complex=parameter_is_complex,\n", + " n_images=n_images,\n", + " n_iterations=hyper_parameters['n_iterations'],\n", + " n_features_parameter_net=hyper_parameters['n_features_parameter_net'],\n", + " n_features_image_net=hyper_parameters['n_features_image_net'],\n", + ")\n", + "state_dict = {\n", + " k.replace('pinqi.', '').replace('_orig_mod.', ''): v\n", + " for k, v in checkpoint['state_dict'].items()\n", + " if 'baseline' not in k\n", + "}\n", + "pinqi.load_state_dict(state_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ae2c56b", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "batch = dataset[40]\n", + "csm, kdata = batch['csm'], batch['kdata']\n", + "\n", + "if torch.cuda.is_available():\n", + " pinqi, csm, kdata = pinqi.cuda(), csm.cuda(), kdata.cuda()\n", + "images, parameters = pinqi(kdata[None], csm[None])\n", + "with torch.no_grad():\n", + " predicted_m0, predicted_t1 = (p.cpu().detach().squeeze() for p in parameters[-1])\n", + "baseline_m0, baseline_t1 = baseline_solution(signalmodel, constraints_op, parameter_is_complex, kdata, csm)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1f460d1", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "(ssim_t1,) = mrpro.operators.functionals.SSIM(batch['t1'][None], batch['mask'][None])(predicted_t1[None])\n", + "(mse_t1,) = mrpro.operators.functionals.MSE(batch['t1'], batch['mask'])(predicted_t1)\n", + "\n", + "(mse_baseline,) = mrpro.operators.functionals.MSE(batch['t1'], batch['mask'])(baseline_t1)\n", + "nrmse_t1 = torch.sqrt(mse_t1) / batch['t1'][batch['mask']].max()\n", + "(ssim_baseline,) = mrpro.operators.functionals.SSIM(batch['t1'][None], batch['mask'][None])(baseline_t1[None])\n", + "nrmse_baseline = torch.sqrt(mse_baseline) / batch['t1'][batch['mask']].max()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e89d5e0", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "from cmap import Colormap\n", + "\n", + "cmap = Colormap('lipari').to_matplotlib()\n", + "\n", + "print(f'SSIM: {ssim_baseline.item():.4f}, NRMSE: {nrmse_baseline.item():.4f}')\n", + "print(f'SSIM: {ssim_t1.item():.4f}, NRMSE: {nrmse_t1.item():.4f}')\n", + "\n", + "\n", + "fig, ax = plt.subplots(\n", + " 1,\n", + " 5,\n", + " gridspec_kw={\n", + " 'width_ratios': [1, 1, 1, 0.28, 0.075],\n", + " 'wspace': -0.25,\n", + " },\n", + " figsize=(6.5, 2.5),\n", + ")\n", + "baseline_t1 = baseline_t1.squeeze()\n", + "baseline_t1[~batch['mask']] = torch.nan\n", + "ax[0].imshow(baseline_t1, vmin=0, vmax=2, cmap=cmap)\n", + "ax[0].axis('off')\n", + "ax[0].set_title('SENSE + NLS')\n", + "ax[0].text(\n", + " 0.5,\n", + " -0.00,\n", + " f'SSIM: {ssim_baseline.item():.2f}',\n", + " color='black',\n", + " horizontalalignment='center',\n", + " verticalalignment='top',\n", + " transform=ax[0].transAxes,\n", + " size=11,\n", + ")\n", + "predicted_t1 = predicted_t1.squeeze()\n", + "predicted_t1[~batch['mask']] = torch.nan\n", + "ax[1].imshow(predicted_t1, vmin=0, vmax=2, cmap=cmap)\n", + "ax[1].axis('off')\n", + "ax[1].set_title('PINQI')\n", + "ax[1].text(\n", + " 0.5,\n", + " -0.0,\n", + " f'SSIM: {ssim_t1.item():.2f}',\n", + " color='black',\n", + " horizontalalignment='center',\n", + " verticalalignment='top',\n", + " transform=ax[1].transAxes,\n", + " size=11,\n", + ")\n", + "\n", + "target_t1 = batch['t1'].squeeze()\n", + "target_t1[~batch['mask']] = torch.nan\n", + "im = ax[2].imshow(target_t1, vmin=0, vmax=2, cmap=cmap)\n", + "ax[2].axis('off')\n", + "ax[2].set_title(\n", + " 'Ground Truth',\n", + ")\n", + "ax[-2].axis('off')\n", + "fig.tight_layout()\n", + "plt.colorbar(im, cax=ax[-1], label='$T_1$ (s)')\n", + "fig.savefig(\n", + " '/home/zimmer08/code/mrpro/examples/scripts/pinqi_t1_3.pdf',\n", + " bbox_inches='tight',\n", + " pad_inches=0,\n", + ")" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "jupytext": { + "cell_metadata_filter": "mystnb,tags,-all" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/modl.ipynb b/examples/notebooks/modl.ipynb new file mode 100644 index 000000000..e385e03a2 --- /dev/null +++ b/examples/notebooks/modl.ipynb @@ -0,0 +1,266 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ca8663d2", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PTB-MR/mrpro/blob/main/examples/notebooks/modl.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc3bb31f", + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "import importlib\n", + "\n", + "if not importlib.util.find_spec('mrpro'):\n", + " %pip install mrpro[notebooks]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ed7a66e", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "from collections.abc import Sequence\n", + "from pathlib import Path\n", + "from typing import TypedDict\n", + "\n", + "import matplotlib.axes\n", + "import matplotlib.pyplot as plt\n", + "import mrpro\n", + "import torch\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "class BatchType(TypedDict):\n", + " \"\"\"A single Batch.\"\"\"\n", + "\n", + " data: mrpro.data.KData\n", + " target: mrpro.data.IData\n", + " csm: mrpro.data.CsmData\n", + "\n", + "\n", + "class AcceleratedFastMRI(torch.utils.data.Dataset):\n", + " \"\"\"An undersampled FastMRI Dataset.\"\"\"\n", + "\n", + " def __init__(self, path: Path, acceleration: float = 12, noise_level: float = 0.1):\n", + " \"\"\"Create an undersampled FastMRI Dataset.\n", + "\n", + " Parameters\n", + " ----------\n", + " path\n", + " Path to the FastMRI dataset.\n", + " acceleration\n", + " Undersampling factor; higher values mean more acceleration. Default is 12.\n", + " noise_level\n", + " Level of additive Gaussian noise applied to the FastMRI dataset. Default is 0.1.\n", + " \"\"\"\n", + " self.acceleration = acceleration\n", + " files = list(path.glob('*AXT1*'))\n", + " self.dataset = mrpro.phantoms.FastMRIKDataDataset(files)\n", + " self.noise_level = noise_level\n", + "\n", + " def __len__(self):\n", + " \"\"\"Get length of the dataset.\"\"\"\n", + " return len(self.dataset)\n", + "\n", + " def __getitem__(self, index: int) -> BatchType:\n", + " \"\"\"Get a single batch of data.\n", + "\n", + " Parameters\n", + " ----------\n", + " index\n", + " Index of the batch.\n", + "\n", + " Returns\n", + " -------\n", + " A single batch of data with keys 'data', 'target', and 'csm'.and\n", + " \"\"\"\n", + " data = self.dataset[index]\n", + " data = data.remove_readout_os()\n", + " data.data /= data.data.std()\n", + " reconstruction = mrpro.algorithms.reconstruction.DirectReconstruction(\n", + " data, csm=lambda data: mrpro.data.CsmData.from_idata_inati(data, downsampled_size=64)\n", + " )\n", + " csm = reconstruction.csm\n", + " target = reconstruction(data)\n", + "\n", + " n = max(data.data.shape[-2:])\n", + " distance = (torch.linspace(-1, 1, n)[:, None] ** 2 + torch.linspace(-1, 1, n) ** 2).sqrt()\n", + " random = 0.1 / (distance + 0.1) + torch.rand_like(distance)\n", + " threshold = torch.kthvalue(random.ravel(), int(n**2 * (1 - 1 / self.acceleration))).values\n", + " undersampling_mask = mrpro.utils.pad_or_crop(random > threshold, data.data.shape[-2:])\n", + " data_undersampled = data[..., undersampling_mask].rearrange('k ... 1 -> ... k')\n", + "\n", + " noise = mrpro.utils.RandomGenerator(seed=index).randn_like(data_undersampled.data)\n", + " data_undersampled.data += self.noise_level * noise\n", + "\n", + " assert csm is not None # for mypy\n", + " return {'data': data_undersampled, 'target': target, 'csm': csm}\n", + "\n", + "\n", + "class MODL(torch.nn.Module):\n", + " \"\"\"MODL network.\"\"\"\n", + "\n", + " def __init__(self, iterations: int = 8, n_features: Sequence[int] = (64, 64, 64, 64)):\n", + " \"\"\"Initialize MODL network.\n", + "\n", + " Parameters\n", + " ----------\n", + " iterations\n", + " Number of iterations.\n", + " n_features\n", + " Number of features in the network.\n", + " \"\"\"\n", + " super().__init__()\n", + " cnn = mrpro.nn.nets.BasicCNN(\n", + " dim=2,\n", + " channels_in=2,\n", + " channels_out=2,\n", + " n_features=n_features,\n", + " batch_norm=True,\n", + " )\n", + " self.network = mrpro.nn.Residual(mrpro.nn.ComplexAsChannel(mrpro.nn.PermutedBlock((-1, -2), cnn)))\n", + " self.network = torch.compile(self.network, dynamic=True, fullgraph=True)\n", + " self.iterations = iterations\n", + " self.regularization_weights = torch.nn.Parameter(0.2 * torch.ones(iterations))\n", + "\n", + " def __call__(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData:\n", + " \"\"\"Apply MODL network.\n", + "\n", + " Parameters\n", + " ----------\n", + " kdata\n", + " The k-space data.\n", + " csm\n", + " The coil sensitivity maps.\n", + "\n", + " Returns\n", + " -------\n", + " The reconstructed image.\n", + " \"\"\"\n", + " return super().__call__(kdata, csm)\n", + "\n", + " def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData:\n", + " \"\"\"Apply the MODL network.\"\"\"\n", + " fourier_op = mrpro.operators.FourierOp.from_kdata(kdata)\n", + " acquisition_op = fourier_op @ csm.as_operator()\n", + " (zero_filled_image,) = acquisition_op.H(kdata.data)\n", + " gram = acquisition_op.gram\n", + " data_consistency_op = mrpro.operators.ConjugateGradientOp(\n", + " operator_factory=lambda _image, weight: gram + weight,\n", + " rhs_factory=lambda image, weight: zero_filled_image + weight * image,\n", + " )\n", + "\n", + " (image,) = mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=5)\n", + " for iteration in range(self.iterations):\n", + " regularization = self.network(image)\n", + " (image,) = data_consistency_op(regularization, self.regularization_weights[iteration])\n", + "\n", + " return mrpro.data.IData(image, header=mrpro.data.IHeader.from_kheader(kdata.header))\n", + "\n", + "\n", + "def plot(batch: BatchType, prediction: mrpro.data.IData, step: int) -> None:\n", + " \"\"\"Plot the direct, sense, and modl reconstructions.\"\"\"\n", + " target = batch['target'].rss().cpu().squeeze()\n", + " direct = mrpro.algorithms.reconstruction.DirectReconstruction(batch['data'], csm=batch['csm'])(batch['data'])\n", + " direct = direct.rss().cpu().squeeze()\n", + " direct *= target.std() / direct.std()\n", + " sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(batch['data'], csm=batch['csm'])(batch['data'])\n", + " sense = sense.rss().cpu().squeeze()\n", + " prediction_ = prediction.rss().cpu().squeeze().detach()\n", + "\n", + " ssim = mrpro.operators.functionals.SSIM(mrpro.utils.pad_or_crop(target[None], (320, 320)))\n", + "\n", + " def show(ax: matplotlib.axes.Axes, data: torch.Tensor, label: str):\n", + " data = mrpro.utils.pad_or_crop(data, (320, 320))\n", + " ax.imshow(data, vmin=0, vmax=target.max().item(), cmap='gray')\n", + " if label != 'Ground Truth':\n", + " (ssim_value,) = ssim(data[None])\n", + " ax.text(\n", + " 0.98,\n", + " 0.1,\n", + " f'SSIM: {ssim_value.item():.2f}',\n", + " color='white',\n", + " horizontalalignment='right',\n", + " verticalalignment='top',\n", + " transform=ax.transAxes,\n", + " )\n", + " ax.set_title(label)\n", + " ax.set_axis_off()\n", + "\n", + " fig, ax = plt.subplots(1, 4)\n", + " show(ax[0], direct, 'Direct')\n", + " show(ax[1], sense, 'CG-SENSE')\n", + " show(ax[2], prediction_, 'MODL')\n", + " show(ax[3], target, 'Ground Truth')\n", + " fig.tight_layout()\n", + " fig.savefig(f'modl_{step}.pdf', bbox_inches='tight', pad_inches=0)\n", + "\n", + "\n", + "# %%.\n", + "path = Path('/echo/allgemein/resources/publicTrainingData/fastmri/brain_multicoil_train/')\n", + "dataset = AcceleratedFastMRI(path)\n", + "dataloader = torch.utils.data.DataLoader(dataset, num_workers=16, shuffle=True, collate_fn=lambda batch: batch[0])\n", + "modl = MODL().cuda()\n", + "optimizer = torch.optim.Adam(modl.parameters(), lr=1e-3)\n", + "pbar = tqdm(dataloader)\n", + "for i, batch in enumerate(pbar):\n", + " optimizer.zero_grad()\n", + " kdata, csm, target = (batch['data'].cuda(), batch['csm'].cuda(), batch['target'].cuda())\n", + " prediction = modl(kdata, csm)\n", + " objective = 0.5 * mrpro.operators.functionals.MSE(target.data) - mrpro.operators.functionals.SSIM(target.data)\n", + " (loss,) = objective(prediction.data)\n", + " loss.backward()\n", + " torch.nn.utils.clip_grad_norm_(modl.parameters(), 5.0)\n", + " optimizer.step()\n", + "\n", + " pbar.set_postfix(loss=loss.item())\n", + " if i % 200 == 0:\n", + " plot(batch, prediction, i)\n", + " print(modl.regularization_weights)\n", + " state = {'modl': modl.state_dict(), 'optimizer': optimizer.state_dict()}\n", + " torch.save(state, f'modl_{i}.pt')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "601b0ff9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "jupytext": { + "cell_metadata_filter": "mystnb,tags,-all" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/train_pinqi.ipynb b/examples/notebooks/train_pinqi.ipynb new file mode 100644 index 000000000..9a012816c --- /dev/null +++ b/examples/notebooks/train_pinqi.ipynb @@ -0,0 +1,840 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a79be4b8", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PTB-MR/mrpro/blob/main/examples/notebooks/train_pinqi.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9a50a13", + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "import importlib\n", + "\n", + "if not importlib.util.find_spec('mrpro'):\n", + " %pip install mrpro[notebooks]" + ] + }, + { + "cell_type": "markdown", + "id": "7d5f4c31", + "metadata": {}, + "source": [ + "ruff: noqa: D102, ANN201" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fced8aa", + "metadata": {}, + "outputs": [], + "source": [ + "import collections\n", + "from collections.abc import Sequence\n", + "from copy import deepcopy\n", + "from pathlib import Path\n", + "from typing import Any, Literal, TypedDict, cast" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec1c97c4", + "metadata": {}, + "outputs": [], + "source": [ + "import einops\n", + "import matplotlib.pyplot as plt\n", + "import mrpro\n", + "import numpy as np\n", + "import pytorch_lightning as pl # type:ignore[import-not-found]\n", + "import torch\n", + "import torch.utils.data._utils\n", + "from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint # type:ignore[import-not-found]\n", + "from pytorch_lightning.loggers import NeptuneLogger # type:ignore[import-not-found]\n", + "from pytorch_lightning.strategies import DDPStrategy # type:ignore[import-not-found]" + ] + }, + { + "cell_type": "markdown", + "id": "06c40aff", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61c14780", + "metadata": {}, + "outputs": [], + "source": [ + "class BatchType(TypedDict):\n", + " \"\"\"Typehint for a batch of data.\"\"\"\n", + "\n", + " kdata: mrpro.data.KData\n", + " csm: mrpro.data.CsmData\n", + " m0: torch.Tensor\n", + " t1: torch.Tensor\n", + " mask: torch.Tensor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fbeed14e", + "metadata": {}, + "outputs": [], + "source": [ + "class Dataset(torch.utils.data.Dataset):\n", + " \"\"\"A brainweb based cartesian qMRI dataset.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " folder: Path,\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " n_images: int,\n", + " size: int,\n", + " acceleration: int,\n", + " n_coils: int,\n", + " max_noise: float,\n", + " orientation: Sequence[Literal['axial', 'coronal', 'sagittal']],\n", + " random: bool = True,\n", + " ):\n", + " \"\"\"Initialize the dataset.\"\"\"\n", + " if random:\n", + " augment = mrpro.phantoms.brainweb.augment(size=size)\n", + " else:\n", + " augment = mrpro.phantoms.brainweb.augment(\n", + " size=size,\n", + " max_random_shear=0,\n", + " max_random_rotation=0,\n", + " max_random_scaling_factor=0,\n", + " p_horizontal_flip=0,\n", + " p_vertical_flip=1.0,\n", + " )\n", + " self.phantom = mrpro.phantoms.brainweb.BrainwebSlices(\n", + " folder=folder,\n", + " what=('m0', 't1', 'mask'),\n", + " seed='index' if not random else 'random',\n", + " slice_preparation=augment,\n", + " orientation=orientation,\n", + " )\n", + " self.signalmodel = signalmodel\n", + " self.encoding_matrix = mrpro.data.SpatialDimension(1, size, size)\n", + " self.fov = mrpro.data.SpatialDimension(0.01, 0.25, 0.25)\n", + " self.acceleration = acceleration\n", + " self.n_coils = n_coils\n", + " self._random = random\n", + " self.max_noise = max_noise\n", + " self._n_images = n_images\n", + "\n", + " def __len__(self) -> int:\n", + " \"\"\"Get the length of the dataset.\"\"\"\n", + " return len(self.phantom)\n", + "\n", + " def __getitem__(self, index: int):\n", + " \"\"\"Get an item from the dataset.\"\"\"\n", + " phantom = self.phantom[index]\n", + " (images,) = self.signalmodel(phantom['m0'], phantom['t1'])\n", + " seed = int(torch.randint(0, 1000000, (1,))) if self._random else index\n", + "\n", + " traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density(\n", + " encoding_matrix=self.encoding_matrix,\n", + " seed=seed,\n", + " acceleration=self.acceleration,\n", + " fwhm_ratio=1.5,\n", + " n_center=10,\n", + " n_other=(self._n_images,),\n", + " )\n", + " header = mrpro.data.KHeader(\n", + " encoding_matrix=self.encoding_matrix,\n", + " recon_matrix=self.encoding_matrix,\n", + " recon_fov=self.fov,\n", + " encoding_fov=self.fov,\n", + " )\n", + "\n", + " if isinstance(self.signalmodel, mrpro.operators.models.SaturationRecovery):\n", + " header.ti = self.signalmodel.saturation_time.tolist()\n", + " elif isinstance(self.signalmodel, mrpro.operators.models.InversionRecovery):\n", + " header.ti = self.signalmodel.ti.tolist()\n", + "\n", + " fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj)\n", + " csm = mrpro.data.CsmData(\n", + " mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix),\n", + " header,\n", + " )\n", + " images = einops.rearrange(images, 't y x -> t 1 1 y x')\n", + " (data,) = (fourier_op @ csm.as_operator())(images)\n", + " data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std()\n", + " kdata = mrpro.data.KData(header, data, traj)\n", + " return {'kdata': kdata, 'csm': csm, **phantom}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed276d85", + "metadata": {}, + "outputs": [], + "source": [ + "def collate_fn(batch: Any): # noqa: ANN401\n", + " \"\"\"Join dataclasses to a batch.\"\"\"\n", + " return torch.utils.data._utils.collate.collate(\n", + " batch,\n", + " collate_fn_map={\n", + " mrpro.data.Dataclass: lambda batch, *, _collate_fn_map: batch[0].stack(*batch[1:]),\n", + " **torch.utils.data._utils.collate.default_collate_fn_map,\n", + " },\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "446bad6e", + "metadata": {}, + "outputs": [], + "source": [ + "class PINQI(torch.nn.Module):\n", + " \"\"\"PINQI model.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp,\n", + " parameter_is_complex: Sequence[bool],\n", + " n_images: int,\n", + " n_iterations: int,\n", + " n_features_parameter_net: Sequence[int],\n", + " n_features_image_net: Sequence[int],\n", + " ):\n", + " \"\"\"Initialize the PINQI model.\"\"\"\n", + " super().__init__()\n", + " self.signalmodel = mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ signalmodel @ constraints_op\n", + " self.constraints_op = constraints_op\n", + " self._n_images = n_images\n", + " self._parameter_is_complex = parameter_is_complex\n", + " real_parameters = sum(1 for c in parameter_is_complex if c) + len(parameter_is_complex)\n", + " self.parameter_net = torch.compile(\n", + " mrpro.nn.nets.UNet(\n", + " n_dim=2,\n", + " n_channels_in=n_images * 2,\n", + " n_channels_out=real_parameters,\n", + " attention_depths=(-1, -2),\n", + " n_features=n_features_parameter_net,\n", + " cond_dim=128,\n", + " ),\n", + " dynamic=False,\n", + " fullgraph=True,\n", + " )\n", + " self.image_net = torch.compile(\n", + " mrpro.nn.nets.UNet(\n", + " n_dim=2,\n", + " n_channels_in=2,\n", + " n_channels_out=2,\n", + " attention_depths=(),\n", + " n_features=n_features_image_net,\n", + " cond_dim=128,\n", + " ),\n", + " dynamic=False,\n", + " fullgraph=True,\n", + " )\n", + " self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3))\n", + " self.softplus = torch.nn.Softplus(beta=5)\n", + " self.iteration_embedding = torch.nn.Embedding(n_iterations + 1, 128)\n", + "\n", + " def objective_factory(\n", + " lambda_parameters: torch.Tensor,\n", + " image: torch.Tensor,\n", + " *parameter_reg: torch.Tensor,\n", + " ):\n", + " dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel\n", + " reg = mrpro.operators.ProximableFunctionalSeparableSum(\n", + " *[mrpro.operators.functionals.L2NormSquared(r) for r in parameter_reg]\n", + " )\n", + " return dc + lambda_parameters * reg\n", + "\n", + " self.nonlinear_solver = mrpro.operators.OptimizerOp(\n", + " objective_factory,\n", + " lambda _l, _i, *parameter_reg: parameter_reg,\n", + " )\n", + "\n", + " def get_linear_solver(self, gram: mrpro.operators.LinearOperator):\n", + " def operator_factory(\n", + " lambda_image: torch.Tensor,\n", + " lambda_q: torch.Tensor,\n", + " *_,\n", + " ):\n", + " return gram + lambda_image + lambda_q\n", + "\n", + " def rhs_factory(\n", + " lambda_image: torch.Tensor,\n", + " lambda_q: torch.Tensor,\n", + " image_reg: torch.Tensor,\n", + " signal: torch.Tensor,\n", + " zero_filled_image: torch.Tensor,\n", + " ):\n", + " return (zero_filled_image + lambda_image * image_reg + lambda_q * signal,)\n", + "\n", + " return mrpro.operators.ConjugateGradientOp(\n", + " operator_factory=operator_factory,\n", + " rhs_factory=rhs_factory,\n", + " )\n", + "\n", + " def get_parameter_reg(self, image: torch.Tensor, iteration: int = 0) -> tuple[torch.Tensor, ...]:\n", + " image = einops.rearrange(\n", + " torch.view_as_real(image),\n", + " 'batch t 1 1 y x complex-> batch (t complex) y x',\n", + " )\n", + " cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None]\n", + " parameters = self.parameter_net(image.contiguous(), cond=cond)\n", + " parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x')\n", + " i = 0\n", + " result = []\n", + " for is_complex in self._parameter_is_complex:\n", + " if is_complex:\n", + " result.append(torch.complex(parameters[i], parameters[i + 1]))\n", + " i += 2\n", + " else:\n", + " result.append(parameters[i])\n", + " i += 1\n", + " return tuple(result)\n", + "\n", + " def get_image_reg(self, image: torch.Tensor, iteration: int = 0) -> torch.Tensor:\n", + " batch = image.shape[0]\n", + " image = einops.rearrange(\n", + " torch.view_as_real(image),\n", + " 'batch t 1 1 y x complex-> (batch t) complex y x',\n", + " )\n", + " cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None]\n", + " image = image + self.image_net(image.contiguous(), cond=cond)\n", + " image = einops.rearrange(image, '(batch t) complex y x-> batch t 1 1 y x complex', batch=batch)\n", + " return torch.view_as_complex(image.contiguous())\n", + "\n", + " def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData):\n", + " csm_op = csm.as_operator()\n", + " fourier_op = mrpro.operators.FourierOp.from_kdata(kdata)\n", + " acquisition_op = fourier_op @ csm_op\n", + " gram = acquisition_op.gram\n", + " (zero_filled_image,) = acquisition_op.H(kdata.data)\n", + " images = list(mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2))\n", + " parameters = [self.get_parameter_reg(images[-1], 0)]\n", + " linear_solver = self.get_linear_solver(gram)\n", + "\n", + " for i, (lambda_image, lambda_q, lambda_parameter) in enumerate(self.softplus(self.lambdas_raw)):\n", + " image_reg = self.get_image_reg(images[-1], i + 1)\n", + " (signal,) = self.signalmodel(*parameters[-1])\n", + " images.extend(linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image))\n", + " parameters_reg = self.get_parameter_reg(images[-1], i + 1)\n", + " parameters.append(self.nonlinear_solver(lambda_parameter, images[-1], *parameters_reg))\n", + " if self.constraints_op is not None:\n", + " parameters = [self.constraints_op(*p) for p in parameters]\n", + " return images, parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "831f3559", + "metadata": {}, + "outputs": [], + "source": [ + "class DataModule(pl.LightningDataModule):\n", + " \"\"\"Data module for training the PINQI model.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " folder: Path,\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " n_images: int,\n", + " size: int = 192,\n", + " acceleration: int = 10,\n", + " n_coils: int = 8,\n", + " max_noise: float = 0.1,\n", + " orientation_train: Sequence[Literal['axial', 'coronal', 'sagittal']] = (\n", + " 'axial',\n", + " 'coronal',\n", + " 'sagittal',\n", + " ),\n", + " orientation_val: Sequence[Literal['axial', 'coronal', 'sagittal']] = ('axial',),\n", + " batch_size: int = 16,\n", + " num_workers: int = 4,\n", + " ):\n", + " \"\"\"Initialize the data module.\"\"\"\n", + " super().__init__()\n", + " self.save_hyperparameters(ignore=['signalmodel', 'folder', 'num_workers'])\n", + " self.batch_size = batch_size\n", + " self.num_workers = num_workers\n", + " self.train_dataset = Dataset(\n", + " folder=folder,\n", + " signalmodel=signalmodel,\n", + " n_images=n_images,\n", + " size=size,\n", + " acceleration=acceleration,\n", + " n_coils=n_coils,\n", + " max_noise=max_noise,\n", + " orientation=orientation_train,\n", + " random=True,\n", + " )\n", + " self.val_dataset = torch.utils.data.Subset(\n", + " Dataset(\n", + " folder=folder,\n", + " signalmodel=signalmodel,\n", + " n_images=n_images,\n", + " size=size,\n", + " acceleration=acceleration,\n", + " n_coils=n_coils,\n", + " max_noise=max_noise,\n", + " orientation=orientation_val,\n", + " random=False,\n", + " ),\n", + " list(range(30, 500, 20)),\n", + " )\n", + "\n", + " def train_dataloader(self):\n", + " return torch.utils.data.DataLoader(\n", + " self.train_dataset,\n", + " batch_size=self.batch_size,\n", + " shuffle=True,\n", + " num_workers=self.num_workers,\n", + " pin_memory=False,\n", + " persistent_workers=self.num_workers > 0,\n", + " collate_fn=collate_fn,\n", + " worker_init_fn=lambda *_: torch.set_num_threads(1),\n", + " )\n", + "\n", + " def val_dataloader(self):\n", + " return torch.utils.data.DataLoader(\n", + " self.val_dataset,\n", + " batch_size=1,\n", + " shuffle=False,\n", + " num_workers=self.num_workers,\n", + " pin_memory=False,\n", + " persistent_workers=self.num_workers > 0,\n", + " collate_fn=collate_fn,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71c8de37", + "metadata": {}, + "outputs": [], + "source": [ + "class PinqiModule(pl.LightningModule):\n", + " \"\"\"Module for training the PINQI model.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " constraints_op: mrpro.operators.ConstraintsOp,\n", + " parameter_is_complex: Sequence[bool],\n", + " n_images: int,\n", + " n_iterations: int = 4,\n", + " n_features_parameter_net: Sequence[int] = (64, 128, 192, 256),\n", + " n_features_image_net: Sequence[int] = (32, 48, 64, 96),\n", + " lr: float = 3e-4, # noqa: ARG002\n", + " weight_decay: float = 1e-3, # noqa: ARG002\n", + " loss_weights: Sequence[float] = (0.2, 0.1, 0.1, 0.1, 0.8),\n", + " ):\n", + " \"\"\"Initialize the PINQI module.\"\"\"\n", + " super().__init__()\n", + " self.save_hyperparameters(ignore=['signalmodel', 'constraints_op'])\n", + " if len(loss_weights) != n_iterations + 1:\n", + " raise ValueError(f'loss_weights must be of length {n_iterations + 1} for {n_iterations} iterations')\n", + " signalmodel = deepcopy(signalmodel)\n", + " constraints_op = deepcopy(constraints_op)\n", + " self.pinqi = PINQI(\n", + " signalmodel=signalmodel,\n", + " constraints_op=constraints_op,\n", + " parameter_is_complex=parameter_is_complex,\n", + " n_images=n_images,\n", + " n_iterations=n_iterations,\n", + " n_features_parameter_net=n_features_parameter_net,\n", + " n_features_image_net=n_features_image_net,\n", + " )\n", + "\n", + " self.validation_step_outputs: dict[str, list] = collections.defaultdict(list)\n", + " self.baseline = Baseline(signalmodel, constraints_op, parameter_is_complex)\n", + "\n", + " def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData):\n", + " \"\"\"Apply the PINQI model to the data.\"\"\"\n", + " return self.pinqi(kdata, csm)\n", + "\n", + " def loss(self, predictions: Sequence[torch.Tensor], batch: BatchType) -> torch.Tensor:\n", + " \"\"\"Compute the loss.\"\"\"\n", + " loss = torch.tensor(0.0, device=self.device)\n", + " target_m0, target_t1, mask = map(torch.squeeze, (batch['m0'], batch['t1'], batch['mask']))\n", + " for prediction, weight in zip(predictions, self.hparams.loss_weights, strict=False):\n", + " prediction_m0, prediction_t1 = map(torch.squeeze, prediction)\n", + " loss_t1 = torch.nn.functional.mse_loss(prediction_t1[mask], target_t1[mask])\n", + " loss_m0 = torch.nn.functional.mse_loss(\n", + " torch.view_as_real(prediction_m0[mask]),\n", + " torch.view_as_real(target_m0[mask]),\n", + " )\n", + " loss_outside = prediction_m0[~mask].abs().mean()\n", + " loss = loss + weight * (loss_t1 + 0.5 * loss_m0 + 0.1 * loss_outside)\n", + " return loss\n", + "\n", + " def training_step(self, batch: BatchType, _batch_idx: int) -> torch.Tensor:\n", + " \"\"\"Training step.\"\"\"\n", + " _images, parameters = self(batch['kdata'], batch['csm'])\n", + " loss = self.loss(parameters, batch)\n", + " self.log(\n", + " 'train/loss',\n", + " loss,\n", + " on_step=True,\n", + " on_epoch=True,\n", + " prog_bar=True,\n", + " sync_dist=True,\n", + " batch_size=len(batch['mask']),\n", + " )\n", + " return loss\n", + "\n", + " def validation_step(self, batch: BatchType, batch_idx: int) -> None:\n", + " \"\"\"Validate.\n", + "\n", + " Needs to be adapted for other signal models than Saturation Recovery.\n", + " \"\"\"\n", + " _images, parameters = self(batch['kdata'], batch['csm'])\n", + " loss = self.loss(parameters, batch)\n", + "\n", + " pred_m0, pred_t1 = parameters[-1]\n", + " target_t1, target_m0 = batch['t1'], batch['m0']\n", + " mask = batch['mask']\n", + " batch_size = len(batch['mask'])\n", + " (ssim_t1,) = mrpro.operators.functionals.SSIM(target_t1, mask)(pred_t1)\n", + " (l1_t1,) = mrpro.operators.functionals.L1Norm(target_t1, mask)(pred_t1)\n", + " (l1_m0,) = mrpro.operators.functionals.L1Norm(target_m0, mask)(pred_m0)\n", + " self.log('val/ssim_t1', ssim_t1, on_epoch=True, sync_dist=True, batch_size=batch_size)\n", + " self.log('val/l1_t1', l1_t1, on_epoch=True, sync_dist=True, batch_size=batch_size)\n", + " self.log('val/l1_m0', l1_m0, on_epoch=True, sync_dist=True, batch_size=batch_size)\n", + " self.log('val/loss', loss, on_epoch=True, sync_dist=True, batch_size=batch_size)\n", + "\n", + " if batch_idx == 0:\n", + " self.validation_step_outputs['target_t1'].append(batch['t1'])\n", + " self.validation_step_outputs['pred_t1'].append(pred_t1)\n", + " self.validation_step_outputs['pred_m0'].append(pred_m0)\n", + " self.validation_step_outputs['target_m0'].append(target_m0)\n", + " self.validation_step_outputs['mask'].append(batch['mask'])\n", + " baseline_m0, baseline_t1 = self.baseline(batch['kdata'], batch['csm'])\n", + " self.validation_step_outputs['baseline_t1'].append(baseline_t1)\n", + " self.validation_step_outputs['baseline_m0'].append(baseline_m0)\n", + "\n", + " def on_validation_epoch_end(self):\n", + " \"\"\"Validate.\n", + "\n", + " Needs to be adapted for other signal models than Saturation Recovery.\n", + " \"\"\"\n", + " outputs = {k: torch.cat(v) for k, v in self.validation_step_outputs.items()}\n", + " self.validation_step_outputs.clear()\n", + " outputs = cast(dict[str, torch.Tensor], self.all_gather(outputs))\n", + "\n", + " if not self.trainer.is_global_zero:\n", + " return\n", + " outputs = {k: v.flatten(0, 1).cpu() for k, v in outputs.items()}\n", + "\n", + " samples = len(outputs['mask'])\n", + " fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 16))\n", + "\n", + " for i in range(samples):\n", + " self.result_plot(\n", + " outputs['target_t1'][i],\n", + " outputs['pred_t1'][i],\n", + " outputs['mask'][i],\n", + " axes[:, i],\n", + " outputs['baseline_t1'][i],\n", + " '$T_1$ (s)',\n", + " )\n", + " fig.suptitle(f'$T_1$ Epoch {self.current_epoch}')\n", + " self.logger.run['val/images/t1'].log(fig)\n", + " plt.close(fig)\n", + "\n", + " fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 12))\n", + " for i in range(samples):\n", + " self.result_plot(\n", + " outputs['target_m0'][i].abs(),\n", + " outputs['pred_m0'][i].abs(),\n", + " outputs['mask'][i],\n", + " axes[:, i],\n", + " outputs['baseline_m0'][i].abs(),\n", + " '$|M_0|$ (a.u.)',\n", + " )\n", + " fig.suptitle(f'$|M_0|$ Epoch {self.current_epoch}')\n", + " self.logger.run['val/images/m0'].log(fig)\n", + " plt.close(fig)\n", + "\n", + " def result_plot(\n", + " self,\n", + " target: torch.Tensor,\n", + " pred: torch.Tensor,\n", + " mask: torch.Tensor,\n", + " axes: Sequence[plt.Axes],\n", + " baseline: torch.Tensor,\n", + " label: str,\n", + " ) -> None:\n", + " \"\"\"Plot the results.\"\"\"\n", + " target = target.squeeze().cpu()\n", + " pred = pred.squeeze().detach().cpu()\n", + " mask = mask.squeeze().detach().bool().cpu()\n", + " baseline = baseline.squeeze().detach().cpu()\n", + "\n", + " target[~mask] = torch.nan\n", + " pred[~mask] = torch.nan\n", + " baseline[~mask] = torch.nan\n", + " difference = (target - pred) / target * 100\n", + " vmax = np.nanmax(target.numpy())\n", + "\n", + " im0 = axes[0].imshow(target, vmin=0, vmax=vmax)\n", + " axes[0].set_title('Ground Truth')\n", + " axes[0].axis('off')\n", + " plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04, label=label)\n", + "\n", + " im1 = axes[1].imshow(baseline, vmin=0, vmax=vmax)\n", + " axes[1].set_title('SENSE + Regression')\n", + " axes[1].axis('off')\n", + " plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04, label=label)\n", + "\n", + " im2 = axes[2].imshow(pred, vmin=0, vmax=vmax)\n", + " axes[2].set_title('PINQI')\n", + " axes[2].axis('off')\n", + " plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04, label=label)\n", + "\n", + " diff_vmax = np.nanpercentile(difference.abs().numpy(), 90)\n", + " im3 = axes[3].imshow(difference, cmap='coolwarm', vmin=-diff_vmax, vmax=diff_vmax)\n", + " axes[3].set_title('rel. Error')\n", + " axes[3].axis('off')\n", + " plt.colorbar(im3, ax=axes[3], fraction=0.046, pad=0.04, label='%')\n", + "\n", + " def configure_optimizers(\n", + " self,\n", + " ) -> dict:\n", + " \"\"\"Configure the optimizer and the learning rate scheduler.\"\"\"\n", + " scalars = ('lambdas_raw', 'rezero')\n", + " params, scalar_params = [], []\n", + " for n, p in self.named_parameters():\n", + " if not p.requires_grad:\n", + " continue\n", + " if any(s in n for s in scalars):\n", + " scalar_params.append(p)\n", + " else:\n", + " params.append(p)\n", + " optimizer = torch.optim.AdamW(\n", + " [\n", + " {\n", + " 'params': params,\n", + " 'weight_decay': self.hparams.weight_decay,\n", + " 'lr': self.hparams.lr,\n", + " },\n", + " {\n", + " 'params': scalar_params,\n", + " 'weight_decay': 0.0,\n", + " 'lr': self.hparams.lr * 10,\n", + " },\n", + " ],\n", + " )\n", + " scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", + " optimizer,\n", + " max_lr=[self.hparams.lr, 10 * self.hparams.lr],\n", + " total_steps=self.trainer.estimated_stepping_batches,\n", + " pct_start=0.1,\n", + " div_factor=20,\n", + " final_div_factor=300,\n", + " )\n", + " return {\n", + " 'optimizer': optimizer,\n", + " 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'},\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b17c48e1", + "metadata": {}, + "outputs": [], + "source": [ + "class Baseline(torch.nn.Module):\n", + " \"\"\"Baseline solution using SENSE + Regression.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp,\n", + " parameter_is_complex: Sequence[bool],\n", + " ):\n", + " \"\"\"Initialize the baseline.\"\"\"\n", + " super().__init__()\n", + " self.signalmodel = signalmodel\n", + " self.constraints_op = constraints_op\n", + " self.parameter_is_complex = parameter_is_complex\n", + "\n", + " def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> tuple[torch.Tensor, ...]:\n", + " \"\"\"Compute the baseline solution.\"\"\"\n", + " sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(kdata, csm=csm)\n", + " images = sense(kdata).rearrange('batch time ...-> time batch ...')\n", + "\n", + " objective = mrpro.operators.functionals.L2NormSquared(images.data) @ self.signalmodel @ self.constraints_op\n", + " initial_values = tuple(\n", + " torch.zeros(\n", + " images.shape[1:],\n", + " device=images.device,\n", + " dtype=torch.complex64 if is_complex else torch.float32,\n", + " )\n", + " for is_complex in self.parameter_is_complex\n", + " )\n", + " solution = self.constraints_op(*mrpro.algorithms.optimizers.lbfgs(objective, initial_values))\n", + " return solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "198ef2c0", + "metadata": {}, + "outputs": [], + "source": [ + "class LogLambdasCallback(pl.Callback):\n", + " \"\"\"Log the lambdas.\"\"\"\n", + "\n", + " def on_train_batch_end(\n", + " self,\n", + " trainer: pl.Trainer,\n", + " pl_module: PinqiModule,\n", + " _outputs: dict,\n", + " _batch: BatchType,\n", + " _batch_idx: int,\n", + " ) -> None:\n", + " if trainer.global_step % 10 == 0:\n", + " lambdas = pl_module.pinqi.softplus(pl_module.pinqi.lambdas_raw).detach().cpu().numpy()\n", + " for iteration, (lambda_image, lambda_q, lambda_parameter) in enumerate(lambdas):\n", + " self.log_dict(\n", + " {\n", + " f'parameter/lambda_image_{iteration}': lambda_image,\n", + " f'parameter/lambda_q_{iteration}': lambda_q,\n", + " f'parameter/lambda_parameter_{iteration}': lambda_parameter,\n", + " },\n", + " on_step=True,\n", + " on_epoch=False,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "632cc485", + "metadata": {}, + "outputs": [], + "source": [ + "if __name__ == '__main__':\n", + " torch.multiprocessing.set_sharing_strategy('file_system')\n", + " torch.set_float32_matmul_precision('high')\n", + " torch._inductor.config.compile_threads = 4\n", + " torch._inductor.config.worker_start_method = 'fork'\n", + " torch._dynamo.config.capture_scalar_outputs = True\n", + " torch._dynamo.config.cache_size_limit = 256\n", + " torch._functorch.config.activation_memory_budget = 0.95\n", + "\n", + " data_folder = Path('/scratch/zimmer08/brainweb')\n", + "\n", + " signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 8.0))\n", + " constraints_op = mrpro.operators.ConstraintsOp(\n", + " bounds=(\n", + " (-2, 2), # M0 in [-2, 2]\n", + " (0.01, 6.0), # T1 is constrained between 10 ms and 6 s\n", + " )\n", + " )\n", + " n_images = len(signalmodel.saturation_time)\n", + " parameter_is_complex = [True, False]\n", + "\n", + " dm = DataModule(\n", + " folder=data_folder,\n", + " signalmodel=signalmodel,\n", + " n_images=n_images,\n", + " batch_size=16,\n", + " num_workers=16,\n", + " size=192,\n", + " acceleration=8,\n", + " n_coils=8,\n", + " max_noise=0.1,\n", + " )\n", + "\n", + " model = PinqiModule(\n", + " signalmodel=signalmodel,\n", + " constraints_op=constraints_op,\n", + " parameter_is_complex=parameter_is_complex,\n", + " n_images=n_images,\n", + " )\n", + "\n", + " neptune_logger = NeptuneLogger(\n", + " log_model_checkpoints=False,\n", + " dependencies='infer',\n", + " )\n", + " neptune_logger.log_model_summary(model=model, max_depth=-1)\n", + "\n", + " checkpoint_callback = ModelCheckpoint(\n", + " monitor='val/loss',\n", + " mode='min',\n", + " save_top_k=2,\n", + " dirpath=Path('checkpoints') / str(neptune_logger.version),\n", + " filename='{epoch:02d}-{val/loss:.4f}',\n", + " save_last=True,\n", + " )\n", + "\n", + " strategy = DDPStrategy(find_unused_parameters=False)\n", + " trainer = pl.Trainer(\n", + " max_epochs=100,\n", + " accelerator='gpu',\n", + " devices=4,\n", + " strategy=strategy,\n", + " logger=neptune_logger,\n", + " callbacks=[\n", + " LearningRateMonitor(logging_interval='step'),\n", + " checkpoint_callback,\n", + " LogLambdasCallback(),\n", + " ],\n", + " log_every_n_steps=10,\n", + " gradient_clip_algorithm='norm',\n", + " gradient_clip_val=5.0,\n", + " )\n", + "\n", + " trainer.fit(model, datamodule=dm)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "jupytext": { + "cell_metadata_filter": "mystnb,tags,-all" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/scripts/apply_pinqi.py b/examples/scripts/apply_pinqi.py index 92dd8584c..9d7bde957 100644 --- a/examples/scripts/apply_pinqi.py +++ b/examples/scripts/apply_pinqi.py @@ -3,24 +3,29 @@ # A recent DL approach, PINQI, approaches learned quantitative MRI by half quadratic splitting to alternate between two # subproblems. The first is a linear image reconstruction task # $$ -# \underset{\mathbf{x}}{\min} \frac{1}{2} \| \mathbf{A} \mathbf{x} - \mathbf{y} \|_2^2 + \frac{\lambda_\mathbf{x}}{2} \left\| \mathbf{x} - \mathbf{x}_{\text{reg}} \right\|_2^2 + \frac{\lambda_{\mathbf{q}}}{2} \left\| \mathbf{q}(\mathbf{p}) - \mathbf{x} \right\|_2^2 +# \underset{\mathbf{x}}{\min} \frac{1}{2} \| \mathbf{A} \mathbf{x} - \mathbf{y} \|_2^2 +# + \frac{\lambda_\mathbf{x}}{2} \left\| \mathbf{x} - \mathbf{x}_{\text{reg}} \right\|_2^2 +# + \frac{\lambda_{\mathbf{q}}}{2} \left\| \mathbf{q}(\mathbf{p}) - \mathbf{x} \right\|_2^2 # $$ # with $\mathbf{x}$ being intermediary qualitative images, $\lambda_{\mathbf{x}}$ and $\lambda_{\mathbf{q}}$ being # regularization strengths and $\mathbf{x}_{\text{reg}}$ denoting an image prior for regularization. # The second, non-linear, subproblem is finding the quantitative parameters by solving # $$ -# \underset{\mathbf{p}}{\min} \frac{\lambda_{\mathbf{q}}}{2}\left \| \mathbf{q}(\vec{p}) - \mathbf{x} \right\|_2^2 + \frac{\lambda_{\mathbf{p}}}{2} \left\| \mathbf{p} - \mathbf{p}_{\text{reg}} \right\|_2^2. +# \underset{\mathbf{p}}{\min} \frac{\lambda_{\mathbf{q}}}{2}\left \| \mathbf{q}(\vec{p}) - \mathbf{x} \right\|_2^2 +# + \frac{\lambda_{\mathbf{p}}}{2} \left\| \mathbf{p} - \mathbf{p}_{\text{reg}} \right\|_2^2. # $$ -# Here, $\mathbf{p}_{\text{reg}}$ is a prior on the parameter maps and $\lambda_{\mathbf{p}}$ the associated weight for regularization. -# In PINQI, a solution is found by iterating between both subproblems. In each iteration $k=1,\ldots,T$, the image and parameter priors are updated by -# U-Nets. The network parameters and the regularization strengths are trained end-to-end. -# Here, we apply a trained PINQI model to a validation set. We first define the dataset, then define the PINQI model, before loading the model weights -# and applying it to the dataset. +# Here, $\mathbf{p}_{\text{reg}}$ is a prior on the parameter maps and $\lambda_{\mathbf{p}}$ the associated weight for +# regularization. +# In PINQI, a solution is found by iterating between both subproblems. In each iteration $k=1,\ldots,T$, +# the image and parameter priors are updated by U-Nets. The network parameters and the regularization strengths +# are trained end-to-end. +# Here, we apply a trained PINQI model to a validation set. We first define the dataset, then define the PINQI model, +# before loading the model weights and applying it to the dataset. # %% [markdown] # ## Dataset -# We base the dataset on the BrainWeb phantom (`mrpro.phantoms.brainweb.BrainwebSlices`) and simulate Cartesian random undersampling in phase -# encode direction. +# We base the dataset on the BrainWeb phantom (`mrpro.phantoms.brainweb.BrainwebSlices`) and simulate Cartesian random +# undersampling in phase encode direction. # %% from collections.abc import Sequence @@ -182,7 +187,7 @@ def objective_factory( lambda_parameters: torch.Tensor, image: torch.Tensor, *parameter_reg: torch.Tensor, - ): + ) -> torch.operators.Operator: dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel reg = mrpro.operators.ProximableFunctionalSeparableSum( *[mrpro.operators.functionals.L2NormSquared(r) for r in parameter_reg] @@ -195,7 +200,7 @@ def objective_factory( ) # This can be done once, as the signal model is the same for all samples. - def get_linear_solver(self, gram: mrpro.operators.LinearOperator): + def get_linear_solver(self, gram: mrpro.operators.LinearOperator) -> mrpro.operators.ConjugateGradientOp: """Set up the linear solver.""" # This needs to be done for each sample, as the undersampling pattern and csm are different for each sample, # thus the gram operator of the acquisition operator is different for each sample. @@ -253,7 +258,7 @@ def get_image_reg(self, image: torch.Tensor, iteration: int = 0) -> torch.Tensor image = einops.rearrange(image, '(batch t) complex y x-> batch t 1 1 y x complex', batch=batch) return torch.view_as_complex(image.contiguous()) - def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> tuple[torch.Tensor, ...]: """Estimate the quantitative parameters. Parameters @@ -294,8 +299,9 @@ def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): # %% -# As a baseline methods for comparision, we use a simple non-learned approach. We reconstruct the qualitative images at different saturation times using iterative SENSE. -# We then perform a constrained non-linear least squares regression usingL-BFGS to obtain the parameter maps. +# As a baseline methods for comparison, we use a simple non-learned approach. We reconstruct the qualitative images at +# different saturation times using iterative SENSE. We then perform a constrained non-linear least squares regression +# using L-BFGS to obtain the parameter maps. # %% def baseline_solution( signalmodel: mrpro.operators.SignalModel, @@ -453,10 +459,3 @@ def baseline_solution( bbox_inches='tight', pad_inches=0, ) - - -# %% - -1 -# %% -# %% diff --git a/examples/scripts/modl.py b/examples/scripts/modl.py index 5039d233a..e2fb227e0 100644 --- a/examples/scripts/modl.py +++ b/examples/scripts/modl.py @@ -12,22 +12,49 @@ class BatchType(TypedDict): + """A single Batch.""" + data: mrpro.data.KData target: mrpro.data.IData csm: mrpro.data.CsmData class AcceleratedFastMRI(torch.utils.data.Dataset): + """An undersampled FastMRI Dataset.""" + def __init__(self, path: Path, acceleration: float = 12, noise_level: float = 0.1): + """Create an undersampled FastMRI Dataset. + + Parameters + ---------- + path + Path to the FastMRI dataset. + acceleration + Undersampling factor; higher values mean more acceleration. Default is 12. + noise_level + Level of additive Gaussian noise applied to the FastMRI dataset. Default is 0.1. + """ self.acceleration = acceleration files = list(path.glob('*AXT1*')) self.dataset = mrpro.phantoms.FastMRIKDataDataset(files) self.noise_level = noise_level def __len__(self): + """Get length of the dataset.""" return len(self.dataset) def __getitem__(self, index: int) -> BatchType: + """Get a single batch of data. + + Parameters + ---------- + index + Index of the batch. + + Returns + ------- + A single batch of data with keys 'data', 'target', and 'csm'.and + """ data = self.dataset[index] data = data.remove_readout_os() data.data /= data.data.std() @@ -52,7 +79,18 @@ def __getitem__(self, index: int) -> BatchType: class MODL(torch.nn.Module): + """MODL network.""" + def __init__(self, iterations: int = 8, n_features: Sequence[int] = (64, 64, 64, 64)): + """Initialize MODL network. + + Parameters + ---------- + iterations + Number of iterations. + n_features + Number of features in the network. + """ super().__init__() cnn = mrpro.nn.nets.BasicCNN( dim=2, @@ -67,9 +105,23 @@ def __init__(self, iterations: int = 8, n_features: Sequence[int] = (64, 64, 64, self.regularization_weights = torch.nn.Parameter(0.2 * torch.ones(iterations)) def __call__(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData: + """Apply MODL network. + + Parameters + ---------- + kdata + The k-space data. + csm + The coil sensitivity maps. + + Returns + ------- + The reconstructed image. + """ return super().__call__(kdata, csm) def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData: + """Apply the MODL network.""" fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) acquisition_op = fourier_op @ csm.as_operator() (zero_filled_image,) = acquisition_op.H(kdata.data) @@ -87,7 +139,7 @@ def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.dat return mrpro.data.IData(image, header=mrpro.data.IHeader.from_kheader(kdata.header)) -def plot(batch: BatchType, prediction: mrpro.data.IData, step: int): +def plot(batch: BatchType, prediction: mrpro.data.IData, step: int) -> None: """Plot the direct, sense, and modl reconstructions.""" target = batch['target'].rss().cpu().squeeze() direct = mrpro.algorithms.reconstruction.DirectReconstruction(batch['data'], csm=batch['csm'])(batch['data']) diff --git a/examples/scripts/train_pinqi.py b/examples/scripts/train_pinqi.py index 762899e2a..a8f055b7c 100644 --- a/examples/scripts/train_pinqi.py +++ b/examples/scripts/train_pinqi.py @@ -10,12 +10,12 @@ import matplotlib.pyplot as plt import mrpro import numpy as np -import pytorch_lightning as pl +import pytorch_lightning as pl # type:ignore[import-not-found] import torch import torch.utils.data._utils -from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint -from pytorch_lightning.loggers import NeptuneLogger -from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint # type:ignore[import-not-found] +from pytorch_lightning.loggers import NeptuneLogger # type:ignore[import-not-found] +from pytorch_lightning.strategies import DDPStrategy # type:ignore[import-not-found] # mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True) @@ -120,7 +120,7 @@ def collate_fn(batch: Any): # noqa: ANN401 return torch.utils.data._utils.collate.collate( batch, collate_fn_map={ - mrpro.data.Dataclass: lambda batch, *, collate_fn_map: batch[0].stack(*batch[1:]), + mrpro.data.Dataclass: lambda batch, *, _collate_fn_map: batch[0].stack(*batch[1:]), **torch.utils.data._utils.collate.default_collate_fn_map, }, ) @@ -360,7 +360,8 @@ def __init__( self.save_hyperparameters(ignore=['signalmodel', 'constraints_op']) if len(loss_weights) != n_iterations + 1: raise ValueError(f'loss_weights must be of length {n_iterations + 1} for {n_iterations} iterations') - signalmodel, constraints_op = map(deepcopy, (signalmodel, constraints_op)) + signalmodel = deepcopy(signalmodel) + constraints_op = deepcopy(constraints_op) self.pinqi = PINQI( signalmodel=signalmodel, constraints_op=constraints_op, @@ -371,7 +372,7 @@ def __init__( n_features_image_net=n_features_image_net, ) - self.validation_step_outputs = collections.defaultdict(list) + self.validation_step_outputs: dict[str, list] = collections.defaultdict(list) self.baseline = Baseline(signalmodel, constraints_op, parameter_is_complex) def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): @@ -395,7 +396,7 @@ def loss(self, predictions: Sequence[torch.Tensor], batch: BatchType) -> torch.T def training_step(self, batch: BatchType, _batch_idx: int) -> torch.Tensor: """Training step.""" - images, parameters = self(batch['kdata'], batch['csm']) + _images, parameters = self(batch['kdata'], batch['csm']) loss = self.loss(parameters, batch) self.log( 'train/loss', @@ -413,7 +414,7 @@ def validation_step(self, batch: BatchType, batch_idx: int) -> None: Needs to be adapted for other signal models than Saturation Recovery. """ - images, parameters = self(batch['kdata'], batch['csm']) + _images, parameters = self(batch['kdata'], batch['csm']) loss = self.loss(parameters, batch) pred_m0, pred_t1 = parameters[-1] @@ -491,16 +492,16 @@ def result_plot( label: str, ) -> None: """Plot the results.""" - target = target.squeeze().numpy() - pred = pred.squeeze().detach().numpy() - mask = mask.squeeze().detach().numpy().astype(bool) - baseline = baseline.squeeze().detach().numpy() - - target[~mask] = np.nan - pred[~mask] = np.nan - baseline[~mask] = np.nan + target = target.squeeze().cpu() + pred = pred.squeeze().detach().cpu() + mask = mask.squeeze().detach().bool().cpu() + baseline = baseline.squeeze().detach().cpu() + + target[~mask] = torch.nan + pred[~mask] = torch.nan + baseline[~mask] = torch.nan difference = (target - pred) / target * 100 - vmax = np.nanmax(target) + vmax = np.nanmax(target.numpy()) im0 = axes[0].imshow(target, vmin=0, vmax=vmax) axes[0].set_title('Ground Truth') @@ -517,7 +518,7 @@ def result_plot( axes[2].axis('off') plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04, label=label) - diff_vmax = np.nanpercentile(np.abs(difference), 90) + diff_vmax = np.nanpercentile(difference.abs().numpy(), 90) im3 = axes[3].imshow(difference, cmap='coolwarm', vmin=-diff_vmax, vmax=diff_vmax) axes[3].set_title('rel. Error') axes[3].axis('off') From 102dc4445eb4438bdc031dd38483dfcc72ae1d53 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 9 Feb 2026 16:35:31 +0100 Subject: [PATCH 188/205] Brainweb: Add ULF values --- src/mrpro/phantoms/brainweb.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/mrpro/phantoms/brainweb.py b/src/mrpro/phantoms/brainweb.py index 0764bb84c..076728ae5 100644 --- a/src/mrpro/phantoms/brainweb.py +++ b/src/mrpro/phantoms/brainweb.py @@ -227,6 +227,22 @@ def trim_indices(mask: torch.Tensor) -> tuple[slice, slice]: return slice(row_min, row_max), slice(col_min, col_max) +VALUES_ULF_RANDOMIZED: Mapping[TClassNames, BrainwebTissue] = MappingProxyType( + { + 'skl': BrainwebTissue((0.100, 0.400), (0.005, 0.015), (0.00, 0.05), (-0.2, 0.2)), + 'gry': BrainwebTissue((0.350, 0.430), (0.090, 0.115), (0.70, 1.00), (-0.2, 0.2)), + 'wht': BrainwebTissue((0.240, 0.280), (0.075, 0.085), (0.50, 0.90), (-0.2, 0.2)), + 'csf': BrainwebTissue((1.500, 2.500), (1.000, 1.600), (0.95, 1.00), (-0.2, 0.2)), + 'mrw': BrainwebTissue((0.150, 0.250), (0.060, 0.100), (0.70, 1.00), (-0.2, 0.2)), + 'dura': BrainwebTissue((0.300, 0.600), (0.100, 0.200), (0.90, 1.00), (-0.2, 0.2)), + 'fat': BrainwebTissue((0.120, 0.160), (0.080, 0.130), (0.90, 1.00), (-0.2, 0.2)), + 'fat2': BrainwebTissue((0.140, 0.180), (0.080, 0.130), (0.60, 0.90), (-0.2, 0.2)), + 'mus': BrainwebTissue((0.160, 0.200), (0.035, 0.045), (0.90, 1.00), (-0.2, 0.2)), + 'm-s': BrainwebTissue((0.200, 0.400), (0.100, 0.250), (0.90, 1.00), (-0.2, 0.2)), + 'ves': BrainwebTissue((0.300, 0.500), (0.150, 0.300), (0.80, 1.00), (-0.2, 0.2)), + } +) + VALUES_3T_RANDOMIZED: Mapping[TClassNames, BrainwebTissue] = MappingProxyType( { 'skl': BrainwebTissue((0.000, 2.000), (0.000, 0.010), (0.00, 0.05), (-0.2, 0.2)), From 863edae3bb893d5bcdfac90567349540742e99d2 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 9 Feb 2026 16:35:45 +0100 Subject: [PATCH 189/205] pinqi --- examples/scripts/train_pinqi.py | 78 +++++++++++++++++---------------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/examples/scripts/train_pinqi.py b/examples/scripts/train_pinqi.py index a8f055b7c..d640b8c18 100644 --- a/examples/scripts/train_pinqi.py +++ b/examples/scripts/train_pinqi.py @@ -1,10 +1,9 @@ +# %% # ruff: noqa: D102, ANN201 - -import collections from collections.abc import Sequence from copy import deepcopy from pathlib import Path -from typing import Any, Literal, TypedDict, cast +from typing import Any, Literal, TypedDict import einops import matplotlib.pyplot as plt @@ -15,9 +14,6 @@ import torch.utils.data._utils from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint # type:ignore[import-not-found] from pytorch_lightning.loggers import NeptuneLogger # type:ignore[import-not-found] -from pytorch_lightning.strategies import DDPStrategy # type:ignore[import-not-found] - -# mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True) class BatchType(TypedDict): @@ -104,10 +100,11 @@ def __getitem__(self, index: int): header.ti = self.signalmodel.ti.tolist() fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj) - csm = mrpro.data.CsmData( - mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix), - header, - ) + if self.n_coils > 1: + csm_tensor = mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix) + else: + csm_tensor = torch.ones(1, 1, *self.encoding_matrix.zyx) + csm = mrpro.data.CsmData(csm_tensor, header) images = einops.rearrange(images, 't y x -> t 1 1 y x') (data,) = (fourier_op @ csm.as_operator())(images) data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std() @@ -120,7 +117,7 @@ def collate_fn(batch: Any): # noqa: ANN401 return torch.utils.data._utils.collate.collate( batch, collate_fn_map={ - mrpro.data.Dataclass: lambda batch, *, _collate_fn_map: batch[0].stack(*batch[1:]), + mrpro.data.Dataclass: lambda batch, *, collate_fn_map: batch[0].stack(*batch[1:]), # noqa: ARG005 **torch.utils.data._utils.collate.default_collate_fn_map, }, ) @@ -330,7 +327,7 @@ def train_dataloader(self): def val_dataloader(self): return torch.utils.data.DataLoader( self.val_dataset, - batch_size=1, + batch_size=4, shuffle=False, num_workers=self.num_workers, pin_memory=False, @@ -372,7 +369,7 @@ def __init__( n_features_image_net=n_features_image_net, ) - self.validation_step_outputs: dict[str, list] = collections.defaultdict(list) + self.validation_step_outputs: dict[str, list] = {} self.baseline = Baseline(signalmodel, constraints_op, parameter_is_complex) def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): @@ -418,7 +415,7 @@ def validation_step(self, batch: BatchType, batch_idx: int) -> None: loss = self.loss(parameters, batch) pred_m0, pred_t1 = parameters[-1] - target_t1, target_m0 = batch['t1'], batch['m0'] + target_t1, target_m0 = batch['t1'][:, None, None], batch['m0'][:, None, None] mask = batch['mask'] batch_size = len(batch['mask']) (ssim_t1,) = mrpro.operators.functionals.SSIM(target_t1, mask)(pred_t1) @@ -429,31 +426,27 @@ def validation_step(self, batch: BatchType, batch_idx: int) -> None: self.log('val/l1_m0', l1_m0, on_epoch=True, sync_dist=True, batch_size=batch_size) self.log('val/loss', loss, on_epoch=True, sync_dist=True, batch_size=batch_size) - if batch_idx == 0: - self.validation_step_outputs['target_t1'].append(batch['t1']) - self.validation_step_outputs['pred_t1'].append(pred_t1) - self.validation_step_outputs['pred_m0'].append(pred_m0) - self.validation_step_outputs['target_m0'].append(target_m0) - self.validation_step_outputs['mask'].append(batch['mask']) + if batch_idx == 0 and self.trainer.is_global_zero: + self.validation_step_outputs['target_t1'] = batch['t1'].cpu() + self.validation_step_outputs['pred_t1'] = pred_t1.cpu() + self.validation_step_outputs['pred_m0'] = pred_m0.cpu() + self.validation_step_outputs['target_m0'] = target_m0.cpu() + self.validation_step_outputs['mask'] = batch['mask'].cpu() baseline_m0, baseline_t1 = self.baseline(batch['kdata'], batch['csm']) - self.validation_step_outputs['baseline_t1'].append(baseline_t1) - self.validation_step_outputs['baseline_m0'].append(baseline_m0) + self.validation_step_outputs['baseline_t1'] = baseline_t1.cpu() + self.validation_step_outputs['baseline_m0'] = baseline_m0.cpu() def on_validation_epoch_end(self): """Validate. Needs to be adapted for other signal models than Saturation Recovery. """ - outputs = {k: torch.cat(v) for k, v in self.validation_step_outputs.items()} - self.validation_step_outputs.clear() - outputs = cast(dict[str, torch.Tensor], self.all_gather(outputs)) - if not self.trainer.is_global_zero: return - outputs = {k: v.flatten(0, 1).cpu() for k, v in outputs.items()} + outputs = self.validation_step_outputs samples = len(outputs['mask']) - fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 16)) + fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 16), squeeze=False) for i in range(samples): self.result_plot( @@ -481,6 +474,7 @@ def on_validation_epoch_end(self): fig.suptitle(f'$|M_0|$ Epoch {self.current_epoch}') self.logger.run['val/images/m0'].log(fig) plt.close(fig) + self.validation_step_outputs.clear() def result_plot( self, @@ -496,7 +490,6 @@ def result_plot( pred = pred.squeeze().detach().cpu() mask = mask.squeeze().detach().bool().cpu() baseline = baseline.squeeze().detach().cpu() - target[~mask] = torch.nan pred[~mask] = torch.nan baseline[~mask] = torch.nan @@ -624,15 +617,24 @@ def on_train_batch_end( if __name__ == '__main__': + import os + + os.environ['NEPTUNE_API_TOKEN'] = ( + 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIyOTdlYTM3NS0wMWU1LTRlMzMtYWU1Ny01MzMzN2ExNTcwMDcifQ==' + ) + os.environ['NEPTUNE_PROJECT'] = 'ptb/pinqi' torch.multiprocessing.set_sharing_strategy('file_system') torch.set_float32_matmul_precision('high') torch._inductor.config.compile_threads = 4 torch._inductor.config.worker_start_method = 'fork' torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.cache_size_limit = 256 - torch._functorch.config.activation_memory_budget = 0.95 + torch._functorch.config.activation_memory_budget = 0.8 - data_folder = Path('/scratch/zimmer08/brainweb') + data_folder = Path(' /echo/zimmer08/brainweb') + if not data_folder.exists(): + data_folder.mkdir(parents=True, exist_ok=True) + mrpro.phantoms.brainweb.download_brainweb(output_directory=data_folder, workers=2, progress=True) signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 8.0)) constraints_op = mrpro.operators.ConstraintsOp( @@ -648,12 +650,12 @@ def on_train_batch_end( folder=data_folder, signalmodel=signalmodel, n_images=n_images, - batch_size=16, - num_workers=16, + batch_size=4, + num_workers=4, size=192, acceleration=8, - n_coils=8, - max_noise=0.1, + n_coils=1, + max_noise=0.3, ) model = PinqiModule( @@ -678,11 +680,11 @@ def on_train_batch_end( save_last=True, ) - strategy = DDPStrategy(find_unused_parameters=False) + strategy = 'auto' # DDPStrategy(find_unused_parameters=False) trainer = pl.Trainer( max_epochs=100, accelerator='gpu', - devices=4, + devices=1, strategy=strategy, logger=neptune_logger, callbacks=[ @@ -696,3 +698,5 @@ def on_train_batch_end( ) trainer.fit(model, datamodule=dm) + +# %% From 3e85c27ebcc7a90a90a466ddb548b4d6de95077e Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:37:25 +0100 Subject: [PATCH 190/205] Add fast path to PatchOp adjoint ghstack-source-id: ce486f1c161829a2109abd998ea46e6446587643 ghstack-comment-id: 3874745101 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/962 --- src/mrpro/operators/PatchOp.py | 88 +++++++++++++++++++++++--------- tests/operators/test_patch_op.py | 9 ++-- 2 files changed, 69 insertions(+), 28 deletions(-) diff --git a/src/mrpro/operators/PatchOp.py b/src/mrpro/operators/PatchOp.py index 190bf28fb..9339f74d0 100644 --- a/src/mrpro/operators/PatchOp.py +++ b/src/mrpro/operators/PatchOp.py @@ -21,6 +21,7 @@ def __init__( stride: Sequence[int] | int | None = None, dilation: Sequence[int] | int = 1, domain_size: int | Sequence[int] | None = None, + flatten_patches: bool = True, ) -> None: """Initialize the PatchOp. @@ -38,6 +39,9 @@ def __init__( Size of the domain in the dimnsions `dim`. If None, it is inferred from the input tensor on the first call. This is only used in the adjoint method. + flatten_patches + If True, flatten the leading grid dimensions to a single patch dimension. + If False, keep shape ``(*grid_size, ...)`` for the forward output. """ super().__init__() self.dim = (dim,) if isinstance(dim, int) else dim @@ -60,6 +64,7 @@ def check(param: int | Sequence[int], name: str) -> tuple[int, ...]: self.stride = check(stride, 'stride') if stride is not None else self.patch_size self.dilation = check(dilation, 'dilation') self.domain_size = check(domain_size, 'domain_size') if domain_size is not None else None + self.flatten_patches = flatten_patches def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor,]: """Extract N-dimensional patches from an input tensor using a sliding window. @@ -100,9 +105,59 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: stride=self.stride, dilation=self.dilation, ) - patches = patches.flatten(start_dim=0, end_dim=len(self.dim) - 1) + if self.flatten_patches: + patches = patches.flatten(start_dim=0, end_dim=len(self.dim) - 1) return (patches,) + def _adjoint_fast(self, patches: torch.Tensor) -> torch.Tensor: + """Adjoint via reshape/permute for non-overlapping patches.""" + assert self.domain_size is not None # mypy # noqa: S101 + grid = tuple(s // p for s, p in zip(self.domain_size, self.patch_size, strict=True)) + n_dim = len(grid) + if self.flatten_patches: + patches = patches.unflatten(0, grid) + permutation: list[int] = [] + reshape: list[int] = [] + dim = [d % (patches.ndim - n_dim) for d in self.dim] + for i, size in enumerate(patches.shape[n_dim:]): + if i in dim: + j = dim.index(i) + permutation.extend([j, n_dim + i]) + reshape.append(grid[j] * self.patch_size[j]) + else: + permutation.append(n_dim + i) + reshape.append(size) + return patches.permute(*permutation).reshape(reshape) + + def _adjoint_scatter(self, patches: torch.Tensor) -> torch.Tensor: + """Adjoint via scatter for overlapping patches.""" + assert self.domain_size is not None # mypy # noqa: S101 + k = len(self.dim) + if not self.flatten_patches: + patches = patches.flatten(start_dim=0, end_dim=k - 1) + output_shape_ = list(patches.shape[1:]) + for dim, size in zip(self.dim, self.domain_size, strict=True): + output_shape_[dim] = size + output_shape = torch.Size(output_shape_) + indices = torch.arange(output_shape.numel(), device=patches.device).reshape(output_shape_) + windowed_indices = sliding_window( + x=indices, + window_shape=self.patch_size, + dim=self.dim, + stride=self.stride, + dilation=self.dilation, + ).flatten(start_dim=0, end_dim=k - 1) + if windowed_indices.shape[0] != patches.shape[0]: + raise ValueError( + f'Number of patches {patches.shape[0]} does not match the number of ' + f'expected patches {windowed_indices.shape[0]}' + ) + + assembled = patches.new_zeros(output_shape.numel()) + assembled.scatter_add_(dim=0, index=windowed_indices.flatten(), src=patches.flatten()) + assembled = assembled.reshape(output_shape) + return assembled + def adjoint( self, patches: torch.Tensor, @@ -127,26 +182,11 @@ def adjoint( """ if self.domain_size is None: raise ValueError('Domain size is not set. Please call forward first or set it at initialization.') - - output_shape_ = list(patches.shape[1:]) - for dim, size in zip(self.dim, self.domain_size, strict=True): - output_shape_[dim] = size - output_shape = torch.Size(output_shape_) - indices = torch.arange(output_shape.numel(), device=patches.device).reshape(output_shape_) - windowed_indices = sliding_window( - x=indices, - window_shape=self.patch_size, - dim=self.dim, - stride=self.stride, - dilation=self.dilation, - ).flatten(start_dim=0, end_dim=len(self.dim) - 1) - if windowed_indices.shape[0] != patches.shape[0]: - raise ValueError( - f'Number of patches {patches.shape[0]} does not match the number of ' - f'expected patches {windowed_indices.shape[0]}' - ) - - assembled = patches.new_zeros(output_shape.numel()) - assembled.scatter_add_(dim=0, index=windowed_indices.flatten(), src=patches.flatten()) - assembled = assembled.reshape(output_shape) - return (assembled,) + if ( + self.stride == self.patch_size # no overlap + and all(d == 1 for d in self.dilation) # no dilation + and all(s % p == 0 for s, p in zip(self.domain_size, self.patch_size, strict=True)) # divisible + ): + return (self._adjoint_fast(patches),) + else: + return (self._adjoint_scatter(patches),) diff --git a/tests/operators/test_patch_op.py b/tests/operators/test_patch_op.py index 46b856178..49395f10c 100644 --- a/tests/operators/test_patch_op.py +++ b/tests/operators/test_patch_op.py @@ -1,6 +1,7 @@ """Tests for Rearrange Operator.""" from collections.abc import Sequence +from typing import Any import pytest import torch @@ -14,13 +15,15 @@ [ ((3, 4, 5), {'dim': (0, 1), 'patch_size': (1, 3), 'stride': (3, 1), 'dilation': (2, 1)}, (2, 1, 3, 5)), ((1, 20), {'dim': -1, 'patch_size': 3, 'stride': 3, 'dilation': 5}, (4, 1, 3)), + ((9, 16), {'dim': (-1, 0), 'patch_size': (2, 3), 'stride': (2, 3), 'dilation': 1}, (24, 3, 2)), + ((9, 16), {'dim': (-1, 0), 'patch_size': (2, 3), 'stride': None, 'flatten_patches': False}, (8, 3, 3, 2)), ], ) @TESTCASES def test_patch_op_adjointness( - input_shape: Sequence[int], arguments: dict[str, int | Sequence[int]], output_shape: Sequence[int] + input_shape: Sequence[int], arguments: dict[str, Any], output_shape: Sequence[int] ) -> None: """Test adjointness and shape of Rearrange Op.""" rng = RandomGenerator(seed=0) @@ -36,9 +39,7 @@ def test_patch_op_adjointness( @TESTCASES -def test_patch_op_autodiff( - input_shape: Sequence[int], arguments: dict[str, int | Sequence[int]], output_shape: Sequence[int] -) -> None: +def test_patch_op_autodiff(input_shape: Sequence[int], arguments: dict[str, Any], output_shape: Sequence[int]) -> None: """Test autodiff works for PatchOp.""" rng = RandomGenerator(seed=0) u = rng.complex64_tensor(size=input_shape) From c0be007732a13b58aa9ce9ce631ea9bff1af85de Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:37:26 +0100 Subject: [PATCH 191/205] add core nn foundations, layers, and resize blocks ghstack-source-id: 17c69fc80ce0e6e8390cb563c732b4f2b8aea912 ghstack-comment-id: 3865650347 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/953 --- pyproject.toml | 10 +- src/mrpro/__init__.py | 6 +- src/mrpro/nn/ComplexAsChannel.py | 59 ++++++ src/mrpro/nn/CondMixin.py | 22 +++ src/mrpro/nn/DropPath.py | 55 ++++++ src/mrpro/nn/FiLM.py | 68 +++++++ src/mrpro/nn/FourierFeatures.py | 50 +++++ src/mrpro/nn/GEGLU.py | 56 ++++++ src/mrpro/nn/GroupNorm.py | 71 +++++++ src/mrpro/nn/LayerNorm.py | 85 ++++++++ src/mrpro/nn/PermutedBlock.py | 75 +++++++ src/mrpro/nn/PixelShuffle.py | 285 +++++++++++++++++++++++++++ src/mrpro/nn/RMSNorm.py | 73 +++++++ src/mrpro/nn/Residual.py | 45 +++++ src/mrpro/nn/Sequential.py | 60 ++++++ src/mrpro/nn/Upsample.py | 66 +++++++ src/mrpro/nn/__init__.py | 45 +++++ src/mrpro/nn/convert_linear_conv.py | 100 ++++++++++ src/mrpro/nn/join.py | 176 +++++++++++++++++ src/mrpro/nn/ndmodules.py | 178 +++++++++++++++++ src/mrpro/utils/__init__.py | 5 +- src/mrpro/utils/to_tuple.py | 36 ++++ tests/nn/test_complexaschannel.py | 30 +++ tests/nn/test_convert_linear_conv.py | 150 ++++++++++++++ tests/nn/test_droppath.py | 30 +++ tests/nn/test_film.py | 58 ++++++ tests/nn/test_fourierfeatures.py | 24 +++ tests/nn/test_geglu.py | 38 ++++ tests/nn/test_groupnorm.py | 45 +++++ tests/nn/test_join.py | 160 +++++++++++++++ tests/nn/test_layernorm.py | 186 +++++++++++++++++ tests/nn/test_ndmodules.py | 76 +++++++ tests/nn/test_pixelshuffle.py | 92 +++++++++ tests/nn/test_rmsnorm.py | 58 ++++++ tests/nn/test_sequential.py | 50 +++++ 35 files changed, 2617 insertions(+), 6 deletions(-) create mode 100644 src/mrpro/nn/ComplexAsChannel.py create mode 100644 src/mrpro/nn/CondMixin.py create mode 100644 src/mrpro/nn/DropPath.py create mode 100644 src/mrpro/nn/FiLM.py create mode 100644 src/mrpro/nn/FourierFeatures.py create mode 100644 src/mrpro/nn/GEGLU.py create mode 100644 src/mrpro/nn/GroupNorm.py create mode 100644 src/mrpro/nn/LayerNorm.py create mode 100644 src/mrpro/nn/PermutedBlock.py create mode 100644 src/mrpro/nn/PixelShuffle.py create mode 100644 src/mrpro/nn/RMSNorm.py create mode 100644 src/mrpro/nn/Residual.py create mode 100644 src/mrpro/nn/Sequential.py create mode 100644 src/mrpro/nn/Upsample.py create mode 100644 src/mrpro/nn/__init__.py create mode 100644 src/mrpro/nn/convert_linear_conv.py create mode 100644 src/mrpro/nn/join.py create mode 100644 src/mrpro/nn/ndmodules.py create mode 100644 src/mrpro/utils/to_tuple.py create mode 100644 tests/nn/test_complexaschannel.py create mode 100644 tests/nn/test_convert_linear_conv.py create mode 100644 tests/nn/test_droppath.py create mode 100644 tests/nn/test_film.py create mode 100644 tests/nn/test_fourierfeatures.py create mode 100644 tests/nn/test_geglu.py create mode 100644 tests/nn/test_groupnorm.py create mode 100644 tests/nn/test_join.py create mode 100644 tests/nn/test_layernorm.py create mode 100644 tests/nn/test_ndmodules.py create mode 100644 tests/nn/test_pixelshuffle.py create mode 100644 tests/nn/test_rmsnorm.py create mode 100644 tests/nn/test_sequential.py diff --git a/pyproject.toml b/pyproject.toml index 67aeb1928..9863b3856 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,6 +98,7 @@ docs = [ "sphinx-autodoc-typehints>=3, <3.1", "sphinx-copybutton>=0.5, <0.6", "sphinx-last-updated-by-git>=0.3, <0.4", + "snowballstemmer>=2.2, <3.0", ] notebooks = [ "zenodo_get>=2.0", @@ -118,10 +119,12 @@ testpaths = ["tests"] filterwarnings = [ "error", "ignore:'write_like_original':DeprecationWarning:pydicom:", - "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd - "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 + "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd + "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 + "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda "ignore:.*In the future, this object will be coerced as if it was first converted using.*:FutureWarning", # numpy 1.2 - "ignore:`torch.jit.script` is deprecated:DeprecationWarning", # torch 2.10 + "ignore:.*load_module.*:DeprecationWarning", # torch compile in torch<2.6 + "ignore:`torch.jit.script` is deprecated:DeprecationWarning", # torch 2.10 ] addopts = "-n auto --dist loadfile --maxprocesses=8" markers = ["cuda : Tests only to be run when cuda device is available"] @@ -230,6 +233,7 @@ iy = "iy" arange = "arange" # torch.arange Ba = "Ba" wht = "wht" # Brainweb tissue class +ND = "ND" # Short for N-dimensional [tool.typos.files] extend-exclude = [ diff --git a/src/mrpro/__init__.py b/src/mrpro/__init__.py index 729ae188c..bbd401f1f 100644 --- a/src/mrpro/__init__.py +++ b/src/mrpro/__init__.py @@ -1,10 +1,12 @@ from mrpro._version import __version__ -from mrpro import algorithms, operators, data, phantoms, utils +from mrpro import algorithms, operators, data, phantoms, utils, nn + __all__ = [ "__version__", "algorithms", "data", + "nn", "operators", "phantoms", "utils" -] +] \ No newline at end of file diff --git a/src/mrpro/nn/ComplexAsChannel.py b/src/mrpro/nn/ComplexAsChannel.py new file mode 100644 index 000000000..22e13458e --- /dev/null +++ b/src/mrpro/nn/ComplexAsChannel.py @@ -0,0 +1,59 @@ +"""ComplexAsChannel: handling complex-valued tensors as channels.""" + +import torch +from einops import rearrange +from torch.nn import Module + +from mrpro.nn.CondMixin import CondMixin, call_with_cond + + +class ComplexAsChannel(CondMixin, Module): + """Wrap module to treat complex numbers as a channel dimension.""" + + def __init__(self, module: Module, convert_back: bool = True): + """Initialize the ComplexAsChannel module. + + Wraps a module to treat complex numbers as a channel dimension. + For each complex tensor in the input, real and imaginary parts are concatenated along the channel dimension + before being passed to the wrapped module. + + + Parameters + ---------- + module + The module to wrap. Should output a single real tensor. + convert_back + If True, the output is converted back to a complex tensor. + The output should have a number of channels that is a multiple of 2. + """ + super().__init__() + self.module = module + self.convert_back = convert_back + + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor (if used by the wrapped module) + """ + return super().__call__(*x, cond=cond) + + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module.""" + x_real = [ + rearrange(torch.view_as_real(c), 'batch channel ... complex -> batch (channel complex) ...') + if c.is_complex() + else c + for c in x + ] + + y = call_with_cond(self.module, *x_real, cond=cond) + + if self.convert_back: + y = rearrange(y, 'b (channel complex) ... -> b channel ... complex', complex=2).contiguous() + y = torch.view_as_complex(y) + return y diff --git a/src/mrpro/nn/CondMixin.py b/src/mrpro/nn/CondMixin.py new file mode 100644 index 000000000..6a902c413 --- /dev/null +++ b/src/mrpro/nn/CondMixin.py @@ -0,0 +1,22 @@ +"""Base class for modules using a conditioning.""" + +import torch +from torch.nn import Module + + +def call_with_cond(module: Module, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Call a module with conditioning if it is a CondMixin.""" + if isinstance(module, CondMixin): + return module(*x, cond=cond) + return module(*x) + + +class CondMixin(Module): + """Mixin for modules using a conditioning. + + Used to determine if a module uses a conditioning within a Sequential container. + """ + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module to the input.""" + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/DropPath.py b/src/mrpro/nn/DropPath.py new file mode 100644 index 000000000..7262fd86c --- /dev/null +++ b/src/mrpro/nn/DropPath.py @@ -0,0 +1,55 @@ +"""DropPath (stochastic depth).""" + +import torch +from torch.nn import Module + + +class DropPath(Module): + """Drop path or stochastic depth. + + Drops full samples from batch with probability `droprate`. + Should be used in the main path of a Resblock. + + References + ---------- + .. [HUANG16] Huang, G., Sun, Y., Liu, Z., Sedra, D., & Weinberger, K. Q. Deep networks with stochastic depth. + ECCV 2016. https://link.springer.com/chapter/10.1007/978-3-319-46493-0_39 + """ + + def __init__(self, droprate: float = 0.0, scale_by_keep: bool = False): + """Initialize the DropPath module. + + Parameters + ---------- + droprate + Drop probability + scale_by_keep + If True, the kept samples are scaled by :math:`1/(1-droprate)` + """ + super().__init__() + self.droprate = droprate + self.scale_by_keep = scale_by_keep + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply DropPath. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Tensor with batch samples randomly dropped + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply DropPath.""" + if self.droprate == 0 or not self.training: + return x + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + mask = ((1 - self.droprate) + torch.rand(shape, dtype=x.dtype, device=x.device)).floor_() + if self.scale_by_keep: + mask = mask.div_(1 - self.droprate) + return x * mask diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py new file mode 100644 index 000000000..8c249ea3e --- /dev/null +++ b/src/mrpro/nn/FiLM.py @@ -0,0 +1,68 @@ +"""Feature-wise Linear Modulation.""" + +import torch +from torch.nn import Linear, Module + +from mrpro.nn.CondMixin import CondMixin +from mrpro.utils.reshape import unsqueeze_tensors_right + + +class FiLM(CondMixin, Module): + """Feature-wise Linear Modulation. + + Feature-wise Linear Modulation from [FiLM]_ to condition a network on a conditioning tensor. + + + References + ---------- + ..[FiLM] Perez, L., Strub, F., de Vries, H., Dumoulin, V., & Courville, A. "FiLM : Visual reasoning with a general + conditioning layer." AAAI (2018). https://arxiv.org/abs/1709.07871 + """ + + features_last: bool + + def __init__(self, channels: int, cond_dim: int, features_last: bool = False) -> None: + """Initialize FiLM. + + Parameters + ---------- + channels + The number of channels in the input tensor. + cond_dim + The dimension of the conditioning tensor. + features_last + Whether the features are in the last dimension of the input tensor (e.g. transformer tokens) + or in the second dimension (e.g. image tensors). + """ + super().__init__() + self.project = Linear(cond_dim, 2 * channels) if cond_dim > 0 else None + self.features_last = features_last + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply FiLM. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply FiLM.""" + if cond is None or self.project is None: + return x + + if self.features_last: + x = x.moveaxis(-1, 1) + + scale, shift = self.project(cond).chunk(2, dim=1) + scale, shift = unsqueeze_tensors_right(scale, shift, ndim=x.ndim) + x = x * (1 + scale) + shift + + if self.features_last: + x = x.moveaxis(1, -1) + + return x diff --git a/src/mrpro/nn/FourierFeatures.py b/src/mrpro/nn/FourierFeatures.py new file mode 100644 index 000000000..847ae3c60 --- /dev/null +++ b/src/mrpro/nn/FourierFeatures.py @@ -0,0 +1,50 @@ +"""Random Fourier feature embedding.""" + +import torch +from torch.nn import Module + + +class FourierFeatures(Module): + """Fourier feature encoding layer. + + Projects input features into a higher dimensional space using random Fourier features. + Used in INRs and to embed the time or other continuous variables. + """ + + weight: torch.Tensor + + def __init__(self, n_features_in: int, n_features_out: int, std: float = 1.0): + """Initialize Fourier feature encoding layer. + + Parameters + ---------- + n_features_in + Number of input features + n_features_out + Number of output features (must be even) + std + Standard deviation for random initialization + """ + if n_features_out % 2 != 0: + raise ValueError('n_features_out must be even.') + super().__init__() + self.register_buffer('weight', torch.randn([n_features_out // 2, n_features_in]) * std) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply Fourier feature encoding. + + Parameters + ---------- + x + Input tensor of shape (..., in_features) + + Returns + ------- + Encoded features of shape (..., out_features) + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply Fourier feature encoding.""" + f = 2 * torch.pi * x @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) diff --git a/src/mrpro/nn/GEGLU.py b/src/mrpro/nn/GEGLU.py new file mode 100644 index 000000000..6151503d2 --- /dev/null +++ b/src/mrpro/nn/GEGLU.py @@ -0,0 +1,56 @@ +"""Gated linear unit activation function.""" + +import torch +from torch.nn import Linear, Module + + +class GEGLU(Module): + r"""Gated linear unit activation function. + + References + ---------- + ..[GLU] Shazeer, N. (2020). GLU variants improve transformer. https://arxiv.org/abs/2002.05202 + """ + + def __init__(self, n_channels_in: int, n_channels_out: int | None = None, features_last: bool = False): + """Initialize the GEGLU activation function. + + Parameters + ---------- + n_channels_in + The number of input features/channels. + n_channels_out + The number of output features/channels. If None, the number of + output features is the same as the number of input features. + features_last + If True, the channel dimension is the last dimension, else in the second dimension. + """ + super().__init__() + out_channels_ = n_channels_in if n_channels_out is None else n_channels_out + self.proj = Linear(n_channels_in, out_channels_ * 2) # gate and output stacked + self.features_last = features_last + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the GEGLU activation.""" + if not self.features_last: + x = x.moveaxis(1, -1) + h, gate = self.proj(x).chunk(2, dim=-1) + gate = torch.nn.functional.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + out = h * gate + if not self.features_last: + out = out.moveaxis(-1, 1) + return out + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the GEGLU activation. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Activated tensor + """ + return super().__call__(x) diff --git a/src/mrpro/nn/GroupNorm.py b/src/mrpro/nn/GroupNorm.py new file mode 100644 index 000000000..e0bbb0a4b --- /dev/null +++ b/src/mrpro/nn/GroupNorm.py @@ -0,0 +1,71 @@ +"""GroupNorm with 32-bit precision.""" + +import torch + + +class GroupNorm(torch.nn.GroupNorm): + """A 32-bit GroupNorm with (optional) automatic group size selection. + + Casts to float32 before calling the parent class to avoid instabilities in mixed precision training. + + If `n_groups` is not provided, the number of groups is selected automatically as follows: + + - start from `1` group, + - try powers of two (`2, 4, 8, ...`), + - keep the largest candidate that divides `n_channels`, + - enforce at most `32` groups and at least `4` channels per group. + + This yields a stable default that stays close to common GroupNorm choices while + adapting to small channel counts. + """ + + features_last: bool + + def __init__(self, n_channels: int, n_groups: int | None = None, affine: bool = False, features_last: bool = False): + """Initialize GroupNorm. + + Parameters + ---------- + n_channels + The number of channels in the input tensor. + n_groups + The number of groups to use. If None, the number of groups is determined automatically as + the largest power of 2 that divides `n_channels`, is less than or equal to 32, + and leaves at least 4 channels per group. + affine + Whether to use learnable affine parameters. + features_last + Whether the features are last in the input tensor, as common in transformer models, + or in the second dimension, as common in CNNs. + """ + if n_groups is None: + groups_, candidate = 1, 2 + while (candidate <= min(32, n_channels // 4)) and (n_channels % candidate == 0): + groups_, candidate = candidate, groups_ * 2 + else: + groups_ = n_groups + self.features_last: bool = features_last + super().__init__(groups_, n_channels, affine=affine) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply GroupNorm32. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x.float()).type(x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply GroupNorm.""" + if self.features_last: + x = x.moveaxis(-1, 1) + result = super().forward(x.float()).type(x.dtype) + if self.features_last: + result = result.moveaxis(1, -1) + return result diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py new file mode 100644 index 000000000..4a7e5df1f --- /dev/null +++ b/src/mrpro/nn/LayerNorm.py @@ -0,0 +1,85 @@ +"""Layer normalization.""" + +import torch +from torch.nn import Linear, Module, Parameter + +from mrpro.nn.CondMixin import CondMixin +from mrpro.utils.reshape import unsqueeze_at, unsqueeze_right + + +class LayerNorm(CondMixin, Module): + """Layer normalization.""" + + def __init__(self, n_channels: int | None, features_last: bool = False, cond_dim: int = 0) -> None: + """Initialize the layer normalization. + + Parameters + ---------- + n_channels + Number of channels in the input tensor. If `None`, the layer normalization does not do an elementwise + affine transformation. + features_last + If `True`, the channel dimension is the last dimension. + cond_dim + Number of channels in the conditioning tensor. If `0`, no adaptive scaling is applied. + """ + super().__init__() + if n_channels is None and cond_dim == 0: + self.weight: Parameter | None = None + self.bias: Parameter | None = None + self.cond_proj: Linear | None = None + elif n_channels is None and cond_dim > 0: + raise ValueError('channels must be provided if cond_dim > 0') + elif n_channels is not None and cond_dim == 0: + self.weight = Parameter(torch.ones(n_channels)) + self.bias = Parameter(torch.zeros(n_channels)) + self.cond_proj = None + elif n_channels is not None: + self.weight = None + self.bias = None + self.cond_proj = Linear(cond_dim, 2 * n_channels) + else: + raise ValueError('cond_dim must be zero or positive.') + + self.features_last = features_last + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply layer normalization to the input tensor. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor. If `None`, no conditioning is applied. + + Returns + ------- + Normalized output tensor + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply layer normalization to the input tensor.""" + dim = -1 if self.features_last else 1 + dtype = x.dtype + x = x.float() + var, mean = torch.var_mean(x, dim=dim, unbiased=False, keepdim=True) + x = (x - mean) * (var + 1e-5).rsqrt() + x = x.to(dtype) + + if self.weight is not None and self.bias is not None: + if self.features_last: + x = x * self.weight + self.bias + else: + x = x * unsqueeze_right(self.weight, x.ndim - 2) + unsqueeze_right(self.bias, x.ndim - 2) + + if self.cond_proj is not None and cond is not None: + scale, shift = self.cond_proj(cond).chunk(2, dim=-1) + scale = 1 + scale + if self.features_last: + x = x * unsqueeze_at(scale, 1, x.ndim - 2) + unsqueeze_at(shift, 1, x.ndim - 2) + else: + x = x * unsqueeze_right(scale, x.ndim - 2) + unsqueeze_right(shift, x.ndim - 2) + + return x diff --git a/src/mrpro/nn/PermutedBlock.py b/src/mrpro/nn/PermutedBlock.py new file mode 100644 index 000000000..935d114dc --- /dev/null +++ b/src/mrpro/nn/PermutedBlock.py @@ -0,0 +1,75 @@ +"""Block that applies a submodule along selected spatial dimensions.""" + +from collections.abc import Sequence + +import torch +from torch import nn + +from mrpro.nn.CondMixin import CondMixin, call_with_cond + + +class PermutedBlock(CondMixin, nn.Module): + """Apply a submodule along selected spatial dimensions.""" + + apply_along_dim: tuple[int, ...] + module: nn.Module + + def __init__(self, apply_along_dim: Sequence[int], module: nn.Module, features_last: bool = False): + """Initialize the PermutedBlock. + + Parameters + ---------- + apply_along_dim + Spatial dimension indices to use when applying the module. + These will be moved to the last dimensions. + module + Module to apply on the selected dims. + features_last + If True, the features dimension is assumed to be the last dimension, as common in transformer models. + """ + super().__init__() + self.apply_along_dim = tuple(sorted(apply_along_dim)) + self.module = module + self.features_last = features_last + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module along the selected dimensions. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor, passed to the module if it supports conditioning + (that is, if it is a subclass of `~mrpro.nn.CondMixin`) + + Returns + ------- + Output tensor. + """ + return self.forward(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module along the selected dimensions.""" + keep = tuple(d % x.ndim for d in self.apply_along_dim) + if 0 in keep: + raise ValueError('Batch dimension should not be in apply_along_dim.') + if self.features_last: + if x.ndim - 1 in keep: + raise ValueError('Features dimension should not be in apply_along_dim.') + batch_dim = tuple(d for d in range(1, x.ndim - 1) if d not in keep) + permute = (0, *batch_dim, *keep, x.ndim - 1) + else: + if 1 in keep: + raise ValueError('Features dimension should not be in apply_along_dim.') + batch_dim = tuple(d for d in range(2, x.ndim) if d not in keep) + permute = (0, *batch_dim, 1, *keep) + h = x.permute(permute) + batch_shape = h.shape[: 1 + len(batch_dim)] + h = h.flatten(0, len(batch_dim)) + h = call_with_cond(self.module, h, cond=cond) + h = h.unflatten(0, batch_shape) + permute_back = [0] * x.ndim + for i, p in enumerate(permute): + permute_back[p] = i + return h.permute(tuple(permute_back)) diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py new file mode 100644 index 000000000..afcff9335 --- /dev/null +++ b/src/mrpro/nn/PixelShuffle.py @@ -0,0 +1,285 @@ +"""ND-version of PixelShuffle and PixelUnshuffle.""" + +from math import ceil + +import torch +from torch.nn import Linear, Module + +from mrpro.nn.ndmodules import convND + + +class PixelUnshuffle(Module): + """ND-version of PixelUnshuffle downscaling.""" + + def __init__(self, downscale_factor: int, features_last: bool = False): + """Initialize PixelUnshuffle. + + Reduces spatial dimensions and increases the channel number by reshaping. + The first dimension is considered a batch dimension, the second dimension + the channel dimension, and the remaining dimensions the spatial dimensions that are downscaled. + + See `mrpro.nn.PixelShuffle` for the inverse operation. + + Parameters + ---------- + downscale_factor + The factor by which to downscale the input tensor. + features_last + Whether the features/channels dimension is the last dimension as common in transformer models or the + second dimension as common in CNN models. + """ + super().__init__() + self.downscale_factor = downscale_factor + self.features_last = features_last + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Downscale the input. + + Parameters + ---------- + x + Tensor of shape `batch, channels, *spatial_dims` or `batch, *spatial_dims, channels` (if `features_last`). + + Returns + ------- + Tensor of shape `batch, channels * downscale_factor**dim, *spatial_dims/downscale_factor` or + `batch, *spatial_dims/downscale_factor, channels * downscale_factor**dim` (if `features_last`). + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Downscale the input.""" + n_dim = x.ndim - 2 + if n_dim == 2 and not self.features_last: # fast path for 2D images + return torch.nn.functional.pixel_unshuffle(x, self.downscale_factor) + + new_shape = list(x.shape[:1]) if self.features_last else list(x.shape[:2]) + source_positions = [] + for i, old in enumerate(x.shape[1:-1] if self.features_last else x.shape[2:]): + if old % self.downscale_factor: + raise ValueError('Spatial size must be divisible by downscale_factor.') + new_shape.append(old // self.downscale_factor) + new_shape.append(self.downscale_factor) + source_positions.append(2 + 2 * i) + if self.features_last: + new_shape.append(x.shape[-1]) + x = x.view(new_shape) + x = x.moveaxis(source_positions, tuple(range(-n_dim, 0))) + if self.features_last: + x = x.flatten(-n_dim - 1) + else: + x = x.flatten(1, -n_dim - 1) + + return x + + +class PixelUnshuffleDownsample(Module): + """PixelUnshuffle Downsampling. + + PixelUnshuffle followed by a convolution. Optionally uses a residual connection [DCAE]_ + + References + ---------- + .. [DCAE] Chen et al. Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models. ICLR 2025 + https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + downscale_factor: int = 2, + residual: bool = False, + features_last: bool = False, + ): + """Initialize a PixelUnshuffleDownsample layer. + + Parameters + ---------- + n_dim + Dimension of the input tensor. + n_channels_in + Number of channels in the input tensor. + n_channels_out + Number of channels in the output tensor. + downscale_factor + Factor by which to downscale the input tensor. + residual + Whether to use a residual connection as proposed in [DCAE]_. + features_last + Whether the features are last in the input tensor, as common in transformer models, + or in the second dimension, as common in CNNs. + """ + super().__init__() + out_ratio = downscale_factor**n_dim + if n_channels_out % out_ratio != 0: + raise ValueError(f'channels_out must be divisible by downscale_factor**{n_dim}.') + if features_last: + self.projection: Module = Linear(n_channels_in, n_channels_out // out_ratio) + else: + self.projection = convND(n_dim)(n_channels_in, n_channels_out // out_ratio, kernel_size=3, padding='same') + self.features_last = features_last + self.residual = residual + self.pixel_unshuffle = PixelUnshuffle(downscale_factor, features_last) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply downsampling. + + Parameters + ---------- + x + Tensor of shape `batch, channels_in, *spatial_dims` + + Returns + ------- + Tensor of shape `batch, channels_out, *spatial_dims/downscale_factor` + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply downsampling.""" + h = self.projection(x) + h = self.pixel_unshuffle(h) + + if self.residual: + x = self.pixel_unshuffle(x) + if self.features_last: + n = (x.shape[-1] // h.shape[-1]) * h.shape[-1] + h = h + x[..., :n].unflatten(-1, (h.shape[-1], -1)).mean(-1) + else: + n = (x.shape[1] // h.shape[1]) * h.shape[1] + h = h + x[:, :n].unflatten(1, (h.shape[1], -1)).mean(2) + return h + + +class PixelShuffleUpsample(Module): + """PixelShuffle Upsampling. + + Convolution followed by PixelShuffle. Optionally uses a residual connection [DCAE]_ + + References + ---------- + .. [DCAE] Chen et al. Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models. ICLR 2025 + https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + upscale_factor: int = 2, + residual: bool = False, + features_last: bool = False, + ): + """Initialize a PixelShuffleUpsample layer. + + Parameters + ---------- + n_dim + Dimension of the input tensor. + n_channels_in + Number of channels in the input tensor. + n_channels_out + Number of channels in the output tensor. + upscale_factor + Factor by which to upscale the input tensor. + residual + Whether to use a residual connection as proposed in [DCAE]_. + features_last + Whether the features are last in the input tensor, as common in transformer models, + or in the second dimension, as common in CNNs. + """ + super().__init__() + if features_last: + self.projection: Module = Linear(n_channels_in, n_channels_out * upscale_factor**n_dim) + else: + self.projection = convND(n_dim)( + n_channels_in, n_channels_out * upscale_factor**n_dim, kernel_size=3, padding='same' + ) + self.features_last = features_last + self.pixel_shuffle = PixelShuffle(upscale_factor, features_last) + self.residual = residual + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply upsampling. + + Parameters + ---------- + x + Tensor of shape `batch, channels_in, *spatial_dims` + + Returns + ------- + Tensor of shape `batch, channels_out, *spatial_dims * upscale_factor` + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply upsampling.""" + h = self.projection(x) + if self.residual: + if self.features_last: + h = h + x.repeat_interleave(ceil(h.shape[-1] / x.shape[-1]), dim=-1)[..., : h.shape[-1]] + else: + h = h + x.repeat_interleave(ceil(h.shape[1] / x.shape[1]), dim=1)[:, : h.shape[1]] + out = self.pixel_shuffle(h) + return out + + +class PixelShuffle(Module): + """ND-version of PixelShuffle upscaling.""" + + def __init__(self, upscale_factor: int, features_last: bool = False): + """Initialize PixelShuffle. + + Upscales spatial dimensions and decreases the channel number by reshaping. + The first dimension is considered a batch dimension, the second dimension + the channel dimension, and the remaining dimensions the spatial dimensions that are upscaled. + + See `mrpro.nn.PixelUnshuffle` for the inverse operation. + + Parameters + ---------- + upscale_factor + The factor by which to upscale the spatial dimensions. + features_last + Whether the features/channels dimension is the last dimension as common in transformer models or the + second dimension as common in CNN models. + """ + super().__init__() + self.upscale_factor = upscale_factor + self.features_last = features_last + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Upscale the input. + + Parameters + ---------- + x + Tensor of shape `batch, channels, *spatial_dims` or `batch, *spatial_dims, channels` (if `features_last`). + + Returns + ------- + Tensor of shape `batch, channels / upscale_factor**n_dim, *spatial_dims * upscale_factor` or + `batch, *spatial_dims * upscale_factor, channels / upscale_factor**n_dim` (if `features_last`). + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Upscale the input.""" + n_dim = x.ndim - 2 + if n_dim == 2 and not self.features_last: # fast path for 2D + return torch.nn.functional.pixel_shuffle(x, self.upscale_factor) + + if self.features_last: + new_shape = (x.shape[0], *(old * self.upscale_factor for old in x.shape[-n_dim - 1 : -1]), -1) + x = x.unflatten(-1, (-1, *(self.upscale_factor,) * n_dim)) + x = x.moveaxis(tuple(range(-n_dim, 0)), tuple(range(-2 * n_dim, 0, 2))) + else: + new_shape = (x.shape[0], -1, *(old * self.upscale_factor for old in x.shape[-n_dim:])) + x = x.unflatten(1, (-1, *(self.upscale_factor,) * n_dim)) + x = x.moveaxis(tuple(range(2, 2 + n_dim)), tuple(range(-2 * n_dim + 1, 0, 2))) + x = x.reshape(new_shape) + return x diff --git a/src/mrpro/nn/RMSNorm.py b/src/mrpro/nn/RMSNorm.py new file mode 100644 index 000000000..b97641545 --- /dev/null +++ b/src/mrpro/nn/RMSNorm.py @@ -0,0 +1,73 @@ +"""RMSNorm over the channel dimension.""" + +import torch +from torch.nn import Module, Parameter + + +class RMSNorm(Module): + """RMSNorm over the channel dimension. + + As used in the DCAE [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_channels: int | None = None, + eps: float = 1e-8, + features_last: bool = False, + ): + """Initialize RMSNorm. + + Includes a learnable weight and bias if n_channels is provided. + + Parameters + ---------- + n_channels + Number of channels. If `None`, no learnable weight and bias are included. + eps + Epsilon value to avoid division by zero. + features_last + If True, the channel dimension is the last dimension. + """ + super().__init__() + if n_channels is not None: + self.weight: Parameter | None = Parameter(torch.zeros(n_channels)) + self.bias: Parameter | None = Parameter(torch.zeros(n_channels)) + else: + self.weight = None + self.bias = None + self.eps = eps + self.channel_dim = -1 if features_last else 1 + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply RMSNorm over the channel dimension. + + Parameters + ---------- + x + Input tensor. + + Returns + ------- + Normalized tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply RMSNorm over the channel dimension.""" + x32 = x.to(torch.float32) # normalization in float32 to stabilize mixed precision training + mean_square = x32.square().mean(dim=self.channel_dim, keepdim=True) + scale = (mean_square + self.eps).rsqrt() + x32 = x32 * scale + if self.weight is not None and self.bias is not None: + shape = [1] * x.ndim + shape[self.channel_dim] = -1 + weight = (self.weight.to(x32.dtype) + 1).view(shape) + bias = self.bias.view(shape) + x32 = x32 * weight + bias + return x32.to(x.dtype) diff --git a/src/mrpro/nn/Residual.py b/src/mrpro/nn/Residual.py new file mode 100644 index 000000000..e524fe169 --- /dev/null +++ b/src/mrpro/nn/Residual.py @@ -0,0 +1,45 @@ +"""Residual connection.""" + +import torch +from torch.nn import Identity, Module + +from mrpro.nn.CondMixin import CondMixin, call_with_cond + + +class Residual(CondMixin, Module): + """Residual connection.""" + + def __init__(self, module: Module, skip: Module | None = None): + """Initialize the residual connection. + + Parameters + ---------- + module + The main path of the residual connection. + skip + The skip path of the residual connection. If None, the identity function is used. + """ + super().__init__() + self.module = module + self.skip = Identity() if skip is None else skip + + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module. + + Parameters + ---------- + x + The input tensor. + cond + The optional conditioning tensor. If the modules are an instance of `CondMixin`, + the conditioning is passed to the modules. + + Returns + ------- + The output tensor. + """ + return super().__call__(*x, cond=cond) + + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module.""" + return call_with_cond(self.module, *x, cond=cond) + call_with_cond(self.skip, *x, cond=cond) diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py new file mode 100644 index 000000000..fed22c3a2 --- /dev/null +++ b/src/mrpro/nn/Sequential.py @@ -0,0 +1,60 @@ +"""Sequential container with support for conditioning and Operators.""" + +from collections import OrderedDict +from typing import cast + +import torch + +from mrpro.nn.CondMixin import CondMixin +from mrpro.operators import Operator + + +class Sequential(CondMixin, torch.nn.Sequential): + """Sequential container with support for conditioning and Operators. + + Allows multiple input tensors and a single output tensor of the sequential block. + + """ + + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply all modules in series to the input. + + Parameters + ---------- + x + The input tensor. + cond + The (optional) conditioning tensor. + + Returns + ------- + The output tensor. + """ + return torch.nn.Sequential.__call__(self, *x, cond=cond) + + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply all modules in series to the input.""" + for module in self: + if isinstance(module, Operator): + x = cast(tuple[torch.Tensor, ...], module(*x)) # always tuple + else: + ret: torch.Tensor | tuple[torch.Tensor, ...] + if isinstance(module, CondMixin): + ret = module(*x, cond=cond) + else: + ret = module(*x) + if isinstance(ret, tuple): + x = ret + else: + x = (ret,) + return x[0] + + def __getitem__(self, idx: slice | int) -> 'Sequential': + """Get a slice or item from the Sequential container. + + Subclasses will decompose to `Sequential` on indexing. + """ + if isinstance(idx, slice): + return Sequential(OrderedDict(list(self._modules.items())[idx])) + else: + return cast(Sequential, self._get_item_by_idx(self._modules.values(), idx)) diff --git a/src/mrpro/nn/Upsample.py b/src/mrpro/nn/Upsample.py new file mode 100644 index 000000000..acced8d48 --- /dev/null +++ b/src/mrpro/nn/Upsample.py @@ -0,0 +1,66 @@ +"""Upsampling by interpolation.""" + +from collections.abc import Sequence +from typing import Literal + +import torch +from torch.nn import Module, Sequential + +from mrpro.nn.PermutedBlock import PermutedBlock + + +class Upsample(Module): + """Upsampling by interpolation.""" + + def __init__( + self, dim: Sequence[int], scale_factor: int = 2, mode: Literal['nearest', 'linear', 'cubic'] = 'linear' + ): + """Initialize the upsampling layer. + + Parameters + ---------- + dim + Dimensions which to upsample + scale_factor + Factor by which to upsample + mode + Interpolation mode. See `torch.nn.functional.interpolate` for details. + """ + super().__init__() + self.scale_factor = scale_factor + if mode == 'nearest': + dims = [d.tolist() for d in torch.tensor(dim).split(3)] + modes = ['nearest'] * len(dim) + elif mode == 'linear': + dims = [d.tolist() for d in torch.tensor(dim).split(3)] + modes = [{1: 'linear', 2: 'bilinear', 3: 'trilinear'}[len(d)] for d in dims] + elif mode == 'cubic': + if not len(dim) == 2: + raise ValueError('Cubic interpolation is only supported for 2D images.') + dims = [tuple(dim)] + modes = ['bicubic'] + + self.blocks = Sequential( + *[ + PermutedBlock(d, torch.nn.Upsample(scale_factor=len(d) * (scale_factor,), mode=m)) + for d, m in zip(dims, modes, strict=False) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Upsample the input tensor.""" + return self.blocks(x) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Upsample the input tensor. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Upsampled tensor + """ + return super().__call__(x) diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py new file mode 100644 index 000000000..d6541f5c8 --- /dev/null +++ b/src/mrpro/nn/__init__.py @@ -0,0 +1,45 @@ +"""Neural network modules and utilities.""" + +from mrpro.nn.ComplexAsChannel import ComplexAsChannel +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.DropPath import DropPath +from mrpro.nn.FiLM import FiLM +from mrpro.nn.FourierFeatures import FourierFeatures +from mrpro.nn.GEGLU import GEGLU +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.RMSNorm import RMSNorm +from mrpro.nn.Residual import Residual +from mrpro.nn.Sequential import Sequential +from mrpro.nn.ndmodules import ( + adaptiveAvgPoolND, + avgPoolND, + batchNormND, + convND, + convTransposeND, + instanceNormND, + maxPoolND, +) + +__all__ = [ + 'ComplexAsChannel', + 'CondMixin', + 'DropPath', + 'FiLM', + 'FourierFeatures', + 'GEGLU', + 'GroupNorm', + 'LayerNorm', + 'PermutedBlock', + 'RMSNorm', + 'Residual', + 'Sequential', + 'adaptiveAvgPoolND', + 'avgPoolND', + 'batchNormND', + 'convND', + 'convTransposeND', + 'instanceNormND', + 'maxPoolND', +] diff --git a/src/mrpro/nn/convert_linear_conv.py b/src/mrpro/nn/convert_linear_conv.py new file mode 100644 index 000000000..767a419ff --- /dev/null +++ b/src/mrpro/nn/convert_linear_conv.py @@ -0,0 +1,100 @@ +"""Convert Linear layers to kernel size 1 ConvNd layers and vice versa.""" + +from typing import Literal, overload + +import torch +from torch.nn import Conv1d, Conv2d, Conv3d, Linear + +from mrpro.nn.ndmodules import convND + + +@overload +def linear_to_conv(linear_layer: Linear, n_dim: Literal[1]) -> Conv1d: ... + + +@overload +def linear_to_conv(linear_layer: Linear, n_dim: Literal[2]) -> Conv2d: ... + + +@overload +def linear_to_conv(linear_layer: Linear, n_dim: Literal[3]) -> Conv3d: ... + + +@overload +def linear_to_conv(linear_layer: Linear, n_dim: int) -> Conv1d | Conv2d | Conv3d: ... + + +def linear_to_conv(linear_layer: Linear, n_dim: int) -> Conv1d | Conv2d | Conv3d: + """Convert a Linear layer to a ConvNd layer with kernel size 1. + + Rearranging the spatial dimensions to the batch dimension, + applying the linear layer and rearranging the spatial dimensions back + is equivalent to applying a kernel size 1 ConvNd layer. + + This function will create the Conv1d, Conv2d, or Conv3d with the correct weights and bias. + + See :func:`conv_to_linear` for the reverse operation. + + + + Parameters + ---------- + linear_layer + The linear layer to convert. + n_dim + The convolution dimension (1, 2, or 3). + + Returns + ------- + A Conv layer with equivalent weights and bias. + """ + conv = convND(n_dim)( + in_channels=linear_layer.in_features, + out_channels=linear_layer.out_features, + kernel_size=1, + bias=linear_layer.bias is not None, + device=linear_layer.weight.device, + dtype=linear_layer.weight.dtype, + ) + + with torch.no_grad(): + conv.weight.copy_(linear_layer.weight.view_as(conv.weight)) + if conv.bias is not None and linear_layer.bias is not None: + conv.bias.copy_(linear_layer.bias) + + return conv + + +def conv_to_linear(conv_layer: Conv1d | Conv2d | Conv3d) -> Linear: + """ + Convert a Conv1d, Conv2d, or Conv3d layer with kernel size 1 to a Linear layer. + + Applying a kernel size 1 ConvNd layer is equivalent to applying a Linear layer to each voxel. + This function will create the Linear layer with the correct weights and bias. + + See :func:`linear_to_conv` for the reverse operation. + + Parameters + ---------- + conv_layer : nn.Module + The convolutional layer to convert. Must have kernel size 1. + + Returns + ------- + A linear layer with equivalent weights and bias. + """ + if not all(k == 1 for k in conv_layer.kernel_size): + raise ValueError('Kernel size must be 1 for conversion.') + linear = Linear( + conv_layer.in_channels, + conv_layer.out_channels, + bias=conv_layer.bias is not None, + device=conv_layer.weight.device, + dtype=conv_layer.weight.dtype, + ) + with torch.no_grad(): + linear.weight.copy_(conv_layer.weight.view_as(linear.weight)) + if linear.bias is not None and conv_layer.bias is not None: + linear.bias.copy_(conv_layer.bias) + + return linear diff --git a/src/mrpro/nn/join.py b/src/mrpro/nn/join.py new file mode 100644 index 000000000..d98aeb7b2 --- /dev/null +++ b/src/mrpro/nn/join.py @@ -0,0 +1,176 @@ +"""Modules for concatenating or adding tensors.""" + +from collections.abc import Sequence +from typing import Literal + +import torch +from torch.nn import Module + +from mrpro.utils.interpolate import interpolate +from mrpro.utils.pad_or_crop import pad_or_crop + + +def _fix_shapes( + xs: Sequence[torch.Tensor], + mode: str, + dim: Sequence[int], +) -> tuple[torch.Tensor, ...]: + """Fix shapes of input tensors by padding or cropping.""" + if mode == 'fail': + return tuple(xs) + + shapes = [[x.shape[d] for d in dim] for x in xs] + if mode == 'crop': # smallest as target + target = tuple(min(s) for s in zip(*shapes, strict=True)) + else: # largest as target + target = tuple(max(s) for s in zip(*shapes, strict=True)) + if mode == 'linear' or mode == 'nearest': + return tuple(interpolate(x, target, dim=dim, mode=mode) for x in xs) # type: ignore[arg-type] + if mode == 'zero' or mode == 'crop': + return tuple(pad_or_crop(x, target, dim=dim, mode='constant', value=0.0) for x in xs) + else: + return tuple(pad_or_crop(x, target, dim=dim, mode=mode) for x in xs) # type: ignore[arg-type] + + +class Concat(Module): + """Concatenate tensors along the channel dimension.""" + + def __init__( + self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'] = 'fail', dim: int = 1 + ) -> None: + """Initialize Concat. + + Parameters + ---------- + mode + How to handle mismatched dimensions: + - 'fail': do not align, raise error if shapes mismatch + - 'crop': center-crop to smallest spatial size + - 'zero': zero-pad to largest spatial size + - 'replicate': pad by edge value replication + - 'circular': circular padding + - 'linear': linear interpolation to largest spatial size + - 'nearest': nearest neighbor interpolation to largest spatial size + dim + Dimension along which to concatenate. + """ + super().__init__() + modes = {'fail', 'crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'} + if mode not in modes: + raise ValueError(f'mode must be one of {modes}') + self.mode = mode + self.dim = dim + + def forward(self, *xs: torch.Tensor) -> torch.Tensor: + """Concatenate input tensors.""" + xs = _fix_shapes(xs, self.mode, dim=[i for i in range(max(x.ndim for x in xs)) if i != self.dim]) + return torch.cat(xs, dim=self.dim) + + def __call__(self, *xs: torch.Tensor) -> torch.Tensor: + """ + Concatenate input tensors. + + Parameters + ---------- + xs + Input tensors + + Returns + ------- + Concatenated tensor + """ + return super().__call__(*xs) + + +class Add(Module): + """Add tensors.""" + + def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'] = 'fail') -> None: + """Initialize Add. + + Parameters + ---------- + mode + How to handle mismatched dimensions: + - 'fail': do not align, raise error if shapes mismatch + - 'crop': center-crop to smallest spatial size + - 'zero': zero-pad to largest spatial size + - 'replicate': pad by edge value replication + - 'circular': circular padding + """ + super().__init__() + modes = {'fail', 'crop', 'zero', 'replicate', 'circular'} + if mode not in modes: + raise ValueError(f'mode must be one of {modes}') + self.mode = mode + + def forward(self, *xs: torch.Tensor) -> torch.Tensor: + """Add input tensors.""" + xs = _fix_shapes(xs, self.mode, dim=range(max(x.ndim for x in xs))) + return sum(xs, start=torch.tensor(0.0)) + + def __call__(self, *xs: torch.Tensor) -> torch.Tensor: + """ + Add input tensors. + + Parameters + ---------- + xs + Input tensors + + Returns + ------- + Summed tensor + """ + return super().__call__(*xs) + + +class Interpolate(Module): + """Linear interpolate between two tensors. + + As suggestions for the Hourglass Transformer [CR]_ + + References + ---------- + .. [CK] Crowson, Katherine, et al. "Scalable high-resolution pixel-space image synthesis with + hourglass diffusion transformers." ICML 2024, https://arxiv.org/abs/2401.11605 + """ + + def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'] = 'fail') -> None: + """Initialize learned linear interpolation. + + Parameters + ---------- + mode + How to handle mismatched dimensions: + - 'fail': do not align, raise error if shapes mismatch + - 'crop': center-crop to smallest spatial size + - 'zero': zero-pad to largest spatial size + - 'replicate': pad by edge value replication + - 'circular': circular padding + """ + super().__init__() + modes = {'fail', 'crop', 'zero', 'replicate', 'circular'} + if mode not in modes: + raise ValueError(f'mode must be one of {modes}') + self.mode = mode + self.weight = torch.nn.Parameter(torch.tensor(0.5)) + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """Linear interpolate between two tensors.""" + x1, x2 = _fix_shapes((x1, x2), self.mode, dim=range(max(x.ndim for x in (x1, x2)))) + return x1 * self.weight + x2 * (1 - self.weight) + + def __call__(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """Linear interpolate between two tensors. + + Parameters + ---------- + x1, x2 + Input tensors + + Returns + ------- + Interpolated tensor + """ + return super().__call__(x1, x2) diff --git a/src/mrpro/nn/ndmodules.py b/src/mrpro/nn/ndmodules.py new file mode 100644 index 000000000..3fb2d894c --- /dev/null +++ b/src/mrpro/nn/ndmodules.py @@ -0,0 +1,178 @@ +"""Helper functions to get the correct N-dimensional module.""" + +import torch + + +def convND(n_dim: int) -> type[torch.nn.Conv1d] | type[torch.nn.Conv2d] | type[torch.nn.Conv3d]: # noqa: N802 + """Get the `n_dim`-dimensional convolution class. + + Parameters + ---------- + n_dim + The dimension of the convolution. + + Returns + ------- + The convolution class. + """ + match n_dim: + case 1: + return torch.nn.Conv1d + case 2: + return torch.nn.Conv2d + case 3: + return torch.nn.Conv3d + case _: + raise NotImplementedError(f'ConvND for dim {n_dim} not implemented. Raise an issue if you need this.') + + +def convTransposeND( # noqa: N802 + n_dim: int, +) -> type[torch.nn.ConvTranspose1d] | type[torch.nn.ConvTranspose2d] | type[torch.nn.ConvTranspose3d]: + """Get the `n_dim`-dimensional transposed convolution class. + + Parameters + ---------- + n_dim + The dimension of the transposed convolution. + + Returns + ------- + The transposed convolution class. + """ + match n_dim: + case 1: + return torch.nn.ConvTranspose1d + case 2: + return torch.nn.ConvTranspose2d + case 3: + return torch.nn.ConvTranspose3d + case _: + raise NotImplementedError( + f'ConvTransposeND for dim {n_dim} not implemented. Raise an issue if you need this.' + ) + + +def maxPoolND(n_dim: int) -> type[torch.nn.MaxPool1d] | type[torch.nn.MaxPool2d] | type[torch.nn.MaxPool3d]: # noqa: N802 + """Get the `n_dim`-dimensional max pooling class. + + Parameters + ---------- + n_dim + The dimension of the max pooling. + + Returns + ------- + The max pooling class. + """ + match n_dim: + case 1: + return torch.nn.MaxPool1d + case 2: + return torch.nn.MaxPool2d + case 3: + return torch.nn.MaxPool3d + case _: + raise NotImplementedError(f'MaxPoolNd for dim {n_dim} not implemented. Raise an issue if you need this.') + + +def avgPoolND(n_dim: int) -> type[torch.nn.AvgPool1d] | type[torch.nn.AvgPool2d] | type[torch.nn.AvgPool3d]: # noqa: N802 + """Get the `n_dim`-dimensional average pooling class. + + Parameters + ---------- + n_dim + The dimension of the average pooling. + + Returns + ------- + The average pooling class. + """ + match n_dim: + case 1: + return torch.nn.AvgPool1d + case 2: + return torch.nn.AvgPool2d + case 3: + return torch.nn.AvgPool3d + case _: + raise NotImplementedError(f'AvgPoolNd for dim {n_dim} not implemented. Raise an issue if you need this.') + + +def adaptiveAvgPoolND( # noqa: N802 + n_dim: int, +) -> type[torch.nn.AdaptiveAvgPool1d] | type[torch.nn.AdaptiveAvgPool2d] | type[torch.nn.AdaptiveAvgPool3d]: + """Get the `n_dim`-dimensional adaptive average pooling class. + + Parameters + ---------- + n_dim + The dimension of the adaptive average pooling. + + Returns + ------- + The adaptive average pooling class. + """ + match n_dim: + case 1: + return torch.nn.AdaptiveAvgPool1d + case 2: + return torch.nn.AdaptiveAvgPool2d + case 3: + return torch.nn.AdaptiveAvgPool3d + case _: + raise NotImplementedError( + f'AdaptiveAvgPoolNd for dim {n_dim} not implemented. Raise an issue if you need this.' + ) + + +def instanceNormND( # noqa: N802 + n_dim: int, +) -> type[torch.nn.InstanceNorm1d] | type[torch.nn.InstanceNorm2d] | type[torch.nn.InstanceNorm3d]: + """Get the `n_dim`-dimensional instance normalization class. + + Parameters + ---------- + n_dim + The dimension of the instance normalization. + + Returns + ------- + The instance normalization class. + """ + match n_dim: + case 1: + return torch.nn.InstanceNorm1d + case 2: + return torch.nn.InstanceNorm2d + case 3: + return torch.nn.InstanceNorm3d + case _: + raise NotImplementedError( + f'InstanceNormNd for dim {n_dim} not implemented. Raise an issue if you need this.' + ) + + +def batchNormND( # noqa: N802 + n_dim: int, +) -> type[torch.nn.BatchNorm1d] | type[torch.nn.BatchNorm2d] | type[torch.nn.BatchNorm3d]: + """Get the `n_dim`-dimensional batch normalization class. + + Parameters + ---------- + n_dim + The dimension of the batch normalization. + + Returns + ------- + The batch normalization class. + """ + match n_dim: + case 1: + return torch.nn.BatchNorm1d + case 2: + return torch.nn.BatchNorm2d + case 3: + return torch.nn.BatchNorm3d + case _: + raise NotImplementedError(f'BatchNormNd for dim {n_dim} not implemented. Raise an issue if you need this.') diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index 345883e12..2d4eceb2f 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -15,8 +15,10 @@ from mrpro.utils.TensorAttributeMixin import TensorAttributeMixin from mrpro.utils.interpolate import interpolate, apply_lowres from mrpro.utils.RandomGenerator import RandomGenerator - +from mrpro.utils.to_tuple import to_tuple +from mrpro.utils.ema import EMADict __all__ = [ + "EMADict", "Indexer", "RandomGenerator", "TensorAttributeMixin", @@ -38,6 +40,7 @@ "split_idx", "summarize_object", "summarize_values", + "to_tuple", "typing", "unit_conversion", "unsqueeze_at", diff --git a/src/mrpro/utils/to_tuple.py b/src/mrpro/utils/to_tuple.py new file mode 100644 index 000000000..657d7bf56 --- /dev/null +++ b/src/mrpro/utils/to_tuple.py @@ -0,0 +1,36 @@ +"""Standardize an argument to a fixed-length tuple.""" + +from collections.abc import Sequence +from typing import TypeVar + +T = TypeVar('T') + + +def to_tuple(length: int, arg: Sequence[T] | T) -> tuple[T, ...]: + """Standardize an argument to a fixed-length tuple. + + If the argument is a sequence, it checks if its length matches the + specified dimension. If it's a single value, it replicates it `dim` times. + + Parameters + ---------- + length + The expected length of the sequence. + arg + The argument to check. Can be a single value of type T or a + sequence of T. + + Returns + ------- + A tuple of length `dim` containing elements of type T. + + Raises + ------ + ValueError + If `arg` is a sequence and its length does not match `length`. + """ + if isinstance(arg, Sequence): + if not len(arg) == length: + raise ValueError(f'The arguments must be either single values or have length {length}. Got {arg}.') + return tuple(arg) + return (arg,) * length diff --git a/tests/nn/test_complexaschannel.py b/tests/nn/test_complexaschannel.py new file mode 100644 index 000000000..37889f654 --- /dev/null +++ b/tests/nn/test_complexaschannel.py @@ -0,0 +1,30 @@ +"""Tests for ComplexAsChannel module.""" + +import pytest +from mrpro.nn.ComplexAsChannel import ComplexAsChannel +from mrpro.utils import RandomGenerator +from torch.nn import Linear + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_complexaschannel(device: str) -> None: + """Test ComplexAsChannel output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + input_shape = (1, 32) + x = rng.complex64_tensor(input_shape).to(device).requires_grad_(True) + module = ComplexAsChannel(Linear(input_shape[1] * 2, input_shape[1] * 2)).to(device) + output = module(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + assert output.is_complex(), 'Output is not complex' + output.sum().abs().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert module.module.weight.grad is not None, 'No gradient computed for weight' + assert module.module.bias.grad is not None, 'No gradient computed for bias' diff --git a/tests/nn/test_convert_linear_conv.py b/tests/nn/test_convert_linear_conv.py new file mode 100644 index 000000000..19438b9d9 --- /dev/null +++ b/tests/nn/test_convert_linear_conv.py @@ -0,0 +1,150 @@ +"""Tests for converting between Linear and Conv layers.""" + +from typing import Literal + +import pytest +import torch +from mrpro.nn.convert_linear_conv import conv_to_linear, linear_to_conv +from mrpro.utils import RandomGenerator +from torch.nn import Conv1d, Conv2d, Conv3d, Linear + +DEVICES = pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +SHAPES = pytest.mark.parametrize( + ('dim', 'channels_in', 'channels_out', 'bias'), + [ + (1, 32, 64, True), + (2, 16, 32, True), + (3, 8, 16, True), + (3, 1, 1, False), + ], + ids=['1d', '2d', '3d', '3d_no_bias'], +) + + +@SHAPES +@DEVICES +def test_linear_to_conv(device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test converting Linear to Conv layer.""" + rng = RandomGenerator(seed=42) + linear = Linear(channels_in, channels_out, bias=bias).to(device) + linear.weight.data = rng.rand_like(linear.weight) + if bias: + linear.bias.data = rng.rand_like(linear.bias) + + conv = linear_to_conv(linear, dim) + assert isinstance(conv, (Conv1d, Conv2d, Conv3d)[dim - 1]) + + assert conv.in_channels == channels_in + assert conv.out_channels == channels_out + assert conv.kernel_size == (1,) * dim + assert conv.bias is not None if bias else conv.bias is None + + assert conv.weight.device.type == device + if conv.bias is not None: + assert conv.bias.device.type == device + + +@SHAPES +def test_linear_to_conv_functional(dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test functional equivalence of Linear to Conv conversion.""" + rng = RandomGenerator(seed=42) + linear = Linear(channels_in, channels_out, bias=bias) + linear.weight.data = rng.rand_like(linear.weight) + if bias: + linear.bias.data = rng.rand_like(linear.bias) + + conv = linear_to_conv(linear, dim) + spatial_shape = (4,) * dim + x = rng.randn_tensor((2, channels_in, *spatial_shape), torch.float32) + + y_conv = conv(x) + y_conv = y_conv.moveaxis(1, -1).flatten(0, -2) + + x_reshaped = x.moveaxis(1, -1).flatten(0, -2) + y_linear = linear(x_reshaped) + + torch.testing.assert_close(y_conv, y_linear) + + +@SHAPES +@DEVICES +def test_conv_to_linear(device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test converting Conv layer to Linear.""" + rng = RandomGenerator(seed=42) + conv_class = (Conv1d, Conv2d, Conv3d)[dim - 1] + conv = conv_class(channels_in, channels_out, kernel_size=1, bias=bias).to(device) + conv.weight.data = rng.rand_like(conv.weight) + if conv.bias is not None: + conv.bias.data = rng.rand_like(conv.bias) + + linear = conv_to_linear(conv) + + assert isinstance(linear, Linear) + assert linear.in_features == channels_in + assert linear.out_features == channels_out + assert linear.bias is not None if bias else linear.bias is None + + assert linear.weight.device.type == device + if bias: + assert linear.bias.device.type == device + + +@SHAPES +def test_conv_to_linear_functional(dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test functional equivalence of Conv to Linear conversion.""" + rng = RandomGenerator(seed=42) + conv_class = (Conv1d, Conv2d, Conv3d)[dim - 1] + conv = conv_class(channels_in, channels_out, kernel_size=1, bias=bias) + conv.weight.data = rng.rand_like(conv.weight) + if conv.bias is not None: + conv.bias.data = rng.rand_like(conv.bias) + + linear = conv_to_linear(conv) + spatial_shape = (4,) * dim + + x = rng.randn_tensor((2, channels_in, *spatial_shape), torch.float32) + y_conv = conv(x) + y_conv = y_conv.moveaxis(1, -1).flatten(0, -2) + + x_reshaped = x.moveaxis(1, -1).flatten(0, -2) + y_linear = linear(x_reshaped) + + torch.testing.assert_close(y_conv, y_linear) + + +def test_conv_to_linear_invalid_kernel() -> None: + """Test conv_to_linear with invalid kernel size.""" + conv = Conv2d(32, 64, kernel_size=3, bias=True) + with pytest.raises(ValueError, match='Kernel size must be 1'): + conv_to_linear(conv) + + +@SHAPES +@DEVICES +def test_round_trip_conversion( + device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool +) -> None: + """Test round-trip conversion between Linear and Conv layers.""" + rng = RandomGenerator(seed=42) + + linear1 = Linear(channels_in, channels_out, bias=bias).to(device) + linear1.weight.data = rng.rand_like(linear1.weight) + if bias: + linear1.bias.data = rng.rand_like(linear1.bias) + + conv = linear_to_conv(linear1, dim) + linear2 = conv_to_linear(conv) + + assert linear2.in_features == channels_in + assert linear2.out_features == channels_out + assert linear2.bias is not None if bias else linear2.bias is None + + torch.testing.assert_close(linear2.weight, linear1.weight) + if bias: + torch.testing.assert_close(linear2.bias, linear1.bias) diff --git a/tests/nn/test_droppath.py b/tests/nn/test_droppath.py new file mode 100644 index 000000000..b4ac7f5d7 --- /dev/null +++ b/tests/nn/test_droppath.py @@ -0,0 +1,30 @@ +"""Test DropPath.""" + +import pytest +from mrpro.nn.DropPath import DropPath +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_droppath_no_drop(device: str) -> None: + """Test DropPath with zero drop rate (should pass through unchanged).""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).to(device) + droppath = DropPath(0).to(device) + y = droppath(x) + assert (y == x).all() + + +def test_droppath_drop_all() -> None: + """Test DropPath with full drop rate (should output zeros).""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)) + droppath = DropPath(1.0) + y = droppath(x) + assert (y == 0).all() diff --git a/tests/nn/test_film.py b/tests/nn/test_film.py new file mode 100644 index 000000000..bf9940dfc --- /dev/null +++ b/tests/nn/test_film.py @@ -0,0 +1,58 @@ +"""Tests for FiLM module.""" + +from collections.abc import Sequence + +import pytest +import torch +from mrpro.nn.FiLM import FiLM +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels', 'n_channels_cond', 'input_shape', 'cond_shape'), + [ + (64, 32, (1, 64, 32, 32), (1, 32)), + (32, 16, (2, 32, 16, 16), (2, 16)), + ], +) +def test_film( + n_channels: int, n_channels_cond: int, input_shape: Sequence[int], cond_shape: Sequence[int], device: str +) -> None: + """Test FiLM output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) + film = FiLM(channels=n_channels, cond_dim=n_channels_cond).to(device) + output = film(x, cond=cond) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert cond.grad is not None, 'No gradient computed for conditioning' + assert not output.isnan().any(), 'NaN values in output' + assert not cond.isnan().any(), 'NaN values in conditioning' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' + assert film.project is not None, 'Linear layer is not initialized' + assert next(film.project.parameters()).grad is not None, 'No gradient computed for Linear layer' + + +def test_film_features_last() -> None: + """Test FiLM with features_last=True vs features_last=False.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)) + cond = rng.float32_tensor((1, 8)) + + film_last = FiLM(channels=3, cond_dim=8, features_last=True) + film = FiLM(channels=3, cond_dim=8, features_last=False) + film.load_state_dict(film_last.state_dict()) + + y_last = film_last(x.moveaxis(1, -1), cond=cond) + y = film(x, cond=cond) + torch.testing.assert_close(y, y_last.moveaxis(-1, 1)) diff --git a/tests/nn/test_fourierfeatures.py b/tests/nn/test_fourierfeatures.py new file mode 100644 index 000000000..9452a369f --- /dev/null +++ b/tests/nn/test_fourierfeatures.py @@ -0,0 +1,24 @@ +"""Test for random fourier features""" + +import pytest +from mrpro.nn import FourierFeatures +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_fourierfeatures(device: str) -> None: + """Test FourierFeatures.""" + n_features_in = 1 + n_features_out = 16 + std = 1.0 + rng = RandomGenerator(444) + x = rng.float32_tensor((1, n_features_in)).to(device) + ff = FourierFeatures(n_features_in, n_features_out, std).to(device) + y = ff(x) + assert y.shape == (1, n_features_out) diff --git a/tests/nn/test_geglu.py b/tests/nn/test_geglu.py new file mode 100644 index 000000000..061837e51 --- /dev/null +++ b/tests/nn/test_geglu.py @@ -0,0 +1,38 @@ +"""Test GEGLU.""" + +import pytest +import torch +from mrpro.nn.GEGLU import GEGLU +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_geglu(device: str) -> None: + """Test GEGLU output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).to(device).requires_grad_(True) + gelu = GEGLU(3, 4).to(device) + y = gelu(x) + assert y.shape == (1, 4, 4, 5) + + y.sum().backward() + assert x.grad is not None + assert gelu.proj.weight.grad is not None + + +def test_geglu_features_last() -> None: + """Test GEGLU with features_last=True vs features_last=False.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + gelu_last = GEGLU(3, 4, features_last=True) + gelu = GEGLU(3, 4, features_last=False) + gelu.proj = gelu_last.proj # need to set the same weights + y_last = gelu_last(x.moveaxis(1, -1)) + y = gelu(x) + torch.testing.assert_close(y, y_last.moveaxis(-1, 1)) diff --git a/tests/nn/test_groupnorm.py b/tests/nn/test_groupnorm.py new file mode 100644 index 000000000..044de17f7 --- /dev/null +++ b/tests/nn/test_groupnorm.py @@ -0,0 +1,45 @@ +"""Tests for GroupNorm module.""" + +from collections.abc import Sequence + +import pytest +from mrpro.nn import GroupNorm +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels', 'n_groups', 'input_shape', 'affine'), + [ + (32, None, (1, 32, 32, 32), True), + (64, 8, (2, 64, 16, 16, 16), False), + ], +) +def test_groupnorm( + n_channels: int, + n_groups: int | None, + input_shape: Sequence[int], + device: str, + affine: bool, +) -> None: + """Test GroupNorm output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + norm = GroupNorm(n_channels=n_channels, n_groups=n_groups, affine=affine).to(device) + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + if affine: + assert norm.weight is not None, 'Weight should not be None when affine is True' + assert norm.weight.grad is not None, 'No gradient computed for weight' + assert norm.bias is not None, 'Bias should not be None when affine is True' + assert norm.bias.grad is not None, 'No gradient computed for bias' diff --git a/tests/nn/test_join.py b/tests/nn/test_join.py new file mode 100644 index 000000000..f86647ac4 --- /dev/null +++ b/tests/nn/test_join.py @@ -0,0 +1,160 @@ +"""Tests for join modules.""" + +from typing import Literal + +import pytest +import torch +from mrpro.nn.join import Add, Concat +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('mode', 'input_shapes', 'expected_shape'), + [ + ('crop', [(1, 3, 32, 32), (1, 5, 30, 30)], (1, 8, 30, 30)), + ('zero', [(1, 3, 32, 32), (1, 5, 34, 34)], (1, 8, 34, 34)), + ('linear', [(1, 3, 32, 32), (1, 5, 34, 34)], (1, 8, 34, 34)), + ('nearest', [(1, 3, 32, 32), (1, 5, 34, 34)], (1, 8, 34, 34)), + ], +) +def test_concat_basic( + mode: Literal['crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'], + input_shapes: list[tuple[int, ...]], + expected_shape: tuple[int, ...], + device: str, +) -> None: + """Test Concat basic functionality.""" + rng = RandomGenerator(seed=42) + xs = [rng.float32_tensor(shape).to(device).requires_grad_(True) for shape in input_shapes] + concat = Concat(mode=mode).to(device) + + output = concat(*xs) + assert output.shape == expected_shape + assert not output.isnan().any(), 'NaN values in output' + + output.sum().backward() + for x in xs: + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('mode', 'input_shapes', 'expected_shape'), + [ + ('crop', [(1, 3, 32, 32), (1, 3, 30, 30)], (1, 3, 30, 30)), + ('zero', [(1, 3, 32, 32), (1, 3, 34, 34)], (1, 3, 34, 34)), + ('replicate', [(1, 1, 1, 2), (1, 1, 1, 3)], (1, 1, 1, 3)), + ('circular', [(1, 1, 1, 2), (1, 1, 1, 4)], (1, 1, 1, 4)), + ], +) +def test_add_basic( + mode: Literal['crop', 'zero', 'replicate', 'circular'], + input_shapes: list[tuple[int, ...]], + expected_shape: tuple[int, ...], + device: str, +) -> None: + """Test Add basic functionality.""" + rng = RandomGenerator(seed=42) + xs = [rng.float32_tensor(shape).to(device).requires_grad_(True) for shape in input_shapes] + add = Add(mode=mode).to(device) + + output = add(*xs) + assert output.shape == expected_shape + assert not output.isnan().any(), 'NaN values in output' + + output.sum().backward() + for x in xs: + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + +@pytest.mark.parametrize( + ('dim', 'input_shapes', 'expected_shape'), + [ + (0, [(1, 3, 32, 32), (1, 3, 32, 32)], (2, 3, 32, 32)), + (1, [(1, 3, 32, 32), (1, 5, 32, 32)], (1, 8, 32, 32)), + (2, [(1, 3, 32, 32), (1, 3, 32, 32)], (1, 3, 64, 32)), + ], +) +def test_concat_dimensions(dim: int, input_shapes: list[tuple[int, ...]], expected_shape: tuple[int, ...]) -> None: + """Test Concat with different concatenation dimensions.""" + rng = RandomGenerator(seed=42) + xs = [rng.float32_tensor(shape).requires_grad_(True) for shape in input_shapes] + concat = Concat(mode='fail', dim=dim) + output = concat(*xs) + assert output.shape == expected_shape + + +def test_concat_values() -> None: + """Test that Concat preserves input values correctly.""" + x1 = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]).requires_grad_(True) + x2 = torch.tensor([[[[5.0, 6.0], [7.0, 8.0]]]]).requires_grad_(True) + + concat = Concat(mode='fail') + output = concat(x1, x2) + + expected = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]) + torch.testing.assert_close(output, expected) + + +def test_add_values() -> None: + """Test that Add correctly sums input values.""" + x1 = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]).requires_grad_(True) + x2 = torch.tensor([[[[5.0, 6.0], [7.0, 8.0]]]]).requires_grad_(True) + + add = Add(mode='fail') + output = add(x1, x2) + + expected = torch.tensor([[[[6.0, 8.0], [10.0, 12.0]]]]) + torch.testing.assert_close(output, expected) + + +def test_concat_mode_fail() -> None: + """Test Concat with mode='fail'.""" + rng = RandomGenerator(seed=42) + + x1 = rng.float32_tensor((1, 3, 32, 32)) + x2 = rng.float32_tensor((1, 5, 32, 32)) + concat = Concat(mode='fail') + output = concat(x1, x2) + assert output.shape == (1, 8, 32, 32) + + x3 = rng.float32_tensor((1, 3, 30, 30)) + with pytest.raises(RuntimeError): + concat(x1, x3) + + +def test_add_mode_fail() -> None: + """Test Add with mode='fail'.""" + rng = RandomGenerator(seed=42) + + x1 = rng.float32_tensor((1, 3, 32, 32)) + x2 = rng.float32_tensor((1, 3, 32, 32)) + add = Add(mode='fail') + output = add(x1, x2) + assert output.shape == (1, 3, 32, 32) + + x3 = rng.float32_tensor((1, 3, 30, 30)) + with pytest.raises(RuntimeError): + add(x1, x3) + + +@pytest.mark.parametrize('module_class', [Concat, Add]) +def test_invalid_mode(module_class: type) -> None: + """Test modules with invalid mode.""" + with pytest.raises(ValueError, match='mode must be one of'): + module_class(mode='invalid_mode') diff --git a/tests/nn/test_layernorm.py b/tests/nn/test_layernorm.py new file mode 100644 index 000000000..51d6ed030 --- /dev/null +++ b/tests/nn/test_layernorm.py @@ -0,0 +1,186 @@ +"""Tests for LayerNorm module.""" + +from collections.abc import Sequence + +import pytest +import torch +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels', 'features_last', 'input_shape'), + [ + (32, False, (1, 32, 32, 32)), + (64, True, (2, 16, 16, 64)), + (None, False, (1, 32, 32, 32)), + (None, True, (2, 16, 16, 64)), + ], +) +def test_layernorm_basic( + n_channels: int | None, + features_last: bool, + input_shape: Sequence[int], + device: str, +) -> None: + """Test LayerNorm basic functionality.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + norm = LayerNorm(n_channels=n_channels, features_last=features_last).to(device) + output = norm(x) + + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + if n_channels is not None: + assert norm.weight is not None, 'Weight should not be None when n_channels is provided' + assert norm.bias is not None, 'Bias should not be None when n_channels is provided' + assert norm.weight.grad is not None, 'No gradient computed for weight' + assert norm.bias.grad is not None, 'No gradient computed for bias' + + +@pytest.mark.parametrize( + ('n_channels', 'cond_dim', 'input_shape', 'cond_shape'), + [ + (32, 16, (1, 32, 32, 32), (1, 16)), + (64, 32, (2, 64, 16, 16), (2, 32)), + ], +) +def test_layernorm_with_conditioning( + n_channels: int, + cond_dim: int, + input_shape: Sequence[int], + cond_shape: Sequence[int], +) -> None: + """Test LayerNorm with conditioning.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).requires_grad_(True) + norm = LayerNorm(n_channels=n_channels, cond_dim=cond_dim) + + output = norm(x, cond=cond) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert cond.grad is not None, 'No gradient computed for conditioning' + assert norm.cond_proj is not None, 'cond_proj should not be None when cond_dim > 0' + assert norm.cond_proj.weight.grad is not None, 'No gradient computed for cond_proj' + + +def test_layernorm_features_last() -> None: + """Test LayerNorm with features_last=True vs features_last=False.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + + norm_standard = LayerNorm(n_channels=3, features_last=False) + y_standard = norm_standard(x) + + norm_last = LayerNorm(n_channels=3, features_last=True) + y_last = norm_last(x.moveaxis(1, -1)) + + torch.testing.assert_close(y_standard, y_last.moveaxis(-1, 1)) + + +def test_layernorm_no_channels() -> None: + """Test LayerNorm without channels (pure normalization).""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 32, 32)).requires_grad_(True) + norm = LayerNorm(n_channels=None) + + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + # Check that normalization is applied over channel dim (dim=1 for features_last=False) + mean = output.mean(dim=1, keepdim=True) + var = (output * output).mean(dim=1, keepdim=True) - mean * mean + + assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-5), 'Mean not close to 0' + assert torch.allclose(var, torch.ones_like(var), atol=1e-3), 'Variance not close to 1' + + +def test_layernorm_conditioning_without_channels() -> None: + """Test LayerNorm with conditioning but no channels (should raise error).""" + with pytest.raises(ValueError, match='channels must be provided if cond_dim > 0'): + LayerNorm(n_channels=None, cond_dim=16) + + +def test_layernorm_invalid_cond_dim() -> None: + """Test LayerNorm with invalid cond_dim.""" + with pytest.raises(RuntimeError, match='Trying to create tensor with negative dimension'): + LayerNorm(n_channels=32, cond_dim=-1) + + +def test_layernorm_3d_input() -> None: + """Test LayerNorm with 3D input.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((2, 64, 128)).requires_grad_(True) + norm = LayerNorm(n_channels=64) + + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + + +def test_layernorm_5d_input() -> None: + """Test LayerNorm with 5D input.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 16, 16, 16)).requires_grad_(True) + norm = LayerNorm(n_channels=32) + + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + + +def test_layernorm_conditioning_features_last() -> None: + """Test LayerNorm with conditioning and features_last=True.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + cond = rng.float32_tensor((1, 8)).requires_grad_(True) + + norm = LayerNorm(n_channels=3, features_last=True, cond_dim=8) + output = norm(x.moveaxis(1, -1), cond=cond) + + assert output.shape == x.moveaxis(1, -1).shape, f'Output shape {output.shape} != expected shape' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert cond.grad is not None, 'No gradient computed for conditioning' + + +def test_layernorm_gradient_flow() -> None: + """Test that gradients flow properly through LayerNorm.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 32, 32)).requires_grad_(True) + norm = LayerNorm(n_channels=32) + + output = norm(x) + loss = output.sum() + loss.backward() + + # Check that gradients are computed for all learnable parameters + assert x.grad is not None, 'Input gradients not computed' + assert norm.weight is not None, 'Weight should not be None when n_channels is provided' + assert norm.bias is not None, 'Bias should not be None when n_channels is provided' + assert norm.weight.grad is not None, 'Weight gradients not computed' + assert norm.bias.grad is not None, 'Bias gradients not computed' + + # Check that gradients are finite + assert torch.isfinite(x.grad).all(), 'Input gradients contain non-finite values' + assert torch.isfinite(norm.weight.grad).all(), 'Weight gradients contain non-finite values' + assert torch.isfinite(norm.bias.grad).all(), 'Bias gradients contain non-finite values' diff --git a/tests/nn/test_ndmodules.py b/tests/nn/test_ndmodules.py new file mode 100644 index 000000000..a0a77a98d --- /dev/null +++ b/tests/nn/test_ndmodules.py @@ -0,0 +1,76 @@ +"""Tests for the ndmodules module.""" + +import pytest +import torch +from mrpro.nn.ndmodules import ( + adaptiveAvgPoolND, + avgPoolND, + batchNormND, + convND, + convTransposeND, + instanceNormND, + maxPoolND, +) + + +def test_convnd() -> None: + """Test ConvND.""" + assert convND(1) is torch.nn.Conv1d + assert convND(2) is torch.nn.Conv2d + assert convND(3) is torch.nn.Conv3d + with pytest.raises(NotImplementedError): + convND(4) + + +def test_convtransposend() -> None: + """Test ConvTransposeND.""" + assert convTransposeND(1) is torch.nn.ConvTranspose1d + assert convTransposeND(2) is torch.nn.ConvTranspose2d + assert convTransposeND(3) is torch.nn.ConvTranspose3d + with pytest.raises(NotImplementedError): + convTransposeND(4) + + +def test_maxpoolnd() -> None: + """Test MaxPoolND.""" + assert maxPoolND(1) is torch.nn.MaxPool1d + assert maxPoolND(2) is torch.nn.MaxPool2d + assert maxPoolND(3) is torch.nn.MaxPool3d + with pytest.raises(NotImplementedError): + maxPoolND(4) + + +def test_avgpoolnd() -> None: + """Test AvgPoolND.""" + assert avgPoolND(1) is torch.nn.AvgPool1d + assert avgPoolND(2) is torch.nn.AvgPool2d + assert avgPoolND(3) is torch.nn.AvgPool3d + with pytest.raises(NotImplementedError): + avgPoolND(4) + + +def test_adaptiveavgpoolnd() -> None: + """Test AdaptiveAvgPoolND.""" + assert adaptiveAvgPoolND(1) is torch.nn.AdaptiveAvgPool1d + assert adaptiveAvgPoolND(2) is torch.nn.AdaptiveAvgPool2d + assert adaptiveAvgPoolND(3) is torch.nn.AdaptiveAvgPool3d + with pytest.raises(NotImplementedError): + adaptiveAvgPoolND(4) + + +def test_instancenormnd() -> None: + """Test InstanceNormND.""" + assert instanceNormND(1) is torch.nn.InstanceNorm1d + assert instanceNormND(2) is torch.nn.InstanceNorm2d + assert instanceNormND(3) is torch.nn.InstanceNorm3d + with pytest.raises(NotImplementedError): + instanceNormND(4) + + +def test_batchnormnd() -> None: + """Test BatchNormND.""" + assert batchNormND(1) is torch.nn.BatchNorm1d + assert batchNormND(2) is torch.nn.BatchNorm2d + assert batchNormND(3) is torch.nn.BatchNorm3d + with pytest.raises(NotImplementedError): + batchNormND(4) diff --git a/tests/nn/test_pixelshuffle.py b/tests/nn/test_pixelshuffle.py new file mode 100644 index 000000000..8d5917a83 --- /dev/null +++ b/tests/nn/test_pixelshuffle.py @@ -0,0 +1,92 @@ +"""Test PixelShuffle and PixelUnshuffle.""" + +from typing import cast + +import torch +from mrpro.nn.PixelShuffle import PixelShuffle, PixelShuffleUpsample, PixelUnshuffle, PixelUnshuffleDownsample +from mrpro.utils import RandomGenerator + + +def test_pixel_shuffle_2d() -> None: + """Test PixelUnshuffle's fast path for 2D images.""" + x = torch.arange(3 * 4 * 8).reshape(1, 3, 4, 8) + pixel_unshuffle = PixelUnshuffle(2) + y = pixel_unshuffle(x) + assert y.shape == (1, 3 * 4, 4 // 2, 8 // 2) + + pixel_shuffle = PixelShuffle(2) + z = pixel_shuffle(y) + assert z.shape == (1, 3, 4, 8) + assert (x == z).all() + + +def test_pixel_unshuffle_4d() -> None: + """Test PixelUnshuffle's general case.""" + x = torch.arange(3 * 4 * 8 * 10 * 12).reshape(1, 3, 4, 8, 10, 12) + pixel_unshuffle = PixelUnshuffle(2) + y = pixel_unshuffle(x) + assert y.shape == (1, 3 * 16, 4 // 2, 8 // 2, 10 // 2, 12 // 2) + + pixel_shuffle = PixelShuffle(2) + z = pixel_shuffle(y) + assert z.shape == (1, 3, 4, 8, 10, 12) + assert (x == z).all() + + +def test_pixelunshuffle_features_last() -> None: + """Test PixelUnshuffle with features_last.""" + x = torch.arange(3 * 4 * 8 * 10 * 12).reshape(1, 3, 4, 8, 10, 12) + pixel_unshuffle_last = PixelUnshuffle(2, features_last=True) + pixel_unshuffle = PixelUnshuffle(2, features_last=False) + y_last = pixel_unshuffle_last(x.moveaxis(1, -1)).moveaxis(-1, 1) + y_normal = pixel_unshuffle(x) + assert (y_last == y_normal).all() + + +def test_pixelshuffle_features_last() -> None: + """Test PixelShuffle with features_last.""" + x = torch.arange(3 * 4 * 8 * 10 * 12).reshape(1, -1, 2, 4, 5, 6) + pixel_shuffle_last = PixelShuffle(2, features_last=True) + pixel_shuffle = PixelShuffle(2, features_last=False) + y_last = pixel_shuffle_last(x.moveaxis(1, -1)).moveaxis(-1, 1) + y_normal = pixel_shuffle(x) + assert (y_last == y_normal).all() + + +def test_unpixelshuffledownsample_residual() -> None: + """Test PixelUnshuffleDownsample with residual.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 2, 9, 12, 15)) + downsample = PixelUnshuffleDownsample(3, 2, 27, downscale_factor=3, residual=True) + y = downsample(x) + assert y.shape == (1, 27, 3, 4, 5) + + +def test_pixelshuffleupsample_residual() -> None: + """Test PixelShuffleUpsample with residual.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 2, 3, 4, 5)) + upsample = PixelShuffleUpsample(3, 2, 1, upscale_factor=3, residual=True) + y = upsample(x) + assert y.shape == (1, 1, 9, 12, 15) + + +def test_pixelshuffleupsample_pixelunshuffledownsample() -> None: + """Test if PixelUnshuffleDownsample is the inverse of PixelShuffleUpsample.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3**3, 3, 4, 5)) + # Only without residual, the upsample and downsample are inverses. + downsample = PixelUnshuffleDownsample(3, 1, 3**3, downscale_factor=3, residual=False) + upsample = PixelShuffleUpsample(3, 3**3, 1, upscale_factor=3, residual=False) + # Only if the convs are Identity, the upsample and downsample are inverses. + torch.nn.init.dirac_(cast(torch.Tensor, downsample.projection.weight)) + torch.nn.init.dirac_(cast(torch.Tensor, upsample.projection.weight)) + downsample_bias = cast(torch.Tensor | None, downsample.projection.bias) + upsample_bias = cast(torch.Tensor | None, upsample.projection.bias) + if downsample_bias is not None: + torch.nn.init.zeros_(downsample_bias) + if upsample_bias is not None: + torch.nn.init.zeros_(upsample_bias) + y = downsample(upsample(x)) + assert y.shape == (1, 3**3, 3, 4, 5) + torch.testing.assert_close(y, x, msg='Upsample and downsample are not inverses.') diff --git a/tests/nn/test_rmsnorm.py b/tests/nn/test_rmsnorm.py new file mode 100644 index 000000000..c8ddc0b69 --- /dev/null +++ b/tests/nn/test_rmsnorm.py @@ -0,0 +1,58 @@ +"""Tests for RMSNorm module.""" + +from collections.abc import Sequence + +import pytest +import torch +from mrpro.nn import RMSNorm +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels', 'features_last', 'input_shape'), + [ + (32, False, (1, 32, 32, 32)), + (64, True, (2, 16, 16, 64)), + (None, False, (1, 32, 32, 32)), + (None, True, (2, 16, 16, 64)), + ], +) +def test_rmsnorm_basic(n_channels: int | None, features_last: bool, input_shape: Sequence[int], device: str) -> None: + """Test RMSNorm basic functionality.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + norm = RMSNorm(n_channels=n_channels, features_last=features_last).to(device) + output = norm(x) + + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + if n_channels is not None: + assert norm.weight is not None, 'Weight should not be None when n_channels is provided' + assert norm.bias is not None, 'Bias should not be None when n_channels is provided' + assert norm.weight.grad is not None, 'No gradient computed for weight' + assert norm.bias.grad is not None, 'No gradient computed for bias' + + +def test_rmsnorm_features_last() -> None: + """Test RMSNorm with features_last=True vs features_last=False.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + + norm_standard = RMSNorm(n_channels=3, features_last=False) + y_standard = norm_standard(x) + + norm_last = RMSNorm(n_channels=3, features_last=True) + y_last = norm_last(x.moveaxis(1, -1)) + + torch.testing.assert_close(y_standard, y_last.moveaxis(-1, 1)) diff --git a/tests/nn/test_sequential.py b/tests/nn/test_sequential.py new file mode 100644 index 000000000..bdf81bf8d --- /dev/null +++ b/tests/nn/test_sequential.py @@ -0,0 +1,50 @@ +"""Tests for Sequential module.""" + +from collections.abc import Sequence + +import pytest +from mrpro.nn import FiLM, Sequential +from mrpro.operators import FastFourierOp, MagnitudeOp +from mrpro.utils import RandomGenerator +from torch.nn import Linear + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('input_shape', 'cond_dim'), + [ + ((1, 32), (1, 16)), + ((2, 32), None), + ], +) +def test_sequential( + input_shape: Sequence[int], + cond_dim: Sequence[int] | None, + device: str, +) -> None: + """Test Sequential output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_dim).to(device).requires_grad_(True) if cond_dim else None + seq = Sequential( + Linear(input_shape[1], 64), + FastFourierOp(dim=(-1,)), + FiLM(channels=64, cond_dim=16), + MagnitudeOp(), + ).to(device) + output = seq(x, cond=cond) + assert output.shape == (input_shape[0], 64) + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + if cond is not None: + assert cond.grad is not None, 'No gradient computed for cond' + assert not cond.grad.isnan().any(), 'NaN values in cond gradients' + assert seq[0].weight.grad is not None, 'No gradient computed for Linear' From 05db778511fd34dd272ac34d946f5def23db37c9 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:37:26 +0100 Subject: [PATCH 192/205] add data consistency modules and tests ghstack-source-id: 555bc2286cf127ebd9411ff45f7f1bb6c68612e4 ghstack-comment-id: 3865650618 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/954 --- docker/minimal_requirements.txt | 2 +- src/mrpro/nn/__init__.py | 2 + .../data_consistency/AnalyticCartesianDC.py | 101 +++++++++++++++ .../data_consistency/ConjugateGradientDC.py | 117 ++++++++++++++++++ .../nn/data_consistency/GradientDescentDC.py | 99 +++++++++++++++ src/mrpro/nn/data_consistency/__init__.py | 5 + tests/nn/data_consistency/conftest.py | 46 +++++++ .../test_analyticcartesiandc.py | 21 ++++ .../test_conjugategradientdc.py | 21 ++++ .../test_gradientdescentdc.py | 21 ++++ 10 files changed, 434 insertions(+), 1 deletion(-) create mode 100644 src/mrpro/nn/data_consistency/AnalyticCartesianDC.py create mode 100644 src/mrpro/nn/data_consistency/ConjugateGradientDC.py create mode 100644 src/mrpro/nn/data_consistency/GradientDescentDC.py create mode 100644 src/mrpro/nn/data_consistency/__init__.py create mode 100644 tests/nn/data_consistency/conftest.py create mode 100644 tests/nn/data_consistency/test_analyticcartesiandc.py create mode 100644 tests/nn/data_consistency/test_conjugategradientdc.py create mode 100644 tests/nn/data_consistency/test_gradientdescentdc.py diff --git a/docker/minimal_requirements.txt b/docker/minimal_requirements.txt index c9be333c9..6723b723e 100644 --- a/docker/minimal_requirements.txt +++ b/docker/minimal_requirements.txt @@ -6,7 +6,7 @@ einops==0.7.0 pydicom==3.0.1 pypulseq==1.4.2 pytorch-finufft==0.1.0 -cufinufft==2.3.1 +cufinufft==2.4.1 scipy==1.12 ptwt==0.1.8 tqdm==4.60.0 diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index d6541f5c8..f988855e7 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -12,6 +12,7 @@ from mrpro.nn.RMSNorm import RMSNorm from mrpro.nn.Residual import Residual from mrpro.nn.Sequential import Sequential +from mrpro.nn import data_consistency from mrpro.nn.ndmodules import ( adaptiveAvgPoolND, avgPoolND, @@ -40,6 +41,7 @@ 'batchNormND', 'convND', 'convTransposeND', + 'data_consistency', 'instanceNormND', 'maxPoolND', ] diff --git a/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py new file mode 100644 index 000000000..5b2848fca --- /dev/null +++ b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py @@ -0,0 +1,101 @@ +"""Analytic Cartesian data consistency.""" + +from typing import overload + +import torch +from torch.nn import Module, Parameter + +from mrpro.data.KData import KData +from mrpro.operators.FourierOp import FourierOp +from mrpro.operators.IdentityOp import IdentityOp + + +class AnalyticCartesianDC(Module): + r"""Analytic Cartesian data consistency. + + Solves the following problem: + :math:`\min_x \|Ax - k\|_2^2 + \lambda \|x-p\|_2^2` + where :math:`A` is the acquisition operator and :math:`k` is the data, :math:`\lambda` is the regularization + parameter and :math:`p` is the regularization image/prior analytically. :math:`A^H A` has to be diagonal. This is a + special case for a Cartesian acquisition without coil sensitivity weighting. This can be used for either single-coil + data or to apply data consistency to each coil image [NOSENSE]_. + + References + ---------- + .. [NOSENSE] Zimmermann, FF, and Kofler, Andreas. "NoSENSE: Learned unrolled cardiac MRI reconstruction without + explicit sensitivity maps." STACOM@MICCAI 2023. https://arxiv.org/abs/2309.15608 + + Parameters + ---------- + initial_regularization_weight + Initial regularization weight. The regularization weight is a trainable parameter. + + + """ + + def __init__(self, initial_regularization_weight: torch.Tensor | float): + """Initialize the data consistency. + + Parameters + ---------- + initial_regularization_weight + Initial regularization weight. The regularization weight is a trainable parameter. + Must be a positive scalar. + """ + super().__init__() + weight = torch.as_tensor(initial_regularization_weight) + if weight.ndim != 0: + raise ValueError('Regularization weight must be a scalar') + if weight.item() <= 0: + raise ValueError('Regularization weight must be positive') + self.log_weight = Parameter(weight.log()) + + @overload + def __call__(self, image: torch.Tensor, data: KData, fourier_op: FourierOp | None = None) -> torch.Tensor: ... + + @overload + def __call__(self, image: torch.Tensor, data: torch.Tensor, fourier_op: FourierOp) -> torch.Tensor: ... + + def __call__( + self, image: torch.Tensor, data: KData | torch.Tensor, fourier_op: FourierOp | None = None + ) -> torch.Tensor: + """Apply the data consistency. + + Parameters + ---------- + image + Current image estimate, i.e. the regularized image. + data + k-space data. + fourier_op + Fourier operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` object, + the Fourier operator is automatically created from the data. + + Returns + ------- + Updated image estimate. + """ + return super().__call__(image, data, fourier_op) + + def forward( + self, image: torch.Tensor, data: KData | torch.Tensor, fourier_op: FourierOp | None = None + ) -> torch.Tensor: + """Apply the data consistency.""" + if fourier_op is None: + if isinstance(data, KData): + fourier_op = FourierOp.from_kdata(data) + else: + raise ValueError('Either a KData or a FourierOp is required') + + if not isinstance(fourier_op, FourierOp) or fourier_op._nufft_dims or fourier_op._fast_fourier_op is None: + raise ValueError('Only Cartesian acquisitions are supported') + + data_ = data.data if isinstance(data, KData) else data + fft_op = fourier_op._fast_fourier_op + sampling_op = fourier_op._cart_sampling_op if fourier_op._cart_sampling_op is not None else IdentityOp() + (zero_filled,) = sampling_op.adjoint(data_) + (k_pred,) = fft_op(image) + regularization_weight = self.log_weight.exp() + (k,) = sampling_op.gram((zero_filled - k_pred) / (1 + regularization_weight)) + (delta,) = fft_op.H(k) + return image + delta diff --git a/src/mrpro/nn/data_consistency/ConjugateGradientDC.py b/src/mrpro/nn/data_consistency/ConjugateGradientDC.py new file mode 100644 index 000000000..f81b24f98 --- /dev/null +++ b/src/mrpro/nn/data_consistency/ConjugateGradientDC.py @@ -0,0 +1,117 @@ +"""Conjugate gradient data consistency.""" + +from typing import overload + +import torch +from torch.nn import Module, Parameter + +from mrpro.data.CsmData import CsmData +from mrpro.data.KData import KData +from mrpro.operators.ConjugateGradientOp import ConjugateGradientOp +from mrpro.operators.FourierOp import FourierOp +from mrpro.operators.LinearOperator import LinearOperator +from mrpro.operators.SensitivityOp import SensitivityOp + + +class ConjugateGradientDC(Module): + """Conjugate gradient data consistency.""" + + def __init__(self, initial_regularization_weight: torch.Tensor | float): + """Initialize the conjugate gradient data consistency. + + Parameters + ---------- + initial_regularization_weight + Initial regularization weight. The regularization weight is a trainable parameter. + Must be a positive scalar. + """ + super().__init__() + weight = torch.as_tensor(initial_regularization_weight) + if weight.ndim != 0: + raise ValueError('Regularization weight must be a scalar') + if weight.item() <= 0: + raise ValueError('Regularization weight must be positive') + self.log_weight = Parameter(weight.log()) + + @overload + def __call__( + self, + image: torch.Tensor, + data: KData, + fourier_op: FourierOp | None = None, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: ... + + @overload + def __call__( + self, + image: torch.Tensor, + data: torch.Tensor, + fourier_op: FourierOp, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: ... + + def __call__( + self, + image: torch.Tensor, + data: KData | torch.Tensor, + fourier_op: LinearOperator | None = None, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: + """Apply the data consistency. + + Parameters + ---------- + image + Current image estimate. + data + k-space data. + fourier_op + Fourier operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` object, + the Fourier operator is automatically created from the data. + This operator can already include the coil sensitivity weighting, if gradients wrt the coil sensitivity maps + NOT required. Otherwise, they should be given as an additional argument. + csm + Coil sensitivity maps. If None, no coil sensitivity weighting is applied. + + Returns + ------- + Updated image estimate. + """ + return super().__call__(image, data, fourier_op, csm) + + def forward( + self, + image: torch.Tensor, + data: torch.Tensor | KData, + fourier_op: FourierOp | None = None, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: + """Apply the data consistency.""" + if fourier_op is None: + if isinstance(data, KData): + fourier_op = FourierOp.from_kdata(data) + else: + raise ValueError('Either a KData or a FourierOp is required') + + data_ = data.data if isinstance(data, KData) else data + + if csm is None: + csm = torch.tensor(()) + elif isinstance(csm, CsmData): + csm = csm.data + + def operator_factory(csm: torch.Tensor, weight: torch.Tensor, *_): + op = fourier_op.gram + if csm.numel(): + csm_op = SensitivityOp(csm) + op = csm_op.H @ op @ csm_op + op = op + weight + return op + + def rhs_factory(_csm: torch.Tensor, weight: torch.Tensor, zero_filled: torch.Tensor, image: torch.Tensor): + return (zero_filled + weight * image,) + + cg_op = ConjugateGradientOp(operator_factory=operator_factory, rhs_factory=rhs_factory) + (result,) = cg_op(csm, self.log_weight.exp(), fourier_op.adjoint(data_)[0], image) + return result diff --git a/src/mrpro/nn/data_consistency/GradientDescentDC.py b/src/mrpro/nn/data_consistency/GradientDescentDC.py new file mode 100644 index 000000000..579e5181a --- /dev/null +++ b/src/mrpro/nn/data_consistency/GradientDescentDC.py @@ -0,0 +1,99 @@ +"""Gradient descent data consistency.""" + +from typing import overload + +import torch +from torch.nn import Module, Parameter + +from mrpro.data.KData import KData +from mrpro.operators.FourierOp import FourierOp +from mrpro.operators.LinearOperator import LinearOperator + + +class GradientDescentDC(Module): + r"""Gradient descent data consistency. + + Performs gradient descent steps on + :math:`\|Ax - k\|_2^2` where :math:`A` is the acquisition operator and :math:`k` is the data. + + Parameters + ---------- + initial_stepsize + Initial stepsize. The stepsize is a trainable parameter. + Must be a positive scalar. + n_steps + Number of gradient descent steps. + + Returns + ------- + The updated image. + """ + + def __init__(self, initial_stepsize: float | torch.Tensor, n_steps: int = 1) -> None: + """Initialize the gradient descent data consistency. + + Parameters + ---------- + initial_stepsize + Initial stepsize. The stepsize is a trainable parameter. + Must be a positive scalar. + n_steps + Number of gradient descent steps. + """ + super().__init__() + stepsize = torch.as_tensor(initial_stepsize) + if stepsize.ndim != 0: + raise ValueError('Stepsize must be a scalar') + if stepsize.item() <= 0: + raise ValueError('Stepsize must be positive') + self.log_stepsize = Parameter(stepsize.log()) + self.n_steps = n_steps + + @overload + def __call__( + self, image: torch.Tensor, data: KData, acquisition_operator: LinearOperator | None = None + ) -> torch.Tensor: ... + + @overload + def __call__( + self, image: torch.Tensor, data: torch.Tensor, acquisition_operator: LinearOperator + ) -> torch.Tensor: ... + + def __call__( + self, image: torch.Tensor, data: KData | torch.Tensor, acquisition_operator: LinearOperator | None = None + ) -> torch.Tensor: + """Apply the data consistency. + + Parameters + ---------- + image + Current image estimate. + data + k-space data. + acquisition_operator + Acquisition operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` + object, the Fourier operator is automatically created from the data. + + Returns + ------- + Updated image estimate. + """ + return super().__call__(image, data, acquisition_operator) + + def forward( + self, image: torch.Tensor, data: KData | torch.Tensor, acquisition_operator: LinearOperator | None = None + ) -> torch.Tensor: + """Apply the data consistency.""" + if acquisition_operator is None: + if isinstance(data, KData): + acquisition_operator = FourierOp.from_kdata(data) + else: + raise ValueError('Either a KData or an acquisition operator is required') + + data_ = data.data if isinstance(data, KData) else data + stepsize = self.log_stepsize.exp() + x = image + for _ in range(self.n_steps): + residual = acquisition_operator(x)[0] - data_ + x = x - stepsize * acquisition_operator.adjoint(residual)[0] + return x diff --git a/src/mrpro/nn/data_consistency/__init__.py b/src/mrpro/nn/data_consistency/__init__.py new file mode 100644 index 000000000..cea955c85 --- /dev/null +++ b/src/mrpro/nn/data_consistency/__init__.py @@ -0,0 +1,5 @@ +from mrpro.nn.data_consistency.AnalyticCartesianDC import AnalyticCartesianDC +from mrpro.nn.data_consistency.GradientDescentDC import GradientDescentDC +from mrpro.nn.data_consistency.ConjugateGradientDC import ConjugateGradientDC + +__all__ = ["AnalyticCartesianDC", "ConjugateGradientDC", "GradientDescentDC"] \ No newline at end of file diff --git a/tests/nn/data_consistency/conftest.py b/tests/nn/data_consistency/conftest.py new file mode 100644 index 000000000..49fca4cf3 --- /dev/null +++ b/tests/nn/data_consistency/conftest.py @@ -0,0 +1,46 @@ +"""Test fixtures for data consistency tests.""" + +import pytest +from mrpro.data.KData import KData +from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.data.traj_calculators.KTrajectoryCartesian import KTrajectoryCartesian +from mrpro.operators.FourierOp import FourierOp +from mrpro.phantoms.EllipsePhantom import EllipsePhantom +from mrpro.utils import RandomGenerator + + +@pytest.fixture +def kdata(): + matrix = SpatialDimension(x=128, y=128, z=1) + kdata = EllipsePhantom().kdata(KTrajectoryCartesian.fullysampled(matrix), matrix) + return kdata + + +@pytest.fixture +def kdata_noisy(kdata: KData): + kdata_noisy = kdata.clone() + kdata_noisy.data += 0.1 * RandomGenerator(123).randn_like(kdata_noisy.data) + return kdata_noisy + + +@pytest.fixture +def kdata_us(kdata: KData): + return kdata[..., ::2, :].clone() + + +@pytest.fixture +def image_noisy(kdata_noisy: KData): + fourier_op = FourierOp.from_kdata(kdata_noisy) + return fourier_op.adjoint(kdata_noisy.data)[0] + + +@pytest.fixture +def image(kdata: KData): + fourier_op = FourierOp.from_kdata(kdata) + return fourier_op.adjoint(kdata.data)[0] + + +@pytest.fixture +def image_us(kdata_us: KData): + fourier_op = FourierOp.from_kdata(kdata_us) + return fourier_op.adjoint(kdata_us.data)[0] diff --git a/tests/nn/data_consistency/test_analyticcartesiandc.py b/tests/nn/data_consistency/test_analyticcartesiandc.py new file mode 100644 index 000000000..9020a8326 --- /dev/null +++ b/tests/nn/data_consistency/test_analyticcartesiandc.py @@ -0,0 +1,21 @@ +"""Tests for AnalyticCartesianDC module.""" + +import torch +from mrpro.data.KData import KData +from mrpro.nn.data_consistency.AnalyticCartesianDC import AnalyticCartesianDC + + +def test_analytic_cartesian_dc( + image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor +) -> None: + image_noisy = image_noisy.clone().requires_grad_(True) + dc = AnalyticCartesianDC(initial_regularization_weight=1e-6) + result = dc(image_noisy, kdata_us) + loss = (result - image).abs().mean() + assert loss < (image_noisy - image).abs().mean() + assert loss < (image_us - image).abs().mean() + loss.backward() + assert image_noisy.grad is not None + assert dc.log_weight.grad is not None + assert not dc.log_weight.grad.isnan() + assert not image_noisy.grad.isnan().any() diff --git a/tests/nn/data_consistency/test_conjugategradientdc.py b/tests/nn/data_consistency/test_conjugategradientdc.py new file mode 100644 index 000000000..99de6f4f5 --- /dev/null +++ b/tests/nn/data_consistency/test_conjugategradientdc.py @@ -0,0 +1,21 @@ +"""Tests for ConjugateGradientDC module.""" + +import torch +from mrpro.data.KData import KData +from mrpro.nn.data_consistency.ConjugateGradientDC import ConjugateGradientDC + + +def test_conjugate_gradient_dc( + image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor +) -> None: + image_noisy = image_noisy.clone().requires_grad_(True) + dc = ConjugateGradientDC(initial_regularization_weight=1.0) + result = dc(image_noisy, kdata_us) + loss = (result - image).abs().mean() + assert loss < (image_noisy - image).abs().mean() + assert loss < (image_us - image).abs().mean() + loss.backward() + assert image_noisy.grad is not None + assert dc.log_weight.grad is not None + assert not dc.log_weight.grad.isnan() + assert not image_noisy.grad.isnan().any() diff --git a/tests/nn/data_consistency/test_gradientdescentdc.py b/tests/nn/data_consistency/test_gradientdescentdc.py new file mode 100644 index 000000000..00a7648f9 --- /dev/null +++ b/tests/nn/data_consistency/test_gradientdescentdc.py @@ -0,0 +1,21 @@ +"""Tests for GradientDescentDC module.""" + +import torch +from mrpro.data.KData import KData +from mrpro.nn.data_consistency.GradientDescentDC import GradientDescentDC + + +def test_gradient_descent_dc( + image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor +) -> None: + image_noisy = image_noisy.clone().requires_grad_(True) + dc = GradientDescentDC(initial_stepsize=1.0) + result = dc(image_noisy, kdata_us) + loss = (result - image).abs().mean() + assert loss < (image_noisy - image).abs().mean() + assert loss < (image_us - image).abs().mean() + loss.backward() + assert image_noisy.grad is not None + assert dc.log_stepsize.grad is not None + assert not dc.log_stepsize.grad.isnan() + assert not image_noisy.grad.isnan().any() From d8eb662bdcc1afd9af2919e0547458182abdf38b Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:37:27 +0100 Subject: [PATCH 193/205] add positional encodings and attention modules ghstack-source-id: 1ef9b0579a1d6a662dc8d306ff8bd373b259a253 ghstack-comment-id: 3865650808 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/955 --- pyproject.toml | 2 + src/mrpro/nn/AbsolutePositionEncoding.py | 74 ++++++ src/mrpro/nn/AxialRoPE.py | 106 ++++++++ src/mrpro/nn/__init__.py | 6 + src/mrpro/nn/attention/AttentionGate.py | 74 ++++++ src/mrpro/nn/attention/LinearSelfAttention.py | 98 ++++++++ src/mrpro/nn/attention/MultiHeadAttention.py | 106 ++++++++ .../nn/attention/NeighborhoodSelfAttention.py | 237 ++++++++++++++++++ .../nn/attention/ShiftedWindowAttention.py | 131 ++++++++++ .../nn/attention/SpatialTransformerBlock.py | 216 ++++++++++++++++ src/mrpro/nn/attention/SqueezeExcitation.py | 57 +++++ src/mrpro/nn/attention/TransposedAttention.py | 76 ++++++ src/mrpro/nn/attention/__init__.py | 17 ++ tests/conftest.py | 6 + tests/nn/test_ape.py | 27 ++ tests/nn/test_attentiongate.py | 51 ++++ tests/nn/test_linearselfattention.py | 58 +++++ tests/nn/test_neighborhoodselfattention.py | 149 +++++++++++ tests/nn/test_rope.py | 36 +++ tests/nn/test_shiftedwindowattention.py | 61 +++++ tests/nn/test_spatialtransformerblock.py | 142 +++++++++++ tests/nn/test_squeezeexcitation.py | 32 +++ tests/nn/test_transposedattention.py | 44 ++++ 23 files changed, 1806 insertions(+) create mode 100644 src/mrpro/nn/AbsolutePositionEncoding.py create mode 100644 src/mrpro/nn/AxialRoPE.py create mode 100644 src/mrpro/nn/attention/AttentionGate.py create mode 100644 src/mrpro/nn/attention/LinearSelfAttention.py create mode 100644 src/mrpro/nn/attention/MultiHeadAttention.py create mode 100644 src/mrpro/nn/attention/NeighborhoodSelfAttention.py create mode 100644 src/mrpro/nn/attention/ShiftedWindowAttention.py create mode 100644 src/mrpro/nn/attention/SpatialTransformerBlock.py create mode 100644 src/mrpro/nn/attention/SqueezeExcitation.py create mode 100644 src/mrpro/nn/attention/TransposedAttention.py create mode 100644 src/mrpro/nn/attention/__init__.py create mode 100644 tests/nn/test_ape.py create mode 100644 tests/nn/test_attentiongate.py create mode 100644 tests/nn/test_linearselfattention.py create mode 100644 tests/nn/test_neighborhoodselfattention.py create mode 100644 tests/nn/test_rope.py create mode 100644 tests/nn/test_shiftedwindowattention.py create mode 100644 tests/nn/test_spatialtransformerblock.py create mode 100644 tests/nn/test_squeezeexcitation.py create mode 100644 tests/nn/test_transposedattention.py diff --git a/pyproject.toml b/pyproject.toml index 9863b3856..a94d9b3c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,7 +124,9 @@ filterwarnings = [ "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda "ignore:.*In the future, this object will be coerced as if it was first converted using.*:FutureWarning", # numpy 1.2 "ignore:.*load_module.*:DeprecationWarning", # torch compile in torch<2.6 + "ignore:The \\.grad attribute of a Tensor that is not a leaf Tensor is being accessed:UserWarning", # torch dynamo bug in flex attention "ignore:`torch.jit.script` is deprecated:DeprecationWarning", # torch 2.10 + "ignore:`torch.jit.script_method` is deprecated", # torch 2.10 ] addopts = "-n auto --dist loadfile --maxprocesses=8" markers = ["cuda : Tests only to be run when cuda device is available"] diff --git a/src/mrpro/nn/AbsolutePositionEncoding.py b/src/mrpro/nn/AbsolutePositionEncoding.py new file mode 100644 index 000000000..f6093d6ba --- /dev/null +++ b/src/mrpro/nn/AbsolutePositionEncoding.py @@ -0,0 +1,74 @@ +"""Absolute position encoding (APE).""" + +from itertools import combinations +from math import ceil + +import torch +from torch.nn import Module + +from mrpro.utils.reshape import unsqueeze_right + + +class AbsolutePositionEncoding(Module): + """Absolute position encoding layer. + + Encodes absolute positions in a grid. Has no learnable parameters. + """ + + encoding: torch.Tensor + + def __init__(self, n_dim: int, n_features: int, include_radii: bool = True, base_resolution: int = 128): + """Initialize absolute position encoding layer. + + Parameters + ---------- + n_dim + Dimensions of the input space (1, 2, or 3) + n_features + Number of features to encode. The input to the forward pass needs to have at least + this many features/channels. + include_radii + Whether to include radius features + base_resolution + Base resolution for position encoding. + Encodings are generated at this resolution and interpolated to the input shape in the forward pass. + """ + super().__init__() + + coords = [unsqueeze_right(torch.linspace(-1, 1, base_resolution), i) for i in range(n_dim)] + if include_radii: + for n in range(2, n_dim + 1): + for combination in combinations(coords, n): + coords.append((2 * sum([c**2 for c in combination])) ** 0.5 - 1) + n_freqs = ceil(n_features / len(coords) / 2) + freqs = unsqueeze_right((base_resolution) ** torch.linspace(0, 1, n_freqs), n_dim) + encoding = [] + for coord in coords: + encoding.append(torch.sin(coord * freqs).broadcast_to(1, -1, *((base_resolution,) * n_dim))) + encoding.append(torch.cos(coord * freqs).broadcast_to(1, -1, *((base_resolution,) * n_dim))) + self.register_buffer('encoding', torch.cat(encoding, dim=1)[:, :n_features]) + self.interpolation_mode = ['linear', 'bilinear', 'trilinear'][n_dim - 1] + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply absolute position encoding to a tensor. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Encoded tensor with absolute position information + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply absolute position encoding to a tensor.""" + features = self.encoding.shape[1] + if features > x.shape[1]: + raise ValueError(f'x has {x.shape[1]} features, but {features} are required') + + x_enc, x_unenc = x.split([features, x.shape[1] - features], dim=1) + encoding = torch.nn.functional.interpolate(self.encoding, size=x_unenc.shape[2:], mode=self.interpolation_mode) + return torch.cat((x_enc + encoding, x_unenc), dim=1) diff --git a/src/mrpro/nn/AxialRoPE.py b/src/mrpro/nn/AxialRoPE.py new file mode 100644 index 000000000..7d76a86d1 --- /dev/null +++ b/src/mrpro/nn/AxialRoPE.py @@ -0,0 +1,106 @@ +"""Rotary Position Embedding (RoPE).""" + +from collections.abc import Sequence + +import torch +from einops import rearrange +from torch.nn import Module + + +@torch.compile +def get_theta( + shape: Sequence[int], n_embedding_channels: int, device: torch.device +) -> torch.Tensor: # pragma: no cover + """Get rotation angles. + + Parameters + ---------- + shape + Spatial shape of the input tensor to use for the position embedding, + i.e. the shape excluding batch and channel dimensions. + n_embedding_channels + Number of embedding channels per head + device + Device to create the rotation angles on + + Returns + ------- + Rotation angles + """ + position = torch.stack( + torch.meshgrid([torch.arange(s, device=device) - s // 2 for s in shape], indexing='ij'), dim=-1 + ) + log_min = torch.log(torch.tensor(torch.pi)) + log_max = torch.log(torch.tensor(10000.0)) + freqs = torch.exp(torch.linspace(log_min, log_max, n_embedding_channels // (2 * position.shape[-1]), device=device)) + return rearrange(freqs * position[..., None], '... dim freqs ->... (dim freqs)') + + +class AxialRoPE(Module): + """Axial Rotary Position Embedding. + + Applies rotary position embeddings along each axis independently. + """ + + embed_fraction: float + freqs: torch.Tensor # explicit annotation kept for static type checking + + def __init__( + self, + embed_fraction: float = 1.0, + ): + """Initialize AxialRoPE. + + Parameters + ---------- + embed_fraction + Fraction of channels used for embedding + """ + super().__init__() + self.embed_fraction: float = float(embed_fraction) + if embed_fraction < 0 or embed_fraction > 1: + raise ValueError('embed_fraction must be between 0 and 1') + + def forward(self, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Apply rotary embeddings to input tensors. + + Parameters + ---------- + *tensors + Tensors to apply rotary embeddings to. + Shape must be `(batch, heads, *spatial_dims, channels)`. + """ + if self.embed_fraction == 0.0: + return tensors + + shape = tensors[0].shape + if not all(t.shape == shape for t in tensors): + raise ValueError('All tensors must have the same shape') + device = tensors[0].device + if not all(t.device == device for t in tensors): + raise ValueError('All tensors must be on the same device') + + shape, n_channels_per_head = shape[2:-1], shape[-1] + n_embedding_channels = int(n_channels_per_head * self.embed_fraction) + theta = get_theta(shape, n_embedding_channels, device) + return tuple(self.apply_rotary_emb(t, theta) for t in tensors) + + @staticmethod + def apply_rotary_emb(x: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + """Add rotary embedding to the input tensor. + + Parameters + ---------- + x + Input tensor to modify + theta + Rotation angles + """ + n_emb = theta.shape[-1] * 2 + if n_emb > x.shape[-1]: + raise ValueError(f'Embedding dimension {n_emb} is larger than input dimension {x.shape[-1]}') + (x1, x2), x_unembed = x[..., :n_emb].chunk(2, dim=-1), x[..., n_emb:] + result = torch.cat( + [x1 * theta.cos() - x2 * theta.sin(), x2 * theta.cos() + x1 * theta.sin(), x_unembed], dim=-1 + ) + return result diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index f988855e7..ffb98843d 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -1,6 +1,8 @@ """Neural network modules and utilities.""" from mrpro.nn.ComplexAsChannel import ComplexAsChannel +from mrpro.nn.AbsolutePositionEncoding import AbsolutePositionEncoding +from mrpro.nn.AxialRoPE import AxialRoPE from mrpro.nn.CondMixin import CondMixin from mrpro.nn.DropPath import DropPath from mrpro.nn.FiLM import FiLM @@ -12,6 +14,7 @@ from mrpro.nn.RMSNorm import RMSNorm from mrpro.nn.Residual import Residual from mrpro.nn.Sequential import Sequential +from mrpro.nn import attention from mrpro.nn import data_consistency from mrpro.nn.ndmodules import ( adaptiveAvgPoolND, @@ -24,6 +27,8 @@ ) __all__ = [ + 'AbsolutePositionEncoding', + 'AxialRoPE', 'ComplexAsChannel', 'CondMixin', 'DropPath', @@ -37,6 +42,7 @@ 'Residual', 'Sequential', 'adaptiveAvgPoolND', + 'attention', 'avgPoolND', 'batchNormND', 'convND', diff --git a/src/mrpro/nn/attention/AttentionGate.py b/src/mrpro/nn/attention/AttentionGate.py new file mode 100644 index 000000000..d7fdfeab4 --- /dev/null +++ b/src/mrpro/nn/attention/AttentionGate.py @@ -0,0 +1,74 @@ +"""Attention gate from Attention UNet.""" + +import torch +from torch.nn import Module, ReLU, Sequential, Sigmoid + +from mrpro.nn.ndmodules import convND + + +class AttentionGate(Module): + """Attention gate from Attention UNet. + + The attention mechanism from the attention UNet [OKT18]_. + + References + ---------- + ..[OKT18] Oktay, Ozan, et al. "Attention U-net: Learning where to look for the pancreas." MIDL (2018). + https://arxiv.org/abs/1804.03999 + """ + + def __init__( + self, n_dim: int, channels_gate: int, channels_in: int, channels_hidden: int, concatenate: bool = False + ): + """Initialize the attention gate. + + Parameters + ---------- + n_dim + The dimension, i.e. 1, 2 or 3. + channels_gate + The number of channels in the gate tensor. + channels_in + The number of channels in the input tensor. + channels_hidden + The number of internal, hidden channels. + concatenate + Whether to concatenate the gated signal with the gate signal in the channel dimension (1) + """ + super().__init__() + self.project_gate = convND(n_dim)(channels_gate, channels_hidden, kernel_size=1) + self.project_x = convND(n_dim)(channels_in, channels_hidden, kernel_size=1) + self.psi = Sequential( + ReLU(), + convND(n_dim)(channels_hidden, 1, kernel_size=1), + Sigmoid(), + ) + self.concatenate = concatenate + + def __call__(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + """Apply the attention gate. + + Parameters + ---------- + x + The input tensor. + gate + The gate tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, gate) + + def forward(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + """Apply the attention gate.""" + projected_gate = self.project_gate(gate) + projected_x = self.project_x(x) + projected_gate = torch.nn.functional.interpolate(projected_gate, size=x.shape[2:], mode='nearest') + alpha = self.psi(projected_gate + projected_x) + x = x * alpha + if self.concatenate: + gate = torch.nn.functional.interpolate(gate, size=x.shape[2:], mode='nearest') + x = torch.cat([x, gate], dim=1) + return x diff --git a/src/mrpro/nn/attention/LinearSelfAttention.py b/src/mrpro/nn/attention/LinearSelfAttention.py new file mode 100644 index 000000000..2bab08930 --- /dev/null +++ b/src/mrpro/nn/attention/LinearSelfAttention.py @@ -0,0 +1,98 @@ +"""Linear self-attention.""" + +import torch +from einops import rearrange +from torch import Tensor +from torch.nn import Linear, Module, ReLU + + +class LinearSelfAttention(Module): + """Linear multi-head self-attention via kernel trick. + + Uses a ReLU kernel to compute attention in O(N) [KAT20]_ time and space. + + + References + ---------- + .. [KAT20] Katharopoulos, Angelos, et al. Transformers are RNNs: Fast autoregressive transformers with linear + attention. ICML 2020. https://arxiv.org/abs/2006.16236 + """ + + def __init__( + self, + n_channels_in: int, + n_channels_out: int, + n_heads: int, + eps: float = 1e-6, + features_last: bool = False, + ): + """Initialize linear self-attention layer. + + Parameters + ---------- + n_channels_in + Input channel dimension. + n_channels_out + Output channel dimension. + n_heads + Number of attention heads. + eps + Small epsilon for numerical stability in normalization. + features_last + Whether the channel dimension is the last dimension, as common in transformer models, + or the second dimension, as common in image models. + """ + super().__init__() + self.features_last = features_last + self.eps = eps + self.n_heads = n_heads + channels_per_head = n_channels_in // n_heads + self.to_qkv = Linear(n_channels_in, 3 * channels_per_head * n_heads) + self.kernel_function = ReLU() + self.to_out = Linear(channels_per_head * n_heads, n_channels_out) + + def __call__(self, x: Tensor) -> Tensor: + """Apply linear self-attention. + + Parameters + ---------- + x + Tensor of shape `batch, channels, *spatial_dims` or (`batch, *spatial_dims, channels` if `features_last`) + + Returns + ------- + Tensor after attention, same shape as input. + """ + return super().__call__(x) + + def forward(self, x: Tensor) -> Tensor: + """Apply linear self-attention.""" + orig_dtype = x.dtype + if x.dtype == torch.float16: + x = x.float() + if not self.features_last: + x = x.moveaxis(1, -1) + spatial_shape = x.shape[1:-1] + + qkv = self.to_qkv(x) + query, key, value = rearrange( + qkv, 'batch ... (qkv head channels) -> qkv batch head (...) channels', qkv=3, head=self.n_heads + ) + + query = self.kernel_function(query) + key = self.kernel_function(key) + + # trick to avoid second attention calculation: add normalization slot + value = torch.nn.functional.pad(value, (0, 0, 0, 1), mode='constant', value=1.0) + + value_key = value @ key.transpose(-1, -2) + value_key_query = value_key @ query + normalization = value_key_query[..., -1:, :] + self.eps + attn = value_key_query[..., :-1, :] / normalization + attn = attn.moveaxis(1, -1).flatten(-2) # join heads and channels + out = self.to_out(attn) + out = out.to(orig_dtype) + out = out.unflatten(-2, spatial_shape) + if not self.features_last: + out = out.moveaxis(-1, 1) + return out diff --git a/src/mrpro/nn/attention/MultiHeadAttention.py b/src/mrpro/nn/attention/MultiHeadAttention.py new file mode 100644 index 000000000..069b695df --- /dev/null +++ b/src/mrpro/nn/attention/MultiHeadAttention.py @@ -0,0 +1,106 @@ +"""Multi-head Attention.""" + +import torch +from einops import rearrange +from torch.nn import Linear, Module + +from mrpro.nn.AxialRoPE import AxialRoPE + + +class MultiHeadAttention(Module): + """Multi-head Attention. + + Implements multihead scaled dot-product attention and supports "image-like" inputs, + i.e. `batch, channels, *spatial_dims` as well as "transformer-like" inputs, `batch, sequence, features`. + """ + + def __init__( + self, + n_channels_in: int, + n_channels_out: int, + n_heads: int, + features_last: bool = False, + p_dropout: float = 0.0, + n_channels_cross: int | None = None, + rope_embed_fraction: float = 0.0, + ): + """Initialize the Multi-head Attention. + + Parameters + ---------- + n_channels_in + Number of input channels. + n_channels_out + Number of output channels. + n_heads + number of attention heads + features_last + Whether the features dimension is the last dimension, as common in transformer models, + or the second dimension, as common in image models. + p_dropout + Dropout probability. + n_channels_cross + Number of channels for cross-attention. If `None`, use `n_channels_in`. + rope_embed_fraction + Fraction of channels to embed with RoPE. + """ + super().__init__() + n_channels_kv = n_channels_cross if n_channels_cross is not None else n_channels_in + channels_per_head_q = n_channels_in // n_heads + channels_per_head_kv = n_channels_kv // n_heads + self.to_q = Linear(n_channels_in, channels_per_head_q * n_heads) + self.to_kv = Linear(n_channels_kv, channels_per_head_kv * n_heads * 2) + self.p_dropout = p_dropout + self.features_last = features_last + self.to_out = Linear(n_channels_in, n_channels_out) + self.n_heads = n_heads + self.rope = AxialRoPE(rope_embed_fraction) + + def __call__(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: + """Apply multi-head attention. + + Parameters + ---------- + x + The input tensor. + cross_attention + The key and value tensors for cross-attention. If `None`, self-attention is applied. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cross_attention) + + def _reshape(self, x: torch.Tensor) -> torch.Tensor: + if not self.features_last: + x = x.moveaxis(1, -1) + return x.flatten(1, -2) + + def forward(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: + """Apply multi-head attention.""" + if cross_attention is None: + cross_attention = x + if not self.features_last: + x = x.moveaxis(1, -1) + cross_attention = cross_attention.moveaxis(1, -1) + + query = rearrange(self.to_q(x), 'batch ... (heads channels) -> batch heads ... channels ', heads=self.n_heads) + key, value = rearrange( + self.to_kv(cross_attention), + 'batch ... (kv heads channels) -> kv batch heads ... channels ', + heads=self.n_heads, + kv=2, + ) + query, key = self.rope(query, key) # NO-OP if rope_embed_fraction is 0.0 + query, key, value = query.flatten(2, -2), key.flatten(2, -2), value.flatten(2, -2) + y = torch.nn.functional.scaled_dot_product_attention( + query, key, value, dropout_p=self.p_dropout, is_causal=False + ) + y = rearrange(y, '... heads L channels -> ... L (heads channels)') + out = self.to_out(y).reshape(x.shape) + + if not self.features_last: + out = out.moveaxis(-1, 1) + + return out diff --git a/src/mrpro/nn/attention/NeighborhoodSelfAttention.py b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py new file mode 100644 index 000000000..27916ee6f --- /dev/null +++ b/src/mrpro/nn/attention/NeighborhoodSelfAttention.py @@ -0,0 +1,237 @@ +"""Neighborhood Self Attention.""" + +from collections.abc import Sequence +from functools import cache, reduce +from typing import TYPE_CHECKING, TypeVar, cast + +import torch +from einops import rearrange +from packaging.version import parse as parse_version +from torch.nn import Linear, Module + +from mrpro.nn.AxialRoPE import AxialRoPE +from mrpro.utils.to_tuple import to_tuple + +T = TypeVar('T') + +if TYPE_CHECKING or parse_version(torch.__version__) >= parse_version('2.6'): + from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention +else: + + class BlockMask: + """Dummy class for older PyTorch versions.""" + + +_compiled_flex_attention = torch.compile( + lambda q, k, v, mask: flex_attention(q, k, v, block_mask=mask), + dynamic=False, +) + + +@torch.compiler.disable +@cache +def neighborhood_mask( + device: str, + input_size: torch.Size, + kernel_size: int | tuple[int, ...], + dilation: int | tuple[int, ...] = 1, + circular: bool | tuple[bool, ...] = False, +) -> BlockMask: # pragma: no cover + """Create a flex attention block mask for neighborhood attention. + + This function defines which key/value pairs a query can attend to based + on a local neighborhood. The neighborhood is defined by `kernel_size` + and `dilation` and can be circular (wrapping around edges). + + Parameters + ---------- + input_size + The dimensions of the input tensor (e.g., (H, W) for 2D). + kernel_size + The size of the attention neighborhood window. Can be a single + integer for a symmetric window or a sequence of integers for + each dimension. + dilation + The dilation factor for the neighborhood + Can be a single integer for a symmetric window or a sequence + of integers for each dimension. + circular + Whether the neighborhood wraps around the edges (circular padding). + Can be a single boolean or a sequence of booleans. + device + The device to create the mask on. + + Returns + ------- + A mask object suitable for `flex_attention` that defines the + allowed attention connections. + """ + kernel_size_tuple, dilation_tuple, circular_tuple = ( + to_tuple(len(input_size), x) for x in (kernel_size, dilation, circular) + ) + + def unravel_index(idx: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Convert a flat 1D index into multi-dimensional coordinates.""" + idx = idx.clone() + coords = [] + for dim in reversed(input_size): + coords.append(idx % dim) + idx = torch.div(idx, dim, rounding_mode='floor').long() + coords.reverse() + return tuple(coords) + + def mask( + _batch: torch.Tensor, + _head: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + ) -> torch.Tensor: + """Determine if a query can attend to a key/value pair.""" + q_coord = unravel_index(q_idx) + kv_coord = unravel_index(kv_idx) + + masks = [] + for input_, kernel_, dilation_, circular_, q_, kv_ in zip( + input_size, + kernel_size_tuple, + dilation_tuple, + circular_tuple, + q_coord, + kv_coord, + strict=False, + ): + masks.append((q_ % dilation_) == (kv_ % dilation_)) + kernel_dilation = kernel_ * dilation_ + window_left = kernel_dilation // 2 + window_right = (kernel_dilation // 2) + ((kernel_dilation % 2) - 1) + if circular_: + left = (q_ - kv_ + input_) % input_ + right = (kv_ - q_ + input_) % input_ + masks.append((left <= window_left) | (right <= window_right)) + else: + center = q_.clamp(window_left, input_ - 1 - window_right) + left = center - kv_ + right = kv_ - center + masks.append(((left >= 0) & (left <= window_left)) | ((right >= 0) & (right <= window_right))) + return reduce(lambda x, y: x & y, masks) + + qkv_len = input_size.numel() + return create_block_mask(mask, B=None, H=None, Q_LEN=qkv_len, KV_LEN=qkv_len, device=torch.device(device)) + + +class NeighborhoodSelfAttention(Module): + """Attention where each query attends to a neighborhood of the key and value. + + Neighborhood attention is a type of attention where each query attends to a neighborhood of the key and value. + It is a more efficient alternative to regular attention, especially for large input sizes [NAT]_. + + This implementation uses `~torch.nn.attention.flex_attention`. For a more efficient implementation, + see also [NATTEN]_. + + + References + ---------- + .. [NAT] Hassani, A. et al. "Neighborhood Attention Transformer" CVPR, 2023, https://arxiv.org/abs/2204.07143 + .. [NATTEN] https://github.com/SHI-Labs/NATTEN/ + """ + + n_head: int + kernel_size: int | tuple[int, ...] + dilation: int | tuple[int, ...] + circular: bool | tuple[bool, ...] + features_last: bool + + def __init__( + self, + n_channels_in: int, + n_channels_out: int, + n_heads: int, + kernel_size: int | Sequence[int], + dilation: int | Sequence[int] = 1, + circular: bool | Sequence[bool] = False, + features_last: bool = False, + rope_embed_fraction: float = 1.0, + ) -> None: + """Initialize a neighborhood attention module. + + The parameters `kernel_size`, `dilation`, and `circular` can either be sequences, interpreted as per-dimension + values, or scalars, interpreted as the same value for all dimensions. + + Parameters + ---------- + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. + n_heads + The number of attention heads. + kernel_size + The size of the attention neighborhood window. + dilation + The dilation factor for the neighborhood. + circular + Whether the neighborhood wraps around the edges (circular padding) + features_last + Whether the channels are in the last dimension of the tensor, as common in visíon transformers. + Otherwise, assume the channels are in the second dimension, as common in CNN models. + rope_embed_fraction + Fraction of channels to embed with RoPE. + + """ + if parse_version(torch.__version__) < parse_version('2.6.0'): + raise NotImplementedError('NeighborhoodSelfAttention requires PyTorch 2.6.0 or higher') + super().__init__() + self.n_head = n_heads + self.kernel_size = kernel_size if isinstance(kernel_size, int) else tuple(kernel_size) + self.dilation = dilation if isinstance(dilation, int) else tuple(dilation) + self.circular = circular if isinstance(circular, bool) else tuple(circular) + self.features_last = features_last + channels_per_head = n_channels_in // n_heads + self.to_qkv = Linear(n_channels_in, 3 * channels_per_head * n_heads) + self.to_out = Linear(channels_per_head * n_heads, n_channels_out) + self.rope = AxialRoPE(rope_embed_fraction) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply neighborhood attention to the input tensor. + + Parameters + ---------- + x + The input tensor, with shape `(batch, channels, *spatial_dims)` + or `(batch, *spatial_dims, channels)` (if `features_last`). + + Returns + ------- + The output tensor after attention, with the same shape as the input tensor. + """ + if not self.features_last: + x = x.moveaxis(1, -1) + spatial_shape = x.shape[1:-1] + qkv = self.to_qkv(x) + query, key, value = rearrange( + qkv, + 'batch ... (qkv heads channels) -> qkv batch heads (...) channels', + qkv=3, + heads=self.n_head, + ) + query, key = self.rope(query, key) + query, key, value = query.contiguous(), key.contiguous(), value.contiguous() + device = str(qkv.device) + mask = neighborhood_mask( + device=device, + input_size=spatial_shape, + kernel_size=self.kernel_size, + dilation=self.dilation, + circular=self.circular, + ) + mask = torch.compiler.assume_constant_result(mask) + if torch.compiler.is_compiling(): + out = cast(torch.Tensor, flex_attention(query, key, value, block_mask=mask)) + else: + out = cast(torch.Tensor, _compiled_flex_attention(query, key, value, mask)) + out = rearrange(out, 'batch head sequence channels -> batch sequence (head channels)') + out = self.to_out(out) + out = out.unflatten(-2, spatial_shape) + if not self.features_last: + out = out.moveaxis(-1, 1) + return out diff --git a/src/mrpro/nn/attention/ShiftedWindowAttention.py b/src/mrpro/nn/attention/ShiftedWindowAttention.py new file mode 100644 index 000000000..0935ff63d --- /dev/null +++ b/src/mrpro/nn/attention/ShiftedWindowAttention.py @@ -0,0 +1,131 @@ +"""Shifted Window Attention.""" + +import warnings + +import torch +from einops import rearrange +from torch.nn import Linear, Module + +from mrpro.utils.reshape import ravel_multi_index +from mrpro.utils.sliding_window import sliding_window + + +class ShiftedWindowAttention(Module): + """Shifted Window Attention. + + (Shifted) Window Attention calculates attention over windows of the input. + It was introduced in Swin Transformer [SWIN]_ and is used in Uformer. + + References + ---------- + .. [SWIN] Liu, Ze, et al. "Swin transformer: Hierarchical vision transformer using shifted windows." ICCV 2021. + """ + + rel_position_index: torch.Tensor + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + n_heads: int, + window_size: int = 7, + shifted: bool = True, + features_last: bool = False, + ): + """Initialize the ShiftedWindowAttention module. + + Parameters + ---------- + n_dim + The dimension of the input. + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. + n_heads + The number of attention heads. The number if channels per head is ``channels // n_heads``. + window_size + The size of the window. + shifted + Whether to shift the window. + features_last + Whether the features are last in the input tensor or in the second dimension. + """ + super().__init__() + self.n_heads = n_heads + self.window_size = window_size + self.shifted = shifted + self.features_last = features_last + channels_per_head = n_channels_in // n_heads + self.to_qkv = Linear(channels_per_head * n_heads, 3 * channels_per_head * n_heads) + self.to_out = Linear(channels_per_head * n_heads, n_channels_out) + self.n_dim = n_dim + coords_1d = torch.arange(window_size) + coords_nd = torch.stack(torch.meshgrid(*([coords_1d] * n_dim), indexing='ij'), 0).flatten(1) + rel_coords = coords_nd[:, :, None] - coords_nd[:, None, :] # (dim, window_size**dim, window_size**dim) + rel_coords += window_size - 1 # shift to >=0 + rel_position_index = ravel_multi_index(tuple(rel_coords), (2 * window_size - 1,) * n_dim) + self.register_buffer('rel_position_index', rel_position_index) + + self.relative_position_bias_table = torch.nn.Parameter(torch.empty((2 * window_size - 1) ** n_dim, n_heads)) + torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02, a=-0.04, b=0.04) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the ShiftedWindowAttention. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the ShiftedWindowAttention.""" + if not self.features_last: + x = x.moveaxis(1, -1) # now it is features last + if self.shifted: + x = torch.roll(x, (-(self.window_size // 2),) * self.n_dim, dims=tuple(range(-self.n_dim - 1, -1))) + + padding = [] + for s in x.shape[-self.n_dim - 1 : -1]: + target = ((s + self.window_size - 1) // self.window_size) * self.window_size + padding.extend([target - s, 0]) + x_padded = torch.nn.functional.pad(x, (0, 0, *padding[::-1]), mode='circular') if any(padding) else x + + qkv = self.to_qkv(x_padded) + windowed = sliding_window( + qkv, window_shape=self.window_size, stride=self.window_size, dim=range(-self.n_dim - 1, -1) + ) + q, k, v = rearrange( + windowed.flatten(-self.n_dim - 1, -2), + '... sequence (qkv heads channels)->qkv ... heads sequence channels', + heads=self.n_heads, + qkv=3, + ) + bias = rearrange(self.relative_position_bias_table[self.rel_position_index], 'wd1 wd2 heads -> 1 heads wd1 wd2') + with warnings.catch_warnings(): + # Inductor in torch 2.6 warns for small batch*n_patches*n_heads about suboptimal softmax compilation. + warnings.filterwarnings('ignore', message='.*softmax.*') + attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=bias) + attention = rearrange(attention, '... head sequence channels->... sequence (head channels)') + attention = attention.unflatten(-2, windowed.shape[-self.n_dim - 1 : -1]) + # permute (in 3d) batch channels z y x wz wy wx -> batch channels wz z wy y wx x + attention = attention.moveaxis(list(range(self.n_dim)), list(range(2, 2 + 2 * self.n_dim, 2))) + attention = attention.reshape(x_padded.shape) + if any(padding): + crop_idx = (Ellipsis, *[slice(0, s) for s in x.shape[-self.n_dim - 1 : -1]], slice(None)) + attention = attention[crop_idx] + if self.shifted: + attention = torch.roll( + attention, (self.window_size // 2,) * self.n_dim, dims=tuple(range(-self.n_dim - 1, -1)) + ) + out = self.to_out(attention) + if not self.features_last: + out = out.moveaxis(-1, 1) + return out diff --git a/src/mrpro/nn/attention/SpatialTransformerBlock.py b/src/mrpro/nn/attention/SpatialTransformerBlock.py new file mode 100644 index 000000000..1fcd7c534 --- /dev/null +++ b/src/mrpro/nn/attention/SpatialTransformerBlock.py @@ -0,0 +1,216 @@ +"""Spatial transformer block.""" + +from collections.abc import Sequence +from typing import Literal + +import torch +from torch.nn import Dropout, Linear, Module + +from mrpro.nn.attention.MultiHeadAttention import MultiHeadAttention +from mrpro.nn.attention.NeighborhoodSelfAttention import NeighborhoodSelfAttention +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.GEGLU import GEGLU +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.RMSNorm import RMSNorm +from mrpro.nn.Sequential import Sequential + + +def zero_init(m: Module) -> Module: + """Initialize module weights and bias to zero.""" + if hasattr(m, 'weight') and isinstance(m.weight, torch.Tensor): + torch.nn.init.zeros_(m.weight) + if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor): + torch.nn.init.zeros_(m.bias) + return m + + +class BasicTransformerBlock(CondMixin, Module): + """Basic vision transformer block.""" + + def __init__( + self, + channels: int, + n_heads: int, + p_dropout: float = 0.0, + cond_dim: int = 0, + mlp_ratio: float = 4, + features_last: bool = False, + rope_embed_fraction: float = 0.0, + attention_neighborhood: int | None = None, + ): + """Initialize the basic transformer block. + + Parameters + ---------- + channels + Number of channels in the input and output. + n_heads + Number of attention heads. + p_dropout + Dropout probability. + cond_dim + Number of channels in the conditioning tensor. + mlp_ratio + Ratio of the hidden dimension to the input dimension. + features_last + Whether the features are last in the input tensor. + rope_embed_fraction + Fraction of channels to embed with RoPE. + attention_neighborhood + If not None, use neighborhood self attention with the given neighborhood size instead + of global self attention. + """ + super().__init__() + self.features_last = features_last + + if attention_neighborhood is None: + attention: Module = MultiHeadAttention( + n_channels_in=channels, + n_channels_out=channels, + n_heads=n_heads, + p_dropout=p_dropout, + features_last=True, + rope_embed_fraction=rope_embed_fraction, + ) + else: + if p_dropout > 0: + raise ValueError('p_dropout > 0 is not supported for neighborhood self attention') + attention = NeighborhoodSelfAttention( + n_channels_in=channels, + n_channels_out=channels, + n_heads=n_heads, + features_last=True, + kernel_size=attention_neighborhood, + circular=True, + rope_embed_fraction=rope_embed_fraction, + ) + self.selfattention = Sequential(LayerNorm(channels, features_last=True), attention) + hidden_dim = int(channels * mlp_ratio) + self.ff = Sequential( + LayerNorm(channels, features_last=True, cond_dim=cond_dim), + GEGLU(channels, hidden_dim, features_last=True), + Dropout(p_dropout), + Linear(hidden_dim, channels), + ) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the basic transformer block. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. If None, no conditioning is applied. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the basic transformer block.""" + if not self.features_last: + x = x.moveaxis(1, -1).contiguous() + x = self.selfattention(x) + x + x = self.ff(x, cond=cond) + x + if not self.features_last: + x = x.moveaxis(-1, 1).contiguous() + return x + + +class SpatialTransformerBlock(CondMixin, Module): + """Spatial transformer block.""" + + def __init__( + self, + dim_groups: Sequence[tuple[int, ...]], + channels: int, + n_heads: int, + depth: int = 1, + p_dropout: float = 0.0, + cond_dim: int = 0, + rope_embed_fraction: float = 0.0, + attention_neighborhood: int | None = None, + features_last: bool = False, + norm: Literal['group', 'rms'] = 'group', + ): + """Initialize the spatial transformer block. + + Parameters + ---------- + dim_groups + Groups of spatial dimensions for separate attention mechanisms. + channels + Number of channels in the input and output. + n_heads + Number of attention heads for each group. + depth + Number of transformer blocks for each group. + p_dropout + Dropout probability. + cond_dim + Dimension of the conditioning tensor. + rope_embed_fraction + Fraction of channels to embed with RoPE. + attention_neighborhood + If not None, use NeighborhoodSelfAttention with the given neighborhood size instead of MultiHeadAttention. + features_last + Whether the features are last in the input tensor, as common in transformer models. + norm + Whether to use GroupNorm or RMSNorm. + """ + super().__init__() + hidden_dim = n_heads * (channels // n_heads) + match norm: + case 'group': + self.norm: Module = GroupNorm(channels, features_last=features_last) + case 'rms': + self.norm = RMSNorm(channels, features_last=features_last) + case _: + raise ValueError(f'Invalid norm: {norm}') + self.features_last = features_last + self.proj_in = Linear(channels, hidden_dim) + self.transformer_blocks = Sequential() + for group in (g for _ in range(depth) for g in dim_groups): + if not self.features_last: + group = tuple(g - 1 if g < 0 else g for g in group) + block = BasicTransformerBlock( + hidden_dim, + n_heads, + p_dropout=p_dropout, + cond_dim=cond_dim, + features_last=True, + rope_embed_fraction=rope_embed_fraction, + attention_neighborhood=attention_neighborhood, + ) + self.transformer_blocks.append(PermutedBlock(group, block, features_last=True)) + self.proj_out = zero_init(Linear(hidden_dim, channels)) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the spatial transformer block.""" + skip = x + h = self.norm(x) + if not self.features_last: + h = h.movedim(1, -1) + h = self.proj_in(h) + h = self.transformer_blocks(h, cond=cond) + h = self.proj_out(h) + if not self.features_last: + h = h.movedim(-1, 1) + return skip + h + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the spatial transformer block. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. If None, no conditioning is applied. + + Returns + ------- + Output tensor. + """ + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/attention/SqueezeExcitation.py b/src/mrpro/nn/attention/SqueezeExcitation.py new file mode 100644 index 000000000..5f7802c75 --- /dev/null +++ b/src/mrpro/nn/attention/SqueezeExcitation.py @@ -0,0 +1,57 @@ +"""Squeeze-and-Excitation block.""" + +import torch +from torch.nn import Module, ReLU, Sigmoid + +from mrpro.nn.ndmodules import adaptiveAvgPoolND, convND +from mrpro.nn.Sequential import Sequential + + +class SqueezeExcitation(Module): + """Squeeze-and-Excitation block. + + Sequeeze-and-Excitation block from [SE]_. + + References + ---------- + ..[SE] Hu, Jie, Li Shen, and Gang Sun. "Squeeze-and-excitation networks." CVPR 2018, https://arxiv.org/abs/1709.01507 + """ + + def __init__(self, n_dim: int, n_channels_input: int, n_channels_squeeze: int) -> None: + """Initialize SqueezeExcitation. + + Parameters + ---------- + n_dim + The dimension of the input tensor. + n_channels_input + The number of channels in the input tensor. + n_channels_squeeze + The number of channels in the squeeze tensor. + """ + super().__init__() + self.scale = Sequential( + adaptiveAvgPoolND(n_dim)(1), + convND(n_dim)(n_channels_input, n_channels_squeeze, kernel_size=1), + ReLU(), + convND(n_dim)(n_channels_squeeze, n_channels_input, kernel_size=1), + Sigmoid(), + ) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply SqueezeExcitation. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply SqueezeExcitation.""" + return x * self.scale(x) diff --git a/src/mrpro/nn/attention/TransposedAttention.py b/src/mrpro/nn/attention/TransposedAttention.py new file mode 100644 index 000000000..88e993c8f --- /dev/null +++ b/src/mrpro/nn/attention/TransposedAttention.py @@ -0,0 +1,76 @@ +"""Transposed Attention from Restormer.""" + +import torch +from einops import rearrange +from torch.nn import Module, Parameter + +from mrpro.nn.ndmodules import convND + + +class TransposedAttention(Module): + """Transposed Self Attention from Restormer. + + Implements the transposed self-attention, i.e. channel-wise multihead self-attention, + layer from Restormer [ZAM22]_. + + References + ---------- + .. [ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for high-resolution image restoration." + CVPR 2022, https://arxiv.org/pdf/2111.09881.pdf + """ + + def __init__(self, n_dim: int, n_channels_in: int, n_channels_out: int, n_heads: int): + """Initialize a TransposedAttention layer. + + Parameters + ---------- + n_dim + input dimension + n_channels_in + Number of channels in the input tensor. + n_channels_out + Number of channels in the output tensor. + n_heads + Number of attention heads. + """ + super().__init__() + self.n_heads = n_heads + self.temperature = Parameter(torch.ones(n_heads, 1, 1)) + channels_per_head = n_channels_in // n_heads + self.to_qkv = convND(n_dim)(n_channels_in, channels_per_head * n_heads * 3, kernel_size=1) + self.qkv_dwconv = convND(n_dim)( + channels_per_head * n_heads * 3, + channels_per_head * n_heads * 3, + kernel_size=3, + groups=n_channels_in * 3, + padding=1, + bias=False, + ) + self.to_out = convND(n_dim)(channels_per_head * n_heads, n_channels_out, kernel_size=1) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply transposed attention. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply transposed attention.""" + qkv = self.qkv_dwconv(self.to_qkv(x)) + q, k, v = rearrange(qkv, 'b (qkv heads channels) ... -> qkv b heads (...) channels', heads=self.n_heads, qkv=3) + q = torch.nn.functional.normalize(q, dim=-1) * self.temperature + k = torch.nn.functional.normalize(k, dim=-1) + attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=1.0) + out = rearrange(attention, '... heads points channels -> ... (heads channels) points').unflatten( + -1, x.shape[2:] + ) + out = self.to_out(out) + return out diff --git a/src/mrpro/nn/attention/__init__.py b/src/mrpro/nn/attention/__init__.py new file mode 100644 index 000000000..719ff1409 --- /dev/null +++ b/src/mrpro/nn/attention/__init__.py @@ -0,0 +1,17 @@ +from mrpro.nn.attention.AttentionGate import AttentionGate +from mrpro.nn.attention.LinearSelfAttention import LinearSelfAttention +from mrpro.nn.attention.NeighborhoodSelfAttention import NeighborhoodSelfAttention +from mrpro.nn.attention.ShiftedWindowAttention import ShiftedWindowAttention +from mrpro.nn.attention.SqueezeExcitation import SqueezeExcitation +from mrpro.nn.attention.TransposedAttention import TransposedAttention +from mrpro.nn.attention.SpatialTransformerBlock import SpatialTransformerBlock + +__all__ = [ + "AttentionGate", + "LinearSelfAttention", + "NeighborhoodSelfAttention", + "ShiftedWindowAttention", + "SpatialTransformerBlock", + "SqueezeExcitation", + "TransposedAttention" +] \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 8490674e9..b2aa1cba2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,11 +12,17 @@ from mrpro.data.enums import AcqFlags from mrpro.utils import RandomGenerator from mrpro.utils.reshape import unsqueeze_tensors_left +from packaging.version import parse as parse_version from xsdata.models.datatype import XmlDate, XmlTime from tests.data import IsmrmrdRawTestData from tests.phantoms import EllipsePhantomTestData +minimal_torch_26 = pytest.mark.xfail( + parse_version(torch.__version__) < parse_version('2.6'), + reason='Requires PyTorch >= 2.6', +) + def generate_random_encodingcounter_properties(rng: RandomGenerator) -> dict[str, Any]: return { diff --git a/tests/nn/test_ape.py b/tests/nn/test_ape.py new file mode 100644 index 000000000..9f71444fc --- /dev/null +++ b/tests/nn/test_ape.py @@ -0,0 +1,27 @@ +"""Tests for absolute position encoding""" + +import pytest +import torch +from mrpro.nn import AbsolutePositionEncoding +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_absolute_position_encodings(device: str) -> None: + """Test absolute position encoding.""" + n_features = 32 + shape = (1, 2 * n_features, 32, 32) + ape = AbsolutePositionEncoding(2, n_features, True, 128).to(device) + rng = RandomGenerator(444) + x1 = rng.float32_tensor(shape).to(device) + x2 = rng.float32_tensor(shape).to(device) + y1, y2 = ape(x1), ape(x2) + assert y1.shape == x1.shape + torch.testing.assert_close(y1 - x1, y2 - x2) + assert (x1[:, n_features:] == y1[:, n_features:]).all() # unembedded features diff --git a/tests/nn/test_attentiongate.py b/tests/nn/test_attentiongate.py new file mode 100644 index 000000000..10d30cb07 --- /dev/null +++ b/tests/nn/test_attentiongate.py @@ -0,0 +1,51 @@ +"""Tests for AttentionGate module.""" + +from collections.abc import Sequence + +import pytest +from mrpro.nn.attention import AttentionGate +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_dim', 'n_channels_gate', 'n_channels_in', 'n_channels_hidden', 'input_shape', 'gate_shape'), + [ + (2, 32, 32, 16, (1, 32, 32, 32), (1, 32, 16, 16)), + (3, 32, 4, 8, (2, 4, 16, 16, 16), (2, 32, 16, 16, 16)), + ], +) +def test_attention_gate( + n_dim: int, + n_channels_gate: int, + n_channels_in: int, + n_channels_hidden: int, + input_shape: Sequence[int], + gate_shape: Sequence[int], + device: str, +) -> None: + """Test AttentionGate output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + gate = rng.float32_tensor(gate_shape).to(device).requires_grad_(True) + attn = AttentionGate( + n_dim=n_dim, channels_gate=n_channels_gate, channels_in=n_channels_in, channels_hidden=n_channels_hidden + ).to(device) + output = attn(x, gate) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert gate.grad is not None, 'No gradient computed for gate' + assert not output.isnan().any(), 'NaN values in output' + assert not gate.isnan().any(), 'NaN values in gate' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert not gate.grad.isnan().any(), 'NaN values in gate gradients' + assert attn.project_gate.weight.grad is not None, 'No gradient computed for project_gate' + assert attn.project_x.weight.grad is not None, 'No gradient computed for project_x' + assert attn.psi[1].weight.grad is not None, 'No gradient computed for psi' diff --git a/tests/nn/test_linearselfattention.py b/tests/nn/test_linearselfattention.py new file mode 100644 index 000000000..11d17b301 --- /dev/null +++ b/tests/nn/test_linearselfattention.py @@ -0,0 +1,58 @@ +"""Tests for LinearSelfAttention module.""" + +import pytest +from mrpro.nn.attention import LinearSelfAttention +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels_in', 'n_channels_out', 'n_heads', 'input_shape', 'features_last'), + [ + (32, 32, 4, (1, 32, 32, 32), False), + (64, 64, 8, (2, 64, 16, 16), False), + (16, 16, 2, (1, 16, 16, 16), True), + ], +) +def test_linear_self_attention( + n_channels_in: int, + n_channels_out: int, + n_heads: int, + input_shape: tuple[int, ...], + features_last: bool, + device: str, +) -> None: + """Test LinearSelfAttention output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + + attn = LinearSelfAttention( + n_channels_in=n_channels_in, + n_channels_out=n_channels_out, + n_heads=n_heads, + features_last=features_last, + ).to(device) + + if features_last: + output = attn(x.moveaxis(1, -1)).moveaxis(-1, 1) + else: + output = attn(x) + + expected_shape = (x.shape[0], n_channels_out, *x.shape[2:]) + assert output.shape == expected_shape, f'Output shape {output.shape} != expected shape {expected_shape}' + assert not output.isnan().any(), 'NaN values in output' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + assert attn.to_qkv.weight.grad is not None, 'No gradient computed for to_qkv.weight' + assert attn.to_qkv.bias.grad is not None, 'No gradient computed for to_qkv.bias' + assert attn.to_out.weight.grad is not None, 'No gradient computed for to_out.weight' + assert attn.to_out.bias.grad is not None, 'No gradient computed for to_out.bias' diff --git a/tests/nn/test_neighborhoodselfattention.py b/tests/nn/test_neighborhoodselfattention.py new file mode 100644 index 000000000..3b51e963d --- /dev/null +++ b/tests/nn/test_neighborhoodselfattention.py @@ -0,0 +1,149 @@ +"""Tests for NeighborhoodSelfAttention module.""" + +import pytest +import torch +from mrpro.nn.attention.NeighborhoodSelfAttention import NeighborhoodSelfAttention +from mrpro.utils import RandomGenerator +from tests.conftest import minimal_torch_26 + + +@minimal_torch_26 +@pytest.mark.parametrize( + 'device', + [ + pytest.param( + 'cpu', + id='cpu', + marks=pytest.mark.skip( + reason='Flex Attention backward not supported on CPU. https://github.com/pytorch/pytorch/issues/148752' + ), + ), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels_in', 'n_channels_out', 'n_heads', 'kernel_size', 'input_shape', 'features_last'), + [ + (2, 3, 1, 2, (1, 2, 16, 16), False), + (3, 2, 2, 4, (1, 3, 8, 8, 8, 8), True), + ], + ids=['2d_kernel2', '4d_features-last_kernel4'], +) +def test_neighborhood_self_attention_backward( + n_channels_in: int, + n_channels_out: int, + n_heads: int, + kernel_size: int, + input_shape: tuple[int, ...], + features_last: bool, + device: str, +) -> None: + """Test NeighborhoodSelfAttention output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + + attention = NeighborhoodSelfAttention( + n_channels_in=n_channels_in, + n_channels_out=n_channels_out, + n_heads=n_heads, + kernel_size=kernel_size, + features_last=features_last, + ).to(device) + + if features_last: + output = attention(x.moveaxis(1, -1)).moveaxis(-1, 1) + else: + output = attention(x) + + expected_shape = (input_shape[0], n_channels_out, *input_shape[2:]) + assert output.shape == expected_shape + assert not output.isnan().any(), 'NaN values in output' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + assert attention.to_qkv.weight.grad is not None, 'No gradient computed for to_qkv.weight' + assert attention.to_qkv.bias.grad is not None, 'No gradient computed for to_qkv.bias' + assert attention.to_out.weight.grad is not None, 'No gradient computed for to_out.weight' + assert attention.to_out.bias.grad is not None, 'No gradient computed for to_out.bias' + + +@minimal_torch_26 +@pytest.mark.cuda +@pytest.mark.parametrize( + ('kernel_size', 'dilation', 'circular', 'rope'), + [ + (3, 1, False, True), + (5, 2, True, False), + (7, 1, False, True), + ], +) +def test_neighborhood_attention_variants(kernel_size: int, dilation: int, circular: bool, rope: bool) -> None: + """Test NeighborhoodSelfAttention with different neighborhood configurations.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 16, 16)).cuda() + + attention = NeighborhoodSelfAttention( + n_channels_in=32, + n_channels_out=32, + n_heads=4, + kernel_size=kernel_size, + dilation=dilation, + circular=circular, + rope_embed_fraction=1.0 if rope else 0.0, + ) + output = attention(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + +@minimal_torch_26 +@pytest.mark.parametrize( + ('kernel_size', 'circular', 'input_shape'), + [ + (11, False, (1, 8, 32, 32)), + (3, True, (1, 8, 64, 64)), + ], + ids=['regular', 'circular'], +) +@torch.no_grad() +def test_neighborhood_constraint(kernel_size: int, circular: bool, input_shape: tuple[int, int, int, int]) -> None: + """Test that neighborhood attention only affects pixels within the kernel window.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape) + attention = NeighborhoodSelfAttention( + n_channels_in=8, + n_channels_out=8, + n_heads=2, + kernel_size=kernel_size, + dilation=1, + circular=circular, + ) + output_original = attention(x) + x_modified = x.clone() + test_point = (input_shape[-2] - 2, input_shape[-1] - 2) + x_modified[..., test_point[0], test_point[1]] += 1.0 + output_modified = attention(x_modified) + + diff = output_modified - output_original + changed_pixels = torch.abs(diff).sum(dim=(0, 1)) > 1e-6 + + half_kernel = kernel_size // 2 + h, w = input_shape[2], input_shape[3] + + i_coords, j_coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') + + if circular: + h_dist = torch.minimum((i_coords - test_point[0]) % h, (test_point[0] - i_coords) % h) + w_dist = torch.minimum((j_coords - test_point[1]) % w, (test_point[1] - j_coords) % w) + in_neighborhood = (h_dist <= half_kernel) & (w_dist <= half_kernel) + else: + h_min, h_max = max(0, test_point[0] - half_kernel), min(h, test_point[0] + half_kernel + 1) + w_min, w_max = max(0, test_point[1] - half_kernel), min(w, test_point[1] + half_kernel + 1) + in_neighborhood = (i_coords >= h_min) & (i_coords < h_max) & (j_coords >= w_min) & (j_coords < w_max) + + neighborhood_changed = changed_pixels[in_neighborhood].all() + outside_changed = changed_pixels[~in_neighborhood].any() + + assert neighborhood_changed, 'Not all pixels in the neighborhood changed, which indicates a problem' + assert not outside_changed, 'Pixels outside the neighborhood changed, which violates the constraint' diff --git a/tests/nn/test_rope.py b/tests/nn/test_rope.py new file mode 100644 index 000000000..665c4bed4 --- /dev/null +++ b/tests/nn/test_rope.py @@ -0,0 +1,36 @@ +"""Tests for AxialRoPE module.""" + +import pytest +import torch +from mrpro.nn import AxialRoPE +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_rope(device: torch.device) -> None: + """Test AxialRoPE rotation and embedding functionality.""" + shape = (10, 10) + n_heads = 2 + n_channels = 64 + n_embed = int(0.5 * n_channels) + q, k = RandomGenerator(seed=42).float32_tensor((2, 1, n_heads, *shape, n_channels), low=0.5).to(device) + + rope = AxialRoPE(embed_fraction=0.5) + (q_rope, k_rope) = rope(q, k) + + assert q_rope.shape == q.shape + assert k_rope.shape == k.shape + + # non embedded channels should be the same + torch.testing.assert_close(q[..., n_embed:], q_rope[..., n_embed:]) + torch.testing.assert_close(k[..., n_embed:], k_rope[..., n_embed:]) + + # other should change + assert not torch.isclose(q_rope[..., :n_embed], q[..., :n_embed]).all() + assert not torch.isclose(k_rope[..., :n_embed], k[..., :n_embed]).all() diff --git a/tests/nn/test_shiftedwindowattention.py b/tests/nn/test_shiftedwindowattention.py new file mode 100644 index 000000000..5d416efee --- /dev/null +++ b/tests/nn/test_shiftedwindowattention.py @@ -0,0 +1,61 @@ +"""Tests for ShiftedWindowAttention module.""" + +import pytest +from mrpro.nn.attention import ShiftedWindowAttention +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_dim', 'window_size', 'shifted'), + [ + (2, 8, False), + (4, 4, True), + ], +) +def test_shifted_window_attention(n_dim: int, window_size: int, shifted: bool, device: str) -> None: + """Test ShiftedWindowAttention output shape and backpropagation.""" + n_batch, n_channels, n_heads = 2, 8, 2 + spatial_shape = (window_size * 4,) * n_dim + rng = RandomGenerator(13) + x = rng.float32_tensor((n_batch, n_channels, *spatial_shape)).to(device).requires_grad_(True) + swin = ShiftedWindowAttention( + n_dim=n_dim, + n_channels_in=n_channels, + n_channels_out=n_channels, + n_heads=n_heads, + window_size=window_size, + shifted=shifted, + ).to(device) + out = swin(x) + assert out.shape == x.shape, f'Output shape {out.shape} != input shape {x.shape}' + assert not out.isnan().any(), 'NaN values in output' + out.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert swin.to_qkv.weight.grad is not None, 'No gradient computed for to_qkv.weight' + assert swin.relative_position_bias_table.grad is not None, 'No gradient computed for relative_position_bias_table' + + +@pytest.mark.parametrize('shifted', [True, False], ids=['shifted', 'non-shifted']) +def test_shifted_window_attention_size_mismatch(shifted: bool): + n_batch, n_channels, n_heads, n_dim, window_size = 3, 4, 2, 2, 7 + spatial_shape = (window_size * 4 + 1,) * n_dim + rng = RandomGenerator(13) + x = rng.float32_tensor((n_batch, n_channels, *spatial_shape)) + swin = ShiftedWindowAttention( + n_dim=n_dim, + n_channels_in=n_channels, + n_channels_out=n_channels, + n_heads=n_heads, + window_size=window_size, + shifted=shifted, + ) + out = swin(x) + assert out.shape == x.shape, f'Output shape {out.shape} != input shape {x.shape}' diff --git a/tests/nn/test_spatialtransformerblock.py b/tests/nn/test_spatialtransformerblock.py new file mode 100644 index 000000000..4c78ceecd --- /dev/null +++ b/tests/nn/test_spatialtransformerblock.py @@ -0,0 +1,142 @@ +"""Test SpatialTransformerBlock""" + +from collections.abc import Sequence +from typing import Literal, cast + +import pytest +import torch +from mrpro.nn.attention import SpatialTransformerBlock +from mrpro.utils import RandomGenerator +from tests.conftest import minimal_torch_26 + + +@minimal_torch_26 +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + ('channels', 'cond_dim', 'attention_neighborhood', 'features_last', 'norm', 'input_shape'), + [ + pytest.param(32, 16, None, True, 'group', (16, 16), id='2d-cond-group-last-global'), + pytest.param(64, 16, 7, False, 'group', (16, 16), marks=minimal_torch_26, id='2d-cond-group-first-NA'), + pytest.param(64, 16, 5, True, 'group', (16, 16), marks=minimal_torch_26, id='2d-cond-group-last-NA'), + pytest.param(64, 0, 7, True, 'rms', (16, 8, 16), marks=minimal_torch_26, id='3d-nocond-rms-last-NA'), + ], +) +def test_spatialtransformerblock_backward( + channels: int, + cond_dim: int, + attention_neighborhood: int | None, + features_last: bool, + norm: Literal['group', 'rms'], + input_shape: Sequence[int], + device: str, + torch_compile: bool, +) -> None: + """Test SpatialTransformerBlock output shape and backpropagation.""" + if device == 'cpu' and attention_neighborhood is not None: + pytest.skip( + 'CompiledFlex Attention backward not supported on CPU. https://github.com/pytorch/pytorch/issues/148752' + ) + rng = RandomGenerator(seed=42) + + x = rng.float32_tensor((1, channels, *input_shape)).to(device).requires_grad_(True) + cond = rng.float32_tensor((1, cond_dim)).to(device).requires_grad_(True) if cond_dim else None + + if features_last: + dims = tuple(range(-len(input_shape) - 1, -1)) + else: + dims = tuple(range(-len(input_shape), 0)) + + block = SpatialTransformerBlock( + dim_groups=[dims], + channels=channels, + n_heads=4, + depth=1, + p_dropout=0, + cond_dim=cond_dim, + rope_embed_fraction=0.5, + attention_neighborhood=attention_neighborhood, + features_last=features_last, + norm=norm, + ).to(device) + if torch_compile: + block = cast(SpatialTransformerBlock, torch.compile(block, dynamic=False)) + if features_last: + output = block(x.moveaxis(1, -1), cond=cond).moveaxis(-1, 1) + else: + output = block(x, cond=cond) + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + if cond is not None: + assert cond.grad is not None, 'No gradient computed for conditioning' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' + + +@minimal_torch_26 +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize('torch_compile', [False, True]) +@pytest.mark.parametrize( + ('channels', 'cond_dim', 'attention_neighborhood', 'features_last', 'norm', 'input_shape'), + [ + pytest.param(32, 16, None, True, 'group', (16, 16), id='2d-cond-group-last-global'), + pytest.param(64, 16, 5, True, 'group', (16, 16), marks=minimal_torch_26, id='2d-cond-group-last-NA'), + pytest.param(64, 16, 7, False, 'group', (16, 16), marks=minimal_torch_26, id='2d-cond-group-first-NA'), + pytest.param(64, 0, 7, True, 'rms', (16, 8, 16), marks=minimal_torch_26, id='3d-nocond-rms-last-NA'), + ], +) +def test_spatialtransformerblock_forward( + channels: int, + cond_dim: int, + attention_neighborhood: int | None, + features_last: bool, + norm: Literal['group', 'rms'], + input_shape: Sequence[int], + device: str, + torch_compile: bool, +) -> None: + """Test SpatialTransformerBlock output shape and backpropagation.""" + + rng = RandomGenerator(seed=42) + + x = rng.float32_tensor((1, channels, *input_shape)).to(device).requires_grad_(True) + cond = rng.float32_tensor((1, cond_dim)).to(device).requires_grad_(True) if cond_dim else None + + if features_last: + dims = tuple(range(-len(input_shape) - 1, -1)) + else: + dims = tuple(range(-len(input_shape), 0)) + + block = SpatialTransformerBlock( + dim_groups=[dims], + channels=channels, + n_heads=4, + depth=1, + p_dropout=0, + cond_dim=cond_dim, + rope_embed_fraction=0.5, + attention_neighborhood=attention_neighborhood, + features_last=features_last, + norm=norm, + ).to(device) + if torch_compile: + block = cast(SpatialTransformerBlock, torch.compile(block, dynamic=False)) + with torch.no_grad(): + if features_last: + output = block(x.moveaxis(1, -1), cond=cond).moveaxis(-1, 1) + else: + output = block(x, cond=cond) + assert output.shape == x.shape + assert not output.isnan().any(), 'NaN values in output' diff --git a/tests/nn/test_squeezeexcitation.py b/tests/nn/test_squeezeexcitation.py new file mode 100644 index 000000000..879dd71ca --- /dev/null +++ b/tests/nn/test_squeezeexcitation.py @@ -0,0 +1,32 @@ +"""Tests for SqueezeExcitation module.""" + +from collections.abc import Sequence + +import pytest +from mrpro.nn.attention import SqueezeExcitation +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + ('dim', 'input_shape', 'squeeze_channels'), + [ + (2, (1, 64, 32, 32), 16), + (3, (1, 64, 16, 16, 16), 16), + ], +) +def test_squeeze_excitation( + dim: int, + input_shape: Sequence[int], + squeeze_channels: int, +) -> None: + """Test SqueezeExcitation output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).requires_grad_(True) + se = SqueezeExcitation(n_dim=dim, n_channels_input=input_shape[1], n_channels_squeeze=squeeze_channels) + output = se(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert se.scale[1].weight.grad is not None, 'No gradient computed for Conv' diff --git a/tests/nn/test_transposedattention.py b/tests/nn/test_transposedattention.py new file mode 100644 index 000000000..afbe53494 --- /dev/null +++ b/tests/nn/test_transposedattention.py @@ -0,0 +1,44 @@ +"""Tests for TransposedAttention module.""" + +from collections.abc import Sequence + +import pytest +from mrpro.nn.attention import TransposedAttention +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'channels', 'num_heads', 'input_shape'), + [ + (2, 32, 4, (1, 32, 32, 32)), + (3, 64, 8, (2, 64, 16, 16, 16)), + ], +) +def test_transposed_attention( + dim: int, + channels: int, + num_heads: int, + input_shape: Sequence[int], + device: str, +) -> None: + """Test TransposedAttention output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + attn = TransposedAttention(n_dim=dim, n_channels_in=channels, n_channels_out=channels, n_heads=num_heads).to(device) + output = attn(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert attn.to_qkv.weight.grad is not None, 'No gradient computed for qkv' + assert attn.qkv_dwconv.weight.grad is not None, 'No gradient computed for qkv_dwconv' + assert attn.to_out.weight.grad is not None, 'No gradient computed for project_out' + assert attn.temperature.grad is not None, 'No gradient computed for temperature' From 88e72b7af0ea274f4523a6977bf96719fd8abf57 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:37:27 +0100 Subject: [PATCH 194/205] add unet, basic cnn, and residual blocks ghstack-source-id: c62dfe4bce3e9b11589f24ad6b0e01e86958832b ghstack-comment-id: 3865651070 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/956 --- src/mrpro/nn/ResBlock.py | 70 ++++++ src/mrpro/nn/SeparableResBlock.py | 89 +++++++ src/mrpro/nn/__init__.py | 6 + src/mrpro/nn/nets/BasicCNN.py | 105 +++++++++ src/mrpro/nn/nets/UNet.py | 364 +++++++++++++++++++++++++++++ src/mrpro/nn/nets/__init__.py | 8 + tests/nn/nets/test_cnn.py | 58 +++++ tests/nn/nets/test_unet.py | 116 +++++++++ tests/nn/test_resblock.py | 56 +++++ tests/nn/test_separableresblock.py | 58 +++++ 10 files changed, 930 insertions(+) create mode 100644 src/mrpro/nn/ResBlock.py create mode 100644 src/mrpro/nn/SeparableResBlock.py create mode 100644 src/mrpro/nn/nets/BasicCNN.py create mode 100644 src/mrpro/nn/nets/UNet.py create mode 100644 src/mrpro/nn/nets/__init__.py create mode 100644 tests/nn/nets/test_cnn.py create mode 100644 tests/nn/nets/test_unet.py create mode 100644 tests/nn/test_resblock.py create mode 100644 tests/nn/test_separableresblock.py diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py new file mode 100644 index 000000000..32870979f --- /dev/null +++ b/src/mrpro/nn/ResBlock.py @@ -0,0 +1,70 @@ +"""Residual convolution block with two convolutions.""" + +import torch +from torch.nn import Identity, Module, SiLU + +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import convND +from mrpro.nn.Sequential import Sequential + + +class ResBlock(CondMixin, Module): + """Residual convolution block with two convolutions.""" + + def __init__(self, n_dim: int, n_channels_in: int, n_channels_out: int, cond_dim: int) -> None: + """Initialize the ResBlock. + + Parameters + ---------- + n_dim + The number of dimensions, i.e. 1, 2 or 3. + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. + cond_dim + The number of features in the conditioning tensor used in a FiLM. + If set to 0 no FiLM is used. + + """ + super().__init__() + self.rezero = torch.nn.Parameter(torch.tensor(0.1)) + self.block = Sequential( + GroupNorm(n_channels_in), + SiLU(), + convND(n_dim)(n_channels_in, n_channels_out, kernel_size=3, padding=1), + GroupNorm(n_channels_out), + SiLU(), + convND(n_dim)(n_channels_out, n_channels_out, kernel_size=3, padding=1), + ) + if cond_dim > 0: + self.block.insert(-3, FiLM(n_channels_out, cond_dim)) + + if n_channels_out == n_channels_in: + self.skip_connection: Module = Identity() + else: + self.skip_connection = convND(n_dim)(n_channels_in, n_channels_out, kernel_size=1) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the ResBlock. + + Parameters + ---------- + x + The input tensor. + cond + A conditioning tensor to be used for FiLM. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the ResBlock.""" + h = self.block(x, cond=cond) + x = self.skip_connection(x) + self.rezero * h + return x diff --git a/src/mrpro/nn/SeparableResBlock.py b/src/mrpro/nn/SeparableResBlock.py new file mode 100644 index 000000000..a12fda16f --- /dev/null +++ b/src/mrpro/nn/SeparableResBlock.py @@ -0,0 +1,89 @@ +"""Residual block with separable convolutions.""" + +from collections.abc import Sequence + +import torch +from torch.nn import Module, SiLU + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import convND +from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.Sequential import Sequential + + +class SeparableResBlock(Module): + """Residual block with separable convolutions.""" + + def __init__( + self, + dim_groups: Sequence[Sequence[int]], + n_channels_in: int, + n_channels_out: int, + cond_dim: int, + ) -> None: + """Initialize the SeparableResBlock. + + Applies convolutions as separable convolutions with SilU activation and group normalization. + For example, if ``dim_groups = ((-1,-2), (-3))`` then one 2D convolution is applied to the last two dimensions, + and one 1D convolution is applied to the last dimension. + The order within the block is Norm->Activation->Conv. + The whole sequence for all dimension groups is performed twice, with optional FiLM conditioning in between. + So for two `dim_groups`, a total of 4 convolutions are applied. + + Parameters + ---------- + dim_groups + Sequence of dimension groups to use in the convolutions. + n_channels_in + Number of input channels. + n_channels_out + Number of output channels. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + """ + super().__init__() + self.rezero = torch.nn.Parameter(torch.tensor(0.1)) + + def block(dims: Sequence[int], channels_in: int) -> Module: + return Sequential( + GroupNorm(channels_in), + SiLU(), + PermutedBlock(dims, convND(len(dims))(channels_in, n_channels_out, 3, padding=1)), + ) + + blocks = Sequential(*(block(d, n_channels_in if i == 0 else n_channels_out) for i, d in enumerate(dim_groups))) + if cond_dim > 0: + blocks.append(FiLM(n_channels_out, cond_dim)) + blocks.extend(block(d, n_channels_out) for d in dim_groups) + self.block = blocks + self.skip_connection = None + if n_channels_in != n_channels_out: + self.skip_connection = torch.nn.Linear(n_channels_in, n_channels_out) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. + + Returns + ------- + Output tensor with the same number and order of dimensions as the input. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock.""" + h = self.block(x, cond=cond) + if self.skip_connection is None: + skip = x + else: + skip = torch.moveaxis(x, 1, -1) + skip = self.skip_connection(skip) + skip = torch.moveaxis(skip, -1, 1) + return skip + self.rezero * h diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index ffb98843d..94f28625e 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -12,10 +12,13 @@ from mrpro.nn.LayerNorm import LayerNorm from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.RMSNorm import RMSNorm +from mrpro.nn.ResBlock import ResBlock from mrpro.nn.Residual import Residual +from mrpro.nn.SeparableResBlock import SeparableResBlock from mrpro.nn.Sequential import Sequential from mrpro.nn import attention from mrpro.nn import data_consistency +from mrpro.nn import nets from mrpro.nn.ndmodules import ( adaptiveAvgPoolND, avgPoolND, @@ -39,7 +42,9 @@ 'LayerNorm', 'PermutedBlock', 'RMSNorm', + 'ResBlock', 'Residual', + 'SeparableResBlock', 'Sequential', 'adaptiveAvgPoolND', 'attention', @@ -50,4 +55,5 @@ 'data_consistency', 'instanceNormND', 'maxPoolND', + 'nets', ] diff --git a/src/mrpro/nn/nets/BasicCNN.py b/src/mrpro/nn/nets/BasicCNN.py new file mode 100644 index 000000000..5bf911eef --- /dev/null +++ b/src/mrpro/nn/nets/BasicCNN.py @@ -0,0 +1,105 @@ +"""Basic CNN.""" + +from collections.abc import Sequence +from itertools import pairwise +from typing import Literal + +import torch +from torch.nn import LeakyReLU, ReLU, SiLU + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import batchNormND, convND +from mrpro.nn.Sequential import Sequential + + +class BasicCNN(Sequential): + """Basic CNN. + + A series of convolutions (window 3, stride 1, padding 1), normalization and activation. + Allows to use FiLM conditioning. + Order is Conv -> Norm (optional) -> FiLM (optional) -> Activation. + + If you need more flexibility, use `~mrpro.nn.Sequential` directly. + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + norm: Literal['batch', 'group', 'instance', 'none', 'layer'] = 'none', + activation: Literal['relu', 'silu', 'leaky_relu'] = 'relu', + n_features: Sequence[int] = (64, 64, 64), + cond_dim: int = 0, + ): + """Initialize a basic CNN. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of input channels. + n_channels_out + The number of output channels. + norm + The type of normalization to use. If 'batch', use batch normalization. If 'group', use group normalization, + if 'instance', use instance normalization, and if `layer`, use layer normalization. + If 'none', use no normalization. + activation + The type of activation to use. If 'relu', use ReLU. If 'silu', use SiLU. If 'leaky_relu', use LeakyReLU. + n_features + The number of features in the hidden layers. The length of this sequence determines the number of hidden + layers. The total number of convolutions is `len(n_features) + 1`. + cond_dim + The dimension of the condition tensor. If 0, no FiLM conditioning is applied. + Otherwise, between convolutions, after normalization, FiLM conditioning is applied. + """ + super().__init__() + use_film = cond_dim > 0 + + self.append(convND(n_dim)(n_channels_in, n_features[0], kernel_size=3, padding='same')) + + for c_in, c_out in pairwise((*n_features, n_channels_out)): + if norm.lower() == 'batch': + self.append(batchNormND(n_dim)(c_in, affine=not use_film)) + elif norm.lower() == 'group': + self.append(GroupNorm(c_in, affine=not use_film)) + elif norm.lower() == 'instance': + self.append(GroupNorm(c_in, n_groups=c_in, affine=not use_film)) # is instance norm + elif norm.lower() == 'layer': + self.append(GroupNorm(c_in, n_groups=1, affine=not use_film)) # is layer norm + elif norm.lower() != 'none': + raise ValueError(f'Invalid normalization type: {norm}') + + if use_film: + self.append(FiLM(c_in, cond_dim)) + + if activation.lower() == 'relu': + self.append(ReLU(True)) + elif activation.lower() == 'silu': + self.append(SiLU(inplace=True)) + elif activation.lower() == 'leaky_relu': + self.append(LeakyReLU(inplace=True)) + else: + raise ValueError(f'Invalid activation type: {activation}') + + self.append(convND(n_dim)(c_in, c_out, kernel_size=3, padding='same')) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: # type: ignore[override] + """Apply the basic CNN to the input tensor. + + Parameters + ---------- + x + The input tensor. Should be of shape `(batch_size, channels_in, *spatial dimensions)` + with `spatial dimensions` being of length `dim`. + cond + The condition tensor. If None, no FiLM conditioning is applied. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py new file mode 100644 index 000000000..7b92b1c73 --- /dev/null +++ b/src/mrpro/nn/nets/UNet.py @@ -0,0 +1,364 @@ +"""UNet variants.""" + +from collections.abc import Sequence +from functools import partial +from itertools import pairwise + +import torch +from torch.nn import Identity, Module, ModuleList, ReLU, SiLU + +from mrpro.nn.attention.AttentionGate import AttentionGate +from mrpro.nn.attention.SpatialTransformerBlock import SpatialTransformerBlock +from mrpro.nn.CondMixin import call_with_cond +from mrpro.nn.FiLM import FiLM +from mrpro.nn.join import Concat +from mrpro.nn.ndmodules import convND, maxPoolND +from mrpro.nn.ResBlock import ResBlock +from mrpro.nn.Sequential import Sequential +from mrpro.nn.Upsample import Upsample + + +class UNetEncoder(Module): + """Encoder.""" + + def __init__( + self, + first_block: Module, + blocks: Sequence[Module], + down_blocks: Sequence[Module], + middle_block: Module, + ) -> None: + """Initialize the UNetEncoder.""" + super().__init__() + self.first = first_block + """The first block. Should expand from the number of input channels.""" + + self.blocks = ModuleList(blocks) + """The encoder blocks. Order is highest resolution to lowest resolution.""" + + self.down_blocks = ModuleList(down_blocks) + """The downsampling blocks""" + + self.middle_block = middle_block + """Also called bottleneck block""" + + def __len__(self): + """Get the number of resolutions levels.""" + return len(self.down_blocks) + 1 + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple[torch.Tensor, ...]: + """Apply to Network.""" + call = partial(call_with_cond, cond=cond) + + x = call(self.first, x) + + xs = [] + for block, down in zip(self.blocks, self.down_blocks, strict=True): + x = call(block, x) + xs.append(x) + x = call(down, x) + + x = call(self.middle_block, x) + + return (*xs, x) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple[torch.Tensor, ...]: + """Apply to Network. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor. + + Returns + ------- + The tensors at the different resolutions, highest resolution first. + """ + return super().__call__(x, cond=cond) + + +class UNetDecoder(Module): + """Decoder.""" + + def __init__( + self, + blocks: Sequence[Module], + up_blocks: Sequence[Module], + concat_blocks: Sequence[Module], + last_block: Module, + ) -> None: + """Initialize the UNetDecoder.""" + super().__init__() + self.blocks = ModuleList(blocks) + """The decoder blocks. Order is lowest resolution to highest resolution.""" + + self.up_blocks = ModuleList(up_blocks) + """The upsampling blocks""" + + self.concat_blocks = ModuleList(concat_blocks) + """Joins the skip connections with the upsampled features from a lower resolution level""" + + self.last_block = last_block + """The last block. Should reduce to the number of output channels.""" + + def __len__(self): + """Get the number of resolutions levels.""" + return len(self.up_blocks) + 1 + + def forward(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network.""" + call = partial(call_with_cond, cond=cond) + + x = hs[-1] # lowest resolution, from middle block + for block, up, concat, h in zip(self.blocks, self.up_blocks, self.concat_blocks, hs[-2::-1], strict=True): + x = call(up, x) + x = concat(h, x) + x = call(block, x) + x = call(self.last_block, x) + return x + + def __call__(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network. + + Parameters + ---------- + hs + The tensors at the different resolutions, highest resolution first. + cond + The conditioning tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(hs, cond=cond) + + +class UNetBase(Module): + """Base class for U-shaped networks.""" + + def __init__(self, encoder: UNetEncoder, decoder: UNetDecoder, skip_blocks: Sequence[Module] | None = None) -> None: + """Initialize the UNetBase.""" + super().__init__() + self.encoder = encoder + """The encoder.""" + + self.decoder = decoder + """The decoder.""" + + self.skip_blocks = ModuleList() + """Modifications of the skip connections.""" + + if len(decoder) != len(encoder): + raise ValueError( + 'The number of resolutions in the encoder and decoder must be the same, ' + f'got {len(decoder)} and {len(encoder)}' + ) + + if skip_blocks is None: + self.skip_blocks.extend(Identity() for _ in range(len(decoder))) + elif len(skip_blocks) != len(decoder): + raise ValueError( + f'The number of skip blocks must be the same as the number of resolutions, ' + f'got {len(skip_blocks)} and {len(encoder)}' + ) + else: + self.skip_blocks.extend(skip_blocks) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network.""" + xs = self.encoder(x, cond=cond) + xs = tuple( + call_with_cond(self.skip_blocks[i], x, cond=cond) if i < len(self.skip_blocks) else x + for i, x in enumerate(xs) + ) + x = self.decoder(xs, cond=cond) + return x + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) + + +class UNet(UNetBase): + """UNet. + + U-shaped convolutional network with optional patch attention. + Inspired by [NOSENSE_] and the OpenAi DDPM UNet/Latent Diffusion UNet [LDM]_. + significant differences to the vanilla UNet [UNET]_ include: + - Spatial transformer blocks + - Convolutional downsampling, nearest neighbor upsampling + - Residual convolution blocks with pre-act group normalization and SiLU activation + + + References + ---------- + .. [UNET] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image + segmentation MICCAI 2015. https://arxiv.org/abs/1505.04597 + .. [LDM] https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py + .. [NOSENSE] Zimmermann, FF, and Kofler, Andreas. "NoSENSE: Learned unrolled cardiac MRI reconstruction without + explicit sensitivity maps." STACOM 2023. https://github.com/fzimmermann89/CMRxRecon/blob/master/src/cmrxrecon/nets/unet.py + + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + attention_depths: Sequence[int] = (-1,), + n_features: Sequence[int] = (64, 128, 192, 256), + n_heads: int = 8, + cond_dim: int = 0, + encoder_blocks_per_scale: int = 2, + ) -> None: + """Initialize the UNet. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. + attention_depths + The depths at which to apply attention. + n_features + Number of features at each resolution level. The length determines the number of resolution levels. + n_heads + Number of attention heads. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + encoder_blocks_per_scale + Number of encoder blocks per resolution level. The number of decoder blocks is one more. + """ + depth = len(n_features) + if not all(-depth <= d < depth for d in attention_depths): + raise ValueError( + f'attention_depths must be in the range [-depth, depth], got {attention_depths=} for {depth=}' + ) + attention_depths = tuple(d % depth for d in attention_depths) + if len(attention_depths) != len(set(attention_depths)): + raise ValueError(f'attention_depths must be unique, got {attention_depths=}') + + def attention_block(channels: int) -> Module: + dim_groups = (tuple(range(-n_dim, 0)),) + return SpatialTransformerBlock(dim_groups, channels, n_heads, cond_dim=cond_dim) + + def blocks(channels_in: int, channels_out: int, attention: bool) -> Module: + blocks = Sequential() + for _ in range(encoder_blocks_per_scale): + blocks.append(ResBlock(n_dim, channels_in, channels_out, cond_dim)) + if attention: + blocks.append(attention_block(channels_out)) + channels_in = channels_out + return blocks + + encoder_blocks: list[Module] = [] + down_blocks: list[Module] = [] + decoder_blocks: list[Module] = [] + up_blocks: list[Module] = [] + + for i_level, (n_feat, n_feat_next) in enumerate(pairwise(n_features)): + encoder_blocks.append(blocks(n_feat, n_feat, i_level in attention_depths)) + down_blocks.append(convND(n_dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) + decoder_blocks.append(blocks(n_feat_next + n_feat, n_feat, i_level in attention_depths)) + up_blocks.append(Upsample(tuple(range(-n_dim, 0)), scale_factor=2)) + + middle_block = Sequential( + ResBlock(n_dim, n_feat_next, n_feat_next, cond_dim), + ResBlock(n_dim, n_feat_next, n_feat_next, cond_dim), + ) + if depth - 1 in attention_depths: + middle_block.insert(1, attention_block(n_feat_next)) + first_block = convND(n_dim)(n_channels_in, n_features[0], 3, padding=1) + encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) + + decoder_blocks, up_blocks = decoder_blocks[::-1], up_blocks[::-1] + last_block = Sequential( + SiLU(), + convND(n_dim)(n_features[0], n_channels_out, 3, padding=1), + ) + concat_blocks = [Concat() for _ in range(len(decoder_blocks))] + decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + + super().__init__(encoder, decoder) + + +class AttentionGatedUNet(UNetBase): + """UNet with attention gates. + + Basic UNet with attention gating of the skip signals by the lower resolution features [OKT18]_. + + References + ---------- + .. [OKT18] Oktay, Ozan, et al. "Attention U-net: Learning where to look for the pancreas." MIDL (2018). + https://arxiv.org/abs/1804.03999 + """ + + def __init__( + self, n_dim: int, n_channels_in: int, n_channels_out: int, n_features: Sequence[int], cond_dim: int = 0 + ): + """Initialize the AttentionGatedUNet. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. + n_features + Number of features at each resolution level. The length determines the number of resolution levels. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + """ + + def block(channels_in: int, channels_out: int) -> Module: + block = Sequential( + convND(n_dim)(channels_in, channels_out, 3, padding=1), + ReLU(True), + convND(n_dim)(channels_out, channels_out, 3, padding=1), + ReLU(True), + ) + if cond_dim > 0: + block.insert(2, FiLM(channels_out, cond_dim)) + return block + + encoder_blocks: list[Module] = [] + down_blocks: list[Module] = [] + n_feat_old = n_channels_in + for n_feat in n_features[:-1]: + encoder_blocks.append(block(n_feat_old, n_feat)) + down_blocks.append(maxPoolND(n_dim)(2)) + n_feat_old = n_feat + middle_block = block(n_features[-2], n_features[-1]) + encoder = UNetEncoder(Identity(), encoder_blocks, down_blocks, middle_block) + + concat_blocks = [] + decoder_blocks: list[Module] = [] + up_blocks: list[Module] = [] + for n_feat, n_feat_skip in pairwise(n_features[::-1]): + concat_blocks.append(AttentionGate(n_dim, n_feat, n_feat_skip, n_feat_skip, concatenate=True)) + decoder_blocks.append(block(n_feat + n_feat_skip, n_feat_skip)) + up_blocks.append(Upsample(range(-n_dim, 0), scale_factor=2)) + last_block = convND(n_dim)(n_features[0], n_channels_out, 1) + decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + + super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py new file mode 100644 index 000000000..06271d970 --- /dev/null +++ b/src/mrpro/nn/nets/__init__.py @@ -0,0 +1,8 @@ +from mrpro.nn.nets.BasicCNN import BasicCNN +from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet + +__all__ = [ + 'AttentionGatedUNet', + 'BasicCNN', + 'UNet', +] diff --git a/tests/nn/nets/test_cnn.py b/tests/nn/nets/test_cnn.py new file mode 100644 index 000000000..5dbbd5ad6 --- /dev/null +++ b/tests/nn/nets/test_cnn.py @@ -0,0 +1,58 @@ +"""Tests for BasicCNN network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import BasicCNN + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_cnn_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the cnn.""" + cnn = BasicCNN( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + norm='layer', + n_features=(8, 8), + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cnn = cnn.to(device) + x = x.to(device) + if torch_compile: + cnn = cast(BasicCNN, torch.compile(cnn)) + y = cnn(x) + assert y.shape == (1, 1, 16, 16) + + +def test_cnn_backward() -> None: + cnn = BasicCNN( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + norm='instance', + activation='silu', + n_features=(8, 8), + cond_dim=32, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = cnn(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in cnn.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' diff --git a/tests/nn/nets/test_unet.py b/tests/nn/nets/test_unet.py new file mode 100644 index 000000000..fdf2f5250 --- /dev/null +++ b/tests/nn/nets/test_unet.py @@ -0,0 +1,116 @@ +"""Tests for UNet and AttentionGatedUNet networks.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import AttentionGatedUNet, UNet + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_unet_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the UNet.""" + unet = UNet( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + attention_depths=(-1,), + n_features=(4, 6, 8), + n_heads=2, + cond_dim=32, + encoder_blocks_per_scale=1, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + unet = unet.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + unet = cast(UNet, torch.compile(unet)) + y = unet(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + + +def test_unet_backward() -> None: + unet = UNet( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + attention_depths=(-1,), + n_features=(4, 6, 8), + n_heads=2, + cond_dim=32, + encoder_blocks_per_scale=1, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = unet(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in unet.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_gated_unet_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the AttentionGatedUNet.""" + unet = AttentionGatedUNet( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_features=(4, 6, 8), + cond_dim=32, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + unet = unet.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + unet = cast(AttentionGatedUNet, torch.compile(unet)) + y = unet(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + + +def test_gated_unet_backward() -> None: + """Test the backward pass of the AttentionGatedUNet.""" + unet = AttentionGatedUNet( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_features=(4, 6, 8), + cond_dim=32, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = unet(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in unet.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' diff --git a/tests/nn/test_resblock.py b/tests/nn/test_resblock.py new file mode 100644 index 000000000..ea4356173 --- /dev/null +++ b/tests/nn/test_resblock.py @@ -0,0 +1,56 @@ +"""Tests for ResBlock module.""" + +from collections.abc import Sequence +from typing import cast + +import pytest +import torch +from mrpro.nn import ResBlock +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'eager']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'channels_in', 'channels_out', 'cond_dim', 'input_shape', 'cond_shape'), + [ + (2, 32, 32, 16, (1, 32, 32, 32), (1, 16)), + (3, 64, 32, 0, (2, 64, 16, 16, 16), None), + ], +) +def test_resblock( + dim: int, + channels_in: int, + channels_out: int, + cond_dim: int, + input_shape: Sequence[int], + cond_shape: Sequence[int] | None, + device: str, + torch_compile: bool, +) -> None: + """Test ResBlock output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) if cond_shape else None + block = ResBlock(n_dim=dim, n_channels_in=channels_in, n_channels_out=channels_out, cond_dim=cond_dim).to(device) + if torch_compile: + block = cast(ResBlock, torch.compile(block, dynamic=False)) + output = block(x, cond=cond) + assert output.shape == (input_shape[0], channels_out, *input_shape[2:]), ( + f'Output shape {output.shape} != expected {(input_shape[0], channels_out, *input_shape[2:])}' + ) + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert block.block[2].weight.grad is not None, 'No gradient computed for first Conv' + if cond is not None: + assert cond.grad is not None, 'No gradient computed for conditioning' + assert not cond.isnan().any(), 'NaN values in conditioning' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' diff --git a/tests/nn/test_separableresblock.py b/tests/nn/test_separableresblock.py new file mode 100644 index 000000000..25b6c65e0 --- /dev/null +++ b/tests/nn/test_separableresblock.py @@ -0,0 +1,58 @@ +"""Tests for SeparableResBlock module.""" + +from collections.abc import Sequence +from typing import cast + +import pytest +import torch +from mrpro.nn import SeparableResBlock +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'eager']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim_groups', 'channels_in', 'channels_out', 'cond_dim', 'input_shape', 'cond_shape'), + [ + (((-1, -2),), 32, 32, 16, (1, 32, 32, 32), (1, 16)), + (((-1, -2), (-3,)), 64, 32, 0, (2, 64, 16, 16, 16), None), # 2D + 1D + ], +) +def test_separable_resblock( + dim_groups: Sequence[Sequence[int]], + channels_in: int, + channels_out: int, + cond_dim: int, + input_shape: Sequence[int], + cond_shape: Sequence[int] | None, + device: str, + torch_compile: bool, +) -> None: + """Test SeparableResBlock output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) if cond_shape else None + block = SeparableResBlock( + dim_groups=dim_groups, n_channels_in=channels_in, n_channels_out=channels_out, cond_dim=cond_dim + ).to(device) + if torch_compile: + block = cast(SeparableResBlock, torch.compile(block, dynamic=False)) + output = block(x, cond=cond) + assert output.shape == (input_shape[0], channels_out, *input_shape[2:]), ( + f'Output shape {output.shape} != expected {(input_shape[0], channels_out, *input_shape[2:])}' + ) + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert block.block[0][2].module.weight.grad is not None, 'No gradient computed for first Conv' # type: ignore[union-attr] + if cond is not None: + assert cond.grad is not None, 'No gradient computed for conditioning' + assert not cond.isnan().any(), 'NaN values in conditioning' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' From 774db723066af8f5423b8b31989d38ad4cdb80c9 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:37:28 +0100 Subject: [PATCH 195/205] Add MLP network and tests ghstack-source-id: 658a1037cd47063cd0a58dca83f346420d4b1087 ghstack-comment-id: 3874745395 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/963 --- src/mrpro/nn/nets/MLP.py | 119 ++++++++++++++++++++++++++++++++++ src/mrpro/nn/nets/__init__.py | 8 ++- tests/nn/test_mlp.py | 89 +++++++++++++++++++++++++ 3 files changed, 213 insertions(+), 3 deletions(-) create mode 100644 src/mrpro/nn/nets/MLP.py create mode 100644 tests/nn/test_mlp.py diff --git a/src/mrpro/nn/nets/MLP.py b/src/mrpro/nn/nets/MLP.py new file mode 100644 index 000000000..524a985cd --- /dev/null +++ b/src/mrpro/nn/nets/MLP.py @@ -0,0 +1,119 @@ +"""Multi-layer perceptron.""" + +from collections.abc import Sequence +from itertools import pairwise +from typing import Literal + +import torch +from torch.nn import GELU, LeakyReLU, Linear, ReLU, SiLU + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.nn.Sequential import Sequential + + +class MLP(Sequential): + """Multi-layer perceptron. + + A series of linear layers, normalization and activation. + Allows FiLM conditioning. + Order is Linear -> Norm (optional) -> FiLM (optional) -> Activation. + + If you need more flexibility, use `~mrpro.nn.Sequential` directly. + """ + + features_last: bool + + def __init__( + self, + n_channels_in: int, + n_channels_out: int, + norm: Literal['layer', 'none'] = 'none', + activation: Literal['gelu', 'relu', 'silu', 'leaky_relu'] = 'gelu', + n_features: Sequence[int] = (256, 256), + cond_dim: int = 0, + features_last: bool = True, + ): + """Initialize a MLP. + + Parameters + ---------- + n_channels_in + The number of input channels. + n_channels_out + The number of output channels. + norm + The type of normalization to use. If `layer`, use layer normalization. + If `none`, use no normalization. + activation + The type of activation to use. If `gelu`, use GELU. + If `relu`, use ReLU. If `silu`, use SiLU. If `leaky_relu`, use LeakyReLU. + n_features + The number of features in the hidden layers. The length of this sequence determines the number of hidden + layers. The total number of linear layers is `len(n_features) + 1`. + cond_dim + The dimension of the condition tensor. If 0, no FiLM conditioning is applied. + Otherwise, between linear layers, after normalization, FiLM conditioning is applied. + features_last + Whether the features are in the last dimension, as common in transformer models, + or in the second dimension, as common in image models. + """ + super().__init__() + use_film = cond_dim > 0 + self.features_last = features_last + + if len(n_features) == 0: + self.append(Linear(n_channels_in, n_channels_out)) + return + + self.append(Linear(n_channels_in, n_features[0])) + + for c_in, c_out in pairwise((*n_features, n_channels_out)): + if norm.lower() == 'layer': + self.append(LayerNorm(c_in, features_last=True)) + elif norm.lower() != 'none': + raise ValueError(f'Invalid normalization type: {norm}') + + if use_film: + self.append(FiLM(c_in, cond_dim, features_last=True)) + + if activation.lower() == 'gelu': + self.append(GELU(approximate='tanh')) + elif activation.lower() == 'relu': + self.append(ReLU()) + elif activation.lower() == 'silu': + self.append(SiLU()) + elif activation.lower() == 'leaky_relu': + self.append(LeakyReLU()) + else: + raise ValueError(f'Invalid activation type: {activation}') + + self.append(Linear(c_in, c_out)) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: # type: ignore[override] + """Apply the MLP to the input tensor. + + Parameters + ---------- + x + The input tensor. + cond + The condition tensor. If None, no FiLM conditioning is applied. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the MLP to the input tensor.""" + if len(x) != 1: + raise ValueError(f'Mlp expects exactly one input tensor, got {len(x)}') + tensor = x[0] + if not self.features_last: + tensor = tensor.moveaxis(1, -1) + out = super().forward(tensor, cond=cond) + if not self.features_last: + out = out.moveaxis(-1, 1) + return out diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 06271d970..af7dbdc5e 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -1,8 +1,10 @@ from mrpro.nn.nets.BasicCNN import BasicCNN from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet +from mrpro.nn.nets.MLP import MLP __all__ = [ - 'AttentionGatedUNet', - 'BasicCNN', - 'UNet', + "AttentionGatedUNet", + "BasicCNN", + "MLP", + "UNet", ] diff --git a/tests/nn/test_mlp.py b/tests/nn/test_mlp.py new file mode 100644 index 000000000..22f5b9ca1 --- /dev/null +++ b/tests/nn/test_mlp.py @@ -0,0 +1,89 @@ +"""Tests for Mlp module.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import MLP +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_mlp_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the Mlp.""" + mlp = MLP( + n_channels_in=8, + n_channels_out=4, + norm='layer', + activation='gelu', + n_features=(16,), + cond_dim=12, + features_last=False, + ).to(device) + x = torch.zeros(1, 8, 9, 7, device=device) + cond = torch.zeros(1, 12, device=device) + if torch_compile: + mlp = cast(MLP, torch.compile(mlp)) + y = mlp(x, cond=cond) + assert y.shape == (1, 4, 9, 7) + + +def test_mlp_backward() -> None: + """Test the backward pass of the Mlp.""" + mlp = MLP( + n_channels_in=6, + n_channels_out=3, + norm='none', + activation='silu', + n_features=(12, 12), + cond_dim=10, + features_last=True, + ) + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 20, 6)).requires_grad_(True) + cond = rng.float32_tensor((1, 10)).requires_grad_(True) + y = mlp(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in mlp.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' + + +def test_mlp_features_last() -> None: + """Test Mlp with features_last=True vs features_last=False.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + + mlp_last = MLP( + n_channels_in=3, + n_channels_out=4, + norm='layer', + activation='relu', + n_features=(6,), + cond_dim=0, + features_last=True, + ) + mlp = MLP( + n_channels_in=3, + n_channels_out=4, + norm='layer', + activation='relu', + n_features=(6,), + cond_dim=0, + features_last=False, + ) + mlp.load_state_dict(mlp_last.state_dict()) + y_last = mlp_last(x.moveaxis(1, -1)) + y = mlp(x) + torch.testing.assert_close(y, y_last.moveaxis(-1, 1)) From 7ea4450471d2e641ede6ec16915c8891444eb06f Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:37:28 +0100 Subject: [PATCH 196/205] add restormer architecture and tests ghstack-source-id: 5e8d1409bb31a7858f66a7087b7bb6c450c49c89 ghstack-comment-id: 3865651248 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/957 --- src/mrpro/nn/nets/Restormer.py | 223 ++++++++++++++++++++++++++++++++ src/mrpro/nn/nets/__init__.py | 2 + tests/nn/nets/test_restormer.py | 62 +++++++++ 3 files changed, 287 insertions(+) create mode 100644 src/mrpro/nn/nets/Restormer.py create mode 100644 tests/nn/nets/test_restormer.py diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py new file mode 100644 index 000000000..357bebf51 --- /dev/null +++ b/src/mrpro/nn/nets/Restormer.py @@ -0,0 +1,223 @@ +"""Restormer implementation.""" + +from collections.abc import Sequence +from itertools import pairwise + +import torch +from torch.nn import Module + +from mrpro.nn.attention.TransposedAttention import TransposedAttention +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.FiLM import FiLM +from mrpro.nn.join import Concat +from mrpro.nn.ndmodules import convND, instanceNormND +from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder +from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample +from mrpro.nn.Sequential import Sequential + + +class GDFN(Module): + """Gated depthwise feed forward network. + + Feed-forward block used in Restormer [ZAM22]_. It first expands channels, + applies a depthwise convolution, then uses a gated interaction between two + channel splits before projecting back to the input width. + + References + ---------- + .. [ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for + high-resolution image restoration." CVPR 2022. + """ + + def __init__(self, n_dim: int, n_channels: int, mlp_ratio: float): + """Initialize GDFN. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + mlp_ratio + Ratio for hidden dimension expansion + """ + super().__init__() + + hidden_features = int(n_channels * mlp_ratio) + self.project_in = convND(n_dim)(n_channels, hidden_features * 2, kernel_size=1) + self.depthwise_conv = convND(n_dim)( + hidden_features * 2, + hidden_features * 2, + kernel_size=3, + stride=1, + padding=1, + groups=hidden_features * 2, + ) + self.project_out = convND(n_dim)(hidden_features, n_channels, kernel_size=1) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the gated depthwise feed forward network. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Output tensor + """ + x = self.project_in(x) + x1, x2 = self.depthwise_conv(x).chunk(2, dim=1) + x = x1 * torch.sigmoid(x2) + x = self.project_out(x) + return x + + +class RestormerBlock(CondMixin, Module): + """Transformer block with transposed attention and gated depthwise feed forward network.""" + + def __init__(self, n_dim: int, n_channels: int, n_heads: int, mlp_ratio: float, cond_dim: int = 0): + """Initialize RestormerBlock. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + n_heads + Number of attention heads + mlp_ratio + Ratio for hidden dimension expansion + cond_dim + Dimension of conditioning input. If 0, no conditioning is applied. + """ + super().__init__() + self.norm1 = Sequential(instanceNormND(n_dim)(n_channels)) + self.attn = TransposedAttention(n_dim, n_channels, n_channels, n_heads) + self.norm2 = Sequential(instanceNormND(n_dim)(n_channels)) + self.ffn = GDFN(n_dim, n_channels, mlp_ratio) + if cond_dim > 0: + self.norm2.append(FiLM(channels=n_channels, cond_dim=cond_dim)) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply Restormer block. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor. If None, no conditioning is applied. + + Returns + ------- + Output tensor + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Forward pass for RestormerBlock.""" + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x, cond=cond)) + return x + + +class Restormer(UNetBase): + """Restormer architecture. + + Implements the Restormer [ZAM22]_ network, which is a U-shaped transformer + with channel wise attention and depthwise convolutions in the feed forward network. + + References + ---------- + .. [ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for high-resolution image restoration." + CVPR 2022, https://arxiv.org/pdf/2111.09881.pdf + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + n_blocks: Sequence[int] = (4, 6, 6, 8), + n_refinement_blocks: int = 4, + n_heads: Sequence[int] = (1, 2, 4, 8), + n_channels_per_head: int = 48, + mlp_ratio: float = 2.66, + cond_dim: int = 0, + ): + """Initialize Restormer. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of input channels. + n_channels_out + The number of output channels. + n_blocks + Number of blocks in each stage + n_refinement_blocks + Number of refinement blocks + n_heads + Number of attention heads in each stage + n_channels_per_head + Number of channels per attention head + mlp_ratio + Ratio for hidden dimension expansion + cond_dim + Dimension of conditioning input. If 0, no conditioning is applied. + """ + if len(n_blocks) != len(n_heads): + raise ValueError('n_blocks and n_heads must have the same length.') + + def blocks(n_heads: int, n_blocks: int): + layers = Sequential( + *(RestormerBlock(n_dim, n_channels_per_head * n_heads, n_heads, mlp_ratio) for _ in range(n_blocks)) + ) + + if cond_dim > 0 and n_blocks > 1: + layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, cond_dim=cond_dim)) + return layers + + first_block = convND(n_dim)(n_channels_in, n_channels_per_head, kernel_size=3, stride=1, padding=1, bias=False) + encoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)] + down_blocks = [ + PixelUnshuffleDownsample(n_dim, n_channels_per_head * head_current, n_channels_per_head * head_next) + for head_current, head_next in pairwise(n_heads) + ] + middle_block = blocks(n_heads[-1], n_blocks[-1]) + encoder = UNetEncoder( + first_block=first_block, + blocks=encoder_blocks, + down_blocks=down_blocks, + middle_block=middle_block, + ) + + up_blocks = [ + PixelShuffleUpsample(n_dim, n_channels_per_head * head_next, n_channels_per_head * head_current) + for head_current, head_next in pairwise(n_heads) + ][::-1] + concat_blocks = [ + Sequential( + Concat(), + convND(n_dim)(2 * n_channels_per_head * head, n_channels_per_head * head, kernel_size=1), + ) + for head in n_heads[-2::-1] + ] + decoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)][::-1] + last_block = Sequential( + *(RestormerBlock(n_dim, n_channels_per_head, n_heads[0], mlp_ratio) for _ in range(n_refinement_blocks)), + convND(n_dim)(n_channels_per_head, n_channels_out, kernel_size=3, stride=1, padding=1), + ) + decoder = UNetDecoder( + blocks=decoder_blocks, + up_blocks=up_blocks, + concat_blocks=concat_blocks, + last_block=last_block, + ) + + super().__init__(encoder=encoder, decoder=decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index af7dbdc5e..a0a9a6ad4 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -1,4 +1,5 @@ from mrpro.nn.nets.BasicCNN import BasicCNN +from mrpro.nn.nets.Restormer import Restormer from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet from mrpro.nn.nets.MLP import MLP @@ -6,5 +7,6 @@ "AttentionGatedUNet", "BasicCNN", "MLP", + "Restormer", "UNet", ] diff --git a/tests/nn/nets/test_restormer.py b/tests/nn/nets/test_restormer.py new file mode 100644 index 000000000..68c84a689 --- /dev/null +++ b/tests/nn/nets/test_restormer.py @@ -0,0 +1,62 @@ +"""Tests for Restormer network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import Restormer + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_restormer_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the restormer.""" + restormer = Restormer( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_heads=(1, 2, 4), + n_blocks=(2, 1, 1), + cond_dim=32, + n_channels_per_head=2, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + restormer = restormer.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + restormer = cast(Restormer, torch.compile(restormer)) + y = restormer(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + + +def test_restormer_backward() -> None: + restormer = Restormer( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_heads=(1, 2), + n_blocks=(2, 2), + cond_dim=32, + n_channels_per_head=4, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = restormer(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in restormer.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From 650bfe5c3072df6957fe8b1211bec38f723f8721 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:37:29 +0100 Subject: [PATCH 197/205] add swinir architecture and tests ghstack-source-id: 9b04f9fb1f47e38ab120aacd06524fc97b9e9e2c ghstack-comment-id: 3865651450 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/958 --- src/mrpro/nn/nets/SwinIR.py | 247 ++++++++++++++++++++++++++++++++++ src/mrpro/nn/nets/__init__.py | 2 + tests/nn/nets/test_swinir.py | 56 ++++++++ 3 files changed, 305 insertions(+) create mode 100644 src/mrpro/nn/nets/SwinIR.py create mode 100644 tests/nn/nets/test_swinir.py diff --git a/src/mrpro/nn/nets/SwinIR.py b/src/mrpro/nn/nets/SwinIR.py new file mode 100644 index 000000000..bc064848d --- /dev/null +++ b/src/mrpro/nn/nets/SwinIR.py @@ -0,0 +1,247 @@ +"""SwinIR implementation.""" + +import torch +from torch.nn import GELU, Module + +from mrpro.nn.attention.ShiftedWindowAttention import ShiftedWindowAttention +from mrpro.nn.DropPath import DropPath +from mrpro.nn.FiLM import FiLM +from mrpro.nn.ndmodules import convND, instanceNormND +from mrpro.nn.Sequential import Sequential + + +class SwinTransformerLayer(Module): + """Swin Transformer layer. + + Implements a single layer of the Swin Transformer architecture. + """ + + def __init__( + self, + n_dim: int, + n_channels: int, + n_heads: int, + window_size: int, + mlp_ratio: int = 4, + emb_dim: int = 0, + p_droppath: float = 0.0, + ): + """Initialize SwinTransformerLayer. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + n_heads + Number of attention heads + window_size + Size of the attention window + mlp_ratio + Ratio for hidden dimension expansion in MLP + emb_dim + Dimension of conditioning input. If 0, no FiLM conditioning is used. + p_droppath + Droppath probability for MLP + """ + super().__init__() + self.norm1 = instanceNormND(n_dim)(n_channels) + self.attn = ShiftedWindowAttention(n_dim, n_channels, n_channels, n_heads, window_size) + self.norm2 = Sequential(instanceNormND(n_dim)(n_channels)) + if emb_dim > 0: + self.norm2.append(FiLM(channels=n_channels, cond_dim=emb_dim)) + self.mlp = Sequential( + convND(n_dim)(n_channels, n_channels * mlp_ratio, 1), + GELU('tanh'), + convND(n_dim)(n_channels * mlp_ratio, n_channels, 1), + DropPath(p_droppath), + ) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the Swin Transformer layer. + + Parameters + ---------- + x + Input tensor + cond + Conditioning input + + Returns + ------- + torch.Tensor + Output tensor + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the Swin Transformer layer.""" + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x, cond=cond)) + return x + + +class ResidualSwinTransformerBlock(Module): + """Residual Swin Transformer block (RSTB). + + Combines a Swin Transformer layer with a residual connection, + as used in the SwinIR architecture. + """ + + def __init__( + self, + n_dim: int, + n_channels: int, + n_heads: int, + window_size: int, + depth: int, + emb_dim: int = 0, + p_droppath: float = 0.0, + mlp_ratio: int = 4, + ): + """Initialize ResidualSwinTransformerBlock. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + n_heads + Number of attention heads + window_size + Size of the attention window + depth + Number of Swin Transformer layers + emb_dim + Dimension of conditioning input. If 0, no FiLM conditioning is used. + p_droppath + Droppath probability for MLP. + mlp_ratio + Ratio for hidden dimension expansion in MLP + """ + super().__init__() + self.layers = Sequential( + *( + SwinTransformerLayer( + n_dim, n_channels, n_heads, window_size, emb_dim=emb_dim, p_droppath=p_droppath, mlp_ratio=mlp_ratio + ) + for _ in range(depth) + ) + ) + self.conv = convND(n_dim)(n_channels, n_channels, 3, padding=1) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the residual Swin Transformer block. + + Parameters + ---------- + x + Input tensor + cond + Conditioning input. If None, no FiLM conditioning is used. + + Returns + ------- + torch.Tensor + Output tensor + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the residual Swin Transformer block.""" + return x + self.conv(self.layers(x, cond=cond)) + + +class SwinIR(Module): + """SwinIR architecture. + + Implements the SwinIR [LZL21]_ network, which is a Swin Transformer based + image restoration network. + + References + ---------- + .. [LZL21] Liang, Jie, et al. "SwinIR: Image restoration using swin transformer." + ICCVW 2021, https://arxiv.org/pdf/2108.10257.pdf + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + n_channels_per_head: int = 16, + n_heads: int = 6, + window_size: int = 64, + n_blocks: int = 6, + n_attn_per_block: int = 6, + emb_dim: int = 0, + p_droppath: float = 0.0, + mlp_ratio: int = 4, + ): + """Initialize SwinIR. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of input channels. + n_channels_out + The number of output channels. + n_channels_per_head + The number of channels per attention head. + n_heads + The number of attention heads. + window_size + The size of the attention window. Inputs sizes must be divisible by this value. + n_blocks + The number of residual blocks. + n_attn_per_block + The number of attention layers per block. + emb_dim + The dimension of the conditioning input. If 0, no FiLM conditioning is used. + p_droppath + The droppath probability for MLP. + mlp_ratio + The ratio for hidden dimension expansion in MLP. + """ + super().__init__() + self.first = convND(n_dim)(n_channels_in, n_channels_per_head * n_heads, kernel_size=3, padding=1) + self.blocks = Sequential( + *( + ResidualSwinTransformerBlock( + n_dim, + n_channels_per_head * n_heads, + n_heads, + window_size, + n_attn_per_block, + emb_dim, + p_droppath, + mlp_ratio, + ) + for _ in range(n_blocks) + ) + ) + self.last = convND(n_dim)(n_channels_per_head * n_heads, n_channels_out, kernel_size=3, padding=1) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply SwinIR. + + Parameters + ---------- + x + Input tensor + cond + Conditioning input. If None, no FiLM conditioning is used. + + Returns + ------- + torch.Tensor + Output tensor + """ + x = self.first(x) + x = self.blocks(x, cond=cond) + x = self.last(x) + return x diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index a0a9a6ad4..582f2be46 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -1,5 +1,6 @@ from mrpro.nn.nets.BasicCNN import BasicCNN from mrpro.nn.nets.Restormer import Restormer +from mrpro.nn.nets.SwinIR import SwinIR from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet from mrpro.nn.nets.MLP import MLP @@ -8,5 +9,6 @@ "BasicCNN", "MLP", "Restormer", + "SwinIR", "UNet", ] diff --git a/tests/nn/nets/test_swinir.py b/tests/nn/nets/test_swinir.py new file mode 100644 index 000000000..c8dbed58c --- /dev/null +++ b/tests/nn/nets/test_swinir.py @@ -0,0 +1,56 @@ +"""Tests for SwinIR network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import SwinIR + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_swinir_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the UNet.""" + swinir = SwinIR( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_heads=2, + n_channels_per_head=4, + n_blocks=2, + window_size=4, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + swinir = swinir.to(device) + if torch_compile: + swinir = cast(SwinIR, torch.compile(swinir)) + y = swinir(x) + assert y.shape == (1, 1, 16, 16) + + +def test_swinir_backward() -> None: + swinir = SwinIR( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_heads=2, + n_channels_per_head=4, + n_blocks=2, + window_size=4, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + y = swinir(x) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + for name, parameter in swinir.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From 9ce3038417912f235098cbdcd4bfdaecaa83de53 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:37:29 +0100 Subject: [PATCH 198/205] add uformer architecture and tests ghstack-source-id: 50681e2c086c097d02a8e8d882c3080662ee0b66 ghstack-comment-id: 3865651637 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/959 --- src/mrpro/nn/nets/Uformer.py | 230 ++++++++++++++++++++++++++++++++++ src/mrpro/nn/nets/__init__.py | 4 +- tests/nn/nets/test_uformer.py | 56 +++++++++ 3 files changed, 289 insertions(+), 1 deletion(-) create mode 100644 src/mrpro/nn/nets/Uformer.py create mode 100644 tests/nn/nets/test_uformer.py diff --git a/src/mrpro/nn/nets/Uformer.py b/src/mrpro/nn/nets/Uformer.py new file mode 100644 index 000000000..02f1d1cc1 --- /dev/null +++ b/src/mrpro/nn/nets/Uformer.py @@ -0,0 +1,230 @@ +"""Uformer: U-Net with window attention.""" + +from collections.abc import Sequence +from itertools import pairwise + +import torch +from torch.nn import GELU, LeakyReLU, Module + +from mrpro.nn.attention.ShiftedWindowAttention import ShiftedWindowAttention +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.DropPath import DropPath +from mrpro.nn.FiLM import FiLM +from mrpro.nn.join import Concat +from mrpro.nn.ndmodules import convND, convTransposeND, instanceNormND +from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder +from mrpro.nn.Sequential import Sequential + + +class LeWinTransformerBlock(CondMixin, Module): + """Locally-enhanced windowed attention transformer block. + + Part of the Uformer architecture. + """ + + def __init__( + self, + n_dim: int, + n_channels_per_head: int, + n_heads: int, + window_size: int = 8, + shifted: bool = False, + mlp_ratio: float = 4.0, + p_droppath: float = 0.0, + cond_dim: int = 0, + ) -> None: + """Initialize the LeWinTransformerBlock module. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_per_head + The number of features per head. + n_heads + Number of attention heads + window_size + Size of the attention window + shifted + Whether to use shifted variant of the attention + mlp_ratio + Ratio of the hidden dimension to the input dimension + p_droppath + Dropout probability for the drop path. + cond_dim + Dimension of a conditioning tensor. If `0`, no FiLM layers are added. + """ + super().__init__() + channels = n_channels_per_head * n_heads + hidden_dim = int(channels * mlp_ratio) + self.norm1 = instanceNormND(n_dim)(channels) + self.attn = ShiftedWindowAttention( + n_dim=n_dim, + n_channels_in=channels, + n_channels_out=channels, + n_heads=n_heads, + window_size=window_size, + shifted=shifted, + ) + self.norm2 = instanceNormND(n_dim)(channels) + self.ff = Sequential( + convND(n_dim)(channels, hidden_dim, 1), + GELU(), + convND(n_dim)(hidden_dim, hidden_dim, kernel_size=3, groups=hidden_dim, stride=1, padding=1), + GELU(), + convND(n_dim)(hidden_dim, channels, 1), + ) + if cond_dim > 0: + self.ff.append(FiLM(channels, cond_dim)) + self.modulator = torch.nn.Parameter(torch.empty(channels, *((window_size,) * n_dim))) + torch.nn.init.trunc_normal_(self.modulator) + self.drop_path = DropPath(droprate=p_droppath) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the transformer block. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor + + Returns + ------- + Output tensor + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the transformer block.""" + modulator = self.modulator.tile([t // s for t, s in zip(x.shape[1:], self.modulator.shape, strict=False)]) + x_mod = self.norm1(x) + modulator + x_attn = self.attn(x_mod) + x_ff = self.ff(self.norm2(x_attn), cond=cond) + return x + self.drop_path(x_ff) + + +class Uformer(UNetBase): + """Uformer: U-Net with window attention. + + Implements the Uformer network proposed in [WANG21]_ + It is SWin-Transformer/U-Net hybrid consisting of (shifted) windows attention transformer layers at different + resolution levels, extended by FiLM layers for conditioning. + + References + ---------- + .. [WANG21] Wang, Z., Cun, X., Bao, J., Zhou, W., Liu, J., & Li, H. Uformer: A general u-shaped transformer for + image restoration. CVPR 2022. https://doi.org/10.48550/arXiv.2106.03106 + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + n_channels_per_head: int = 32, + n_heads: Sequence[int] = (1, 2, 4, 8), + n_blocks: int = 2, + cond_dim: int = 0, + window_size: int = 8, + mlp_ratio: float = 4.0, + max_droppath_rate: float = 0.1, + ): + """Initialize the Uformer module. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of input channels. + n_channels_out + The number of output channels. + n_channels_per_head + The number of features per head. The number of features at a resolution level is given by + `n_channels_per_head * n_heads`. + n_heads + Number of attention heads at each resolution level. + n_blocks + The number of transformer blocks at each resolution level in the input and output path + cond_dim + Dimension of a conditioning tensor. If `0`, no FiLM layers are added. + window_size + The size of the attention windows in the (shifted) window attention layers. + mlp_ratio + Ratio of the hidden dimension to the input dimension in the feed-forward blocks + max_droppath_rate + Maximum drop path rate. As in the original implementation, the drop path rate in the input path + is linearly increased from `0` to `max_droppath_rate` with decreasing resolution. The rate in output + blocks is fixed to `max_droppath_rate`. + """ + + def blocks(n_heads: int, p_droppath: float = 0.0): + return Sequential( + *( + LeWinTransformerBlock( + n_dim=n_dim, + n_heads=n_heads, + n_channels_per_head=n_channels_per_head, + window_size=window_size, + mlp_ratio=mlp_ratio, + shifted=bool(i % 2), + p_droppath=p_droppath, + cond_dim=cond_dim, + ) + for i in range(n_blocks) + ) + ) + + first_block = torch.nn.Sequential( + convND(n_dim)(n_channels_in, n_channels_per_head * n_heads[0], kernel_size=3, stride=1, padding='same'), + LeakyReLU(), + ) + drop_path_rates = torch.linspace(0, max_droppath_rate, len(n_heads)).tolist() + encoder_blocks = [ + blocks(n_heads=n_head, p_droppath=p_droppath_input) + for n_head, p_droppath_input in zip(n_heads[:-1], drop_path_rates[:-1], strict=True) + ] + down_blocks = [ + convND(n_dim)( + n_channels_per_head * n_head_current, + n_channels_per_head * n_head_next, + kernel_size=4, + stride=2, + padding=1, + ) + for n_head_current, n_head_next in pairwise(n_heads) + ] + middle_block = blocks(n_heads=n_heads[-1], p_droppath=max_droppath_rate) + encoder = UNetEncoder( + first_block=first_block, + blocks=encoder_blocks, + down_blocks=down_blocks, + middle_block=middle_block, + ) + + decoder_blocks = [blocks(n_heads=2 * n_head, p_droppath=max_droppath_rate) for n_head in reversed(n_heads[:-1])] + concat_blocks = [Concat() for _ in range(len(decoder_blocks))] + up_blocks = [ + convTransposeND(n_dim)( + n_channels_per_head * n_heads[-1], n_channels_per_head * n_heads[-2], kernel_size=2, stride=2 + ) + ] + for n_head_current, n_head_next in pairwise(reversed(n_heads[:-1])): + up_blocks.append( + convTransposeND(n_dim)( + 2 * n_channels_per_head * n_head_current, n_channels_per_head * n_head_next, kernel_size=2, stride=2 + ) + ) + last_block = convND(n_dim)( + 2 * n_channels_per_head * n_heads[0], n_channels_out, kernel_size=3, stride=1, padding='same' + ) + decoder = UNetDecoder( + blocks=decoder_blocks, + concat_blocks=concat_blocks, + up_blocks=up_blocks, + last_block=last_block, + ) + + super().__init__(encoder=encoder, decoder=decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 582f2be46..a610364a7 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -2,6 +2,7 @@ from mrpro.nn.nets.Restormer import Restormer from mrpro.nn.nets.SwinIR import SwinIR from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet +from mrpro.nn.nets.Uformer import Uformer from mrpro.nn.nets.MLP import MLP __all__ = [ @@ -11,4 +12,5 @@ "Restormer", "SwinIR", "UNet", -] + "Uformer" +] \ No newline at end of file diff --git a/tests/nn/nets/test_uformer.py b/tests/nn/nets/test_uformer.py new file mode 100644 index 000000000..f4315702e --- /dev/null +++ b/tests/nn/nets/test_uformer.py @@ -0,0 +1,56 @@ +"""Tests for Uformer network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import Uformer + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_uformer_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the uformer.""" + uformer = Uformer( + n_dim=2, n_channels_in=1, n_channels_out=1, n_heads=(1, 2), cond_dim=32, n_channels_per_head=8, window_size=2 + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + uformer = uformer.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + uformer = cast(Uformer, torch.compile(uformer)) + y = uformer(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + + +def test_uformer_backward() -> None: + uformer = Uformer( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_heads=(1, 2, 4), + cond_dim=32, + n_channels_per_head=8, + window_size=2, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = uformer(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in uformer.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From 6b51894d4ccf5d196b997cccad4d062792b4c75a Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:37:30 +0100 Subject: [PATCH 199/205] add hourglass transformer architecture and tests ghstack-source-id: 0e60fc37ef7ff7fdec243648ff308fea95bb350b ghstack-comment-id: 3865651799 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/960 --- src/mrpro/nn/nets/HourglassTransformer.py | 147 ++++++++++++++++++++++ src/mrpro/nn/nets/__init__.py | 2 + tests/nn/nets/test_hourglass.py | 65 ++++++++++ 3 files changed, 214 insertions(+) create mode 100644 src/mrpro/nn/nets/HourglassTransformer.py create mode 100644 tests/nn/nets/test_hourglass.py diff --git a/src/mrpro/nn/nets/HourglassTransformer.py b/src/mrpro/nn/nets/HourglassTransformer.py new file mode 100644 index 000000000..ee4d9d44c --- /dev/null +++ b/src/mrpro/nn/nets/HourglassTransformer.py @@ -0,0 +1,147 @@ +"""Hourglass Transformer.""" + +from collections.abc import Sequence +from itertools import pairwise + +from torch.nn import Module + +from mrpro.nn.attention.SpatialTransformerBlock import SpatialTransformerBlock +from mrpro.nn.join import Interpolate +from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder +from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample +from mrpro.nn.Sequential import Sequential +from mrpro.operators.RearrangeOp import RearrangeOp +from mrpro.utils.to_tuple import to_tuple + + +class HourglassTransformer(UNetBase): + """Hourglass Transformer. + + A U-shaped transformer [CK]_ with neighborhood self-attention [NAT]_. + + References + ---------- + .. [CK] Crowson, Katherine, et al. "Scalable high-resolution pixel-space image synthesis with + hourglass diffusion transformers." ICML 2024, https://arxiv.org/abs/2401.11605 + .. [NAT] Hassani, A. et al. "Neighborhood Attention Transformer" CVPR, 2023, https://arxiv.org/abs/2204.07143 + + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + n_features: Sequence[int] | int, + depths: Sequence[int] | int = 3, + attention_neighborhood: Sequence[None | int] | int | None = 11, + n_heads: int | Sequence[int] = 4, + cond_dim: int = 0, + ): + """Initialize the Hourglass Transformer. + + Parameters + ---------- + n_dim + Number of (spatial)dimensions of the input data. + n_channels_in + Number of channels in the input data. + n_channels_out + Number of channels in the output data. + n_features + Number of features in each stage. + depths + Number of layers in each stage. + attention_neighborhood + Neighborhood size for the neighborhood self-attention. If None, use global attention + for that stage. + n_heads + Number of heads in each stage. + cond_dim + Number of dimensions of the conditioning tensor. + """ + n_layers_ = [ + len(x) + for x in (n_features, depths, attention_neighborhood, n_heads) + if (x is not None and not isinstance(x, int)) + ] + n_layers = n_layers_[0] + + if any(x != n_layers_[0] for x in n_layers_): + raise ValueError('All arguments must have the same length or be scalars') + + n_features_ = to_tuple(n_layers, n_features) + depths_ = to_tuple(n_layers, depths) + attention_neighborhood_ = to_tuple(n_layers, attention_neighborhood) + n_heads_ = to_tuple(n_layers, n_heads) + + move_channels_last = RearrangeOp('batch channels ... -> batch ... channels') + first_block = Sequential( + move_channels_last, + PixelUnshuffleDownsample(n_dim, n_channels_in, n_features_[0], downscale_factor=2, features_last=True), + ) + dim_group = (tuple(range(-n_dim - 1, -1)),) + encoder_blocks: list[Module] = [] + decoder_blocks: list[Module] = [] + merge_blocks: list[Module] = [] + down_blocks: list[Module] = [] + up_blocks: list[Module] = [] + for channels, depth, neighborhood, head in zip( + n_features_[:-1], + depths_[:-1], + attention_neighborhood_[:-1], + n_heads_[:-1], + strict=True, + ): + encoder_blocks.append( + SpatialTransformerBlock( + dim_groups=dim_group, + channels=channels, + depth=depth, + attention_neighborhood=neighborhood, + n_heads=head, + rope_embed_fraction=1.0, + cond_dim=cond_dim, + features_last=True, + norm='rms', + ) + ) + decoder_blocks.append( + SpatialTransformerBlock( + dim_groups=dim_group, + channels=channels, + depth=depth, + attention_neighborhood=neighborhood, + n_heads=head, + rope_embed_fraction=1.0, + cond_dim=cond_dim, + features_last=True, + norm='rms', + ) + ) + merge_blocks.append(Interpolate()) + for channels, channels_next in pairwise(n_features_): + down_blocks.append( + PixelUnshuffleDownsample(n_dim, channels, channels_next, downscale_factor=2, features_last=True) + ) + up_blocks.append(PixelShuffleUpsample(n_dim, channels_next, channels, upscale_factor=2, features_last=True)) + + last_block = Sequential( + PixelShuffleUpsample(n_dim, n_features_[-1], n_channels_out, upscale_factor=2, features_last=True), + move_channels_last.H, # moves channels back to front + ) + middle_block = SpatialTransformerBlock( + dim_groups=dim_group, + channels=n_features_[-1], + depth=depths_[-1], + attention_neighborhood=attention_neighborhood_[-1], + n_heads=n_heads_[-1], + rope_embed_fraction=1.0, + cond_dim=cond_dim, + features_last=True, + norm='rms', + ) + encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) + decoder = UNetDecoder(decoder_blocks, up_blocks, merge_blocks, last_block) + + super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index a610364a7..9a83b31aa 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -1,4 +1,5 @@ from mrpro.nn.nets.BasicCNN import BasicCNN +from mrpro.nn.nets.HourglassTransformer import HourglassTransformer from mrpro.nn.nets.Restormer import Restormer from mrpro.nn.nets.SwinIR import SwinIR from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet @@ -8,6 +9,7 @@ __all__ = [ "AttentionGatedUNet", "BasicCNN", + "HourglassTransformer", "MLP", "Restormer", "SwinIR", diff --git a/tests/nn/nets/test_hourglass.py b/tests/nn/nets/test_hourglass.py new file mode 100644 index 000000000..717e7e203 --- /dev/null +++ b/tests/nn/nets/test_hourglass.py @@ -0,0 +1,65 @@ +"""Test Hourglass Transformer""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import HourglassTransformer +from tests.conftest import minimal_torch_26 + + +@minimal_torch_26 +@torch.no_grad() +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_hourglass_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the hourglass.""" + hourglass = HourglassTransformer( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_features=64, + attention_neighborhood=(7, 7, None), + cond_dim=32, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + hourglass = hourglass.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + hourglass = cast(HourglassTransformer, torch.compile(hourglass, dynamic=False)) + y = hourglass(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + + +@minimal_torch_26 +@pytest.mark.cuda +def test_hourglass_backward() -> None: + hourglass = HourglassTransformer( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_features=64, + attention_neighborhood=(7, 7, None), + cond_dim=32, + ).cuda() + + x = torch.zeros(1, 1, 16, requires_grad=True).cuda() + cond = torch.zeros(1, 32, requires_grad=True).cuda() + y = hourglass(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in hourglass.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From fccecc4dd7a54206c2872beb4982dfb08aa4c5ff Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:37:30 +0100 Subject: [PATCH 200/205] add vae ghstack-source-id: 0254d8c637580665808507d6e73ab05ab1b35a59 ghstack-comment-id: 3865652021 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/961 --- src/mrpro/nn/nets/VAE.py | 142 ++++++++++++++++++++++++++++++++++ src/mrpro/nn/nets/__init__.py | 4 +- tests/nn/nets/test_vae.py | 79 +++++++++++++++++++ 3 files changed, 224 insertions(+), 1 deletion(-) create mode 100644 src/mrpro/nn/nets/VAE.py create mode 100644 tests/nn/nets/test_vae.py diff --git a/src/mrpro/nn/nets/VAE.py b/src/mrpro/nn/nets/VAE.py new file mode 100644 index 000000000..74fd7a67f --- /dev/null +++ b/src/mrpro/nn/nets/VAE.py @@ -0,0 +1,142 @@ +"""Variational Autoencoder with a Gaussian latent space.""" + +from collections.abc import Sequence +from itertools import pairwise + +import torch +from torch.nn import Module, SiLU + +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import convND +from mrpro.nn.ResBlock import ResBlock +from mrpro.nn.Sequential import Sequential +from mrpro.nn.Upsample import Upsample + + +class VAEBase(Module): + """Basic Variational Autoencoder. + + Consists of an encoder to transform the input into a latent space and a decoder to transform the latent space back + into the original space. The encoder should return twice the number of channels as the decoder needs to reconstruct + the input: half of the channels are the mean and the other half the log variance of the latent space. + The reparameterization trick is used to sample from the latent space. + The forward pass returns the reconstructed image and the KL divergence between the latent space and the standard + normal distribution. + """ + + def __init__(self, encoder: Module, decoder: Module): + """Initialize the VAE. + + Parameters + ---------- + encoder + Encoder module. Should return double the number of channels of the latent space. + decoder + Decoder module + """ + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass of the VAE. + + Calculates the reconstruction as well as the KL divergence between the latent space and the + standard normal distribution. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + tuple of the reconstructed image and + the KL divergence between the latent space and the standard normal distribution. + """ + return self.forward(x) + + def mode(self, x: torch.Tensor) -> torch.Tensor: + """Mode of the VAE.""" + z = self.encoder(x) + mean, _ = z.chunk(2, dim=1) + return self.decoder(mean) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass of the VAE.""" + z = self.encoder(x) + mean, logvar = z.chunk(2, dim=1) + std = torch.exp(0.5 * logvar) + sample = mean + torch.randn_like(std) * std + reconstruction = self.decoder(sample) + kl = (-0.5 / len(z)) * torch.sum(1 + logvar - mean.square() - std.square()) + return reconstruction, kl + + +class VAE(VAEBase): + """Variational autoencoder with convolutional encoder and decoder.""" + + def __init__( + self, + n_dim: int = 2, + n_channels_in: int = 2, + latent_channels: int = 8, + n_features: Sequence[int] = (32, 64, 128), + n_res_blocks: int = 2, + ) -> None: + """Initialize the VAE. + + Parameters + ---------- + n_dim + The number of dimensions, i.e. 1, 2 or 3. + n_channels_in + The number of channels in the input tensor. + latent_channels + The number of channels in the latent space. + n_features + The number of features at each resolution level. + n_res_blocks + Number of residual blocks per resolution level. + """ + encoder = Sequential(convND(n_dim)(n_channels_in, n_features[0], kernel_size=3, padding=1)) + + for n_feat, n_feat_next in pairwise(n_features): + for _ in range(n_res_blocks): + encoder.append(ResBlock(n_dim, n_feat, n_feat, cond_dim=0)) + encoder.append(convND(n_dim)(n_feat, n_feat_next, kernel_size=3, stride=2, padding=1)) + + for _ in range(n_res_blocks): + encoder.append(ResBlock(n_dim, n_features[-1], n_features[-1], cond_dim=0)) + + encoder.extend( + [ + GroupNorm(n_features[-1]), + SiLU(), + convND(n_dim)(n_features[-1], 2 * latent_channels, kernel_size=3, padding=1), + ] + ) + + decoder = Sequential(convND(n_dim)(latent_channels, n_features[-1], kernel_size=3, padding=1)) + for _ in range(n_res_blocks): + decoder.append(ResBlock(n_dim, n_features[-1], n_features[-1], cond_dim=0)) + + for n_feat, n_feat_next in pairwise(reversed(n_features)): + decoder.append( + Sequential( + Upsample(dim=range(-n_dim, 0), scale_factor=2, mode='linear'), + convND(n_dim)(n_feat, n_feat_next, kernel_size=3, padding=1), + ) + ) + for _ in range(n_res_blocks): + decoder.append(ResBlock(n_dim, n_feat_next, n_feat_next, cond_dim=0)) + + decoder.extend( + [ + GroupNorm(n_features[0]), + SiLU(), + convND(n_dim)(n_features[0], n_channels_in, kernel_size=3, padding=1), + ] + ) + + super().__init__(encoder=encoder, decoder=decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 9a83b31aa..a343b0be8 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -1,4 +1,5 @@ from mrpro.nn.nets.BasicCNN import BasicCNN +from mrpro.nn.nets.VAE import VAE from mrpro.nn.nets.HourglassTransformer import HourglassTransformer from mrpro.nn.nets.Restormer import Restormer from mrpro.nn.nets.SwinIR import SwinIR @@ -14,5 +15,6 @@ "Restormer", "SwinIR", "UNet", - "Uformer" + "Uformer", + "VAE" ] \ No newline at end of file diff --git a/tests/nn/nets/test_vae.py b/tests/nn/nets/test_vae.py new file mode 100644 index 000000000..621a505d7 --- /dev/null +++ b/tests/nn/nets/test_vae.py @@ -0,0 +1,79 @@ +"""Tests for VAE network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import VAE + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_vae_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the VAE.""" + vae = VAE( + n_dim=2, + n_channels_in=1, + latent_channels=4, + n_features=(6, 8, 10), + n_res_blocks=2, + ) + + x = torch.zeros(1, 1, 8, 8, device=device) + vae = vae.to(device) + x = x.to(device) + if torch_compile: + vae = cast(VAE, torch.compile(vae)) + y, kl = vae(x) + assert y.shape == (1, 1, 8, 8) + assert kl.shape == () + latent = vae.encoder(x) + assert latent.shape == (1, 2 * 4, 2, 2) # 2 because of mean and logvar + + +def test_vae_backward_kl() -> None: + """Test the backward pass of the VAE wrt kl.""" + vae = VAE( + n_dim=1, + n_channels_in=1, + latent_channels=4, + n_features=(6, 8, 10), + n_res_blocks=2, + ) + + x = torch.zeros(1, 1, 8, requires_grad=True) + + _, kl = vae(x) + kl.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + for name, parameter in vae.encoder.named_parameters(): # only the encoder parameters can influence kl + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' + + +def test_vae_backward_y() -> None: + """Test the backward pass of the VAE wrt y.""" + vae = VAE( + n_dim=1, + n_channels_in=1, + latent_channels=4, + n_features=(6, 8, 10), + n_res_blocks=2, + ) + + x = torch.zeros(1, 1, 8, requires_grad=True) + + y, _ = vae(x) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + for name, parameter in vae.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From 512c4c96f7c67a240e9cfb5ba065031140d770ff Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:37:31 +0100 Subject: [PATCH 201/205] add dit ghstack-source-id: 450699e11434243ee232f29eaa47cf18157c3d57 ghstack-comment-id: 3874745738 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/964 --- src/mrpro/nn/nets/DiT.py | 266 ++++++++++++++++++++++++++++++++++ src/mrpro/nn/nets/__init__.py | 2 + tests/nn/nets/test_dit.py | 115 +++++++++++++++ 3 files changed, 383 insertions(+) create mode 100644 src/mrpro/nn/nets/DiT.py create mode 100644 tests/nn/nets/test_dit.py diff --git a/src/mrpro/nn/nets/DiT.py b/src/mrpro/nn/nets/DiT.py new file mode 100644 index 000000000..28ae78b9a --- /dev/null +++ b/src/mrpro/nn/nets/DiT.py @@ -0,0 +1,266 @@ +"""Diffusion Transformer (DiT).""" + +from collections.abc import Sequence +from math import prod + +import torch +from torch.nn import Linear, Module, Parameter, SiLU + +from mrpro.nn.attention.MultiHeadAttention import MultiHeadAttention +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.nn.nets.MLP import MLP +from mrpro.nn.Sequential import Sequential +from mrpro.operators.PatchOp import PatchOp +from mrpro.utils.to_tuple import to_tuple + + +class DiTBlock(CondMixin, Module): + """DiT block with adaptive layer normalization and residual gating. + + References + ---------- + .. [DiT] Peebles, W., & Xie, S. Scalable Diffusion Models with Transformers. + ICCV 2023, https://arxiv.org/abs/2212.09748 + """ + + features_last: bool + + def __init__( + self, + n_channels: int, + n_heads: int, + cond_dim: int, + mlp_ratio: float = 4.0, + features_last: bool = True, + ): + """Initialize a DiT block. + + Parameters + ---------- + n_channels + Number of channels in the input and output. + n_heads + Number of attention heads. + cond_dim + Number of channels in the conditioning tensor. + mlp_ratio + Ratio of hidden MLP channels to input channels. + features_last + Whether the features are in the last dimension of the input tensor. + """ + super().__init__() + self.features_last = features_last + self.norm1 = LayerNorm(n_channels, features_last=True, cond_dim=cond_dim) + self.attn = MultiHeadAttention( + n_channels_in=n_channels, + n_channels_out=n_channels, + n_heads=n_heads, + features_last=True, + ) + self.norm2 = LayerNorm(n_channels, features_last=True, cond_dim=cond_dim) + self.mlp = MLP( + n_channels_in=n_channels, + n_channels_out=n_channels, + n_features=(int(n_channels * mlp_ratio),), + norm='none', + activation='gelu', + cond_dim=0, + features_last=True, + ) + self.gate = Sequential( + SiLU(), + Linear(cond_dim, 2 * n_channels), + ) + linear = self.gate[-1] + if isinstance(linear, Linear): + torch.nn.init.zeros_(linear.weight) + torch.nn.init.zeros_(linear.bias) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the DiT block. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. + + Returns + ------- + Output tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the DiT block.""" + if not self.features_last: + x = x.moveaxis(1, -1) + + gate_msa, gate_mlp = self.gate(cond).unsqueeze(-2).chunk(2, dim=-1) if cond is not None else (1.0, 1.0) + x = x + gate_msa * self.attn(self.norm1(x, cond=cond)) + x = x + gate_mlp * self.mlp(self.norm2(x, cond=cond)) + + if not self.features_last: + x = x.moveaxis(-1, 1) + + return x + + +class DiT(Module): + """DiT model. + + DiT is a vision transformer popularized by [DiT]_. + Often used for latent diffusion models, but also suitable for image restoration etc. + + References + ---------- + .. [DiT] Peebles, W., & Xie, S. Scalable Diffusion Models with Transformers. + ICCV 2023, https://arxiv.org/abs/2212.09748 + + """ + + grid_size: tuple[int, ...] + patch_size: tuple[int, ...] + n_channels_out: int + + def __init__( + self, + n_dim: int, + n_channels_in: int, + cond_dim: int, + input_size: int | Sequence[int] = 32, + patch_size: int | Sequence[int] = 2, + n_channels_out: int | None = None, + hidden_dim: int = 1152, + depth: int = 28, + n_heads: int = 16, + mlp_ratio: float = 4.0, + ) -> None: + """Initialize DiT. + + Parameters + ---------- + n_dim + Number of spatial dimensions. + n_channels_in + Number of channels in the input tensor. + cond_dim + Dimension of the conditioning tensor. + input_size + Input spatial size. If scalar, the same size is used for all spatial dimensions. + patch_size + Patch size. If scalar, the same patch size is used for all spatial dimensions. + n_channels_out + Number of output channels. If `None`, use `n_channels_in`. + hidden_dim + Transformer hidden dimension. + depth + Number of transformer blocks. + n_heads + Number of attention heads. + mlp_ratio + Ratio of hidden MLP channels to input channels. + """ + super().__init__() + self.n_dim = n_dim + self.input_size = to_tuple(n_dim, input_size) + self.patch_size = to_tuple(n_dim, patch_size) + + if any(s % p != 0 for s, p in zip(self.input_size, self.patch_size, strict=True)): + raise ValueError(f'Input size {self.input_size} must be divisible by patch size {self.patch_size}.') + if hidden_dim % (2 * n_dim) != 0: + raise ValueError(f'Hidden dimension {hidden_dim} must be divisible by 2 * {n_dim=}.') + + self.grid_size = tuple(s // p for s, p in zip(self.input_size, self.patch_size, strict=True)) + self.n_patches = prod(self.grid_size) + self.hidden_dim = hidden_dim + + self.n_channels_in = n_channels_in + self.n_channels_out = n_channels_in if n_channels_out is None else n_channels_out + + spatial_dim = tuple(range(2, 2 + n_dim)) + self.patch_op = PatchOp( + dim=spatial_dim, + patch_size=self.patch_size, + stride=self.patch_size, + dilation=1, + domain_size=self.input_size, + ) + + patch_volume = prod(self.patch_size) + self.in_proj = Linear(n_channels_in * patch_volume, hidden_dim) + self.pos_embed = Parameter(torch.zeros(self.n_patches, hidden_dim), requires_grad=False) + + self.blocks = Sequential( + *( + DiTBlock( + n_channels=hidden_dim, + n_heads=n_heads, + cond_dim=cond_dim, + mlp_ratio=mlp_ratio, + features_last=True, + ) + for _ in range(depth) + ) + ) + + self.final_layer = Sequential( + LayerNorm(hidden_dim, features_last=True, cond_dim=cond_dim), + Linear(hidden_dim, patch_volume * self.n_channels_out), + ) + + self.initialize_weights() + + def initialize_weights(self) -> None: + """Initialize network weights.""" + + def _basic_init(module: Module) -> None: + if isinstance(module, Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + + self.apply(_basic_init) + + w = self.in_proj.weight.data + torch.nn.init.xavier_uniform_(w.reshape(w.shape[0], -1)) + if self.in_proj.bias is not None: + torch.nn.init.zeros_(self.in_proj.bias) + + for block in self.blocks: + if isinstance(block, DiTBlock): + gate_linear = block.gate[-1] + if isinstance(gate_linear, Linear): + torch.nn.init.zeros_(gate_linear.weight) + torch.nn.init.zeros_(gate_linear.bias) + + w = 1.0 / (10000 ** torch.linspace(0, 1, self.hidden_dim // (2 * len(self.grid_size)))) + x = torch.stack(torch.meshgrid(*[torch.arange(s).float() for s in self.grid_size], indexing='ij'), dim=-1) + wx = w * x.unsqueeze(-1) + pos_embed = torch.cat([torch.sin(wx), torch.cos(wx)], dim=-1).reshape(-1, self.hidden_dim) + self.pos_embed.data.copy_(pos_embed.to(self.pos_embed.data)) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply DiT. + + Parameters + ---------- + x + Input tensor with shape `batch, channels, *spatial_dims`. + cond + Conditioning tensor. + + Returns + ------- + Output tensor with shape `batch, out_channels, *spatial_dims`. + """ + x = self.patch_op(x)[0].swapaxes(0, 1).flatten(2) + x = self.in_proj(x) + x = x + self.pos_embed + x = self.blocks(x, cond=cond) + x = self.final_layer(x, cond=cond) + x = x.unflatten(-1, (self.n_channels_out, *self.patch_size)).swapaxes(0, 1) + (x,) = self.patch_op.adjoint(x) + return x diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index a343b0be8..3fab59405 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -1,5 +1,6 @@ from mrpro.nn.nets.BasicCNN import BasicCNN from mrpro.nn.nets.VAE import VAE +from mrpro.nn.nets.DiT import DiT from mrpro.nn.nets.HourglassTransformer import HourglassTransformer from mrpro.nn.nets.Restormer import Restormer from mrpro.nn.nets.SwinIR import SwinIR @@ -10,6 +11,7 @@ __all__ = [ "AttentionGatedUNet", "BasicCNN", + "DiT", "HourglassTransformer", "MLP", "Restormer", diff --git a/tests/nn/nets/test_dit.py b/tests/nn/nets/test_dit.py new file mode 100644 index 000000000..aed95599b --- /dev/null +++ b/tests/nn/nets/test_dit.py @@ -0,0 +1,115 @@ +"""Tests for DiT network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import DiT +from mrpro.nn.nets.DiT import DiTBlock +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_ditblock_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of DiTBlock.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 64, 32)).to(device).requires_grad_(True) + cond = rng.float32_tensor((1, 16)).to(device).requires_grad_(True) + block = DiTBlock(n_channels=32, n_heads=4, cond_dim=16, mlp_ratio=2.0, features_last=True).to(device) + if torch_compile: + block = cast(DiTBlock, torch.compile(block, dynamic=False)) + y = block(x, cond=cond) + assert y.shape == x.shape + assert not y.isnan().any(), 'NaN values in output' + + +def test_ditblock_backward() -> None: + """Test the backward pass of DiTBlock.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 8, 8)).requires_grad_(True) + cond = rng.float32_tensor((1, 12)).requires_grad_(True) + block = DiTBlock(n_channels=32, n_heads=4, cond_dim=12, mlp_ratio=2.0, features_last=False) + y = block(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in block.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +@pytest.mark.parametrize('input_size', [(16, 32), (4, 8, 16)], ids=['2d', '3d']) +def test_dit_forward(torch_compile: bool, device: str, input_size: tuple[int, ...]) -> None: + """Test the forward pass of DiT.""" + n_channels_in = 3 + n_channels_out = 2 + n_batch = 1 + hidden_dim = 12 + cond_dim = 32 + dit = DiT( + n_dim=len(input_size), + n_channels_in=n_channels_in, + cond_dim=cond_dim, + input_size=input_size, + patch_size=2, + n_channels_out=n_channels_out, + hidden_dim=hidden_dim, + depth=2, + n_heads=4, + mlp_ratio=2.0, + ) + + x = torch.zeros(n_batch, n_channels_in, *input_size, device=device) + cond = torch.zeros(n_batch, cond_dim, device=device) + dit = dit.to(device) + if torch_compile: + dit = cast(DiT, torch.compile(dit)) + y = dit(x, cond=cond) + assert y.shape == (n_batch, n_channels_out, *input_size) + + +def test_dit_backward() -> None: + """Test the backward pass of DiT.""" + dit = DiT( + n_dim=2, + n_channels_in=1, + cond_dim=24, + input_size=16, + patch_size=2, + n_channels_out=1, + hidden_dim=32, + depth=2, + n_heads=4, + mlp_ratio=2.0, + ) + + x = torch.zeros(1, 1, 16, 16, requires_grad=True) + cond = torch.zeros(1, 24, requires_grad=True) + y = dit(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in dit.named_parameters(): + if name == 'pos_embed': + continue # embedding is fixed + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From 495d916ca91dfa7540a87442ddc72a2d04a1eb78 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:44:57 +0100 Subject: [PATCH 202/205] Squashed commit of the following: commit 512c4c96f7c67a240e9cfb5ba065031140d770ff Author: Felix Zimmermann Date: Tue Feb 10 14:37:31 2026 +0100 add dit ghstack-source-id: 450699e11434243ee232f29eaa47cf18157c3d57 ghstack-comment-id: 3874745738 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/964 commit fccecc4dd7a54206c2872beb4982dfb08aa4c5ff Author: Felix Zimmermann Date: Tue Feb 10 14:37:30 2026 +0100 add vae ghstack-source-id: 0254d8c637580665808507d6e73ab05ab1b35a59 ghstack-comment-id: 3865652021 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/961 commit 6b51894d4ccf5d196b997cccad4d062792b4c75a Author: Felix Zimmermann Date: Tue Feb 10 14:37:30 2026 +0100 add hourglass transformer architecture and tests ghstack-source-id: 0e60fc37ef7ff7fdec243648ff308fea95bb350b ghstack-comment-id: 3865651799 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/960 commit 9ce3038417912f235098cbdcd4bfdaecaa83de53 Author: Felix Zimmermann Date: Tue Feb 10 14:37:29 2026 +0100 add uformer architecture and tests ghstack-source-id: 50681e2c086c097d02a8e8d882c3080662ee0b66 ghstack-comment-id: 3865651637 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/959 commit 650bfe5c3072df6957fe8b1211bec38f723f8721 Author: Felix Zimmermann Date: Tue Feb 10 14:37:29 2026 +0100 add swinir architecture and tests ghstack-source-id: 9b04f9fb1f47e38ab120aacd06524fc97b9e9e2c ghstack-comment-id: 3865651450 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/958 commit 7ea4450471d2e641ede6ec16915c8891444eb06f Author: Felix Zimmermann Date: Tue Feb 10 14:37:28 2026 +0100 add restormer architecture and tests ghstack-source-id: 5e8d1409bb31a7858f66a7087b7bb6c450c49c89 ghstack-comment-id: 3865651248 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/957 commit 774db723066af8f5423b8b31989d38ad4cdb80c9 Author: Felix Zimmermann Date: Tue Feb 10 14:37:28 2026 +0100 Add MLP network and tests ghstack-source-id: 658a1037cd47063cd0a58dca83f346420d4b1087 ghstack-comment-id: 3874745395 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/963 commit 88e72b7af0ea274f4523a6977bf96719fd8abf57 Author: Felix Zimmermann Date: Tue Feb 10 14:37:27 2026 +0100 add unet, basic cnn, and residual blocks ghstack-source-id: c62dfe4bce3e9b11589f24ad6b0e01e86958832b ghstack-comment-id: 3865651070 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/956 commit d8eb662bdcc1afd9af2919e0547458182abdf38b Author: Felix Zimmermann Date: Tue Feb 10 14:37:27 2026 +0100 add positional encodings and attention modules ghstack-source-id: 1ef9b0579a1d6a662dc8d306ff8bd373b259a253 ghstack-comment-id: 3865650808 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/955 commit 05db778511fd34dd272ac34d946f5def23db37c9 Author: Felix Zimmermann Date: Tue Feb 10 14:37:26 2026 +0100 add data consistency modules and tests ghstack-source-id: 555bc2286cf127ebd9411ff45f7f1bb6c68612e4 ghstack-comment-id: 3865650618 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/954 commit c0be007732a13b58aa9ce9ce631ea9bff1af85de Author: Felix Zimmermann Date: Tue Feb 10 14:37:26 2026 +0100 add core nn foundations, layers, and resize blocks ghstack-source-id: 17c69fc80ce0e6e8390cb563c732b4f2b8aea912 ghstack-comment-id: 3865650347 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/953 commit 3e85c27ebcc7a90a90a466ddb548b4d6de95077e Author: Felix Zimmermann Date: Tue Feb 10 14:37:25 2026 +0100 Add fast path to PatchOp adjoint ghstack-source-id: ce486f1c161829a2109abd998ea46e6446587643 ghstack-comment-id: 3874745101 Pull-Request-resolved: https://github.com/PTB-MR/mrpro/pull/962 commit 966707c032d7acb7691c1dabffc456b6c98737ac Author: Mara Guastini <112558042+guastinimara@users.noreply.github.com> Date: Tue Feb 10 02:09:40 2026 +0100 Fix GridSamplingOperator from inputs on the gpu (#824) Co-authored-by: Felix Zimmermann --- src/mrpro/nn/FiLM.py | 30 +- src/mrpro/nn/LayerNorm.py | 38 ++- src/mrpro/nn/attention/MultiHeadAttention.py | 31 +- .../nn/attention/SpatialTransformerBlock.py | 42 ++- src/mrpro/nn/nets/DiT.py | 266 ++++++++++++++++++ src/mrpro/nn/nets/MLP.py | 119 ++++++++ src/mrpro/nn/nets/VAE.py | 92 +++++- src/mrpro/nn/nets/__init__.py | 24 +- src/mrpro/operators/GridSamplingOp.py | 15 +- src/mrpro/operators/PatchOp.py | 88 ++++-- tests/nn/nets/test_dit.py | 115 ++++++++ tests/nn/nets/test_vae.py | 79 ++++++ tests/nn/test_film.py | 52 +++- tests/nn/test_layernorm.py | 111 +++++--- tests/nn/test_mlp.py | 89 ++++++ tests/operators/test_grid_sampling_op.py | 108 +++++-- tests/operators/test_patch_op.py | 9 +- 17 files changed, 1135 insertions(+), 173 deletions(-) create mode 100644 src/mrpro/nn/nets/DiT.py create mode 100644 src/mrpro/nn/nets/MLP.py create mode 100644 tests/nn/nets/test_dit.py create mode 100644 tests/nn/nets/test_vae.py create mode 100644 tests/nn/test_mlp.py diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py index 92780aae3..4ac313caf 100644 --- a/src/mrpro/nn/FiLM.py +++ b/src/mrpro/nn/FiLM.py @@ -19,7 +19,11 @@ class FiLM(CondMixin, Module): conditioning layer." AAAI (2018). https://arxiv.org/abs/1709.07871 """ - def __init__(self, channels: int, cond_dim: int) -> None: + features_last: bool + + def __init__( + self, channels: int, cond_dim: int, features_last: bool = False + ) -> None: """Initialize FiLM. Parameters @@ -28,11 +32,17 @@ def __init__(self, channels: int, cond_dim: int) -> None: The number of channels in the input tensor. cond_dim The dimension of the conditioning tensor. + features_last + Whether the features are in the last dimension of the input tensor (e.g. transformer tokens) + or in the second dimension (e.g. image tensors). """ super().__init__() self.project = Linear(cond_dim, 2 * channels) if cond_dim > 0 else None + self.features_last = features_last - def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__( + self, x: torch.Tensor, *, cond: torch.Tensor | None = None + ) -> torch.Tensor: """Apply FiLM. Parameters @@ -44,11 +54,21 @@ def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torc """ return super().__call__(x, cond=cond) - def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, *, cond: torch.Tensor | None = None + ) -> torch.Tensor: """Apply FiLM.""" if cond is None or self.project is None: return x - scale, shift = self.project(cond).chunk(2, dim=1) + if self.features_last: + x = x.moveaxis(-1, 1) + + scale, shift = self.project(cond).chunk(2, dim=1) scale, shift = unsqueeze_tensors_right(scale, shift, ndim=x.ndim) - return x * (1 + scale) + shift + x = x * (1 + scale) + shift + + if self.features_last: + x = x.moveaxis(1, -1) + + return x diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py index 7c35eee96..23e059fe6 100644 --- a/src/mrpro/nn/LayerNorm.py +++ b/src/mrpro/nn/LayerNorm.py @@ -10,7 +10,9 @@ class LayerNorm(CondMixin, Module): """Layer normalization.""" - def __init__(self, n_channels: int | None, features_last: bool = False, cond_dim: int = 0) -> None: + def __init__( + self, n_channels: int | None, features_last: bool = False, cond_dim: int = 0 + ) -> None: """Initialize the layer normalization. Parameters @@ -29,7 +31,7 @@ def __init__(self, n_channels: int | None, features_last: bool = False, cond_dim self.bias: Parameter | None = None self.cond_proj: Linear | None = None elif n_channels is None and cond_dim > 0: - raise ValueError('channels must be provided if cond_dim > 0') + raise ValueError("channels must be provided if cond_dim > 0") elif n_channels is not None and cond_dim == 0: self.weight = Parameter(torch.ones(n_channels)) self.bias = Parameter(torch.zeros(n_channels)) @@ -39,11 +41,13 @@ def __init__(self, n_channels: int | None, features_last: bool = False, cond_dim self.bias = None self.cond_proj = Linear(cond_dim, 2 * n_channels) else: - raise ValueError('cond_dim must be zero or positive.') + raise ValueError("cond_dim must be zero or positive.") self.features_last = features_last - def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__( + self, x: torch.Tensor, *, cond: torch.Tensor | None = None + ) -> torch.Tensor: """Apply layer normalization to the input tensor. Parameters @@ -59,25 +63,35 @@ def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torc """ return super().__call__(x, cond=cond) - def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, *, cond: torch.Tensor | None = None + ) -> torch.Tensor: """Apply layer normalization to the input tensor.""" - dims = tuple(range(1, x.ndim)) - mean = x.mean(dim=dims, keepdim=True) - std = x.std(dim=dims, keepdim=True, unbiased=False) - x = (x - mean) / (std + 1e-5) + dim = -1 if self.features_last else 1 + dtype = x.dtype + x = x.float() + var, mean = torch.var_mean(x, dim=dim, unbiased=False, keepdim=True) + x = (x - mean) * (var + 1e-5).rsqrt() + x = x.to(dtype) if self.weight is not None and self.bias is not None: if self.features_last: x = x * self.weight + self.bias else: - x = x * unsqueeze_right(self.weight, x.ndim - 2) + unsqueeze_right(self.bias, x.ndim - 2) + x = x * unsqueeze_right(self.weight, x.ndim - 2) + unsqueeze_right( + self.bias, x.ndim - 2 + ) if self.cond_proj is not None and cond is not None: scale, shift = self.cond_proj(cond).chunk(2, dim=-1) scale = 1 + scale if self.features_last: - x = x * unsqueeze_at(scale, 1, x.ndim - 2) + unsqueeze_at(shift, 1, x.ndim - 2) + x = x * unsqueeze_at(scale, 1, x.ndim - 2) + unsqueeze_at( + shift, 1, x.ndim - 2 + ) else: - x = x * unsqueeze_right(scale, x.ndim - 2) + unsqueeze_right(shift, x.ndim - 2) + x = x * unsqueeze_right(scale, x.ndim - 2) + unsqueeze_right( + shift, x.ndim - 2 + ) return x diff --git a/src/mrpro/nn/attention/MultiHeadAttention.py b/src/mrpro/nn/attention/MultiHeadAttention.py index 212bc68eb..917c34d70 100644 --- a/src/mrpro/nn/attention/MultiHeadAttention.py +++ b/src/mrpro/nn/attention/MultiHeadAttention.py @@ -45,17 +45,22 @@ def __init__( Fraction of channels to embed with RoPE. """ super().__init__() + n_channels_kv = ( + n_channels_cross if n_channels_cross is not None else n_channels_in + ) channels_per_head_q = n_channels_in // n_heads - channels_per_head_kv = n_channels_cross // n_heads if n_channels_cross is not None else n_channels_in // n_heads + channels_per_head_kv = n_channels_kv // n_heads self.to_q = Linear(n_channels_in, channels_per_head_q * n_heads) - self.to_kv = Linear(n_channels_in, channels_per_head_kv * n_heads * 2) + self.to_kv = Linear(n_channels_kv, channels_per_head_kv * n_heads * 2) self.p_dropout = p_dropout self.features_last = features_last self.to_out = Linear(n_channels_in, n_channels_out) self.n_heads = n_heads self.rope = AxialRoPE(rope_embed_fraction) - def __call__(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: + def __call__( + self, x: torch.Tensor, cross_attention: torch.Tensor | None = None + ) -> torch.Tensor: """Apply multi-head attention. Parameters @@ -76,7 +81,9 @@ def _reshape(self, x: torch.Tensor) -> torch.Tensor: x = x.moveaxis(1, -1) return x.flatten(1, -2) - def forward(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, cross_attention: torch.Tensor | None = None + ) -> torch.Tensor: """Apply multi-head attention.""" if cross_attention is None: cross_attention = x @@ -84,19 +91,27 @@ def forward(self, x: torch.Tensor, cross_attention: torch.Tensor | None = None) x = x.moveaxis(1, -1) cross_attention = cross_attention.moveaxis(1, -1) - query = rearrange(self.to_q(x), 'batch ... (heads channels) -> batch heads ... channels ', heads=self.n_heads) + query = rearrange( + self.to_q(x), + "batch ... (heads channels) -> batch heads ... channels ", + heads=self.n_heads, + ) key, value = rearrange( self.to_kv(cross_attention), - 'batch ... (kv heads channels) -> kv batch heads ... channels ', + "batch ... (kv heads channels) -> kv batch heads ... channels ", heads=self.n_heads, kv=2, ) query, key = self.rope(query, key) # NO-OP if rope_embed_fraction is 0.0 - query, key, value = query.flatten(2, -2), key.flatten(2, -2), value.flatten(2, -2) + query, key, value = ( + query.flatten(2, -2), + key.flatten(2, -2), + value.flatten(2, -2), + ) y = torch.nn.functional.scaled_dot_product_attention( query, key, value, dropout_p=self.p_dropout, is_causal=False ) - y = rearrange(y, '... heads L channels -> ... L (heads channels)') + y = rearrange(y, "... heads L channels -> ... L (heads channels)") out = self.to_out(y).reshape(x.shape) if not self.features_last: diff --git a/src/mrpro/nn/attention/SpatialTransformerBlock.py b/src/mrpro/nn/attention/SpatialTransformerBlock.py index 18817b26f..fd1e11cc7 100644 --- a/src/mrpro/nn/attention/SpatialTransformerBlock.py +++ b/src/mrpro/nn/attention/SpatialTransformerBlock.py @@ -19,9 +19,9 @@ def zero_init(m: Module) -> Module: """Initialize module weights and bias to zero.""" - if hasattr(m, 'weight') and isinstance(m.weight, torch.Tensor): + if hasattr(m, "weight") and isinstance(m.weight, torch.Tensor): torch.nn.init.zeros_(m.weight) - if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor): + if hasattr(m, "bias") and m.bias is not None and isinstance(m.bias, torch.Tensor): torch.nn.init.zeros_(m.bias) return m @@ -76,7 +76,9 @@ def __init__( ) else: if p_dropout > 0: - raise ValueError('p_dropout > 0 is not supported for neighborhood self attention') + raise ValueError( + "p_dropout > 0 is not supported for neighborhood self attention" + ) attention = NeighborhoodSelfAttention( n_channels_in=channels, n_channels_out=channels, @@ -86,7 +88,9 @@ def __init__( circular=True, rope_embed_fraction=rope_embed_fraction, ) - self.selfattention = Sequential(LayerNorm(channels, features_last=True), attention) + self.selfattention = Sequential( + LayerNorm(channels, features_last=True), attention + ) hidden_dim = int(channels * mlp_ratio) self.ff = Sequential( LayerNorm(channels, features_last=True, cond_dim=cond_dim), @@ -95,7 +99,9 @@ def __init__( Linear(hidden_dim, channels), ) - def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__( + self, x: torch.Tensor, *, cond: torch.Tensor | None = None + ) -> torch.Tensor: """Apply the basic transformer block. Parameters @@ -107,7 +113,9 @@ def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torc """ return super().__call__(x, cond=cond) - def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, *, cond: torch.Tensor | None = None + ) -> torch.Tensor: """Apply the basic transformer block.""" if not self.features_last: x = x.moveaxis(1, -1).contiguous() @@ -132,7 +140,7 @@ def __init__( rope_embed_fraction: float = 0.0, attention_neighborhood: int | None = None, features_last: bool = False, - norm: Literal['group', 'rms'] = 'group', + norm: Literal["group", "rms"] = "group", ): """Initialize the spatial transformer block. @@ -162,12 +170,12 @@ def __init__( super().__init__() hidden_dim = n_heads * (channels // n_heads) match norm: - case 'group': + case "group": self.norm: Module = GroupNorm(channels, features_last=features_last) - case 'rms': + case "rms": self.norm = RMSNorm(channels, features_last=features_last) case _: - raise ValueError(f'Invalid norm: {norm}') + raise ValueError(f"Invalid norm: {norm}") self.features_last = features_last self.proj_in = Linear(channels, hidden_dim) self.transformer_blocks = Sequential() @@ -183,10 +191,14 @@ def __init__( rope_embed_fraction=rope_embed_fraction, attention_neighborhood=attention_neighborhood, ) - self.transformer_blocks.append(PermutedBlock(group, block, features_last=True)) - self.proj_out = Linear(hidden_dim, channels) + self.transformer_blocks.append( + PermutedBlock(group, block, features_last=True) + ) + self.proj_out = zero_init(Linear(hidden_dim, channels)) - def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, *, cond: torch.Tensor | None = None + ) -> torch.Tensor: """Apply the spatial transformer block.""" skip = x h = self.norm(x) @@ -199,7 +211,9 @@ def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch h = h.movedim(-1, 1) return skip + h - def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + def __call__( + self, x: torch.Tensor, *, cond: torch.Tensor | None = None + ) -> torch.Tensor: """Apply the spatial transformer block. Parameters diff --git a/src/mrpro/nn/nets/DiT.py b/src/mrpro/nn/nets/DiT.py new file mode 100644 index 000000000..28ae78b9a --- /dev/null +++ b/src/mrpro/nn/nets/DiT.py @@ -0,0 +1,266 @@ +"""Diffusion Transformer (DiT).""" + +from collections.abc import Sequence +from math import prod + +import torch +from torch.nn import Linear, Module, Parameter, SiLU + +from mrpro.nn.attention.MultiHeadAttention import MultiHeadAttention +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.nn.nets.MLP import MLP +from mrpro.nn.Sequential import Sequential +from mrpro.operators.PatchOp import PatchOp +from mrpro.utils.to_tuple import to_tuple + + +class DiTBlock(CondMixin, Module): + """DiT block with adaptive layer normalization and residual gating. + + References + ---------- + .. [DiT] Peebles, W., & Xie, S. Scalable Diffusion Models with Transformers. + ICCV 2023, https://arxiv.org/abs/2212.09748 + """ + + features_last: bool + + def __init__( + self, + n_channels: int, + n_heads: int, + cond_dim: int, + mlp_ratio: float = 4.0, + features_last: bool = True, + ): + """Initialize a DiT block. + + Parameters + ---------- + n_channels + Number of channels in the input and output. + n_heads + Number of attention heads. + cond_dim + Number of channels in the conditioning tensor. + mlp_ratio + Ratio of hidden MLP channels to input channels. + features_last + Whether the features are in the last dimension of the input tensor. + """ + super().__init__() + self.features_last = features_last + self.norm1 = LayerNorm(n_channels, features_last=True, cond_dim=cond_dim) + self.attn = MultiHeadAttention( + n_channels_in=n_channels, + n_channels_out=n_channels, + n_heads=n_heads, + features_last=True, + ) + self.norm2 = LayerNorm(n_channels, features_last=True, cond_dim=cond_dim) + self.mlp = MLP( + n_channels_in=n_channels, + n_channels_out=n_channels, + n_features=(int(n_channels * mlp_ratio),), + norm='none', + activation='gelu', + cond_dim=0, + features_last=True, + ) + self.gate = Sequential( + SiLU(), + Linear(cond_dim, 2 * n_channels), + ) + linear = self.gate[-1] + if isinstance(linear, Linear): + torch.nn.init.zeros_(linear.weight) + torch.nn.init.zeros_(linear.bias) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the DiT block. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. + + Returns + ------- + Output tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the DiT block.""" + if not self.features_last: + x = x.moveaxis(1, -1) + + gate_msa, gate_mlp = self.gate(cond).unsqueeze(-2).chunk(2, dim=-1) if cond is not None else (1.0, 1.0) + x = x + gate_msa * self.attn(self.norm1(x, cond=cond)) + x = x + gate_mlp * self.mlp(self.norm2(x, cond=cond)) + + if not self.features_last: + x = x.moveaxis(-1, 1) + + return x + + +class DiT(Module): + """DiT model. + + DiT is a vision transformer popularized by [DiT]_. + Often used for latent diffusion models, but also suitable for image restoration etc. + + References + ---------- + .. [DiT] Peebles, W., & Xie, S. Scalable Diffusion Models with Transformers. + ICCV 2023, https://arxiv.org/abs/2212.09748 + + """ + + grid_size: tuple[int, ...] + patch_size: tuple[int, ...] + n_channels_out: int + + def __init__( + self, + n_dim: int, + n_channels_in: int, + cond_dim: int, + input_size: int | Sequence[int] = 32, + patch_size: int | Sequence[int] = 2, + n_channels_out: int | None = None, + hidden_dim: int = 1152, + depth: int = 28, + n_heads: int = 16, + mlp_ratio: float = 4.0, + ) -> None: + """Initialize DiT. + + Parameters + ---------- + n_dim + Number of spatial dimensions. + n_channels_in + Number of channels in the input tensor. + cond_dim + Dimension of the conditioning tensor. + input_size + Input spatial size. If scalar, the same size is used for all spatial dimensions. + patch_size + Patch size. If scalar, the same patch size is used for all spatial dimensions. + n_channels_out + Number of output channels. If `None`, use `n_channels_in`. + hidden_dim + Transformer hidden dimension. + depth + Number of transformer blocks. + n_heads + Number of attention heads. + mlp_ratio + Ratio of hidden MLP channels to input channels. + """ + super().__init__() + self.n_dim = n_dim + self.input_size = to_tuple(n_dim, input_size) + self.patch_size = to_tuple(n_dim, patch_size) + + if any(s % p != 0 for s, p in zip(self.input_size, self.patch_size, strict=True)): + raise ValueError(f'Input size {self.input_size} must be divisible by patch size {self.patch_size}.') + if hidden_dim % (2 * n_dim) != 0: + raise ValueError(f'Hidden dimension {hidden_dim} must be divisible by 2 * {n_dim=}.') + + self.grid_size = tuple(s // p for s, p in zip(self.input_size, self.patch_size, strict=True)) + self.n_patches = prod(self.grid_size) + self.hidden_dim = hidden_dim + + self.n_channels_in = n_channels_in + self.n_channels_out = n_channels_in if n_channels_out is None else n_channels_out + + spatial_dim = tuple(range(2, 2 + n_dim)) + self.patch_op = PatchOp( + dim=spatial_dim, + patch_size=self.patch_size, + stride=self.patch_size, + dilation=1, + domain_size=self.input_size, + ) + + patch_volume = prod(self.patch_size) + self.in_proj = Linear(n_channels_in * patch_volume, hidden_dim) + self.pos_embed = Parameter(torch.zeros(self.n_patches, hidden_dim), requires_grad=False) + + self.blocks = Sequential( + *( + DiTBlock( + n_channels=hidden_dim, + n_heads=n_heads, + cond_dim=cond_dim, + mlp_ratio=mlp_ratio, + features_last=True, + ) + for _ in range(depth) + ) + ) + + self.final_layer = Sequential( + LayerNorm(hidden_dim, features_last=True, cond_dim=cond_dim), + Linear(hidden_dim, patch_volume * self.n_channels_out), + ) + + self.initialize_weights() + + def initialize_weights(self) -> None: + """Initialize network weights.""" + + def _basic_init(module: Module) -> None: + if isinstance(module, Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + + self.apply(_basic_init) + + w = self.in_proj.weight.data + torch.nn.init.xavier_uniform_(w.reshape(w.shape[0], -1)) + if self.in_proj.bias is not None: + torch.nn.init.zeros_(self.in_proj.bias) + + for block in self.blocks: + if isinstance(block, DiTBlock): + gate_linear = block.gate[-1] + if isinstance(gate_linear, Linear): + torch.nn.init.zeros_(gate_linear.weight) + torch.nn.init.zeros_(gate_linear.bias) + + w = 1.0 / (10000 ** torch.linspace(0, 1, self.hidden_dim // (2 * len(self.grid_size)))) + x = torch.stack(torch.meshgrid(*[torch.arange(s).float() for s in self.grid_size], indexing='ij'), dim=-1) + wx = w * x.unsqueeze(-1) + pos_embed = torch.cat([torch.sin(wx), torch.cos(wx)], dim=-1).reshape(-1, self.hidden_dim) + self.pos_embed.data.copy_(pos_embed.to(self.pos_embed.data)) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply DiT. + + Parameters + ---------- + x + Input tensor with shape `batch, channels, *spatial_dims`. + cond + Conditioning tensor. + + Returns + ------- + Output tensor with shape `batch, out_channels, *spatial_dims`. + """ + x = self.patch_op(x)[0].swapaxes(0, 1).flatten(2) + x = self.in_proj(x) + x = x + self.pos_embed + x = self.blocks(x, cond=cond) + x = self.final_layer(x, cond=cond) + x = x.unflatten(-1, (self.n_channels_out, *self.patch_size)).swapaxes(0, 1) + (x,) = self.patch_op.adjoint(x) + return x diff --git a/src/mrpro/nn/nets/MLP.py b/src/mrpro/nn/nets/MLP.py new file mode 100644 index 000000000..524a985cd --- /dev/null +++ b/src/mrpro/nn/nets/MLP.py @@ -0,0 +1,119 @@ +"""Multi-layer perceptron.""" + +from collections.abc import Sequence +from itertools import pairwise +from typing import Literal + +import torch +from torch.nn import GELU, LeakyReLU, Linear, ReLU, SiLU + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.nn.Sequential import Sequential + + +class MLP(Sequential): + """Multi-layer perceptron. + + A series of linear layers, normalization and activation. + Allows FiLM conditioning. + Order is Linear -> Norm (optional) -> FiLM (optional) -> Activation. + + If you need more flexibility, use `~mrpro.nn.Sequential` directly. + """ + + features_last: bool + + def __init__( + self, + n_channels_in: int, + n_channels_out: int, + norm: Literal['layer', 'none'] = 'none', + activation: Literal['gelu', 'relu', 'silu', 'leaky_relu'] = 'gelu', + n_features: Sequence[int] = (256, 256), + cond_dim: int = 0, + features_last: bool = True, + ): + """Initialize a MLP. + + Parameters + ---------- + n_channels_in + The number of input channels. + n_channels_out + The number of output channels. + norm + The type of normalization to use. If `layer`, use layer normalization. + If `none`, use no normalization. + activation + The type of activation to use. If `gelu`, use GELU. + If `relu`, use ReLU. If `silu`, use SiLU. If `leaky_relu`, use LeakyReLU. + n_features + The number of features in the hidden layers. The length of this sequence determines the number of hidden + layers. The total number of linear layers is `len(n_features) + 1`. + cond_dim + The dimension of the condition tensor. If 0, no FiLM conditioning is applied. + Otherwise, between linear layers, after normalization, FiLM conditioning is applied. + features_last + Whether the features are in the last dimension, as common in transformer models, + or in the second dimension, as common in image models. + """ + super().__init__() + use_film = cond_dim > 0 + self.features_last = features_last + + if len(n_features) == 0: + self.append(Linear(n_channels_in, n_channels_out)) + return + + self.append(Linear(n_channels_in, n_features[0])) + + for c_in, c_out in pairwise((*n_features, n_channels_out)): + if norm.lower() == 'layer': + self.append(LayerNorm(c_in, features_last=True)) + elif norm.lower() != 'none': + raise ValueError(f'Invalid normalization type: {norm}') + + if use_film: + self.append(FiLM(c_in, cond_dim, features_last=True)) + + if activation.lower() == 'gelu': + self.append(GELU(approximate='tanh')) + elif activation.lower() == 'relu': + self.append(ReLU()) + elif activation.lower() == 'silu': + self.append(SiLU()) + elif activation.lower() == 'leaky_relu': + self.append(LeakyReLU()) + else: + raise ValueError(f'Invalid activation type: {activation}') + + self.append(Linear(c_in, c_out)) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: # type: ignore[override] + """Apply the MLP to the input tensor. + + Parameters + ---------- + x + The input tensor. + cond + The condition tensor. If None, no FiLM conditioning is applied. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the MLP to the input tensor.""" + if len(x) != 1: + raise ValueError(f'Mlp expects exactly one input tensor, got {len(x)}') + tensor = x[0] + if not self.features_last: + tensor = tensor.moveaxis(1, -1) + out = super().forward(tensor, cond=cond) + if not self.features_last: + out = out.moveaxis(-1, 1) + return out diff --git a/src/mrpro/nn/nets/VAE.py b/src/mrpro/nn/nets/VAE.py index e0b1bfc58..b1d48ef33 100644 --- a/src/mrpro/nn/nets/VAE.py +++ b/src/mrpro/nn/nets/VAE.py @@ -1,10 +1,19 @@ """Variational Autoencoder with a Gaussian latent space.""" +from collections.abc import Sequence +from itertools import pairwise + import torch -from torch.nn import Module +from torch.nn import Module, SiLU + +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import convND +from mrpro.nn.ResBlock import ResBlock +from mrpro.nn.Sequential import Sequential +from mrpro.nn.Upsample import Upsample -class VAE(Module): +class VAEBase(Module): """Basic Variational Autoencoder. Consists of an encoder to transform the input into a latent space and a decoder to transform the latent space back @@ -60,5 +69,82 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: std = torch.exp(0.5 * logvar) sample = mean + torch.randn_like(std) * std reconstruction = self.decoder(sample) - kl = -0.5 * torch.sum(1 + logvar - mean.square() - std.square()) + kl = (-0.5 / len(z)) * torch.sum(1 + logvar - mean.square() - std.square()) return reconstruction, kl + + +class VAE(VAEBase): + """Variational autoencoder with convolutional encoder and decoder.""" + + def __init__( + self, + n_dim: int = 2, + n_channels_in: int = 2, + latent_channels: int = 8, + n_features: Sequence[int] = (32, 64, 128), + n_res_blocks: int = 2, + ) -> None: + """Initialize the VAE. + + Parameters + ---------- + n_dim + The number of dimensions, i.e. 1, 2 or 3. + n_channels_in + The number of channels in the input tensor. + latent_channels + The number of channels in the latent space. + n_features + The number of features at each resolution level. + n_res_blocks + Number of residual blocks per resolution level. + """ + encoder = Sequential( + convND(n_dim)(n_channels_in, n_features[0], kernel_size=3, padding=1) + ) + + for n_feat, n_feat_next in pairwise(n_features): + for _ in range(n_res_blocks): + encoder.append(ResBlock(n_dim, n_feat, n_feat, cond_dim=0)) + encoder.append( + convND(n_dim)(n_feat, n_feat_next, kernel_size=3, stride=2, padding=1) + ) + + for _ in range(n_res_blocks): + encoder.append(ResBlock(n_dim, n_features[-1], n_features[-1], cond_dim=0)) + + encoder.extend( + [ + GroupNorm(n_features[-1]), + SiLU(), + convND(n_dim)( + n_features[-1], 2 * latent_channels, kernel_size=3, padding=1 + ), + ] + ) + + decoder = Sequential( + convND(n_dim)(latent_channels, n_features[-1], kernel_size=3, padding=1) + ) + for _ in range(n_res_blocks): + decoder.append(ResBlock(n_dim, n_features[-1], n_features[-1], cond_dim=0)) + + for n_feat, n_feat_next in pairwise(reversed(n_features)): + decoder.append( + Sequential( + Upsample(dim=range(-n_dim, 0), scale_factor=2, mode="linear"), + convND(n_dim)(n_feat, n_feat_next, kernel_size=3, padding=1), + ) + ) + for _ in range(n_res_blocks): + decoder.append(ResBlock(n_dim, n_feat_next, n_feat_next, cond_dim=0)) + + decoder.extend( + [ + GroupNorm(n_features[0]), + SiLU(), + convND(n_dim)(n_features[0], n_channels_in, kernel_size=3, padding=1), + ] + ) + + super().__init__(encoder=encoder, decoder=decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index 87d9075f7..a3752e60f 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -1,20 +1,22 @@ from mrpro.nn.nets.BasicCNN import BasicCNN -from mrpro.nn.nets.DCVAE import DCVAE +from mrpro.nn.nets.VAE import VAE +from mrpro.nn.nets.DiT import DiT from mrpro.nn.nets.HourglassTransformer import HourglassTransformer from mrpro.nn.nets.Restormer import Restormer from mrpro.nn.nets.SwinIR import SwinIR from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet from mrpro.nn.nets.Uformer import Uformer -from mrpro.nn.nets.VAE import VAE +from mrpro.nn.nets.MLP import MLP __all__ = [ - 'AttentionGatedUNet', - 'BasicCNN', - 'DCVAE', - 'HourglassTransformer', - 'Restormer', - 'SwinIR', - 'UNet', - 'Uformer', - 'VAE', + "AttentionGatedUNet", + "BasicCNN", + "DiT", + "HourglassTransformer", + "MLP", + "Restormer", + "SwinIR", + "UNet", + "Uformer", + "VAE", ] diff --git a/src/mrpro/operators/GridSamplingOp.py b/src/mrpro/operators/GridSamplingOp.py index fd25aa832..9cc0daa01 100644 --- a/src/mrpro/operators/GridSamplingOp.py +++ b/src/mrpro/operators/GridSamplingOp.py @@ -293,11 +293,14 @@ def from_displacement( f'Got shapes {displacement_z.shape}, {displacement_y.shape}, {displacement_x.shape}.' ) from None grid_z, grid_y, grid_x = torch.meshgrid( - torch.linspace(-1, 1, n_z), torch.linspace(-1, 1, n_y), torch.linspace(-1, 1, n_x), indexing='ij' + torch.linspace(-1, 1, n_z), + torch.linspace(-1, 1, n_y), + torch.linspace(-1, 1, n_x), + indexing='ij', ) - grid_z = grid_z + displacement_z * 2 / (n_z - 1) - grid_y = grid_y + displacement_y * 2 / (n_y - 1) - grid_x = grid_x + displacement_x * 2 / (n_x - 1) + grid_z = grid_z.to(displacement_z) + displacement_z * 2 / (n_z - 1) + grid_y = grid_y.to(displacement_y) + displacement_y * 2 / (n_y - 1) + grid_x = grid_x.to(displacement_x) + displacement_x * 2 / (n_x - 1) else: # 2D if displacement_x.ndim < 3 or displacement_y.ndim < 3: raise ValueError( @@ -312,8 +315,8 @@ def from_displacement( f'Got shapes {displacement_y.shape}, {displacement_x.shape}.' ) from None grid_y, grid_x = torch.meshgrid(torch.linspace(-1, 1, n_y), torch.linspace(-1, 1, n_x), indexing='ij') - grid_y = grid_y + displacement_y * 2 / (n_y - 1) - grid_x = grid_x + displacement_x * 2 / (n_x - 1) + grid_y = grid_y.to(displacement_y) + displacement_y * 2 / (n_y - 1) + grid_x = grid_x.to(displacement_x) + displacement_x * 2 / (n_x - 1) grid_z = None return cls(grid_z, grid_y, grid_x, None, interpolation_mode, padding_mode, align_corners=True) diff --git a/src/mrpro/operators/PatchOp.py b/src/mrpro/operators/PatchOp.py index 190bf28fb..9339f74d0 100644 --- a/src/mrpro/operators/PatchOp.py +++ b/src/mrpro/operators/PatchOp.py @@ -21,6 +21,7 @@ def __init__( stride: Sequence[int] | int | None = None, dilation: Sequence[int] | int = 1, domain_size: int | Sequence[int] | None = None, + flatten_patches: bool = True, ) -> None: """Initialize the PatchOp. @@ -38,6 +39,9 @@ def __init__( Size of the domain in the dimnsions `dim`. If None, it is inferred from the input tensor on the first call. This is only used in the adjoint method. + flatten_patches + If True, flatten the leading grid dimensions to a single patch dimension. + If False, keep shape ``(*grid_size, ...)`` for the forward output. """ super().__init__() self.dim = (dim,) if isinstance(dim, int) else dim @@ -60,6 +64,7 @@ def check(param: int | Sequence[int], name: str) -> tuple[int, ...]: self.stride = check(stride, 'stride') if stride is not None else self.patch_size self.dilation = check(dilation, 'dilation') self.domain_size = check(domain_size, 'domain_size') if domain_size is not None else None + self.flatten_patches = flatten_patches def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor,]: """Extract N-dimensional patches from an input tensor using a sliding window. @@ -100,9 +105,59 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: stride=self.stride, dilation=self.dilation, ) - patches = patches.flatten(start_dim=0, end_dim=len(self.dim) - 1) + if self.flatten_patches: + patches = patches.flatten(start_dim=0, end_dim=len(self.dim) - 1) return (patches,) + def _adjoint_fast(self, patches: torch.Tensor) -> torch.Tensor: + """Adjoint via reshape/permute for non-overlapping patches.""" + assert self.domain_size is not None # mypy # noqa: S101 + grid = tuple(s // p for s, p in zip(self.domain_size, self.patch_size, strict=True)) + n_dim = len(grid) + if self.flatten_patches: + patches = patches.unflatten(0, grid) + permutation: list[int] = [] + reshape: list[int] = [] + dim = [d % (patches.ndim - n_dim) for d in self.dim] + for i, size in enumerate(patches.shape[n_dim:]): + if i in dim: + j = dim.index(i) + permutation.extend([j, n_dim + i]) + reshape.append(grid[j] * self.patch_size[j]) + else: + permutation.append(n_dim + i) + reshape.append(size) + return patches.permute(*permutation).reshape(reshape) + + def _adjoint_scatter(self, patches: torch.Tensor) -> torch.Tensor: + """Adjoint via scatter for overlapping patches.""" + assert self.domain_size is not None # mypy # noqa: S101 + k = len(self.dim) + if not self.flatten_patches: + patches = patches.flatten(start_dim=0, end_dim=k - 1) + output_shape_ = list(patches.shape[1:]) + for dim, size in zip(self.dim, self.domain_size, strict=True): + output_shape_[dim] = size + output_shape = torch.Size(output_shape_) + indices = torch.arange(output_shape.numel(), device=patches.device).reshape(output_shape_) + windowed_indices = sliding_window( + x=indices, + window_shape=self.patch_size, + dim=self.dim, + stride=self.stride, + dilation=self.dilation, + ).flatten(start_dim=0, end_dim=k - 1) + if windowed_indices.shape[0] != patches.shape[0]: + raise ValueError( + f'Number of patches {patches.shape[0]} does not match the number of ' + f'expected patches {windowed_indices.shape[0]}' + ) + + assembled = patches.new_zeros(output_shape.numel()) + assembled.scatter_add_(dim=0, index=windowed_indices.flatten(), src=patches.flatten()) + assembled = assembled.reshape(output_shape) + return assembled + def adjoint( self, patches: torch.Tensor, @@ -127,26 +182,11 @@ def adjoint( """ if self.domain_size is None: raise ValueError('Domain size is not set. Please call forward first or set it at initialization.') - - output_shape_ = list(patches.shape[1:]) - for dim, size in zip(self.dim, self.domain_size, strict=True): - output_shape_[dim] = size - output_shape = torch.Size(output_shape_) - indices = torch.arange(output_shape.numel(), device=patches.device).reshape(output_shape_) - windowed_indices = sliding_window( - x=indices, - window_shape=self.patch_size, - dim=self.dim, - stride=self.stride, - dilation=self.dilation, - ).flatten(start_dim=0, end_dim=len(self.dim) - 1) - if windowed_indices.shape[0] != patches.shape[0]: - raise ValueError( - f'Number of patches {patches.shape[0]} does not match the number of ' - f'expected patches {windowed_indices.shape[0]}' - ) - - assembled = patches.new_zeros(output_shape.numel()) - assembled.scatter_add_(dim=0, index=windowed_indices.flatten(), src=patches.flatten()) - assembled = assembled.reshape(output_shape) - return (assembled,) + if ( + self.stride == self.patch_size # no overlap + and all(d == 1 for d in self.dilation) # no dilation + and all(s % p == 0 for s, p in zip(self.domain_size, self.patch_size, strict=True)) # divisible + ): + return (self._adjoint_fast(patches),) + else: + return (self._adjoint_scatter(patches),) diff --git a/tests/nn/nets/test_dit.py b/tests/nn/nets/test_dit.py new file mode 100644 index 000000000..aed95599b --- /dev/null +++ b/tests/nn/nets/test_dit.py @@ -0,0 +1,115 @@ +"""Tests for DiT network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import DiT +from mrpro.nn.nets.DiT import DiTBlock +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_ditblock_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of DiTBlock.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 64, 32)).to(device).requires_grad_(True) + cond = rng.float32_tensor((1, 16)).to(device).requires_grad_(True) + block = DiTBlock(n_channels=32, n_heads=4, cond_dim=16, mlp_ratio=2.0, features_last=True).to(device) + if torch_compile: + block = cast(DiTBlock, torch.compile(block, dynamic=False)) + y = block(x, cond=cond) + assert y.shape == x.shape + assert not y.isnan().any(), 'NaN values in output' + + +def test_ditblock_backward() -> None: + """Test the backward pass of DiTBlock.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 8, 8)).requires_grad_(True) + cond = rng.float32_tensor((1, 12)).requires_grad_(True) + block = DiTBlock(n_channels=32, n_heads=4, cond_dim=12, mlp_ratio=2.0, features_last=False) + y = block(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in block.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +@pytest.mark.parametrize('input_size', [(16, 32), (4, 8, 16)], ids=['2d', '3d']) +def test_dit_forward(torch_compile: bool, device: str, input_size: tuple[int, ...]) -> None: + """Test the forward pass of DiT.""" + n_channels_in = 3 + n_channels_out = 2 + n_batch = 1 + hidden_dim = 12 + cond_dim = 32 + dit = DiT( + n_dim=len(input_size), + n_channels_in=n_channels_in, + cond_dim=cond_dim, + input_size=input_size, + patch_size=2, + n_channels_out=n_channels_out, + hidden_dim=hidden_dim, + depth=2, + n_heads=4, + mlp_ratio=2.0, + ) + + x = torch.zeros(n_batch, n_channels_in, *input_size, device=device) + cond = torch.zeros(n_batch, cond_dim, device=device) + dit = dit.to(device) + if torch_compile: + dit = cast(DiT, torch.compile(dit)) + y = dit(x, cond=cond) + assert y.shape == (n_batch, n_channels_out, *input_size) + + +def test_dit_backward() -> None: + """Test the backward pass of DiT.""" + dit = DiT( + n_dim=2, + n_channels_in=1, + cond_dim=24, + input_size=16, + patch_size=2, + n_channels_out=1, + hidden_dim=32, + depth=2, + n_heads=4, + mlp_ratio=2.0, + ) + + x = torch.zeros(1, 1, 16, 16, requires_grad=True) + cond = torch.zeros(1, 24, requires_grad=True) + y = dit(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in dit.named_parameters(): + if name == 'pos_embed': + continue # embedding is fixed + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' diff --git a/tests/nn/nets/test_vae.py b/tests/nn/nets/test_vae.py new file mode 100644 index 000000000..621a505d7 --- /dev/null +++ b/tests/nn/nets/test_vae.py @@ -0,0 +1,79 @@ +"""Tests for VAE network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import VAE + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_vae_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the VAE.""" + vae = VAE( + n_dim=2, + n_channels_in=1, + latent_channels=4, + n_features=(6, 8, 10), + n_res_blocks=2, + ) + + x = torch.zeros(1, 1, 8, 8, device=device) + vae = vae.to(device) + x = x.to(device) + if torch_compile: + vae = cast(VAE, torch.compile(vae)) + y, kl = vae(x) + assert y.shape == (1, 1, 8, 8) + assert kl.shape == () + latent = vae.encoder(x) + assert latent.shape == (1, 2 * 4, 2, 2) # 2 because of mean and logvar + + +def test_vae_backward_kl() -> None: + """Test the backward pass of the VAE wrt kl.""" + vae = VAE( + n_dim=1, + n_channels_in=1, + latent_channels=4, + n_features=(6, 8, 10), + n_res_blocks=2, + ) + + x = torch.zeros(1, 1, 8, requires_grad=True) + + _, kl = vae(x) + kl.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + for name, parameter in vae.encoder.named_parameters(): # only the encoder parameters can influence kl + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' + + +def test_vae_backward_y() -> None: + """Test the backward pass of the VAE wrt y.""" + vae = VAE( + n_dim=1, + n_channels_in=1, + latent_channels=4, + n_features=(6, 8, 10), + n_res_blocks=2, + ) + + x = torch.zeros(1, 1, 8, requires_grad=True) + + y, _ = vae(x) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + for name, parameter in vae.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' diff --git a/tests/nn/test_film.py b/tests/nn/test_film.py index 0e564c675..95835aea0 100644 --- a/tests/nn/test_film.py +++ b/tests/nn/test_film.py @@ -3,26 +3,31 @@ from collections.abc import Sequence import pytest +import torch from mrpro.nn.FiLM import FiLM from mrpro.utils import RandomGenerator @pytest.mark.parametrize( - 'device', + "device", [ - pytest.param('cpu', id='cpu'), - pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + pytest.param("cpu", id="cpu"), + pytest.param("cuda", id="cuda", marks=pytest.mark.cuda), ], ) @pytest.mark.parametrize( - ('n_channels', 'n_channels_cond', 'input_shape', 'cond_shape'), + ("n_channels", "n_channels_cond", "input_shape", "cond_shape"), [ (64, 32, (1, 64, 32, 32), (1, 32)), (32, 16, (2, 32, 16, 16), (2, 16)), ], ) def test_film( - n_channels: int, n_channels_cond: int, input_shape: Sequence[int], cond_shape: Sequence[int], device: str + n_channels: int, + n_channels_cond: int, + input_shape: Sequence[int], + cond_shape: Sequence[int], + device: str, ) -> None: """Test FiLM output shape and backpropagation.""" rng = RandomGenerator(seed=42) @@ -30,13 +35,32 @@ def test_film( cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) film = FiLM(channels=n_channels, cond_dim=n_channels_cond).to(device) output = film(x, cond=cond) - assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + assert output.shape == x.shape, ( + f"Output shape {output.shape} != input shape {x.shape}" + ) output.sum().backward() - assert x.grad is not None, 'No gradient computed for input' - assert cond.grad is not None, 'No gradient computed for conditioning' - assert not output.isnan().any(), 'NaN values in output' - assert not cond.isnan().any(), 'NaN values in conditioning' - assert not x.grad.isnan().any(), 'NaN values in input gradients' - assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' - assert film.project is not None, 'Linear layer is not initialized' - assert next(film.project.parameters()).grad is not None, 'No gradient computed for Linear layer' + assert x.grad is not None, "No gradient computed for input" + assert cond.grad is not None, "No gradient computed for conditioning" + assert not output.isnan().any(), "NaN values in output" + assert not cond.isnan().any(), "NaN values in conditioning" + assert not x.grad.isnan().any(), "NaN values in input gradients" + assert not cond.grad.isnan().any(), "NaN values in conditioning gradients" + assert film.project is not None, "Linear layer is not initialized" + assert next(film.project.parameters()).grad is not None, ( + "No gradient computed for Linear layer" + ) + + +def test_film_features_last() -> None: + """Test FiLM with features_last=True vs features_last=False.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)) + cond = rng.float32_tensor((1, 8)) + + film_last = FiLM(channels=3, cond_dim=8, features_last=True) + film = FiLM(channels=3, cond_dim=8, features_last=False) + film.load_state_dict(film_last.state_dict()) + + y_last = film_last(x.moveaxis(1, -1), cond=cond) + y = film(x, cond=cond) + torch.testing.assert_close(y, y_last.moveaxis(-1, 1)) diff --git a/tests/nn/test_layernorm.py b/tests/nn/test_layernorm.py index ebc11ccb8..0123aff0a 100644 --- a/tests/nn/test_layernorm.py +++ b/tests/nn/test_layernorm.py @@ -9,14 +9,14 @@ @pytest.mark.parametrize( - 'device', + "device", [ - pytest.param('cpu', id='cpu'), - pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + pytest.param("cpu", id="cpu"), + pytest.param("cuda", id="cuda", marks=pytest.mark.cuda), ], ) @pytest.mark.parametrize( - ('n_channels', 'features_last', 'input_shape'), + ("n_channels", "features_last", "input_shape"), [ (32, False, (1, 32, 32, 32)), (64, True, (2, 16, 16, 64)), @@ -36,21 +36,27 @@ def test_layernorm_basic( norm = LayerNorm(n_channels=n_channels, features_last=features_last).to(device) output = norm(x) - assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + assert output.shape == x.shape, ( + f"Output shape {output.shape} != input shape {x.shape}" + ) output.sum().backward() - assert x.grad is not None, 'No gradient computed for input' - assert not output.isnan().any(), 'NaN values in output' - assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert x.grad is not None, "No gradient computed for input" + assert not output.isnan().any(), "NaN values in output" + assert not x.grad.isnan().any(), "NaN values in input gradients" if n_channels is not None: - assert norm.weight is not None, 'Weight should not be None when n_channels is provided' - assert norm.bias is not None, 'Bias should not be None when n_channels is provided' - assert norm.weight.grad is not None, 'No gradient computed for weight' - assert norm.bias.grad is not None, 'No gradient computed for bias' + assert norm.weight is not None, ( + "Weight should not be None when n_channels is provided" + ) + assert norm.bias is not None, ( + "Bias should not be None when n_channels is provided" + ) + assert norm.weight.grad is not None, "No gradient computed for weight" + assert norm.bias.grad is not None, "No gradient computed for bias" @pytest.mark.parametrize( - ('n_channels', 'cond_dim', 'input_shape', 'cond_shape'), + ("n_channels", "cond_dim", "input_shape", "cond_shape"), [ (32, 16, (1, 32, 32, 32), (1, 16)), (64, 32, (2, 64, 16, 16), (2, 32)), @@ -69,13 +75,15 @@ def test_layernorm_with_conditioning( norm = LayerNorm(n_channels=n_channels, cond_dim=cond_dim) output = norm(x, cond=cond) - assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + assert output.shape == x.shape, ( + f"Output shape {output.shape} != input shape {x.shape}" + ) output.sum().backward() - assert x.grad is not None, 'No gradient computed for input' - assert cond.grad is not None, 'No gradient computed for conditioning' - assert norm.cond_proj is not None, 'cond_proj should not be None when cond_dim > 0' - assert norm.cond_proj.weight.grad is not None, 'No gradient computed for cond_proj' + assert x.grad is not None, "No gradient computed for input" + assert cond.grad is not None, "No gradient computed for conditioning" + assert norm.cond_proj is not None, "cond_proj should not be None when cond_dim > 0" + assert norm.cond_proj.weight.grad is not None, "No gradient computed for cond_proj" def test_layernorm_features_last() -> None: @@ -99,26 +107,33 @@ def test_layernorm_no_channels() -> None: norm = LayerNorm(n_channels=None) output = norm(x) - assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + assert output.shape == x.shape, ( + f"Output shape {output.shape} != input shape {x.shape}" + ) - # Check that normalization is applied (mean close to 0, std close to 1) - dims = tuple(range(1, x.ndim)) - mean = output.mean(dim=dims) - std = output.std(dim=dims) + # Check that normalization is applied over channel dim (dim=1 for features_last=False) + mean = output.mean(dim=1, keepdim=True) + var = (output * output).mean(dim=1, keepdim=True) - mean * mean - assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-6), 'Mean not close to 0' - assert torch.allclose(std, torch.ones_like(std), atol=1e-5), 'Std not close to 1' + assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-5), ( + "Mean not close to 0" + ) + assert torch.allclose(var, torch.ones_like(var), atol=1e-3), ( + "Variance not close to 1" + ) def test_layernorm_conditioning_without_channels() -> None: """Test LayerNorm with conditioning but no channels (should raise error).""" - with pytest.raises(ValueError, match='channels must be provided if cond_dim > 0'): + with pytest.raises(ValueError, match="channels must be provided if cond_dim > 0"): LayerNorm(n_channels=None, cond_dim=16) def test_layernorm_invalid_cond_dim() -> None: """Test LayerNorm with invalid cond_dim.""" - with pytest.raises(RuntimeError, match='Trying to create tensor with negative dimension'): + with pytest.raises( + RuntimeError, match="Trying to create tensor with negative dimension" + ): LayerNorm(n_channels=32, cond_dim=-1) @@ -129,10 +144,12 @@ def test_layernorm_3d_input() -> None: norm = LayerNorm(n_channels=64) output = norm(x) - assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + assert output.shape == x.shape, ( + f"Output shape {output.shape} != input shape {x.shape}" + ) output.sum().backward() - assert x.grad is not None, 'No gradient computed for input' + assert x.grad is not None, "No gradient computed for input" def test_layernorm_5d_input() -> None: @@ -142,10 +159,12 @@ def test_layernorm_5d_input() -> None: norm = LayerNorm(n_channels=32) output = norm(x) - assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + assert output.shape == x.shape, ( + f"Output shape {output.shape} != input shape {x.shape}" + ) output.sum().backward() - assert x.grad is not None, 'No gradient computed for input' + assert x.grad is not None, "No gradient computed for input" def test_layernorm_conditioning_features_last() -> None: @@ -157,11 +176,13 @@ def test_layernorm_conditioning_features_last() -> None: norm = LayerNorm(n_channels=3, features_last=True, cond_dim=8) output = norm(x.moveaxis(1, -1), cond=cond) - assert output.shape == x.moveaxis(1, -1).shape, f'Output shape {output.shape} != expected shape' + assert output.shape == x.moveaxis(1, -1).shape, ( + f"Output shape {output.shape} != expected shape" + ) output.sum().backward() - assert x.grad is not None, 'No gradient computed for input' - assert cond.grad is not None, 'No gradient computed for conditioning' + assert x.grad is not None, "No gradient computed for input" + assert cond.grad is not None, "No gradient computed for conditioning" def test_layernorm_gradient_flow() -> None: @@ -175,13 +196,19 @@ def test_layernorm_gradient_flow() -> None: loss.backward() # Check that gradients are computed for all learnable parameters - assert x.grad is not None, 'Input gradients not computed' - assert norm.weight is not None, 'Weight should not be None when n_channels is provided' - assert norm.bias is not None, 'Bias should not be None when n_channels is provided' - assert norm.weight.grad is not None, 'Weight gradients not computed' - assert norm.bias.grad is not None, 'Bias gradients not computed' + assert x.grad is not None, "Input gradients not computed" + assert norm.weight is not None, ( + "Weight should not be None when n_channels is provided" + ) + assert norm.bias is not None, "Bias should not be None when n_channels is provided" + assert norm.weight.grad is not None, "Weight gradients not computed" + assert norm.bias.grad is not None, "Bias gradients not computed" # Check that gradients are finite - assert torch.isfinite(x.grad).all(), 'Input gradients contain non-finite values' - assert torch.isfinite(norm.weight.grad).all(), 'Weight gradients contain non-finite values' - assert torch.isfinite(norm.bias.grad).all(), 'Bias gradients contain non-finite values' + assert torch.isfinite(x.grad).all(), "Input gradients contain non-finite values" + assert torch.isfinite(norm.weight.grad).all(), ( + "Weight gradients contain non-finite values" + ) + assert torch.isfinite(norm.bias.grad).all(), ( + "Bias gradients contain non-finite values" + ) diff --git a/tests/nn/test_mlp.py b/tests/nn/test_mlp.py new file mode 100644 index 000000000..22f5b9ca1 --- /dev/null +++ b/tests/nn/test_mlp.py @@ -0,0 +1,89 @@ +"""Tests for Mlp module.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import MLP +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_mlp_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the Mlp.""" + mlp = MLP( + n_channels_in=8, + n_channels_out=4, + norm='layer', + activation='gelu', + n_features=(16,), + cond_dim=12, + features_last=False, + ).to(device) + x = torch.zeros(1, 8, 9, 7, device=device) + cond = torch.zeros(1, 12, device=device) + if torch_compile: + mlp = cast(MLP, torch.compile(mlp)) + y = mlp(x, cond=cond) + assert y.shape == (1, 4, 9, 7) + + +def test_mlp_backward() -> None: + """Test the backward pass of the Mlp.""" + mlp = MLP( + n_channels_in=6, + n_channels_out=3, + norm='none', + activation='silu', + n_features=(12, 12), + cond_dim=10, + features_last=True, + ) + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 20, 6)).requires_grad_(True) + cond = rng.float32_tensor((1, 10)).requires_grad_(True) + y = mlp(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in mlp.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' + + +def test_mlp_features_last() -> None: + """Test Mlp with features_last=True vs features_last=False.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + + mlp_last = MLP( + n_channels_in=3, + n_channels_out=4, + norm='layer', + activation='relu', + n_features=(6,), + cond_dim=0, + features_last=True, + ) + mlp = MLP( + n_channels_in=3, + n_channels_out=4, + norm='layer', + activation='relu', + n_features=(6,), + cond_dim=0, + features_last=False, + ) + mlp.load_state_dict(mlp_last.state_dict()) + y_last = mlp_last(x.moveaxis(1, -1)) + y = mlp(x) + torch.testing.assert_close(y, y_last.moveaxis(-1, 1)) diff --git a/tests/operators/test_grid_sampling_op.py b/tests/operators/test_grid_sampling_op.py index 9430a426d..f0adf255b 100644 --- a/tests/operators/test_grid_sampling_op.py +++ b/tests/operators/test_grid_sampling_op.py @@ -1,6 +1,7 @@ """Tests for grid sampling operator.""" import contextlib +from typing import Any, Literal import pytest import torch @@ -13,7 +14,7 @@ @pytest.mark.parametrize('dtype', ['float32', 'float64', 'complex64']) -def test_grid_sampling_op_dtype(dtype): +def test_grid_sampling_op_dtype(dtype: str) -> None: """Test for different data types.""" _test_grid_sampling_op_adjoint(dtype=dtype) @@ -22,33 +23,33 @@ def test_grid_sampling_op_dtype(dtype): @pytest.mark.parametrize('batched', ['batched', 'non_batched']) @pytest.mark.parametrize('channel', ['multi_channel', 'single_channel']) @pytest.mark.parametrize('dtype', ['float32', 'complex64']) -def test_grid_sampling_op_dim_batch_channel(dim_str, batched, channel, dtype): +def test_grid_sampling_op_dim_batch_channel(dim_str: str, batched: str, channel: str, dtype: str) -> None: """Test for different dimensions.""" _test_grid_sampling_op_adjoint(dim=int(dim_str[0]), batched=batched, channel=channel, dtype=dtype) @pytest.mark.parametrize('interpolation_mode', ['bilinear', 'nearest', 'bicubic']) -def test_grid_sampling_op_interpolation_mode(interpolation_mode): +def test_grid_sampling_op_interpolation_mode(interpolation_mode: str) -> None: """Test for different interpolation_modes.""" # bicubic only supports 2D _test_grid_sampling_op_adjoint(dim=2, interpolation_mode=interpolation_mode) @pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection']) -def test_grid_sampling_op_padding_mode(padding_mode): +def test_grid_sampling_op_padding_mode(padding_mode: str) -> None: """Test for different padding_modes.""" _test_grid_sampling_op_adjoint(padding_mode=padding_mode) @pytest.mark.parametrize('align_corners', ['no_align', 'align']) -def test_grid_sampling_op_align_mode(align_corners): +def test_grid_sampling_op_align_mode(align_corners: str) -> None: """Test for different align modes .""" _test_grid_sampling_op_adjoint(align_corners=align_corners) def _test_grid_sampling_op_adjoint( dtype='float32', - dim=2, + dim: int = 2, interpolation_mode='bilinear', padding_mode='zeros', align_corners='no_align', @@ -80,7 +81,9 @@ def _test_grid_sampling_op_adjoint( @pytest.mark.parametrize('interpolation_mode', ['bilinear', 'nearest', 'bicubic']) -def test_grid_sampling_op_interpolation_mode_backward_is_adjoint(interpolation_mode): +def test_grid_sampling_op_interpolation_mode_backward_is_adjoint( + interpolation_mode: Literal['bilinear', 'nearest', 'bicubic'], +) -> None: """Test for different interpolation_modes.""" # bicubic only supports 2D dim = 2 if interpolation_mode == 'bicubic' else 3 @@ -88,18 +91,25 @@ def test_grid_sampling_op_interpolation_mode_backward_is_adjoint(interpolation_m @pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection']) -def test_grid_sampling_op_padding_mode_backward_is_adjoint(padding_mode): +def test_grid_sampling_op_padding_mode_backward_is_adjoint( + padding_mode: Literal['zeros', 'border', 'reflection'], +) -> None: """Test for different padding_modes.""" _test_grid_sampling_op_x_backward(padding_mode=padding_mode) @pytest.mark.parametrize('align_corners', ['no_align', 'align']) -def test_grid_sampling_op_align_mode_backward_is_adjoint(align_corners): +def test_grid_sampling_op_align_mode_backward_is_adjoint(align_corners: Literal['no_align', 'align']) -> None: """Test for different align modes .""" _test_grid_sampling_op_x_backward(align_corners=align_corners == 'align') -def _test_grid_sampling_op_x_backward(dim=3, interpolation_mode='bilinear', padding_mode='zeros', align_corners=False): +def _test_grid_sampling_op_x_backward( + dim: int = 3, + interpolation_mode: Literal['bilinear', 'nearest', 'bicubic'] = 'bilinear', + padding_mode: Literal['zeros', 'border', 'reflection'] = 'zeros', + align_corners: bool = False, +) -> None: """Used in the tests above.""" rng = RandomGenerator(0).float32_tensor batch = (2, 3) @@ -127,7 +137,7 @@ def _test_grid_sampling_op_x_backward(dim=3, interpolation_mode='bilinear', padd torch.testing.assert_close(v.grad, forward_u) -def test_grid_sampling_op_gradcheck_x_forward(): +def test_grid_sampling_op_gradcheck_x_forward() -> None: """Gradient check for forward wrt x.""" rng = RandomGenerator(0).float64_tensor grid = rng((2, 1, 2, 2), -0.8, 0.8) @@ -141,7 +151,7 @@ def test_grid_sampling_op_gradcheck_x_forward(): ) -def test_grid_sampling_op_gradcheck_grid_forward(): +def test_grid_sampling_op_gradcheck_grid_forward() -> None: """Gradient check for forward wrt grid.""" rng = RandomGenerator(0).float64_tensor grid = rng((2, 1, 2, 2), -0.8, 0.8).requires_grad_(True) @@ -155,7 +165,7 @@ def test_grid_sampling_op_gradcheck_grid_forward(): ) -def test_grid_sampling_op_gradcheck_x_adjoint(): +def test_grid_sampling_op_gradcheck_x_adjoint() -> None: """Gradient check for adjoint wrt x.""" rng = RandomGenerator(0).float64_tensor grid = rng((2, 1, 2, 2), -0.8, 0.8) @@ -169,7 +179,7 @@ def test_grid_sampling_op_gradcheck_x_adjoint(): ) -def test_grid_sampling_op_gradcheck_grid_adjoint(): +def test_grid_sampling_op_gradcheck_grid_adjoint() -> None: """Gradient check for adjoint wrt grid.""" rng = RandomGenerator(0).float64_tensor grid = rng((2, 1, 2, 2), -0.8, 0.8).requires_grad_(True) @@ -183,7 +193,7 @@ def test_grid_sampling_op_gradcheck_grid_adjoint(): ) -def test_grid_sampling_op_errormsg_gridshape_3d(): +def test_grid_sampling_op_errormsg_gridshape_3d() -> None: """Test if error message on mismatch of grid shape is raised.""" with pytest.raises(ValueError, match='should have the same shape'): _ = GridSamplingOp( @@ -193,13 +203,13 @@ def test_grid_sampling_op_errormsg_gridshape_3d(): ) -def test_grid_sampling_op_errormsg_gridshape_2d(): +def test_grid_sampling_op_errormsg_gridshape_2d() -> None: """Test if error message on mismatch of grid shape is raised.""" with pytest.raises(ValueError, match='should have the same shape'): _ = GridSamplingOp(grid_z=None, grid_y=torch.ones(1, 3, 1), grid_x=torch.ones(1, 1, 1)) -def test_grid_sampling_op_errormsg_gridndims_3d(): +def test_grid_sampling_op_errormsg_gridndims_3d() -> None: """Test if error message on missing batch dim is raised.""" with pytest.raises(ValueError, match='batch z y x'): _ = GridSamplingOp( @@ -209,13 +219,13 @@ def test_grid_sampling_op_errormsg_gridndims_3d(): ) -def test_grid_sampling_op_errormsg_gridndims_2d(): +def test_grid_sampling_op_errormsg_gridndims_2d() -> None: """Test if error message on missing batch dim is raised.""" with pytest.raises(ValueError, match='batch y x'): _ = GridSamplingOp(grid_z=None, grid_y=torch.ones(1, 1), grid_x=torch.ones(1, 1)) -def test_grid_sampling_op_errormsg_cubic3d(): +def test_grid_sampling_op_errormsg_cubic3d() -> None: """Test if error for 3D cubic is raised.""" grid = torch.ones(1, 1, 1, 1, 3) # 3d with pytest.raises(NotImplementedError, match='cubic'): @@ -228,7 +238,7 @@ def test_grid_sampling_op_errormsg_cubic3d(): ) -def test_grid_sampling_op_errormsg_complexgrid(): +def test_grid_sampling_op_errormsg_complexgrid() -> None: """Test if error for complex grid is raised.""" grid = torch.ones(1, 1, 1, 1, 3) + 0j with pytest.raises(ValueError, match='real'): @@ -241,7 +251,7 @@ def test_grid_sampling_op_errormsg_complexgrid(): ('value', 'error_message'), [(1.0001, 'values outside range'), (-1.0001, 'values outside range'), (1.0, None), (-1.0, None)], ) -def test_grid_sampling_op_warning_gridrange(value, error_message): +def test_grid_sampling_op_warning_gridrange(value: float, error_message: str | None) -> None: """Test if warning for grid values outside [-1,1] is raised""" grid = torch.zeros(1, 1, 1, 1, 3) grid[..., 1] = value @@ -254,7 +264,7 @@ def test_grid_sampling_op_warning_gridrange(value, error_message): ) -def test_grid_sampling_op_errormsg_inputdim_3d(): +def test_grid_sampling_op_errormsg_inputdim_3d() -> None: """Test if error for wrong input dimensions is raised.""" grid = torch.ones(1, 1, 1, 1, 3) input_shape = SpatialDimension(2, 3, 4) @@ -264,7 +274,7 @@ def test_grid_sampling_op_errormsg_inputdim_3d(): _ = operator(u) -def test_grid_sampling_op_warningmsg_inputshape_3d(): +def test_grid_sampling_op_warningmsg_inputshape_3d() -> None: """Test if warning for wrong input_shape is raised in forward""" grid = torch.ones(1, 1, 1, 1, 3) input_shape = SpatialDimension(2, 3, 4) @@ -274,7 +284,7 @@ def test_grid_sampling_op_warningmsg_inputshape_3d(): _ = operator(u) -def test_grid_sampling_op_errormsg_inputdim_2d(): +def test_grid_sampling_op_errormsg_inputdim_2d() -> None: """Test if error for wrong input dimensions is raised.""" grid = torch.ones(1, 1, 1, 2) input_shape = SpatialDimension(2, 3, 4) @@ -284,7 +294,7 @@ def test_grid_sampling_op_errormsg_inputdim_2d(): _ = operator(u) -def test_grid_sampling_op_warningmsg_inputshape_2d(): +def test_grid_sampling_op_warningmsg_inputshape_2d() -> None: """Test if warning for wrong input_shape is raised in forward""" grid = torch.ones(1, 1, 1, 2) input_shape = SpatialDimension(2, 3, 4) @@ -294,7 +304,7 @@ def test_grid_sampling_op_warningmsg_inputshape_2d(): _ = operator(u) -def test_grid_sampling_op_errormsg_inputdim_z_2d(): +def test_grid_sampling_op_errormsg_inputdim_z_2d() -> None: """Test if no error for wrong input dimensions is raised if only z is wrong for 2d.""" grid = torch.ones(1, 1, 1, 2) input_shape = SpatialDimension(2, 3, 4) @@ -313,7 +323,12 @@ def test_grid_sampling_op_errormsg_inputdim_z_2d(): ((7, 1, 2), (2,), (4,), 'not broadcastable'), ], ) -def test_grid_sampling_op_batchdims(grid_batch, u_batch, channel, expected_output): +def test_grid_sampling_op_batchdims( + grid_batch: tuple[int, ...], + u_batch: tuple[int, ...], + channel: tuple[int, ...], + expected_output: tuple[int, ...] | str, +) -> None: """Test if error for wrong input dimensions is raised.""" grid = torch.ones(*grid_batch, 7, 8, 9, 3) # 3d input_shape = SpatialDimension(2, 3, 4) @@ -330,7 +345,7 @@ def test_grid_sampling_op_batchdims(grid_batch, u_batch, channel, expected_outpu # MRpro uses (z,y,x)-convention for grid sampling # PyTorch uses (x,y,z)-convention for grid sampling @pytest.mark.parametrize(('dim', 'grid_sample_dim'), [(-1, -3), (-2, -2), (-3, -1)]) -def test_grid_sampling_op_orientation(dim, grid_sample_dim): +def test_grid_sampling_op_orientation(dim: int, grid_sample_dim: int) -> None: """Test orientation of transformation.""" phantom = torch.zeros(1, 1, 20, 30, 40) phantom[..., 5:15, 10:20, 10:30] = 1 @@ -352,7 +367,7 @@ def test_grid_sampling_op_orientation(dim, grid_sample_dim): torch.testing.assert_close(phantom_shifted, operator(phantom)[0]) -def test_grid_sampling_op_from_displacement_3d(): +def test_grid_sampling_op_from_displacement_3d() -> None: """Test transformation created from displacement.""" phantom = torch.zeros(3, 4, 20, 30, 40) phantom[..., 6:10, 10:20, 10:30] = 1 @@ -377,7 +392,7 @@ def test_grid_sampling_op_from_displacement_3d(): torch.testing.assert_close(phantom_shifted, operator(phantom)[0]) -def test_grid_sampling_op_from_displacement_2d(): +def test_grid_sampling_op_from_displacement_2d() -> None: """Test transformation created from displacement.""" phantom = torch.zeros(3, 4, 20, 30, 40) phantom[..., 6:10, 10:20, 10:30] = 1 @@ -399,3 +414,36 @@ def test_grid_sampling_op_from_displacement_2d(): ) torch.testing.assert_close(phantom_shifted, operator(phantom)[0]) + + +@pytest.mark.cuda +@pytest.mark.parametrize('dim', [3, 2]) +def test_grid_sampling_op_from_displacement_cuda(dim: int) -> None: + """Test operator grid on cuda if the input displacement on cuda.""" + batch, coil = (2, 3), 3 + if dim == 3: + zyx = (2, 4, 8) + displacement_cuda: Any = torch.zeros(dim, *batch, *zyx, device='cuda').unbind(0) + displacement_cpu: Any = torch.zeros(dim, *batch, *zyx, device='cpu').unbind(0) + elif dim == 2: + zyx = (1, 4, 8) + displacement_cuda = (None, *torch.zeros(dim, *batch, *zyx, device='cuda').unbind(0)) + displacement_cpu = (None, *torch.zeros(dim, *batch, *zyx, device='cpu').unbind(0)) + + image = torch.ones(*batch, coil, *zyx) + + operator_cuda = GridSamplingOp.from_displacement(*displacement_cuda) + (result,) = operator_cuda(image.cuda()) + assert result.is_cuda + + operator_cpu = operator_cuda.cpu() + (result,) = operator_cpu(image) + assert result.is_cpu + + operator_cpu = GridSamplingOp.from_displacement(*displacement_cpu) + (result,) = operator_cpu(image) + assert result.is_cpu + + operator_cuda = operator_cpu.cuda() + (result,) = operator_cuda(image.cuda()) + assert result.is_cuda diff --git a/tests/operators/test_patch_op.py b/tests/operators/test_patch_op.py index 46b856178..49395f10c 100644 --- a/tests/operators/test_patch_op.py +++ b/tests/operators/test_patch_op.py @@ -1,6 +1,7 @@ """Tests for Rearrange Operator.""" from collections.abc import Sequence +from typing import Any import pytest import torch @@ -14,13 +15,15 @@ [ ((3, 4, 5), {'dim': (0, 1), 'patch_size': (1, 3), 'stride': (3, 1), 'dilation': (2, 1)}, (2, 1, 3, 5)), ((1, 20), {'dim': -1, 'patch_size': 3, 'stride': 3, 'dilation': 5}, (4, 1, 3)), + ((9, 16), {'dim': (-1, 0), 'patch_size': (2, 3), 'stride': (2, 3), 'dilation': 1}, (24, 3, 2)), + ((9, 16), {'dim': (-1, 0), 'patch_size': (2, 3), 'stride': None, 'flatten_patches': False}, (8, 3, 3, 2)), ], ) @TESTCASES def test_patch_op_adjointness( - input_shape: Sequence[int], arguments: dict[str, int | Sequence[int]], output_shape: Sequence[int] + input_shape: Sequence[int], arguments: dict[str, Any], output_shape: Sequence[int] ) -> None: """Test adjointness and shape of Rearrange Op.""" rng = RandomGenerator(seed=0) @@ -36,9 +39,7 @@ def test_patch_op_adjointness( @TESTCASES -def test_patch_op_autodiff( - input_shape: Sequence[int], arguments: dict[str, int | Sequence[int]], output_shape: Sequence[int] -) -> None: +def test_patch_op_autodiff(input_shape: Sequence[int], arguments: dict[str, Any], output_shape: Sequence[int]) -> None: """Test autodiff works for PatchOp.""" rng = RandomGenerator(seed=0) u = rng.complex64_tensor(size=input_shape) From 3ed29c9ce9362ce7c5143b8b2f885f3829d478cc Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:49:26 +0100 Subject: [PATCH 203/205] update --- examples/notebooks/apply_pinqi.ipynb | 30 ++-- examples/notebooks/modl.ipynb | 8 +- examples/notebooks/train_pinqi.ipynb | 242 ++++++++++----------------- examples/scripts/train_pinqi.py | 24 ++- 4 files changed, 113 insertions(+), 191 deletions(-) diff --git a/examples/notebooks/apply_pinqi.ipynb b/examples/notebooks/apply_pinqi.ipynb index 0e8ed1b52..9d1afb4d5 100644 --- a/examples/notebooks/apply_pinqi.ipynb +++ b/examples/notebooks/apply_pinqi.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "82f66a37", + "id": "0", "metadata": { "lines_to_next_cell": 0 }, @@ -13,7 +13,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19a028c0", + "id": "1", "metadata": { "tags": [ "remove-cell" @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "08651577", + "id": "2", "metadata": {}, "source": [ "# End-to-end physics informed network for quantitative MRI (PINQI)\n", @@ -58,7 +58,7 @@ }, { "cell_type": "markdown", - "id": "275fcb68", + "id": "3", "metadata": {}, "source": [ "## Dataset\n", @@ -69,7 +69,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d60b908c", + "id": "4", "metadata": { "lines_to_next_cell": 2 }, @@ -90,7 +90,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2927cf00", + "id": "5", "metadata": { "lines_to_next_cell": 2 }, @@ -193,7 +193,7 @@ }, { "cell_type": "markdown", - "id": "8b611057", + "id": "6", "metadata": { "lines_to_next_cell": 2 }, @@ -205,7 +205,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2777b221", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -368,7 +368,7 @@ { "cell_type": "code", "execution_count": null, - "id": "08494939", + "id": "8", "metadata": { "lines_to_next_cell": 0 }, @@ -382,7 +382,7 @@ { "cell_type": "code", "execution_count": null, - "id": "249b9f7f", + "id": "9", "metadata": {}, "outputs": [], "source": [ @@ -412,7 +412,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0e88d174", + "id": "10", "metadata": { "lines_to_next_cell": 0 }, @@ -450,7 +450,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e804c074", + "id": "11", "metadata": { "lines_to_next_cell": 0 }, @@ -480,7 +480,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7ae2c56b", + "id": "12", "metadata": { "lines_to_next_cell": 0 }, @@ -500,7 +500,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d1f460d1", + "id": "13", "metadata": { "lines_to_next_cell": 2 }, @@ -518,7 +518,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8e89d5e0", + "id": "14", "metadata": {}, "outputs": [], "source": [ diff --git a/examples/notebooks/modl.ipynb b/examples/notebooks/modl.ipynb index e385e03a2..54743423d 100644 --- a/examples/notebooks/modl.ipynb +++ b/examples/notebooks/modl.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "ca8663d2", + "id": "0", "metadata": { "lines_to_next_cell": 0 }, @@ -13,7 +13,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bc3bb31f", + "id": "1", "metadata": { "tags": [ "remove-cell" @@ -30,7 +30,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9ed7a66e", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -240,7 +240,7 @@ { "cell_type": "code", "execution_count": null, - "id": "601b0ff9", + "id": "3", "metadata": {}, "outputs": [], "source": [] diff --git a/examples/notebooks/train_pinqi.ipynb b/examples/notebooks/train_pinqi.ipynb index 9a012816c..3261d7550 100644 --- a/examples/notebooks/train_pinqi.ipynb +++ b/examples/notebooks/train_pinqi.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "a79be4b8", + "id": "0", "metadata": { "lines_to_next_cell": 0 }, @@ -13,7 +13,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f9a50a13", + "id": "1", "metadata": { "tags": [ "remove-cell" @@ -27,35 +27,19 @@ " %pip install mrpro[notebooks]" ] }, - { - "cell_type": "markdown", - "id": "7d5f4c31", - "metadata": {}, - "source": [ - "ruff: noqa: D102, ANN201" - ] - }, { "cell_type": "code", "execution_count": null, - "id": "5fced8aa", + "id": "2", "metadata": {}, "outputs": [], "source": [ - "import collections\n", + "# ruff: noqa: D102, ANN201\n", "from collections.abc import Sequence\n", "from copy import deepcopy\n", "from pathlib import Path\n", - "from typing import Any, Literal, TypedDict, cast" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ec1c97c4", - "metadata": {}, - "outputs": [], - "source": [ + "from typing import Any, Literal, TypedDict\n", + "\n", "import einops\n", "import matplotlib.pyplot as plt\n", "import mrpro\n", @@ -65,26 +49,8 @@ "import torch.utils.data._utils\n", "from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint # type:ignore[import-not-found]\n", "from pytorch_lightning.loggers import NeptuneLogger # type:ignore[import-not-found]\n", - "from pytorch_lightning.strategies import DDPStrategy # type:ignore[import-not-found]" - ] - }, - { - "cell_type": "markdown", - "id": "06c40aff", - "metadata": { - "lines_to_next_cell": 2 - }, - "source": [ - "mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "61c14780", - "metadata": {}, - "outputs": [], - "source": [ + "\n", + "\n", "class BatchType(TypedDict):\n", " \"\"\"Typehint for a batch of data.\"\"\"\n", "\n", @@ -92,16 +58,9 @@ " csm: mrpro.data.CsmData\n", " m0: torch.Tensor\n", " t1: torch.Tensor\n", - " mask: torch.Tensor" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fbeed14e", - "metadata": {}, - "outputs": [], - "source": [ + " mask: torch.Tensor\n", + "\n", + "\n", "class Dataset(torch.utils.data.Dataset):\n", " \"\"\"A brainweb based cartesian qMRI dataset.\"\"\"\n", "\n", @@ -160,7 +119,7 @@ " seed=seed,\n", " acceleration=self.acceleration,\n", " fwhm_ratio=1.5,\n", - " n_center=10,\n", + " n_center=12,\n", " n_other=(self._n_images,),\n", " )\n", " header = mrpro.data.KHeader(\n", @@ -176,42 +135,29 @@ " header.ti = self.signalmodel.ti.tolist()\n", "\n", " fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj)\n", - " csm = mrpro.data.CsmData(\n", - " mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix),\n", - " header,\n", - " )\n", + " if self.n_coils > 1:\n", + " csm_tensor = mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix)\n", + " else:\n", + " csm_tensor = torch.ones(1, 1, *self.encoding_matrix.zyx)\n", + " csm = mrpro.data.CsmData(csm_tensor, header)\n", " images = einops.rearrange(images, 't y x -> t 1 1 y x')\n", " (data,) = (fourier_op @ csm.as_operator())(images)\n", " data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std()\n", " kdata = mrpro.data.KData(header, data, traj)\n", - " return {'kdata': kdata, 'csm': csm, **phantom}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ed276d85", - "metadata": {}, - "outputs": [], - "source": [ + " return {'kdata': kdata, 'csm': csm, **phantom}\n", + "\n", + "\n", "def collate_fn(batch: Any): # noqa: ANN401\n", " \"\"\"Join dataclasses to a batch.\"\"\"\n", " return torch.utils.data._utils.collate.collate(\n", " batch,\n", " collate_fn_map={\n", - " mrpro.data.Dataclass: lambda batch, *, _collate_fn_map: batch[0].stack(*batch[1:]),\n", + " mrpro.data.Dataclass: lambda batch, *, collate_fn_map: batch[0].stack(*batch[1:]), # noqa: ARG005\n", " **torch.utils.data._utils.collate.default_collate_fn_map,\n", " },\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "446bad6e", - "metadata": {}, - "outputs": [], - "source": [ + " )\n", + "\n", + "\n", "class PINQI(torch.nn.Module):\n", " \"\"\"PINQI model.\"\"\"\n", "\n", @@ -346,16 +292,9 @@ " parameters.append(self.nonlinear_solver(lambda_parameter, images[-1], *parameters_reg))\n", " if self.constraints_op is not None:\n", " parameters = [self.constraints_op(*p) for p in parameters]\n", - " return images, parameters" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "831f3559", - "metadata": {}, - "outputs": [], - "source": [ + " return images, parameters\n", + "\n", + "\n", "class DataModule(pl.LightningDataModule):\n", " \"\"\"Data module for training the PINQI model.\"\"\"\n", "\n", @@ -423,22 +362,15 @@ " def val_dataloader(self):\n", " return torch.utils.data.DataLoader(\n", " self.val_dataset,\n", - " batch_size=1,\n", + " batch_size=4,\n", " shuffle=False,\n", " num_workers=self.num_workers,\n", " pin_memory=False,\n", " persistent_workers=self.num_workers > 0,\n", " collate_fn=collate_fn,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "71c8de37", - "metadata": {}, - "outputs": [], - "source": [ + " )\n", + "\n", + "\n", "class PinqiModule(pl.LightningModule):\n", " \"\"\"Module for training the PINQI model.\"\"\"\n", "\n", @@ -472,7 +404,7 @@ " n_features_image_net=n_features_image_net,\n", " )\n", "\n", - " self.validation_step_outputs: dict[str, list] = collections.defaultdict(list)\n", + " self.validation_step_outputs: dict[str, list] = {}\n", " self.baseline = Baseline(signalmodel, constraints_op, parameter_is_complex)\n", "\n", " def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData):\n", @@ -518,7 +450,7 @@ " loss = self.loss(parameters, batch)\n", "\n", " pred_m0, pred_t1 = parameters[-1]\n", - " target_t1, target_m0 = batch['t1'], batch['m0']\n", + " target_t1, target_m0 = batch['t1'][:, None, None], batch['m0'][:, None, None]\n", " mask = batch['mask']\n", " batch_size = len(batch['mask'])\n", " (ssim_t1,) = mrpro.operators.functionals.SSIM(target_t1, mask)(pred_t1)\n", @@ -529,31 +461,27 @@ " self.log('val/l1_m0', l1_m0, on_epoch=True, sync_dist=True, batch_size=batch_size)\n", " self.log('val/loss', loss, on_epoch=True, sync_dist=True, batch_size=batch_size)\n", "\n", - " if batch_idx == 0:\n", - " self.validation_step_outputs['target_t1'].append(batch['t1'])\n", - " self.validation_step_outputs['pred_t1'].append(pred_t1)\n", - " self.validation_step_outputs['pred_m0'].append(pred_m0)\n", - " self.validation_step_outputs['target_m0'].append(target_m0)\n", - " self.validation_step_outputs['mask'].append(batch['mask'])\n", + " if batch_idx == 0 and self.trainer.is_global_zero:\n", + " self.validation_step_outputs['target_t1'] = batch['t1'].cpu()\n", + " self.validation_step_outputs['pred_t1'] = pred_t1.cpu()\n", + " self.validation_step_outputs['pred_m0'] = pred_m0.cpu()\n", + " self.validation_step_outputs['target_m0'] = target_m0.cpu()\n", + " self.validation_step_outputs['mask'] = batch['mask'].cpu()\n", " baseline_m0, baseline_t1 = self.baseline(batch['kdata'], batch['csm'])\n", - " self.validation_step_outputs['baseline_t1'].append(baseline_t1)\n", - " self.validation_step_outputs['baseline_m0'].append(baseline_m0)\n", + " self.validation_step_outputs['baseline_t1'] = baseline_t1.cpu()\n", + " self.validation_step_outputs['baseline_m0'] = baseline_m0.cpu()\n", "\n", " def on_validation_epoch_end(self):\n", " \"\"\"Validate.\n", "\n", " Needs to be adapted for other signal models than Saturation Recovery.\n", " \"\"\"\n", - " outputs = {k: torch.cat(v) for k, v in self.validation_step_outputs.items()}\n", - " self.validation_step_outputs.clear()\n", - " outputs = cast(dict[str, torch.Tensor], self.all_gather(outputs))\n", - "\n", " if not self.trainer.is_global_zero:\n", " return\n", - " outputs = {k: v.flatten(0, 1).cpu() for k, v in outputs.items()}\n", + " outputs = self.validation_step_outputs\n", "\n", " samples = len(outputs['mask'])\n", - " fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 16))\n", + " fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 16), squeeze=False)\n", "\n", " for i in range(samples):\n", " self.result_plot(\n", @@ -581,6 +509,7 @@ " fig.suptitle(f'$|M_0|$ Epoch {self.current_epoch}')\n", " self.logger.run['val/images/m0'].log(fig)\n", " plt.close(fig)\n", + " self.validation_step_outputs.clear()\n", "\n", " def result_plot(\n", " self,\n", @@ -596,7 +525,6 @@ " pred = pred.squeeze().detach().cpu()\n", " mask = mask.squeeze().detach().bool().cpu()\n", " baseline = baseline.squeeze().detach().cpu()\n", - "\n", " target[~mask] = torch.nan\n", " pred[~mask] = torch.nan\n", " baseline[~mask] = torch.nan\n", @@ -662,16 +590,9 @@ " return {\n", " 'optimizer': optimizer,\n", " 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'},\n", - " }" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b17c48e1", - "metadata": {}, - "outputs": [], - "source": [ + " }\n", + "\n", + "\n", "class Baseline(torch.nn.Module):\n", " \"\"\"Baseline solution using SENSE + Regression.\"\"\"\n", "\n", @@ -689,7 +610,9 @@ "\n", " def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> tuple[torch.Tensor, ...]:\n", " \"\"\"Compute the baseline solution.\"\"\"\n", - " sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(kdata, csm=csm)\n", + " sense = mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction(\n", + " kdata, csm=csm, regularization_weight=0.01, n_iterations=3\n", + " )\n", " images = sense(kdata).rearrange('batch time ...-> time batch ...')\n", "\n", " objective = mrpro.operators.functionals.L2NormSquared(images.data) @ self.signalmodel @ self.constraints_op\n", @@ -702,16 +625,9 @@ " for is_complex in self.parameter_is_complex\n", " )\n", " solution = self.constraints_op(*mrpro.algorithms.optimizers.lbfgs(objective, initial_values))\n", - " return solution" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "198ef2c0", - "metadata": {}, - "outputs": [], - "source": [ + " return solution\n", + "\n", + "\n", "class LogLambdasCallback(pl.Callback):\n", " \"\"\"Log the lambdas.\"\"\"\n", "\n", @@ -734,28 +650,30 @@ " },\n", " on_step=True,\n", " on_epoch=False,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "632cc485", - "metadata": {}, - "outputs": [], - "source": [ + " )\n", + "\n", + "\n", "if __name__ == '__main__':\n", + " import os\n", + "\n", + " os.environ['NEPTUNE_API_TOKEN'] = (\n", + " 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIyOTdlYTM3NS0wMWU1LTRlMzMtYWU1Ny01MzMzN2ExNTcwMDcifQ=='\n", + " )\n", + " os.environ['NEPTUNE_PROJECT'] = 'ptb/pinqi'\n", " torch.multiprocessing.set_sharing_strategy('file_system')\n", " torch.set_float32_matmul_precision('high')\n", " torch._inductor.config.compile_threads = 4\n", " torch._inductor.config.worker_start_method = 'fork'\n", " torch._dynamo.config.capture_scalar_outputs = True\n", " torch._dynamo.config.cache_size_limit = 256\n", - " torch._functorch.config.activation_memory_budget = 0.95\n", + " torch._functorch.config.activation_memory_budget = 0.5\n", "\n", - " data_folder = Path('/scratch/zimmer08/brainweb')\n", + " data_folder = Path(' /echo/zimmer08/brainweb')\n", + " if not data_folder.exists():\n", + " data_folder.mkdir(parents=True, exist_ok=True)\n", + " mrpro.phantoms.brainweb.download_brainweb(output_directory=data_folder, workers=2, progress=True)\n", "\n", - " signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 8.0))\n", + " signalmodel = mrpro.operators.models.SaturationRecovery((0.2, 0.8, 4.0))\n", " constraints_op = mrpro.operators.ConstraintsOp(\n", " bounds=(\n", " (-2, 2), # M0 in [-2, 2]\n", @@ -769,12 +687,12 @@ " folder=data_folder,\n", " signalmodel=signalmodel,\n", " n_images=n_images,\n", - " batch_size=16,\n", - " num_workers=16,\n", + " batch_size=8,\n", + " num_workers=8,\n", " size=192,\n", - " acceleration=8,\n", - " n_coils=8,\n", - " max_noise=0.1,\n", + " acceleration=6,\n", + " n_coils=1,\n", + " max_noise=0.3,\n", " )\n", "\n", " model = PinqiModule(\n", @@ -799,11 +717,11 @@ " save_last=True,\n", " )\n", "\n", - " strategy = DDPStrategy(find_unused_parameters=False)\n", + " strategy = 'auto' # DDPStrategy(find_unused_parameters=False)\n", " trainer = pl.Trainer(\n", " max_epochs=100,\n", " accelerator='gpu',\n", - " devices=4,\n", + " devices=1,\n", " strategy=strategy,\n", " logger=neptune_logger,\n", " callbacks=[\n", @@ -816,8 +734,16 @@ " gradient_clip_val=5.0,\n", " )\n", "\n", - " trainer.fit(model, datamodule=dm)" + " # trainer.fit(model, datamodule=dm)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/examples/scripts/train_pinqi.py b/examples/scripts/train_pinqi.py index d640b8c18..d1c3dd5fd 100644 --- a/examples/scripts/train_pinqi.py +++ b/examples/scripts/train_pinqi.py @@ -84,7 +84,7 @@ def __getitem__(self, index: int): seed=seed, acceleration=self.acceleration, fwhm_ratio=1.5, - n_center=10, + n_center=12, n_other=(self._n_images,), ) header = mrpro.data.KHeader( @@ -575,7 +575,9 @@ def __init__( def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> tuple[torch.Tensor, ...]: """Compute the baseline solution.""" - sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(kdata, csm=csm) + sense = mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction( + kdata, csm=csm, regularization_weight=0.01, n_iterations=3 + ) images = sense(kdata).rearrange('batch time ...-> time batch ...') objective = mrpro.operators.functionals.L2NormSquared(images.data) @ self.signalmodel @ self.constraints_op @@ -617,26 +619,20 @@ def on_train_batch_end( if __name__ == '__main__': - import os - - os.environ['NEPTUNE_API_TOKEN'] = ( - 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIyOTdlYTM3NS0wMWU1LTRlMzMtYWU1Ny01MzMzN2ExNTcwMDcifQ==' - ) - os.environ['NEPTUNE_PROJECT'] = 'ptb/pinqi' torch.multiprocessing.set_sharing_strategy('file_system') torch.set_float32_matmul_precision('high') torch._inductor.config.compile_threads = 4 torch._inductor.config.worker_start_method = 'fork' torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.cache_size_limit = 256 - torch._functorch.config.activation_memory_budget = 0.8 + torch._functorch.config.activation_memory_budget = 0.5 data_folder = Path(' /echo/zimmer08/brainweb') if not data_folder.exists(): data_folder.mkdir(parents=True, exist_ok=True) mrpro.phantoms.brainweb.download_brainweb(output_directory=data_folder, workers=2, progress=True) - signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 8.0)) + signalmodel = mrpro.operators.models.SaturationRecovery((0.2, 0.8, 4.0)) constraints_op = mrpro.operators.ConstraintsOp( bounds=( (-2, 2), # M0 in [-2, 2] @@ -650,10 +646,10 @@ def on_train_batch_end( folder=data_folder, signalmodel=signalmodel, n_images=n_images, - batch_size=4, - num_workers=4, + batch_size=8, + num_workers=8, size=192, - acceleration=8, + acceleration=6, n_coils=1, max_noise=0.3, ) @@ -697,6 +693,6 @@ def on_train_batch_end( gradient_clip_val=5.0, ) - trainer.fit(model, datamodule=dm) + # trainer.fit(model, datamodule=dm) # %% From 6431fc47adc042fbbdb2297bc50a85def2838991 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:51:09 +0100 Subject: [PATCH 204/205] cleanup --- src/mrpro/nn/GluMBConvResBlock.py | 118 ------------ src/mrpro/nn/nets/DCVAE.py | 302 ------------------------------ tests/nn/nets/test_dcvae.py | 82 -------- 3 files changed, 502 deletions(-) delete mode 100644 src/mrpro/nn/GluMBConvResBlock.py delete mode 100644 src/mrpro/nn/nets/DCVAE.py delete mode 100644 tests/nn/nets/test_dcvae.py diff --git a/src/mrpro/nn/GluMBConvResBlock.py b/src/mrpro/nn/GluMBConvResBlock.py deleted file mode 100644 index c17bc019f..000000000 --- a/src/mrpro/nn/GluMBConvResBlock.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Gateded MBConv Residual Block.""" - -import torch -from torch.nn import Identity, Module, Sequential, SiLU - -from mrpro.nn.CondMixin import CondMixin -from mrpro.nn.FiLM import FiLM -from mrpro.nn.ndmodules import convND -from mrpro.nn.RMSNorm import RMSNorm - - -class GluMBConvResBlock(CondMixin, Module): - """Gated MBConv residual block. - - Gated variant [DCAE]_ of the MBConv block [EffNet]_ with a residual connection and (optional) conditioning. - - References - ---------- - .. [DCAE] Chen et al. Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models. ICLR 2025 - https://arxiv.org/abs/2410.10733 - .. [EffNet] Tan et al. EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. ICML 2019 - https://arxiv.org/abs/1905.11946 - """ - - def __init__( - self, - n_dim: int, - n_channels_in: int, - n_channels_out: int, - expand_ratio: int = 6, - stride: int = 1, - kernel_size: int = 3, - cond_dim: int = 0, - ): - """Initialize MBConv block. - - Parameters - ---------- - n_dim - Number of spatial dimensions. - n_channels_in - Number of input channels. - n_channels_out - Number of output channels. - expand_ratio - Expansion ratio inside the block. - stride - Stride of the depthwise convolution. - kernel_size - Kernel size of the depthwise convolution. - cond_dim - Dimension of the conditioning tensor used in a FiLM. If 0, no FiLM is used. - """ - super().__init__() - channels_mid = n_channels_in * expand_ratio - if stride == 1 and n_channels_in == n_channels_out: - self.skip: Module = Identity() - else: - self.skip = convND(n_dim)(n_channels_in, n_channels_out, kernel_size=1, stride=stride) - self.inverted_conv = Sequential( - convND(n_dim)( - n_channels_in, - channels_mid * 2, - kernel_size=1, - ), - SiLU(), - ) - self.depth_conv = Sequential( - convND(n_dim)( - channels_mid * 2, - channels_mid * 2, - kernel_size=kernel_size, - stride=stride, - padding='same', - groups=channels_mid * 2, - ), - SiLU(), - ) - self.point_conv = Sequential( - convND(n_dim)( - channels_mid, - n_channels_out, - kernel_size=1, - ), - RMSNorm(n_channels_out), - SiLU(), - ) - if cond_dim > 0: - self.film: FiLM | None = FiLM(channels_mid, cond_dim) - else: - self.film = None - - def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: - """Apply MBConv block. - - Parameters - ---------- - x - Input tensor. - cond - Conditioning tensor. If `None`, no conditioning is applied. - - Returns - ------- - Output tensor. - """ - return super().__call__(x, cond=cond) - - def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: - """Apply MBConv block.""" - h = self.inverted_conv(x) - h = self.depth_conv(h) - h, gate = torch.chunk(h, 2, dim=1) - h = h * torch.nn.functional.silu(gate) - if self.film is not None: - h = self.film(h, cond=cond) - h = self.point_conv(h) - return self.skip(x) + h diff --git a/src/mrpro/nn/nets/DCVAE.py b/src/mrpro/nn/nets/DCVAE.py deleted file mode 100644 index 5faba554b..000000000 --- a/src/mrpro/nn/nets/DCVAE.py +++ /dev/null @@ -1,302 +0,0 @@ -"""Deep Compression Autoencoder.""" - -from collections.abc import Sequence -from typing import Literal - -import torch -from torch.nn import Module, ReLU, SiLU - -from mrpro.nn.attention.LinearSelfAttention import LinearSelfAttention -from mrpro.nn.attention.MultiHeadAttention import MultiHeadAttention -from mrpro.nn.GluMBConvResBlock import GluMBConvResBlock -from mrpro.nn.ndmodules import convND -from mrpro.nn.nets.VAE import VAE -from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample -from mrpro.nn.Residual import Residual -from mrpro.nn.RMSNorm import RMSNorm -from mrpro.nn.Sequential import Sequential - - -class CNNBlock(Residual): - """Block with two convolutions and normalization. - - As used in the DCAE [DCAE]_. - - References - ---------- - .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder - for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 - """ - - def __init__( - self, - n_dim: int, - n_channels: int, - ): - """Initialize the CNNBlock. - - Parameters - ---------- - n_dim - The number of spatial dimensions of the input tensor. - n_channels - The number of channels in the input tensor. - """ - super().__init__( - Sequential( - convND(n_dim)(n_channels, n_channels, kernel_size=3, padding=1), - SiLU(True), - convND(n_dim)(n_channels, n_channels, kernel_size=3, padding=1, bias=False), - RMSNorm(n_channels), - ) - ) - - -class EfficientViTBlock(Module): - """Efficient Vision Transformer block with optional linear attention. - - As used in the DCAE [DCAE]_. - - References - ---------- - .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder - for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 - """ - - def __init__( - self, - n_dim: int, - n_channels: int, - n_heads: int, - expand_ratio: int = 4, - linear_attn: bool = False, - ): - """Initialize the EfficientViTBlock. - - Parameters - ---------- - n_dim - The number of spatial dimensions of the input tensor. - n_channels - The number of channels in the input tensor. - n_heads - The number of attention heads. - expand_ratio - The expansion ratio of the GluMBConvResBlock. - linear_attn - Whether to use linear attention instead of softmax attention with quadratic complexity. - """ - super().__init__() - if linear_attn: - attention: Module = LinearSelfAttention(n_channels, n_channels, n_heads) - else: - attention = MultiHeadAttention(n_channels, n_channels, n_heads, features_last=False) - self.context_module = Residual(Sequential(attention, RMSNorm(n_channels))) - self.local_module = GluMBConvResBlock( - n_dim=n_dim, - n_channels_in=n_channels, - n_channels_out=n_channels, - expand_ratio=expand_ratio, - ) - - def __call__(self, x: torch.Tensor) -> torch.Tensor: - """Apply the EfficientViTBlock. - - Parameters - ---------- - x - Input tensor - - Returns - ------- - Output tensor - """ - return super().__call__(x) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass for EfficientViTBlock.""" - x = self.context_module(x) - x = self.local_module(x) - return x - - -class Encoder(Sequential): - """Encoder for DCAE. - - As used in the DC-Autoencoder [DCAE]_. - - References - ---------- - .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder - for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 - """ - - def __init__( - self, - n_dim: int = 2, - n_channels_in: int = 3, - n_channels_out: int = 32, - block_types: Sequence[Literal['CNN', 'LinearViT', 'ViT']] = ('CNN', 'CNN', 'LinearViT', 'LinearViT', 'ViT'), - widths: Sequence[int] = (256, 512, 512, 1024, 1024), - depths: Sequence[int] = (4, 6, 2, 2, 2), - ): - """Initialize the Encoder. - - The length of the `block_types`, `widths`, and `depths` must be the same and determine - the number of stages in the encoder. Between the stages, downsampling is performed. - - Parameters - ---------- - n_dim - The number of spatial dimensions of the input tensor. - n_channels_in - The number of channels in the input tensor, i.e. the latent space - n_channels_out - The number of channels in the output tensor, i.e. the original space - block_types - The types of blocks to use in the decoder. - widths - The widths of the blocks in the decoder, i.e. the number of channels in the blocks - depths - The depths of the blocks in the decoder, i.e. the number blocks in the stage - """ - super().__init__() - self.append(PixelUnshuffleDownsample(n_dim, n_channels_in, widths[0], downscale_factor=2, residual=False)) - if len(block_types) != len(widths) or len(block_types) != len(depths): - raise ValueError('block_types, widths, and depths must have the same length') - for block_type, width, next_width, depth in zip(block_types, widths, (*widths[1:], None), depths, strict=False): - match block_type: - case 'CNN': - stage: list[Module] = [CNNBlock(n_dim, width) for _ in range(depth)] - case 'LinearViT': - stage = [ - EfficientViTBlock(n_dim, width, max(1, width // 32), linear_attn=True) for _ in range(depth) - ] - case 'ViT': - stage = [EfficientViTBlock(n_dim, width, max(1, width // 32)) for _ in range(depth)] - case _: - raise ValueError(f'Block type {block_type} not supported') - self.append(Sequential(*stage)) - if next_width: - self.append(PixelUnshuffleDownsample(n_dim, width, next_width, downscale_factor=2, residual=True)) - self.append( - Sequential( - RMSNorm(widths[-1]), - ReLU(), - PixelUnshuffleDownsample(n_dim, widths[-1], n_channels_out, downscale_factor=1, residual=True), - ) - ) - - -class Decoder(Sequential): - """Decoder for DCAE. - - As used in the DC-Autoencoder [DCAE]_. - - References - ---------- - .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder - for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 - """ - - def __init__( - self, - n_dim: int = 2, - n_channels_in: int = 32, - n_channels_out: int = 3, - block_types: Sequence[Literal['ViT', 'LinearViT', 'CNN']] = ('ViT', 'LinearViT', 'LinearViT', 'CNN', 'CNN'), - widths: Sequence[int] = (1024, 1024, 512, 512, 256), - depths: Sequence[int] = (2, 2, 2, 6, 4), - ): - """Initialize the Decoder. - - The length of the `block_types`, `widths`, and `depths` must be the same and determine - the number of stages in the decoder. Between the stages, upsampling is performed. - - Parameters - ---------- - n_dim - The number of spatial dimensions of the input tensor. - n_channels_in - The number of channels in the input tensor, i.e. the latent space - n_channels_out - The number of channels in the output tensor, i.e. the original space - block_types - The types of blocks to use in the decoder. - widths - The widths of the blocks in the decoder, i.e. the number of channels in the blocks - depths - The depths of the blocks in the decoder, i.e. the number blocks in the stage - """ - super().__init__() - if not (len(block_types) == len(widths) == len(depths)): - raise ValueError('block_types, widths, and depths must have the same length') - self.append(PixelShuffleUpsample(n_dim, n_channels_in, widths[0], upscale_factor=1, residual=True)) - - for block_type, width, next_width, depth in zip(block_types, widths, (*widths[1:], None), depths, strict=False): - match block_type: - case 'CNN': - stage: list[Module] = [CNNBlock(n_dim, width) for _ in range(depth)] - case 'LinearViT': - stage = [ - EfficientViTBlock(n_dim, width, n_heads=max(1, width // 32), linear_attn=True) - for _ in range(depth) - ] - case 'ViT': - stage = [ - EfficientViTBlock(n_dim, width, n_heads=max(1, width // 32), linear_attn=False) - for _ in range(depth) - ] - case _: - raise ValueError(f'Block type {block_type} not supported') - self.append(Sequential(*stage)) - if next_width: - self.append(PixelShuffleUpsample(n_dim, width, next_width, upscale_factor=2, residual=True)) - - self.append( - Sequential( - RMSNorm(widths[-1]), - ReLU(), - PixelShuffleUpsample(n_dim, widths[-1], n_channels_out, upscale_factor=2), - ) - ) - - -class DCVAE(VAE): - """Variational Autoencoder based on DCAE. - - References - ---------- - .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder - for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 - """ - - def __init__( - self, - n_dim: int, - n_channels: int, - latent_dim: int = 32, - block_types: Sequence[Literal['CNN', 'LinearViT', 'ViT']] = ('CNN', 'CNN', 'LinearViT', 'LinearViT', 'ViT'), - widths: Sequence[int] = (256, 512, 512, 1024, 1024), - depths: Sequence[int] = (4, 6, 2, 2, 2), - ): - """Initialize the DCVAE. - - Parameters - ---------- - n_dim - The number of spatial dimensions of the input tensor. - n_channels - The number of channels in the input tensor. - latent_dim - The number of channels in the latent space. - block_types - The types of blocks to use in the encoder and decoder. - widths - The widths of the blocks in the encoder and decoder. - depths - The depths of the blocks in the encoder and decoder. - """ - encoder = Encoder(n_dim, n_channels, latent_dim * 2, block_types, widths, depths) - decoder = Decoder(n_dim, latent_dim, n_channels, block_types[::-1], widths[::-1], depths[::-1]) - super().__init__(encoder, decoder) diff --git a/tests/nn/nets/test_dcvae.py b/tests/nn/nets/test_dcvae.py deleted file mode 100644 index ff5371b7b..000000000 --- a/tests/nn/nets/test_dcvae.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Tests for DCVAE network.""" - -from typing import cast - -import pytest -import torch -from mrpro.nn.nets import DCVAE - - -@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) -@pytest.mark.parametrize( - 'device', - [ - pytest.param('cpu', id='cpu'), - pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), - ], -) -def test_dcvae_forward(torch_compile: bool, device: str) -> None: - """Test the forward pass of the DCVAE.""" - dcvae = DCVAE( - n_dim=2, - n_channels=1, - latent_dim=4, - block_types=('CNN', 'LinearViT', 'ViT'), - widths=(32, 64, 32), - depths=(1, 2, 2), - ) - - x = torch.zeros(1, 1, 16, 16, device=device) - dcvae = dcvae.to(device) - x = x.to(device) - if torch_compile: - dcvae = cast(DCVAE, torch.compile(dcvae)) - y, kl = dcvae(x) - assert y.shape == (1, 1, 16, 16) - assert kl.shape == () - latent = dcvae.encoder(x) - assert latent.shape == (1, 2 * 4, 2, 2) # 2 because of mean and logvar - - -def test_dcvae_backward_kl() -> None: - """Test the backward pass of the DCVAE wrt kl.""" - dcvae = DCVAE( - n_dim=1, - n_channels=1, - latent_dim=4, - block_types=('CNN', 'LinearViT', 'ViT'), - widths=(8, 12, 16), - depths=(2, 2, 3), - ) - - x = torch.zeros(1, 1, 16, requires_grad=True) - - _, kl = dcvae(x) - kl.sum().backward() - assert x.grad is not None, 'x.grad is None' - assert not x.grad.isnan().any(), 'x.grad is NaN' - for name, parameter in dcvae.encoder.named_parameters(): # only the encoder parameters can influence kl - assert parameter.grad is not None, f'{name}.grad is None' - assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' - - -def test_dcvae_backward_y() -> None: - """Test the backward pass of the DCVAE wrt y.""" - dcvae = DCVAE( - n_dim=1, - n_channels=1, - latent_dim=4, - block_types=('CNN', 'LinearViT', 'ViT'), - widths=(8, 12, 16), - depths=(2, 2, 3), - ) - - x = torch.zeros(1, 1, 16, requires_grad=True) - - y, _ = dcvae(x) - y.sum().backward() - assert x.grad is not None, 'x.grad is None' - assert not x.grad.isnan().any(), 'x.grad is NaN' - for name, parameter in dcvae.named_parameters(): - assert parameter.grad is not None, f'{name}.grad is None' - assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' From e55d0629b973d617309e3ebc9edeb212992c2f1e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 13:53:15 +0000 Subject: [PATCH 205/205] [pre-commit] auto fixes from pre-commit hooks --- examples/notebooks/train_pinqi.ipynb | 6 -- tests/nn/test_layernorm.py | 104 ++++++++++----------------- 2 files changed, 38 insertions(+), 72 deletions(-) diff --git a/examples/notebooks/train_pinqi.ipynb b/examples/notebooks/train_pinqi.ipynb index 3261d7550..c64bfa584 100644 --- a/examples/notebooks/train_pinqi.ipynb +++ b/examples/notebooks/train_pinqi.ipynb @@ -654,12 +654,6 @@ "\n", "\n", "if __name__ == '__main__':\n", - " import os\n", - "\n", - " os.environ['NEPTUNE_API_TOKEN'] = (\n", - " 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIyOTdlYTM3NS0wMWU1LTRlMzMtYWU1Ny01MzMzN2ExNTcwMDcifQ=='\n", - " )\n", - " os.environ['NEPTUNE_PROJECT'] = 'ptb/pinqi'\n", " torch.multiprocessing.set_sharing_strategy('file_system')\n", " torch.set_float32_matmul_precision('high')\n", " torch._inductor.config.compile_threads = 4\n", diff --git a/tests/nn/test_layernorm.py b/tests/nn/test_layernorm.py index 0123aff0a..51d6ed030 100644 --- a/tests/nn/test_layernorm.py +++ b/tests/nn/test_layernorm.py @@ -9,14 +9,14 @@ @pytest.mark.parametrize( - "device", + 'device', [ - pytest.param("cpu", id="cpu"), - pytest.param("cuda", id="cuda", marks=pytest.mark.cuda), + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), ], ) @pytest.mark.parametrize( - ("n_channels", "features_last", "input_shape"), + ('n_channels', 'features_last', 'input_shape'), [ (32, False, (1, 32, 32, 32)), (64, True, (2, 16, 16, 64)), @@ -36,27 +36,21 @@ def test_layernorm_basic( norm = LayerNorm(n_channels=n_channels, features_last=features_last).to(device) output = norm(x) - assert output.shape == x.shape, ( - f"Output shape {output.shape} != input shape {x.shape}" - ) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() - assert x.grad is not None, "No gradient computed for input" - assert not output.isnan().any(), "NaN values in output" - assert not x.grad.isnan().any(), "NaN values in input gradients" + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' if n_channels is not None: - assert norm.weight is not None, ( - "Weight should not be None when n_channels is provided" - ) - assert norm.bias is not None, ( - "Bias should not be None when n_channels is provided" - ) - assert norm.weight.grad is not None, "No gradient computed for weight" - assert norm.bias.grad is not None, "No gradient computed for bias" + assert norm.weight is not None, 'Weight should not be None when n_channels is provided' + assert norm.bias is not None, 'Bias should not be None when n_channels is provided' + assert norm.weight.grad is not None, 'No gradient computed for weight' + assert norm.bias.grad is not None, 'No gradient computed for bias' @pytest.mark.parametrize( - ("n_channels", "cond_dim", "input_shape", "cond_shape"), + ('n_channels', 'cond_dim', 'input_shape', 'cond_shape'), [ (32, 16, (1, 32, 32, 32), (1, 16)), (64, 32, (2, 64, 16, 16), (2, 32)), @@ -75,15 +69,13 @@ def test_layernorm_with_conditioning( norm = LayerNorm(n_channels=n_channels, cond_dim=cond_dim) output = norm(x, cond=cond) - assert output.shape == x.shape, ( - f"Output shape {output.shape} != input shape {x.shape}" - ) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() - assert x.grad is not None, "No gradient computed for input" - assert cond.grad is not None, "No gradient computed for conditioning" - assert norm.cond_proj is not None, "cond_proj should not be None when cond_dim > 0" - assert norm.cond_proj.weight.grad is not None, "No gradient computed for cond_proj" + assert x.grad is not None, 'No gradient computed for input' + assert cond.grad is not None, 'No gradient computed for conditioning' + assert norm.cond_proj is not None, 'cond_proj should not be None when cond_dim > 0' + assert norm.cond_proj.weight.grad is not None, 'No gradient computed for cond_proj' def test_layernorm_features_last() -> None: @@ -107,33 +99,25 @@ def test_layernorm_no_channels() -> None: norm = LayerNorm(n_channels=None) output = norm(x) - assert output.shape == x.shape, ( - f"Output shape {output.shape} != input shape {x.shape}" - ) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' # Check that normalization is applied over channel dim (dim=1 for features_last=False) mean = output.mean(dim=1, keepdim=True) var = (output * output).mean(dim=1, keepdim=True) - mean * mean - assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-5), ( - "Mean not close to 0" - ) - assert torch.allclose(var, torch.ones_like(var), atol=1e-3), ( - "Variance not close to 1" - ) + assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-5), 'Mean not close to 0' + assert torch.allclose(var, torch.ones_like(var), atol=1e-3), 'Variance not close to 1' def test_layernorm_conditioning_without_channels() -> None: """Test LayerNorm with conditioning but no channels (should raise error).""" - with pytest.raises(ValueError, match="channels must be provided if cond_dim > 0"): + with pytest.raises(ValueError, match='channels must be provided if cond_dim > 0'): LayerNorm(n_channels=None, cond_dim=16) def test_layernorm_invalid_cond_dim() -> None: """Test LayerNorm with invalid cond_dim.""" - with pytest.raises( - RuntimeError, match="Trying to create tensor with negative dimension" - ): + with pytest.raises(RuntimeError, match='Trying to create tensor with negative dimension'): LayerNorm(n_channels=32, cond_dim=-1) @@ -144,12 +128,10 @@ def test_layernorm_3d_input() -> None: norm = LayerNorm(n_channels=64) output = norm(x) - assert output.shape == x.shape, ( - f"Output shape {output.shape} != input shape {x.shape}" - ) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() - assert x.grad is not None, "No gradient computed for input" + assert x.grad is not None, 'No gradient computed for input' def test_layernorm_5d_input() -> None: @@ -159,12 +141,10 @@ def test_layernorm_5d_input() -> None: norm = LayerNorm(n_channels=32) output = norm(x) - assert output.shape == x.shape, ( - f"Output shape {output.shape} != input shape {x.shape}" - ) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' output.sum().backward() - assert x.grad is not None, "No gradient computed for input" + assert x.grad is not None, 'No gradient computed for input' def test_layernorm_conditioning_features_last() -> None: @@ -176,13 +156,11 @@ def test_layernorm_conditioning_features_last() -> None: norm = LayerNorm(n_channels=3, features_last=True, cond_dim=8) output = norm(x.moveaxis(1, -1), cond=cond) - assert output.shape == x.moveaxis(1, -1).shape, ( - f"Output shape {output.shape} != expected shape" - ) + assert output.shape == x.moveaxis(1, -1).shape, f'Output shape {output.shape} != expected shape' output.sum().backward() - assert x.grad is not None, "No gradient computed for input" - assert cond.grad is not None, "No gradient computed for conditioning" + assert x.grad is not None, 'No gradient computed for input' + assert cond.grad is not None, 'No gradient computed for conditioning' def test_layernorm_gradient_flow() -> None: @@ -196,19 +174,13 @@ def test_layernorm_gradient_flow() -> None: loss.backward() # Check that gradients are computed for all learnable parameters - assert x.grad is not None, "Input gradients not computed" - assert norm.weight is not None, ( - "Weight should not be None when n_channels is provided" - ) - assert norm.bias is not None, "Bias should not be None when n_channels is provided" - assert norm.weight.grad is not None, "Weight gradients not computed" - assert norm.bias.grad is not None, "Bias gradients not computed" + assert x.grad is not None, 'Input gradients not computed' + assert norm.weight is not None, 'Weight should not be None when n_channels is provided' + assert norm.bias is not None, 'Bias should not be None when n_channels is provided' + assert norm.weight.grad is not None, 'Weight gradients not computed' + assert norm.bias.grad is not None, 'Bias gradients not computed' # Check that gradients are finite - assert torch.isfinite(x.grad).all(), "Input gradients contain non-finite values" - assert torch.isfinite(norm.weight.grad).all(), ( - "Weight gradients contain non-finite values" - ) - assert torch.isfinite(norm.bias.grad).all(), ( - "Bias gradients contain non-finite values" - ) + assert torch.isfinite(x.grad).all(), 'Input gradients contain non-finite values' + assert torch.isfinite(norm.weight.grad).all(), 'Weight gradients contain non-finite values' + assert torch.isfinite(norm.bias.grad).all(), 'Bias gradients contain non-finite values'