From aa7ba8be883920189eb0b84d52ed36727cf4cdad Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sat, 7 Feb 2026 23:25:12 +0100 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- pyproject.toml | 10 +- src/mrpro/__init__.py | 6 +- src/mrpro/nn/ComplexAsChannel.py | 59 ++++++ src/mrpro/nn/CondMixin.py | 22 +++ src/mrpro/nn/DropPath.py | 55 ++++++ src/mrpro/nn/FiLM.py | 54 +++++ src/mrpro/nn/FourierFeatures.py | 50 +++++ src/mrpro/nn/GEGLU.py | 56 ++++++ src/mrpro/nn/GroupNorm.py | 71 +++++++ src/mrpro/nn/LayerNorm.py | 83 ++++++++ src/mrpro/nn/PermutedBlock.py | 75 +++++++ src/mrpro/nn/PixelShuffle.py | 285 +++++++++++++++++++++++++++ src/mrpro/nn/RMSNorm.py | 73 +++++++ src/mrpro/nn/Residual.py | 45 +++++ src/mrpro/nn/Sequential.py | 60 ++++++ src/mrpro/nn/Upsample.py | 66 +++++++ src/mrpro/nn/__init__.py | 45 +++++ src/mrpro/nn/convert_linear_conv.py | 100 ++++++++++ src/mrpro/nn/join.py | 176 +++++++++++++++++ src/mrpro/nn/ndmodules.py | 178 +++++++++++++++++ src/mrpro/utils/__init__.py | 5 +- src/mrpro/utils/to_tuple.py | 36 ++++ tests/nn/test_complexaschannel.py | 30 +++ tests/nn/test_convert_linear_conv.py | 150 ++++++++++++++ tests/nn/test_droppath.py | 30 +++ tests/nn/test_film.py | 42 ++++ tests/nn/test_fourierfeatures.py | 24 +++ tests/nn/test_geglu.py | 38 ++++ tests/nn/test_groupnorm.py | 45 +++++ tests/nn/test_join.py | 160 +++++++++++++++ tests/nn/test_layernorm.py | 187 ++++++++++++++++++ tests/nn/test_ndmodules.py | 76 +++++++ tests/nn/test_pixelshuffle.py | 92 +++++++++ tests/nn/test_rmsnorm.py | 58 ++++++ tests/nn/test_sequential.py | 50 +++++ 35 files changed, 2586 insertions(+), 6 deletions(-) create mode 100644 src/mrpro/nn/ComplexAsChannel.py create mode 100644 src/mrpro/nn/CondMixin.py create mode 100644 src/mrpro/nn/DropPath.py create mode 100644 src/mrpro/nn/FiLM.py create mode 100644 src/mrpro/nn/FourierFeatures.py create mode 100644 src/mrpro/nn/GEGLU.py create mode 100644 src/mrpro/nn/GroupNorm.py create mode 100644 src/mrpro/nn/LayerNorm.py create mode 100644 src/mrpro/nn/PermutedBlock.py create mode 100644 src/mrpro/nn/PixelShuffle.py create mode 100644 src/mrpro/nn/RMSNorm.py create mode 100644 src/mrpro/nn/Residual.py create mode 100644 src/mrpro/nn/Sequential.py create mode 100644 src/mrpro/nn/Upsample.py create mode 100644 src/mrpro/nn/__init__.py create mode 100644 src/mrpro/nn/convert_linear_conv.py create mode 100644 src/mrpro/nn/join.py create mode 100644 src/mrpro/nn/ndmodules.py create mode 100644 src/mrpro/utils/to_tuple.py create mode 100644 tests/nn/test_complexaschannel.py create mode 100644 tests/nn/test_convert_linear_conv.py create mode 100644 tests/nn/test_droppath.py create mode 100644 tests/nn/test_film.py create mode 100644 tests/nn/test_fourierfeatures.py create mode 100644 tests/nn/test_geglu.py create mode 100644 tests/nn/test_groupnorm.py create mode 100644 tests/nn/test_join.py create mode 100644 tests/nn/test_layernorm.py create mode 100644 tests/nn/test_ndmodules.py create mode 100644 tests/nn/test_pixelshuffle.py create mode 100644 tests/nn/test_rmsnorm.py create mode 100644 tests/nn/test_sequential.py diff --git a/pyproject.toml b/pyproject.toml index 67aeb1928..9863b3856 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,6 +98,7 @@ docs = [ "sphinx-autodoc-typehints>=3, <3.1", "sphinx-copybutton>=0.5, <0.6", "sphinx-last-updated-by-git>=0.3, <0.4", + "snowballstemmer>=2.2, <3.0", ] notebooks = [ "zenodo_get>=2.0", @@ -118,10 +119,12 @@ testpaths = ["tests"] filterwarnings = [ "error", "ignore:'write_like_original':DeprecationWarning:pydicom:", - "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd - "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 + "ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd + "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5 + "ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda "ignore:.*In the future, this object will be coerced as if it was first converted using.*:FutureWarning", # numpy 1.2 - "ignore:`torch.jit.script` is deprecated:DeprecationWarning", # torch 2.10 + "ignore:.*load_module.*:DeprecationWarning", # torch compile in torch<2.6 + "ignore:`torch.jit.script` is deprecated:DeprecationWarning", # torch 2.10 ] addopts = "-n auto --dist loadfile --maxprocesses=8" markers = ["cuda : Tests only to be run when cuda device is available"] @@ -230,6 +233,7 @@ iy = "iy" arange = "arange" # torch.arange Ba = "Ba" wht = "wht" # Brainweb tissue class +ND = "ND" # Short for N-dimensional [tool.typos.files] extend-exclude = [ diff --git a/src/mrpro/__init__.py b/src/mrpro/__init__.py index 729ae188c..bbd401f1f 100644 --- a/src/mrpro/__init__.py +++ b/src/mrpro/__init__.py @@ -1,10 +1,12 @@ from mrpro._version import __version__ -from mrpro import algorithms, operators, data, phantoms, utils +from mrpro import algorithms, operators, data, phantoms, utils, nn + __all__ = [ "__version__", "algorithms", "data", + "nn", "operators", "phantoms", "utils" -] +] \ No newline at end of file diff --git a/src/mrpro/nn/ComplexAsChannel.py b/src/mrpro/nn/ComplexAsChannel.py new file mode 100644 index 000000000..22e13458e --- /dev/null +++ b/src/mrpro/nn/ComplexAsChannel.py @@ -0,0 +1,59 @@ +"""ComplexAsChannel: handling complex-valued tensors as channels.""" + +import torch +from einops import rearrange +from torch.nn import Module + +from mrpro.nn.CondMixin import CondMixin, call_with_cond + + +class ComplexAsChannel(CondMixin, Module): + """Wrap module to treat complex numbers as a channel dimension.""" + + def __init__(self, module: Module, convert_back: bool = True): + """Initialize the ComplexAsChannel module. + + Wraps a module to treat complex numbers as a channel dimension. + For each complex tensor in the input, real and imaginary parts are concatenated along the channel dimension + before being passed to the wrapped module. + + + Parameters + ---------- + module + The module to wrap. Should output a single real tensor. + convert_back + If True, the output is converted back to a complex tensor. + The output should have a number of channels that is a multiple of 2. + """ + super().__init__() + self.module = module + self.convert_back = convert_back + + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor (if used by the wrapped module) + """ + return super().__call__(*x, cond=cond) + + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module.""" + x_real = [ + rearrange(torch.view_as_real(c), 'batch channel ... complex -> batch (channel complex) ...') + if c.is_complex() + else c + for c in x + ] + + y = call_with_cond(self.module, *x_real, cond=cond) + + if self.convert_back: + y = rearrange(y, 'b (channel complex) ... -> b channel ... complex', complex=2).contiguous() + y = torch.view_as_complex(y) + return y diff --git a/src/mrpro/nn/CondMixin.py b/src/mrpro/nn/CondMixin.py new file mode 100644 index 000000000..6a902c413 --- /dev/null +++ b/src/mrpro/nn/CondMixin.py @@ -0,0 +1,22 @@ +"""Base class for modules using a conditioning.""" + +import torch +from torch.nn import Module + + +def call_with_cond(module: Module, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Call a module with conditioning if it is a CondMixin.""" + if isinstance(module, CondMixin): + return module(*x, cond=cond) + return module(*x) + + +class CondMixin(Module): + """Mixin for modules using a conditioning. + + Used to determine if a module uses a conditioning within a Sequential container. + """ + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module to the input.""" + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/DropPath.py b/src/mrpro/nn/DropPath.py new file mode 100644 index 000000000..7262fd86c --- /dev/null +++ b/src/mrpro/nn/DropPath.py @@ -0,0 +1,55 @@ +"""DropPath (stochastic depth).""" + +import torch +from torch.nn import Module + + +class DropPath(Module): + """Drop path or stochastic depth. + + Drops full samples from batch with probability `droprate`. + Should be used in the main path of a Resblock. + + References + ---------- + .. [HUANG16] Huang, G., Sun, Y., Liu, Z., Sedra, D., & Weinberger, K. Q. Deep networks with stochastic depth. + ECCV 2016. https://link.springer.com/chapter/10.1007/978-3-319-46493-0_39 + """ + + def __init__(self, droprate: float = 0.0, scale_by_keep: bool = False): + """Initialize the DropPath module. + + Parameters + ---------- + droprate + Drop probability + scale_by_keep + If True, the kept samples are scaled by :math:`1/(1-droprate)` + """ + super().__init__() + self.droprate = droprate + self.scale_by_keep = scale_by_keep + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply DropPath. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Tensor with batch samples randomly dropped + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply DropPath.""" + if self.droprate == 0 or not self.training: + return x + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + mask = ((1 - self.droprate) + torch.rand(shape, dtype=x.dtype, device=x.device)).floor_() + if self.scale_by_keep: + mask = mask.div_(1 - self.droprate) + return x * mask diff --git a/src/mrpro/nn/FiLM.py b/src/mrpro/nn/FiLM.py new file mode 100644 index 000000000..92780aae3 --- /dev/null +++ b/src/mrpro/nn/FiLM.py @@ -0,0 +1,54 @@ +"""Feature-wise Linear Modulation.""" + +import torch +from torch.nn import Linear, Module + +from mrpro.nn.CondMixin import CondMixin +from mrpro.utils.reshape import unsqueeze_tensors_right + + +class FiLM(CondMixin, Module): + """Feature-wise Linear Modulation. + + Feature-wise Linear Modulation from [FiLM]_ to condition a network on a conditioning tensor. + + + References + ---------- + ..[FiLM] Perez, L., Strub, F., de Vries, H., Dumoulin, V., & Courville, A. "FiLM : Visual reasoning with a general + conditioning layer." AAAI (2018). https://arxiv.org/abs/1709.07871 + """ + + def __init__(self, channels: int, cond_dim: int) -> None: + """Initialize FiLM. + + Parameters + ---------- + channels + The number of channels in the input tensor. + cond_dim + The dimension of the conditioning tensor. + """ + super().__init__() + self.project = Linear(cond_dim, 2 * channels) if cond_dim > 0 else None + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply FiLM. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply FiLM.""" + if cond is None or self.project is None: + return x + scale, shift = self.project(cond).chunk(2, dim=1) + + scale, shift = unsqueeze_tensors_right(scale, shift, ndim=x.ndim) + return x * (1 + scale) + shift diff --git a/src/mrpro/nn/FourierFeatures.py b/src/mrpro/nn/FourierFeatures.py new file mode 100644 index 000000000..847ae3c60 --- /dev/null +++ b/src/mrpro/nn/FourierFeatures.py @@ -0,0 +1,50 @@ +"""Random Fourier feature embedding.""" + +import torch +from torch.nn import Module + + +class FourierFeatures(Module): + """Fourier feature encoding layer. + + Projects input features into a higher dimensional space using random Fourier features. + Used in INRs and to embed the time or other continuous variables. + """ + + weight: torch.Tensor + + def __init__(self, n_features_in: int, n_features_out: int, std: float = 1.0): + """Initialize Fourier feature encoding layer. + + Parameters + ---------- + n_features_in + Number of input features + n_features_out + Number of output features (must be even) + std + Standard deviation for random initialization + """ + if n_features_out % 2 != 0: + raise ValueError('n_features_out must be even.') + super().__init__() + self.register_buffer('weight', torch.randn([n_features_out // 2, n_features_in]) * std) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply Fourier feature encoding. + + Parameters + ---------- + x + Input tensor of shape (..., in_features) + + Returns + ------- + Encoded features of shape (..., out_features) + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply Fourier feature encoding.""" + f = 2 * torch.pi * x @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) diff --git a/src/mrpro/nn/GEGLU.py b/src/mrpro/nn/GEGLU.py new file mode 100644 index 000000000..6151503d2 --- /dev/null +++ b/src/mrpro/nn/GEGLU.py @@ -0,0 +1,56 @@ +"""Gated linear unit activation function.""" + +import torch +from torch.nn import Linear, Module + + +class GEGLU(Module): + r"""Gated linear unit activation function. + + References + ---------- + ..[GLU] Shazeer, N. (2020). GLU variants improve transformer. https://arxiv.org/abs/2002.05202 + """ + + def __init__(self, n_channels_in: int, n_channels_out: int | None = None, features_last: bool = False): + """Initialize the GEGLU activation function. + + Parameters + ---------- + n_channels_in + The number of input features/channels. + n_channels_out + The number of output features/channels. If None, the number of + output features is the same as the number of input features. + features_last + If True, the channel dimension is the last dimension, else in the second dimension. + """ + super().__init__() + out_channels_ = n_channels_in if n_channels_out is None else n_channels_out + self.proj = Linear(n_channels_in, out_channels_ * 2) # gate and output stacked + self.features_last = features_last + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the GEGLU activation.""" + if not self.features_last: + x = x.moveaxis(1, -1) + h, gate = self.proj(x).chunk(2, dim=-1) + gate = torch.nn.functional.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + out = h * gate + if not self.features_last: + out = out.moveaxis(-1, 1) + return out + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the GEGLU activation. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Activated tensor + """ + return super().__call__(x) diff --git a/src/mrpro/nn/GroupNorm.py b/src/mrpro/nn/GroupNorm.py new file mode 100644 index 000000000..e0bbb0a4b --- /dev/null +++ b/src/mrpro/nn/GroupNorm.py @@ -0,0 +1,71 @@ +"""GroupNorm with 32-bit precision.""" + +import torch + + +class GroupNorm(torch.nn.GroupNorm): + """A 32-bit GroupNorm with (optional) automatic group size selection. + + Casts to float32 before calling the parent class to avoid instabilities in mixed precision training. + + If `n_groups` is not provided, the number of groups is selected automatically as follows: + + - start from `1` group, + - try powers of two (`2, 4, 8, ...`), + - keep the largest candidate that divides `n_channels`, + - enforce at most `32` groups and at least `4` channels per group. + + This yields a stable default that stays close to common GroupNorm choices while + adapting to small channel counts. + """ + + features_last: bool + + def __init__(self, n_channels: int, n_groups: int | None = None, affine: bool = False, features_last: bool = False): + """Initialize GroupNorm. + + Parameters + ---------- + n_channels + The number of channels in the input tensor. + n_groups + The number of groups to use. If None, the number of groups is determined automatically as + the largest power of 2 that divides `n_channels`, is less than or equal to 32, + and leaves at least 4 channels per group. + affine + Whether to use learnable affine parameters. + features_last + Whether the features are last in the input tensor, as common in transformer models, + or in the second dimension, as common in CNNs. + """ + if n_groups is None: + groups_, candidate = 1, 2 + while (candidate <= min(32, n_channels // 4)) and (n_channels % candidate == 0): + groups_, candidate = candidate, groups_ * 2 + else: + groups_ = n_groups + self.features_last: bool = features_last + super().__init__(groups_, n_channels, affine=affine) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply GroupNorm32. + + Parameters + ---------- + x + The input tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x.float()).type(x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply GroupNorm.""" + if self.features_last: + x = x.moveaxis(-1, 1) + result = super().forward(x.float()).type(x.dtype) + if self.features_last: + result = result.moveaxis(1, -1) + return result diff --git a/src/mrpro/nn/LayerNorm.py b/src/mrpro/nn/LayerNorm.py new file mode 100644 index 000000000..7c35eee96 --- /dev/null +++ b/src/mrpro/nn/LayerNorm.py @@ -0,0 +1,83 @@ +"""Layer normalization.""" + +import torch +from torch.nn import Linear, Module, Parameter + +from mrpro.nn.CondMixin import CondMixin +from mrpro.utils.reshape import unsqueeze_at, unsqueeze_right + + +class LayerNorm(CondMixin, Module): + """Layer normalization.""" + + def __init__(self, n_channels: int | None, features_last: bool = False, cond_dim: int = 0) -> None: + """Initialize the layer normalization. + + Parameters + ---------- + n_channels + Number of channels in the input tensor. If `None`, the layer normalization does not do an elementwise + affine transformation. + features_last + If `True`, the channel dimension is the last dimension. + cond_dim + Number of channels in the conditioning tensor. If `0`, no adaptive scaling is applied. + """ + super().__init__() + if n_channels is None and cond_dim == 0: + self.weight: Parameter | None = None + self.bias: Parameter | None = None + self.cond_proj: Linear | None = None + elif n_channels is None and cond_dim > 0: + raise ValueError('channels must be provided if cond_dim > 0') + elif n_channels is not None and cond_dim == 0: + self.weight = Parameter(torch.ones(n_channels)) + self.bias = Parameter(torch.zeros(n_channels)) + self.cond_proj = None + elif n_channels is not None: + self.weight = None + self.bias = None + self.cond_proj = Linear(cond_dim, 2 * n_channels) + else: + raise ValueError('cond_dim must be zero or positive.') + + self.features_last = features_last + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply layer normalization to the input tensor. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor. If `None`, no conditioning is applied. + + Returns + ------- + Normalized output tensor + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply layer normalization to the input tensor.""" + dims = tuple(range(1, x.ndim)) + mean = x.mean(dim=dims, keepdim=True) + std = x.std(dim=dims, keepdim=True, unbiased=False) + x = (x - mean) / (std + 1e-5) + + if self.weight is not None and self.bias is not None: + if self.features_last: + x = x * self.weight + self.bias + else: + x = x * unsqueeze_right(self.weight, x.ndim - 2) + unsqueeze_right(self.bias, x.ndim - 2) + + if self.cond_proj is not None and cond is not None: + scale, shift = self.cond_proj(cond).chunk(2, dim=-1) + scale = 1 + scale + if self.features_last: + x = x * unsqueeze_at(scale, 1, x.ndim - 2) + unsqueeze_at(shift, 1, x.ndim - 2) + else: + x = x * unsqueeze_right(scale, x.ndim - 2) + unsqueeze_right(shift, x.ndim - 2) + + return x diff --git a/src/mrpro/nn/PermutedBlock.py b/src/mrpro/nn/PermutedBlock.py new file mode 100644 index 000000000..935d114dc --- /dev/null +++ b/src/mrpro/nn/PermutedBlock.py @@ -0,0 +1,75 @@ +"""Block that applies a submodule along selected spatial dimensions.""" + +from collections.abc import Sequence + +import torch +from torch import nn + +from mrpro.nn.CondMixin import CondMixin, call_with_cond + + +class PermutedBlock(CondMixin, nn.Module): + """Apply a submodule along selected spatial dimensions.""" + + apply_along_dim: tuple[int, ...] + module: nn.Module + + def __init__(self, apply_along_dim: Sequence[int], module: nn.Module, features_last: bool = False): + """Initialize the PermutedBlock. + + Parameters + ---------- + apply_along_dim + Spatial dimension indices to use when applying the module. + These will be moved to the last dimensions. + module + Module to apply on the selected dims. + features_last + If True, the features dimension is assumed to be the last dimension, as common in transformer models. + """ + super().__init__() + self.apply_along_dim = tuple(sorted(apply_along_dim)) + self.module = module + self.features_last = features_last + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module along the selected dimensions. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor, passed to the module if it supports conditioning + (that is, if it is a subclass of `~mrpro.nn.CondMixin`) + + Returns + ------- + Output tensor. + """ + return self.forward(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module along the selected dimensions.""" + keep = tuple(d % x.ndim for d in self.apply_along_dim) + if 0 in keep: + raise ValueError('Batch dimension should not be in apply_along_dim.') + if self.features_last: + if x.ndim - 1 in keep: + raise ValueError('Features dimension should not be in apply_along_dim.') + batch_dim = tuple(d for d in range(1, x.ndim - 1) if d not in keep) + permute = (0, *batch_dim, *keep, x.ndim - 1) + else: + if 1 in keep: + raise ValueError('Features dimension should not be in apply_along_dim.') + batch_dim = tuple(d for d in range(2, x.ndim) if d not in keep) + permute = (0, *batch_dim, 1, *keep) + h = x.permute(permute) + batch_shape = h.shape[: 1 + len(batch_dim)] + h = h.flatten(0, len(batch_dim)) + h = call_with_cond(self.module, h, cond=cond) + h = h.unflatten(0, batch_shape) + permute_back = [0] * x.ndim + for i, p in enumerate(permute): + permute_back[p] = i + return h.permute(tuple(permute_back)) diff --git a/src/mrpro/nn/PixelShuffle.py b/src/mrpro/nn/PixelShuffle.py new file mode 100644 index 000000000..afcff9335 --- /dev/null +++ b/src/mrpro/nn/PixelShuffle.py @@ -0,0 +1,285 @@ +"""ND-version of PixelShuffle and PixelUnshuffle.""" + +from math import ceil + +import torch +from torch.nn import Linear, Module + +from mrpro.nn.ndmodules import convND + + +class PixelUnshuffle(Module): + """ND-version of PixelUnshuffle downscaling.""" + + def __init__(self, downscale_factor: int, features_last: bool = False): + """Initialize PixelUnshuffle. + + Reduces spatial dimensions and increases the channel number by reshaping. + The first dimension is considered a batch dimension, the second dimension + the channel dimension, and the remaining dimensions the spatial dimensions that are downscaled. + + See `mrpro.nn.PixelShuffle` for the inverse operation. + + Parameters + ---------- + downscale_factor + The factor by which to downscale the input tensor. + features_last + Whether the features/channels dimension is the last dimension as common in transformer models or the + second dimension as common in CNN models. + """ + super().__init__() + self.downscale_factor = downscale_factor + self.features_last = features_last + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Downscale the input. + + Parameters + ---------- + x + Tensor of shape `batch, channels, *spatial_dims` or `batch, *spatial_dims, channels` (if `features_last`). + + Returns + ------- + Tensor of shape `batch, channels * downscale_factor**dim, *spatial_dims/downscale_factor` or + `batch, *spatial_dims/downscale_factor, channels * downscale_factor**dim` (if `features_last`). + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Downscale the input.""" + n_dim = x.ndim - 2 + if n_dim == 2 and not self.features_last: # fast path for 2D images + return torch.nn.functional.pixel_unshuffle(x, self.downscale_factor) + + new_shape = list(x.shape[:1]) if self.features_last else list(x.shape[:2]) + source_positions = [] + for i, old in enumerate(x.shape[1:-1] if self.features_last else x.shape[2:]): + if old % self.downscale_factor: + raise ValueError('Spatial size must be divisible by downscale_factor.') + new_shape.append(old // self.downscale_factor) + new_shape.append(self.downscale_factor) + source_positions.append(2 + 2 * i) + if self.features_last: + new_shape.append(x.shape[-1]) + x = x.view(new_shape) + x = x.moveaxis(source_positions, tuple(range(-n_dim, 0))) + if self.features_last: + x = x.flatten(-n_dim - 1) + else: + x = x.flatten(1, -n_dim - 1) + + return x + + +class PixelUnshuffleDownsample(Module): + """PixelUnshuffle Downsampling. + + PixelUnshuffle followed by a convolution. Optionally uses a residual connection [DCAE]_ + + References + ---------- + .. [DCAE] Chen et al. Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models. ICLR 2025 + https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + downscale_factor: int = 2, + residual: bool = False, + features_last: bool = False, + ): + """Initialize a PixelUnshuffleDownsample layer. + + Parameters + ---------- + n_dim + Dimension of the input tensor. + n_channels_in + Number of channels in the input tensor. + n_channels_out + Number of channels in the output tensor. + downscale_factor + Factor by which to downscale the input tensor. + residual + Whether to use a residual connection as proposed in [DCAE]_. + features_last + Whether the features are last in the input tensor, as common in transformer models, + or in the second dimension, as common in CNNs. + """ + super().__init__() + out_ratio = downscale_factor**n_dim + if n_channels_out % out_ratio != 0: + raise ValueError(f'channels_out must be divisible by downscale_factor**{n_dim}.') + if features_last: + self.projection: Module = Linear(n_channels_in, n_channels_out // out_ratio) + else: + self.projection = convND(n_dim)(n_channels_in, n_channels_out // out_ratio, kernel_size=3, padding='same') + self.features_last = features_last + self.residual = residual + self.pixel_unshuffle = PixelUnshuffle(downscale_factor, features_last) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply downsampling. + + Parameters + ---------- + x + Tensor of shape `batch, channels_in, *spatial_dims` + + Returns + ------- + Tensor of shape `batch, channels_out, *spatial_dims/downscale_factor` + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply downsampling.""" + h = self.projection(x) + h = self.pixel_unshuffle(h) + + if self.residual: + x = self.pixel_unshuffle(x) + if self.features_last: + n = (x.shape[-1] // h.shape[-1]) * h.shape[-1] + h = h + x[..., :n].unflatten(-1, (h.shape[-1], -1)).mean(-1) + else: + n = (x.shape[1] // h.shape[1]) * h.shape[1] + h = h + x[:, :n].unflatten(1, (h.shape[1], -1)).mean(2) + return h + + +class PixelShuffleUpsample(Module): + """PixelShuffle Upsampling. + + Convolution followed by PixelShuffle. Optionally uses a residual connection [DCAE]_ + + References + ---------- + .. [DCAE] Chen et al. Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models. ICLR 2025 + https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + upscale_factor: int = 2, + residual: bool = False, + features_last: bool = False, + ): + """Initialize a PixelShuffleUpsample layer. + + Parameters + ---------- + n_dim + Dimension of the input tensor. + n_channels_in + Number of channels in the input tensor. + n_channels_out + Number of channels in the output tensor. + upscale_factor + Factor by which to upscale the input tensor. + residual + Whether to use a residual connection as proposed in [DCAE]_. + features_last + Whether the features are last in the input tensor, as common in transformer models, + or in the second dimension, as common in CNNs. + """ + super().__init__() + if features_last: + self.projection: Module = Linear(n_channels_in, n_channels_out * upscale_factor**n_dim) + else: + self.projection = convND(n_dim)( + n_channels_in, n_channels_out * upscale_factor**n_dim, kernel_size=3, padding='same' + ) + self.features_last = features_last + self.pixel_shuffle = PixelShuffle(upscale_factor, features_last) + self.residual = residual + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply upsampling. + + Parameters + ---------- + x + Tensor of shape `batch, channels_in, *spatial_dims` + + Returns + ------- + Tensor of shape `batch, channels_out, *spatial_dims * upscale_factor` + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply upsampling.""" + h = self.projection(x) + if self.residual: + if self.features_last: + h = h + x.repeat_interleave(ceil(h.shape[-1] / x.shape[-1]), dim=-1)[..., : h.shape[-1]] + else: + h = h + x.repeat_interleave(ceil(h.shape[1] / x.shape[1]), dim=1)[:, : h.shape[1]] + out = self.pixel_shuffle(h) + return out + + +class PixelShuffle(Module): + """ND-version of PixelShuffle upscaling.""" + + def __init__(self, upscale_factor: int, features_last: bool = False): + """Initialize PixelShuffle. + + Upscales spatial dimensions and decreases the channel number by reshaping. + The first dimension is considered a batch dimension, the second dimension + the channel dimension, and the remaining dimensions the spatial dimensions that are upscaled. + + See `mrpro.nn.PixelUnshuffle` for the inverse operation. + + Parameters + ---------- + upscale_factor + The factor by which to upscale the spatial dimensions. + features_last + Whether the features/channels dimension is the last dimension as common in transformer models or the + second dimension as common in CNN models. + """ + super().__init__() + self.upscale_factor = upscale_factor + self.features_last = features_last + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Upscale the input. + + Parameters + ---------- + x + Tensor of shape `batch, channels, *spatial_dims` or `batch, *spatial_dims, channels` (if `features_last`). + + Returns + ------- + Tensor of shape `batch, channels / upscale_factor**n_dim, *spatial_dims * upscale_factor` or + `batch, *spatial_dims * upscale_factor, channels / upscale_factor**n_dim` (if `features_last`). + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Upscale the input.""" + n_dim = x.ndim - 2 + if n_dim == 2 and not self.features_last: # fast path for 2D + return torch.nn.functional.pixel_shuffle(x, self.upscale_factor) + + if self.features_last: + new_shape = (x.shape[0], *(old * self.upscale_factor for old in x.shape[-n_dim - 1 : -1]), -1) + x = x.unflatten(-1, (-1, *(self.upscale_factor,) * n_dim)) + x = x.moveaxis(tuple(range(-n_dim, 0)), tuple(range(-2 * n_dim, 0, 2))) + else: + new_shape = (x.shape[0], -1, *(old * self.upscale_factor for old in x.shape[-n_dim:])) + x = x.unflatten(1, (-1, *(self.upscale_factor,) * n_dim)) + x = x.moveaxis(tuple(range(2, 2 + n_dim)), tuple(range(-2 * n_dim + 1, 0, 2))) + x = x.reshape(new_shape) + return x diff --git a/src/mrpro/nn/RMSNorm.py b/src/mrpro/nn/RMSNorm.py new file mode 100644 index 000000000..b97641545 --- /dev/null +++ b/src/mrpro/nn/RMSNorm.py @@ -0,0 +1,73 @@ +"""RMSNorm over the channel dimension.""" + +import torch +from torch.nn import Module, Parameter + + +class RMSNorm(Module): + """RMSNorm over the channel dimension. + + As used in the DCAE [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_channels: int | None = None, + eps: float = 1e-8, + features_last: bool = False, + ): + """Initialize RMSNorm. + + Includes a learnable weight and bias if n_channels is provided. + + Parameters + ---------- + n_channels + Number of channels. If `None`, no learnable weight and bias are included. + eps + Epsilon value to avoid division by zero. + features_last + If True, the channel dimension is the last dimension. + """ + super().__init__() + if n_channels is not None: + self.weight: Parameter | None = Parameter(torch.zeros(n_channels)) + self.bias: Parameter | None = Parameter(torch.zeros(n_channels)) + else: + self.weight = None + self.bias = None + self.eps = eps + self.channel_dim = -1 if features_last else 1 + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply RMSNorm over the channel dimension. + + Parameters + ---------- + x + Input tensor. + + Returns + ------- + Normalized tensor. + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply RMSNorm over the channel dimension.""" + x32 = x.to(torch.float32) # normalization in float32 to stabilize mixed precision training + mean_square = x32.square().mean(dim=self.channel_dim, keepdim=True) + scale = (mean_square + self.eps).rsqrt() + x32 = x32 * scale + if self.weight is not None and self.bias is not None: + shape = [1] * x.ndim + shape[self.channel_dim] = -1 + weight = (self.weight.to(x32.dtype) + 1).view(shape) + bias = self.bias.view(shape) + x32 = x32 * weight + bias + return x32.to(x.dtype) diff --git a/src/mrpro/nn/Residual.py b/src/mrpro/nn/Residual.py new file mode 100644 index 000000000..e524fe169 --- /dev/null +++ b/src/mrpro/nn/Residual.py @@ -0,0 +1,45 @@ +"""Residual connection.""" + +import torch +from torch.nn import Identity, Module + +from mrpro.nn.CondMixin import CondMixin, call_with_cond + + +class Residual(CondMixin, Module): + """Residual connection.""" + + def __init__(self, module: Module, skip: Module | None = None): + """Initialize the residual connection. + + Parameters + ---------- + module + The main path of the residual connection. + skip + The skip path of the residual connection. If None, the identity function is used. + """ + super().__init__() + self.module = module + self.skip = Identity() if skip is None else skip + + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module. + + Parameters + ---------- + x + The input tensor. + cond + The optional conditioning tensor. If the modules are an instance of `CondMixin`, + the conditioning is passed to the modules. + + Returns + ------- + The output tensor. + """ + return super().__call__(*x, cond=cond) + + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the module.""" + return call_with_cond(self.module, *x, cond=cond) + call_with_cond(self.skip, *x, cond=cond) diff --git a/src/mrpro/nn/Sequential.py b/src/mrpro/nn/Sequential.py new file mode 100644 index 000000000..fed22c3a2 --- /dev/null +++ b/src/mrpro/nn/Sequential.py @@ -0,0 +1,60 @@ +"""Sequential container with support for conditioning and Operators.""" + +from collections import OrderedDict +from typing import cast + +import torch + +from mrpro.nn.CondMixin import CondMixin +from mrpro.operators import Operator + + +class Sequential(CondMixin, torch.nn.Sequential): + """Sequential container with support for conditioning and Operators. + + Allows multiple input tensors and a single output tensor of the sequential block. + + """ + + def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply all modules in series to the input. + + Parameters + ---------- + x + The input tensor. + cond + The (optional) conditioning tensor. + + Returns + ------- + The output tensor. + """ + return torch.nn.Sequential.__call__(self, *x, cond=cond) + + def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply all modules in series to the input.""" + for module in self: + if isinstance(module, Operator): + x = cast(tuple[torch.Tensor, ...], module(*x)) # always tuple + else: + ret: torch.Tensor | tuple[torch.Tensor, ...] + if isinstance(module, CondMixin): + ret = module(*x, cond=cond) + else: + ret = module(*x) + if isinstance(ret, tuple): + x = ret + else: + x = (ret,) + return x[0] + + def __getitem__(self, idx: slice | int) -> 'Sequential': + """Get a slice or item from the Sequential container. + + Subclasses will decompose to `Sequential` on indexing. + """ + if isinstance(idx, slice): + return Sequential(OrderedDict(list(self._modules.items())[idx])) + else: + return cast(Sequential, self._get_item_by_idx(self._modules.values(), idx)) diff --git a/src/mrpro/nn/Upsample.py b/src/mrpro/nn/Upsample.py new file mode 100644 index 000000000..acced8d48 --- /dev/null +++ b/src/mrpro/nn/Upsample.py @@ -0,0 +1,66 @@ +"""Upsampling by interpolation.""" + +from collections.abc import Sequence +from typing import Literal + +import torch +from torch.nn import Module, Sequential + +from mrpro.nn.PermutedBlock import PermutedBlock + + +class Upsample(Module): + """Upsampling by interpolation.""" + + def __init__( + self, dim: Sequence[int], scale_factor: int = 2, mode: Literal['nearest', 'linear', 'cubic'] = 'linear' + ): + """Initialize the upsampling layer. + + Parameters + ---------- + dim + Dimensions which to upsample + scale_factor + Factor by which to upsample + mode + Interpolation mode. See `torch.nn.functional.interpolate` for details. + """ + super().__init__() + self.scale_factor = scale_factor + if mode == 'nearest': + dims = [d.tolist() for d in torch.tensor(dim).split(3)] + modes = ['nearest'] * len(dim) + elif mode == 'linear': + dims = [d.tolist() for d in torch.tensor(dim).split(3)] + modes = [{1: 'linear', 2: 'bilinear', 3: 'trilinear'}[len(d)] for d in dims] + elif mode == 'cubic': + if not len(dim) == 2: + raise ValueError('Cubic interpolation is only supported for 2D images.') + dims = [tuple(dim)] + modes = ['bicubic'] + + self.blocks = Sequential( + *[ + PermutedBlock(d, torch.nn.Upsample(scale_factor=len(d) * (scale_factor,), mode=m)) + for d, m in zip(dims, modes, strict=False) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Upsample the input tensor.""" + return self.blocks(x) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Upsample the input tensor. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Upsampled tensor + """ + return super().__call__(x) diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py new file mode 100644 index 000000000..d6541f5c8 --- /dev/null +++ b/src/mrpro/nn/__init__.py @@ -0,0 +1,45 @@ +"""Neural network modules and utilities.""" + +from mrpro.nn.ComplexAsChannel import ComplexAsChannel +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.DropPath import DropPath +from mrpro.nn.FiLM import FiLM +from mrpro.nn.FourierFeatures import FourierFeatures +from mrpro.nn.GEGLU import GEGLU +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.RMSNorm import RMSNorm +from mrpro.nn.Residual import Residual +from mrpro.nn.Sequential import Sequential +from mrpro.nn.ndmodules import ( + adaptiveAvgPoolND, + avgPoolND, + batchNormND, + convND, + convTransposeND, + instanceNormND, + maxPoolND, +) + +__all__ = [ + 'ComplexAsChannel', + 'CondMixin', + 'DropPath', + 'FiLM', + 'FourierFeatures', + 'GEGLU', + 'GroupNorm', + 'LayerNorm', + 'PermutedBlock', + 'RMSNorm', + 'Residual', + 'Sequential', + 'adaptiveAvgPoolND', + 'avgPoolND', + 'batchNormND', + 'convND', + 'convTransposeND', + 'instanceNormND', + 'maxPoolND', +] diff --git a/src/mrpro/nn/convert_linear_conv.py b/src/mrpro/nn/convert_linear_conv.py new file mode 100644 index 000000000..767a419ff --- /dev/null +++ b/src/mrpro/nn/convert_linear_conv.py @@ -0,0 +1,100 @@ +"""Convert Linear layers to kernel size 1 ConvNd layers and vice versa.""" + +from typing import Literal, overload + +import torch +from torch.nn import Conv1d, Conv2d, Conv3d, Linear + +from mrpro.nn.ndmodules import convND + + +@overload +def linear_to_conv(linear_layer: Linear, n_dim: Literal[1]) -> Conv1d: ... + + +@overload +def linear_to_conv(linear_layer: Linear, n_dim: Literal[2]) -> Conv2d: ... + + +@overload +def linear_to_conv(linear_layer: Linear, n_dim: Literal[3]) -> Conv3d: ... + + +@overload +def linear_to_conv(linear_layer: Linear, n_dim: int) -> Conv1d | Conv2d | Conv3d: ... + + +def linear_to_conv(linear_layer: Linear, n_dim: int) -> Conv1d | Conv2d | Conv3d: + """Convert a Linear layer to a ConvNd layer with kernel size 1. + + Rearranging the spatial dimensions to the batch dimension, + applying the linear layer and rearranging the spatial dimensions back + is equivalent to applying a kernel size 1 ConvNd layer. + + This function will create the Conv1d, Conv2d, or Conv3d with the correct weights and bias. + + See :func:`conv_to_linear` for the reverse operation. + + + + Parameters + ---------- + linear_layer + The linear layer to convert. + n_dim + The convolution dimension (1, 2, or 3). + + Returns + ------- + A Conv layer with equivalent weights and bias. + """ + conv = convND(n_dim)( + in_channels=linear_layer.in_features, + out_channels=linear_layer.out_features, + kernel_size=1, + bias=linear_layer.bias is not None, + device=linear_layer.weight.device, + dtype=linear_layer.weight.dtype, + ) + + with torch.no_grad(): + conv.weight.copy_(linear_layer.weight.view_as(conv.weight)) + if conv.bias is not None and linear_layer.bias is not None: + conv.bias.copy_(linear_layer.bias) + + return conv + + +def conv_to_linear(conv_layer: Conv1d | Conv2d | Conv3d) -> Linear: + """ + Convert a Conv1d, Conv2d, or Conv3d layer with kernel size 1 to a Linear layer. + + Applying a kernel size 1 ConvNd layer is equivalent to applying a Linear layer to each voxel. + This function will create the Linear layer with the correct weights and bias. + + See :func:`linear_to_conv` for the reverse operation. + + Parameters + ---------- + conv_layer : nn.Module + The convolutional layer to convert. Must have kernel size 1. + + Returns + ------- + A linear layer with equivalent weights and bias. + """ + if not all(k == 1 for k in conv_layer.kernel_size): + raise ValueError('Kernel size must be 1 for conversion.') + linear = Linear( + conv_layer.in_channels, + conv_layer.out_channels, + bias=conv_layer.bias is not None, + device=conv_layer.weight.device, + dtype=conv_layer.weight.dtype, + ) + with torch.no_grad(): + linear.weight.copy_(conv_layer.weight.view_as(linear.weight)) + if linear.bias is not None and conv_layer.bias is not None: + linear.bias.copy_(conv_layer.bias) + + return linear diff --git a/src/mrpro/nn/join.py b/src/mrpro/nn/join.py new file mode 100644 index 000000000..d98aeb7b2 --- /dev/null +++ b/src/mrpro/nn/join.py @@ -0,0 +1,176 @@ +"""Modules for concatenating or adding tensors.""" + +from collections.abc import Sequence +from typing import Literal + +import torch +from torch.nn import Module + +from mrpro.utils.interpolate import interpolate +from mrpro.utils.pad_or_crop import pad_or_crop + + +def _fix_shapes( + xs: Sequence[torch.Tensor], + mode: str, + dim: Sequence[int], +) -> tuple[torch.Tensor, ...]: + """Fix shapes of input tensors by padding or cropping.""" + if mode == 'fail': + return tuple(xs) + + shapes = [[x.shape[d] for d in dim] for x in xs] + if mode == 'crop': # smallest as target + target = tuple(min(s) for s in zip(*shapes, strict=True)) + else: # largest as target + target = tuple(max(s) for s in zip(*shapes, strict=True)) + if mode == 'linear' or mode == 'nearest': + return tuple(interpolate(x, target, dim=dim, mode=mode) for x in xs) # type: ignore[arg-type] + if mode == 'zero' or mode == 'crop': + return tuple(pad_or_crop(x, target, dim=dim, mode='constant', value=0.0) for x in xs) + else: + return tuple(pad_or_crop(x, target, dim=dim, mode=mode) for x in xs) # type: ignore[arg-type] + + +class Concat(Module): + """Concatenate tensors along the channel dimension.""" + + def __init__( + self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'] = 'fail', dim: int = 1 + ) -> None: + """Initialize Concat. + + Parameters + ---------- + mode + How to handle mismatched dimensions: + - 'fail': do not align, raise error if shapes mismatch + - 'crop': center-crop to smallest spatial size + - 'zero': zero-pad to largest spatial size + - 'replicate': pad by edge value replication + - 'circular': circular padding + - 'linear': linear interpolation to largest spatial size + - 'nearest': nearest neighbor interpolation to largest spatial size + dim + Dimension along which to concatenate. + """ + super().__init__() + modes = {'fail', 'crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'} + if mode not in modes: + raise ValueError(f'mode must be one of {modes}') + self.mode = mode + self.dim = dim + + def forward(self, *xs: torch.Tensor) -> torch.Tensor: + """Concatenate input tensors.""" + xs = _fix_shapes(xs, self.mode, dim=[i for i in range(max(x.ndim for x in xs)) if i != self.dim]) + return torch.cat(xs, dim=self.dim) + + def __call__(self, *xs: torch.Tensor) -> torch.Tensor: + """ + Concatenate input tensors. + + Parameters + ---------- + xs + Input tensors + + Returns + ------- + Concatenated tensor + """ + return super().__call__(*xs) + + +class Add(Module): + """Add tensors.""" + + def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'] = 'fail') -> None: + """Initialize Add. + + Parameters + ---------- + mode + How to handle mismatched dimensions: + - 'fail': do not align, raise error if shapes mismatch + - 'crop': center-crop to smallest spatial size + - 'zero': zero-pad to largest spatial size + - 'replicate': pad by edge value replication + - 'circular': circular padding + """ + super().__init__() + modes = {'fail', 'crop', 'zero', 'replicate', 'circular'} + if mode not in modes: + raise ValueError(f'mode must be one of {modes}') + self.mode = mode + + def forward(self, *xs: torch.Tensor) -> torch.Tensor: + """Add input tensors.""" + xs = _fix_shapes(xs, self.mode, dim=range(max(x.ndim for x in xs))) + return sum(xs, start=torch.tensor(0.0)) + + def __call__(self, *xs: torch.Tensor) -> torch.Tensor: + """ + Add input tensors. + + Parameters + ---------- + xs + Input tensors + + Returns + ------- + Summed tensor + """ + return super().__call__(*xs) + + +class Interpolate(Module): + """Linear interpolate between two tensors. + + As suggestions for the Hourglass Transformer [CR]_ + + References + ---------- + .. [CK] Crowson, Katherine, et al. "Scalable high-resolution pixel-space image synthesis with + hourglass diffusion transformers." ICML 2024, https://arxiv.org/abs/2401.11605 + """ + + def __init__(self, mode: Literal['fail', 'crop', 'zero', 'replicate', 'circular'] = 'fail') -> None: + """Initialize learned linear interpolation. + + Parameters + ---------- + mode + How to handle mismatched dimensions: + - 'fail': do not align, raise error if shapes mismatch + - 'crop': center-crop to smallest spatial size + - 'zero': zero-pad to largest spatial size + - 'replicate': pad by edge value replication + - 'circular': circular padding + """ + super().__init__() + modes = {'fail', 'crop', 'zero', 'replicate', 'circular'} + if mode not in modes: + raise ValueError(f'mode must be one of {modes}') + self.mode = mode + self.weight = torch.nn.Parameter(torch.tensor(0.5)) + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """Linear interpolate between two tensors.""" + x1, x2 = _fix_shapes((x1, x2), self.mode, dim=range(max(x.ndim for x in (x1, x2)))) + return x1 * self.weight + x2 * (1 - self.weight) + + def __call__(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """Linear interpolate between two tensors. + + Parameters + ---------- + x1, x2 + Input tensors + + Returns + ------- + Interpolated tensor + """ + return super().__call__(x1, x2) diff --git a/src/mrpro/nn/ndmodules.py b/src/mrpro/nn/ndmodules.py new file mode 100644 index 000000000..3fb2d894c --- /dev/null +++ b/src/mrpro/nn/ndmodules.py @@ -0,0 +1,178 @@ +"""Helper functions to get the correct N-dimensional module.""" + +import torch + + +def convND(n_dim: int) -> type[torch.nn.Conv1d] | type[torch.nn.Conv2d] | type[torch.nn.Conv3d]: # noqa: N802 + """Get the `n_dim`-dimensional convolution class. + + Parameters + ---------- + n_dim + The dimension of the convolution. + + Returns + ------- + The convolution class. + """ + match n_dim: + case 1: + return torch.nn.Conv1d + case 2: + return torch.nn.Conv2d + case 3: + return torch.nn.Conv3d + case _: + raise NotImplementedError(f'ConvND for dim {n_dim} not implemented. Raise an issue if you need this.') + + +def convTransposeND( # noqa: N802 + n_dim: int, +) -> type[torch.nn.ConvTranspose1d] | type[torch.nn.ConvTranspose2d] | type[torch.nn.ConvTranspose3d]: + """Get the `n_dim`-dimensional transposed convolution class. + + Parameters + ---------- + n_dim + The dimension of the transposed convolution. + + Returns + ------- + The transposed convolution class. + """ + match n_dim: + case 1: + return torch.nn.ConvTranspose1d + case 2: + return torch.nn.ConvTranspose2d + case 3: + return torch.nn.ConvTranspose3d + case _: + raise NotImplementedError( + f'ConvTransposeND for dim {n_dim} not implemented. Raise an issue if you need this.' + ) + + +def maxPoolND(n_dim: int) -> type[torch.nn.MaxPool1d] | type[torch.nn.MaxPool2d] | type[torch.nn.MaxPool3d]: # noqa: N802 + """Get the `n_dim`-dimensional max pooling class. + + Parameters + ---------- + n_dim + The dimension of the max pooling. + + Returns + ------- + The max pooling class. + """ + match n_dim: + case 1: + return torch.nn.MaxPool1d + case 2: + return torch.nn.MaxPool2d + case 3: + return torch.nn.MaxPool3d + case _: + raise NotImplementedError(f'MaxPoolNd for dim {n_dim} not implemented. Raise an issue if you need this.') + + +def avgPoolND(n_dim: int) -> type[torch.nn.AvgPool1d] | type[torch.nn.AvgPool2d] | type[torch.nn.AvgPool3d]: # noqa: N802 + """Get the `n_dim`-dimensional average pooling class. + + Parameters + ---------- + n_dim + The dimension of the average pooling. + + Returns + ------- + The average pooling class. + """ + match n_dim: + case 1: + return torch.nn.AvgPool1d + case 2: + return torch.nn.AvgPool2d + case 3: + return torch.nn.AvgPool3d + case _: + raise NotImplementedError(f'AvgPoolNd for dim {n_dim} not implemented. Raise an issue if you need this.') + + +def adaptiveAvgPoolND( # noqa: N802 + n_dim: int, +) -> type[torch.nn.AdaptiveAvgPool1d] | type[torch.nn.AdaptiveAvgPool2d] | type[torch.nn.AdaptiveAvgPool3d]: + """Get the `n_dim`-dimensional adaptive average pooling class. + + Parameters + ---------- + n_dim + The dimension of the adaptive average pooling. + + Returns + ------- + The adaptive average pooling class. + """ + match n_dim: + case 1: + return torch.nn.AdaptiveAvgPool1d + case 2: + return torch.nn.AdaptiveAvgPool2d + case 3: + return torch.nn.AdaptiveAvgPool3d + case _: + raise NotImplementedError( + f'AdaptiveAvgPoolNd for dim {n_dim} not implemented. Raise an issue if you need this.' + ) + + +def instanceNormND( # noqa: N802 + n_dim: int, +) -> type[torch.nn.InstanceNorm1d] | type[torch.nn.InstanceNorm2d] | type[torch.nn.InstanceNorm3d]: + """Get the `n_dim`-dimensional instance normalization class. + + Parameters + ---------- + n_dim + The dimension of the instance normalization. + + Returns + ------- + The instance normalization class. + """ + match n_dim: + case 1: + return torch.nn.InstanceNorm1d + case 2: + return torch.nn.InstanceNorm2d + case 3: + return torch.nn.InstanceNorm3d + case _: + raise NotImplementedError( + f'InstanceNormNd for dim {n_dim} not implemented. Raise an issue if you need this.' + ) + + +def batchNormND( # noqa: N802 + n_dim: int, +) -> type[torch.nn.BatchNorm1d] | type[torch.nn.BatchNorm2d] | type[torch.nn.BatchNorm3d]: + """Get the `n_dim`-dimensional batch normalization class. + + Parameters + ---------- + n_dim + The dimension of the batch normalization. + + Returns + ------- + The batch normalization class. + """ + match n_dim: + case 1: + return torch.nn.BatchNorm1d + case 2: + return torch.nn.BatchNorm2d + case 3: + return torch.nn.BatchNorm3d + case _: + raise NotImplementedError(f'BatchNormNd for dim {n_dim} not implemented. Raise an issue if you need this.') diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index 345883e12..2d4eceb2f 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -15,8 +15,10 @@ from mrpro.utils.TensorAttributeMixin import TensorAttributeMixin from mrpro.utils.interpolate import interpolate, apply_lowres from mrpro.utils.RandomGenerator import RandomGenerator - +from mrpro.utils.to_tuple import to_tuple +from mrpro.utils.ema import EMADict __all__ = [ + "EMADict", "Indexer", "RandomGenerator", "TensorAttributeMixin", @@ -38,6 +40,7 @@ "split_idx", "summarize_object", "summarize_values", + "to_tuple", "typing", "unit_conversion", "unsqueeze_at", diff --git a/src/mrpro/utils/to_tuple.py b/src/mrpro/utils/to_tuple.py new file mode 100644 index 000000000..657d7bf56 --- /dev/null +++ b/src/mrpro/utils/to_tuple.py @@ -0,0 +1,36 @@ +"""Standardize an argument to a fixed-length tuple.""" + +from collections.abc import Sequence +from typing import TypeVar + +T = TypeVar('T') + + +def to_tuple(length: int, arg: Sequence[T] | T) -> tuple[T, ...]: + """Standardize an argument to a fixed-length tuple. + + If the argument is a sequence, it checks if its length matches the + specified dimension. If it's a single value, it replicates it `dim` times. + + Parameters + ---------- + length + The expected length of the sequence. + arg + The argument to check. Can be a single value of type T or a + sequence of T. + + Returns + ------- + A tuple of length `dim` containing elements of type T. + + Raises + ------ + ValueError + If `arg` is a sequence and its length does not match `length`. + """ + if isinstance(arg, Sequence): + if not len(arg) == length: + raise ValueError(f'The arguments must be either single values or have length {length}. Got {arg}.') + return tuple(arg) + return (arg,) * length diff --git a/tests/nn/test_complexaschannel.py b/tests/nn/test_complexaschannel.py new file mode 100644 index 000000000..37889f654 --- /dev/null +++ b/tests/nn/test_complexaschannel.py @@ -0,0 +1,30 @@ +"""Tests for ComplexAsChannel module.""" + +import pytest +from mrpro.nn.ComplexAsChannel import ComplexAsChannel +from mrpro.utils import RandomGenerator +from torch.nn import Linear + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_complexaschannel(device: str) -> None: + """Test ComplexAsChannel output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + input_shape = (1, 32) + x = rng.complex64_tensor(input_shape).to(device).requires_grad_(True) + module = ComplexAsChannel(Linear(input_shape[1] * 2, input_shape[1] * 2)).to(device) + output = module(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + assert output.is_complex(), 'Output is not complex' + output.sum().abs().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert module.module.weight.grad is not None, 'No gradient computed for weight' + assert module.module.bias.grad is not None, 'No gradient computed for bias' diff --git a/tests/nn/test_convert_linear_conv.py b/tests/nn/test_convert_linear_conv.py new file mode 100644 index 000000000..19438b9d9 --- /dev/null +++ b/tests/nn/test_convert_linear_conv.py @@ -0,0 +1,150 @@ +"""Tests for converting between Linear and Conv layers.""" + +from typing import Literal + +import pytest +import torch +from mrpro.nn.convert_linear_conv import conv_to_linear, linear_to_conv +from mrpro.utils import RandomGenerator +from torch.nn import Conv1d, Conv2d, Conv3d, Linear + +DEVICES = pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +SHAPES = pytest.mark.parametrize( + ('dim', 'channels_in', 'channels_out', 'bias'), + [ + (1, 32, 64, True), + (2, 16, 32, True), + (3, 8, 16, True), + (3, 1, 1, False), + ], + ids=['1d', '2d', '3d', '3d_no_bias'], +) + + +@SHAPES +@DEVICES +def test_linear_to_conv(device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test converting Linear to Conv layer.""" + rng = RandomGenerator(seed=42) + linear = Linear(channels_in, channels_out, bias=bias).to(device) + linear.weight.data = rng.rand_like(linear.weight) + if bias: + linear.bias.data = rng.rand_like(linear.bias) + + conv = linear_to_conv(linear, dim) + assert isinstance(conv, (Conv1d, Conv2d, Conv3d)[dim - 1]) + + assert conv.in_channels == channels_in + assert conv.out_channels == channels_out + assert conv.kernel_size == (1,) * dim + assert conv.bias is not None if bias else conv.bias is None + + assert conv.weight.device.type == device + if conv.bias is not None: + assert conv.bias.device.type == device + + +@SHAPES +def test_linear_to_conv_functional(dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test functional equivalence of Linear to Conv conversion.""" + rng = RandomGenerator(seed=42) + linear = Linear(channels_in, channels_out, bias=bias) + linear.weight.data = rng.rand_like(linear.weight) + if bias: + linear.bias.data = rng.rand_like(linear.bias) + + conv = linear_to_conv(linear, dim) + spatial_shape = (4,) * dim + x = rng.randn_tensor((2, channels_in, *spatial_shape), torch.float32) + + y_conv = conv(x) + y_conv = y_conv.moveaxis(1, -1).flatten(0, -2) + + x_reshaped = x.moveaxis(1, -1).flatten(0, -2) + y_linear = linear(x_reshaped) + + torch.testing.assert_close(y_conv, y_linear) + + +@SHAPES +@DEVICES +def test_conv_to_linear(device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test converting Conv layer to Linear.""" + rng = RandomGenerator(seed=42) + conv_class = (Conv1d, Conv2d, Conv3d)[dim - 1] + conv = conv_class(channels_in, channels_out, kernel_size=1, bias=bias).to(device) + conv.weight.data = rng.rand_like(conv.weight) + if conv.bias is not None: + conv.bias.data = rng.rand_like(conv.bias) + + linear = conv_to_linear(conv) + + assert isinstance(linear, Linear) + assert linear.in_features == channels_in + assert linear.out_features == channels_out + assert linear.bias is not None if bias else linear.bias is None + + assert linear.weight.device.type == device + if bias: + assert linear.bias.device.type == device + + +@SHAPES +def test_conv_to_linear_functional(dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool) -> None: + """Test functional equivalence of Conv to Linear conversion.""" + rng = RandomGenerator(seed=42) + conv_class = (Conv1d, Conv2d, Conv3d)[dim - 1] + conv = conv_class(channels_in, channels_out, kernel_size=1, bias=bias) + conv.weight.data = rng.rand_like(conv.weight) + if conv.bias is not None: + conv.bias.data = rng.rand_like(conv.bias) + + linear = conv_to_linear(conv) + spatial_shape = (4,) * dim + + x = rng.randn_tensor((2, channels_in, *spatial_shape), torch.float32) + y_conv = conv(x) + y_conv = y_conv.moveaxis(1, -1).flatten(0, -2) + + x_reshaped = x.moveaxis(1, -1).flatten(0, -2) + y_linear = linear(x_reshaped) + + torch.testing.assert_close(y_conv, y_linear) + + +def test_conv_to_linear_invalid_kernel() -> None: + """Test conv_to_linear with invalid kernel size.""" + conv = Conv2d(32, 64, kernel_size=3, bias=True) + with pytest.raises(ValueError, match='Kernel size must be 1'): + conv_to_linear(conv) + + +@SHAPES +@DEVICES +def test_round_trip_conversion( + device: str, dim: Literal[1, 2, 3], channels_in: int, channels_out: int, bias: bool +) -> None: + """Test round-trip conversion between Linear and Conv layers.""" + rng = RandomGenerator(seed=42) + + linear1 = Linear(channels_in, channels_out, bias=bias).to(device) + linear1.weight.data = rng.rand_like(linear1.weight) + if bias: + linear1.bias.data = rng.rand_like(linear1.bias) + + conv = linear_to_conv(linear1, dim) + linear2 = conv_to_linear(conv) + + assert linear2.in_features == channels_in + assert linear2.out_features == channels_out + assert linear2.bias is not None if bias else linear2.bias is None + + torch.testing.assert_close(linear2.weight, linear1.weight) + if bias: + torch.testing.assert_close(linear2.bias, linear1.bias) diff --git a/tests/nn/test_droppath.py b/tests/nn/test_droppath.py new file mode 100644 index 000000000..b4ac7f5d7 --- /dev/null +++ b/tests/nn/test_droppath.py @@ -0,0 +1,30 @@ +"""Test DropPath.""" + +import pytest +from mrpro.nn.DropPath import DropPath +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_droppath_no_drop(device: str) -> None: + """Test DropPath with zero drop rate (should pass through unchanged).""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).to(device) + droppath = DropPath(0).to(device) + y = droppath(x) + assert (y == x).all() + + +def test_droppath_drop_all() -> None: + """Test DropPath with full drop rate (should output zeros).""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)) + droppath = DropPath(1.0) + y = droppath(x) + assert (y == 0).all() diff --git a/tests/nn/test_film.py b/tests/nn/test_film.py new file mode 100644 index 000000000..0e564c675 --- /dev/null +++ b/tests/nn/test_film.py @@ -0,0 +1,42 @@ +"""Tests for FiLM module.""" + +from collections.abc import Sequence + +import pytest +from mrpro.nn.FiLM import FiLM +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels', 'n_channels_cond', 'input_shape', 'cond_shape'), + [ + (64, 32, (1, 64, 32, 32), (1, 32)), + (32, 16, (2, 32, 16, 16), (2, 16)), + ], +) +def test_film( + n_channels: int, n_channels_cond: int, input_shape: Sequence[int], cond_shape: Sequence[int], device: str +) -> None: + """Test FiLM output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) + film = FiLM(channels=n_channels, cond_dim=n_channels_cond).to(device) + output = film(x, cond=cond) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert cond.grad is not None, 'No gradient computed for conditioning' + assert not output.isnan().any(), 'NaN values in output' + assert not cond.isnan().any(), 'NaN values in conditioning' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' + assert film.project is not None, 'Linear layer is not initialized' + assert next(film.project.parameters()).grad is not None, 'No gradient computed for Linear layer' diff --git a/tests/nn/test_fourierfeatures.py b/tests/nn/test_fourierfeatures.py new file mode 100644 index 000000000..9452a369f --- /dev/null +++ b/tests/nn/test_fourierfeatures.py @@ -0,0 +1,24 @@ +"""Test for random fourier features""" + +import pytest +from mrpro.nn import FourierFeatures +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_fourierfeatures(device: str) -> None: + """Test FourierFeatures.""" + n_features_in = 1 + n_features_out = 16 + std = 1.0 + rng = RandomGenerator(444) + x = rng.float32_tensor((1, n_features_in)).to(device) + ff = FourierFeatures(n_features_in, n_features_out, std).to(device) + y = ff(x) + assert y.shape == (1, n_features_out) diff --git a/tests/nn/test_geglu.py b/tests/nn/test_geglu.py new file mode 100644 index 000000000..061837e51 --- /dev/null +++ b/tests/nn/test_geglu.py @@ -0,0 +1,38 @@ +"""Test GEGLU.""" + +import pytest +import torch +from mrpro.nn.GEGLU import GEGLU +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +def test_geglu(device: str) -> None: + """Test GEGLU output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).to(device).requires_grad_(True) + gelu = GEGLU(3, 4).to(device) + y = gelu(x) + assert y.shape == (1, 4, 4, 5) + + y.sum().backward() + assert x.grad is not None + assert gelu.proj.weight.grad is not None + + +def test_geglu_features_last() -> None: + """Test GEGLU with features_last=True vs features_last=False.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + gelu_last = GEGLU(3, 4, features_last=True) + gelu = GEGLU(3, 4, features_last=False) + gelu.proj = gelu_last.proj # need to set the same weights + y_last = gelu_last(x.moveaxis(1, -1)) + y = gelu(x) + torch.testing.assert_close(y, y_last.moveaxis(-1, 1)) diff --git a/tests/nn/test_groupnorm.py b/tests/nn/test_groupnorm.py new file mode 100644 index 000000000..044de17f7 --- /dev/null +++ b/tests/nn/test_groupnorm.py @@ -0,0 +1,45 @@ +"""Tests for GroupNorm module.""" + +from collections.abc import Sequence + +import pytest +from mrpro.nn import GroupNorm +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels', 'n_groups', 'input_shape', 'affine'), + [ + (32, None, (1, 32, 32, 32), True), + (64, 8, (2, 64, 16, 16, 16), False), + ], +) +def test_groupnorm( + n_channels: int, + n_groups: int | None, + input_shape: Sequence[int], + device: str, + affine: bool, +) -> None: + """Test GroupNorm output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + norm = GroupNorm(n_channels=n_channels, n_groups=n_groups, affine=affine).to(device) + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + if affine: + assert norm.weight is not None, 'Weight should not be None when affine is True' + assert norm.weight.grad is not None, 'No gradient computed for weight' + assert norm.bias is not None, 'Bias should not be None when affine is True' + assert norm.bias.grad is not None, 'No gradient computed for bias' diff --git a/tests/nn/test_join.py b/tests/nn/test_join.py new file mode 100644 index 000000000..f86647ac4 --- /dev/null +++ b/tests/nn/test_join.py @@ -0,0 +1,160 @@ +"""Tests for join modules.""" + +from typing import Literal + +import pytest +import torch +from mrpro.nn.join import Add, Concat +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('mode', 'input_shapes', 'expected_shape'), + [ + ('crop', [(1, 3, 32, 32), (1, 5, 30, 30)], (1, 8, 30, 30)), + ('zero', [(1, 3, 32, 32), (1, 5, 34, 34)], (1, 8, 34, 34)), + ('linear', [(1, 3, 32, 32), (1, 5, 34, 34)], (1, 8, 34, 34)), + ('nearest', [(1, 3, 32, 32), (1, 5, 34, 34)], (1, 8, 34, 34)), + ], +) +def test_concat_basic( + mode: Literal['crop', 'zero', 'replicate', 'circular', 'linear', 'nearest'], + input_shapes: list[tuple[int, ...]], + expected_shape: tuple[int, ...], + device: str, +) -> None: + """Test Concat basic functionality.""" + rng = RandomGenerator(seed=42) + xs = [rng.float32_tensor(shape).to(device).requires_grad_(True) for shape in input_shapes] + concat = Concat(mode=mode).to(device) + + output = concat(*xs) + assert output.shape == expected_shape + assert not output.isnan().any(), 'NaN values in output' + + output.sum().backward() + for x in xs: + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('mode', 'input_shapes', 'expected_shape'), + [ + ('crop', [(1, 3, 32, 32), (1, 3, 30, 30)], (1, 3, 30, 30)), + ('zero', [(1, 3, 32, 32), (1, 3, 34, 34)], (1, 3, 34, 34)), + ('replicate', [(1, 1, 1, 2), (1, 1, 1, 3)], (1, 1, 1, 3)), + ('circular', [(1, 1, 1, 2), (1, 1, 1, 4)], (1, 1, 1, 4)), + ], +) +def test_add_basic( + mode: Literal['crop', 'zero', 'replicate', 'circular'], + input_shapes: list[tuple[int, ...]], + expected_shape: tuple[int, ...], + device: str, +) -> None: + """Test Add basic functionality.""" + rng = RandomGenerator(seed=42) + xs = [rng.float32_tensor(shape).to(device).requires_grad_(True) for shape in input_shapes] + add = Add(mode=mode).to(device) + + output = add(*xs) + assert output.shape == expected_shape + assert not output.isnan().any(), 'NaN values in output' + + output.sum().backward() + for x in xs: + assert x.grad is not None, 'No gradient computed for input' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + +@pytest.mark.parametrize( + ('dim', 'input_shapes', 'expected_shape'), + [ + (0, [(1, 3, 32, 32), (1, 3, 32, 32)], (2, 3, 32, 32)), + (1, [(1, 3, 32, 32), (1, 5, 32, 32)], (1, 8, 32, 32)), + (2, [(1, 3, 32, 32), (1, 3, 32, 32)], (1, 3, 64, 32)), + ], +) +def test_concat_dimensions(dim: int, input_shapes: list[tuple[int, ...]], expected_shape: tuple[int, ...]) -> None: + """Test Concat with different concatenation dimensions.""" + rng = RandomGenerator(seed=42) + xs = [rng.float32_tensor(shape).requires_grad_(True) for shape in input_shapes] + concat = Concat(mode='fail', dim=dim) + output = concat(*xs) + assert output.shape == expected_shape + + +def test_concat_values() -> None: + """Test that Concat preserves input values correctly.""" + x1 = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]).requires_grad_(True) + x2 = torch.tensor([[[[5.0, 6.0], [7.0, 8.0]]]]).requires_grad_(True) + + concat = Concat(mode='fail') + output = concat(x1, x2) + + expected = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]) + torch.testing.assert_close(output, expected) + + +def test_add_values() -> None: + """Test that Add correctly sums input values.""" + x1 = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]).requires_grad_(True) + x2 = torch.tensor([[[[5.0, 6.0], [7.0, 8.0]]]]).requires_grad_(True) + + add = Add(mode='fail') + output = add(x1, x2) + + expected = torch.tensor([[[[6.0, 8.0], [10.0, 12.0]]]]) + torch.testing.assert_close(output, expected) + + +def test_concat_mode_fail() -> None: + """Test Concat with mode='fail'.""" + rng = RandomGenerator(seed=42) + + x1 = rng.float32_tensor((1, 3, 32, 32)) + x2 = rng.float32_tensor((1, 5, 32, 32)) + concat = Concat(mode='fail') + output = concat(x1, x2) + assert output.shape == (1, 8, 32, 32) + + x3 = rng.float32_tensor((1, 3, 30, 30)) + with pytest.raises(RuntimeError): + concat(x1, x3) + + +def test_add_mode_fail() -> None: + """Test Add with mode='fail'.""" + rng = RandomGenerator(seed=42) + + x1 = rng.float32_tensor((1, 3, 32, 32)) + x2 = rng.float32_tensor((1, 3, 32, 32)) + add = Add(mode='fail') + output = add(x1, x2) + assert output.shape == (1, 3, 32, 32) + + x3 = rng.float32_tensor((1, 3, 30, 30)) + with pytest.raises(RuntimeError): + add(x1, x3) + + +@pytest.mark.parametrize('module_class', [Concat, Add]) +def test_invalid_mode(module_class: type) -> None: + """Test modules with invalid mode.""" + with pytest.raises(ValueError, match='mode must be one of'): + module_class(mode='invalid_mode') diff --git a/tests/nn/test_layernorm.py b/tests/nn/test_layernorm.py new file mode 100644 index 000000000..ebc11ccb8 --- /dev/null +++ b/tests/nn/test_layernorm.py @@ -0,0 +1,187 @@ +"""Tests for LayerNorm module.""" + +from collections.abc import Sequence + +import pytest +import torch +from mrpro.nn.LayerNorm import LayerNorm +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels', 'features_last', 'input_shape'), + [ + (32, False, (1, 32, 32, 32)), + (64, True, (2, 16, 16, 64)), + (None, False, (1, 32, 32, 32)), + (None, True, (2, 16, 16, 64)), + ], +) +def test_layernorm_basic( + n_channels: int | None, + features_last: bool, + input_shape: Sequence[int], + device: str, +) -> None: + """Test LayerNorm basic functionality.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + norm = LayerNorm(n_channels=n_channels, features_last=features_last).to(device) + output = norm(x) + + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + if n_channels is not None: + assert norm.weight is not None, 'Weight should not be None when n_channels is provided' + assert norm.bias is not None, 'Bias should not be None when n_channels is provided' + assert norm.weight.grad is not None, 'No gradient computed for weight' + assert norm.bias.grad is not None, 'No gradient computed for bias' + + +@pytest.mark.parametrize( + ('n_channels', 'cond_dim', 'input_shape', 'cond_shape'), + [ + (32, 16, (1, 32, 32, 32), (1, 16)), + (64, 32, (2, 64, 16, 16), (2, 32)), + ], +) +def test_layernorm_with_conditioning( + n_channels: int, + cond_dim: int, + input_shape: Sequence[int], + cond_shape: Sequence[int], +) -> None: + """Test LayerNorm with conditioning.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).requires_grad_(True) + norm = LayerNorm(n_channels=n_channels, cond_dim=cond_dim) + + output = norm(x, cond=cond) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert cond.grad is not None, 'No gradient computed for conditioning' + assert norm.cond_proj is not None, 'cond_proj should not be None when cond_dim > 0' + assert norm.cond_proj.weight.grad is not None, 'No gradient computed for cond_proj' + + +def test_layernorm_features_last() -> None: + """Test LayerNorm with features_last=True vs features_last=False.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + + norm_standard = LayerNorm(n_channels=3, features_last=False) + y_standard = norm_standard(x) + + norm_last = LayerNorm(n_channels=3, features_last=True) + y_last = norm_last(x.moveaxis(1, -1)) + + torch.testing.assert_close(y_standard, y_last.moveaxis(-1, 1)) + + +def test_layernorm_no_channels() -> None: + """Test LayerNorm without channels (pure normalization).""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 32, 32)).requires_grad_(True) + norm = LayerNorm(n_channels=None) + + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + # Check that normalization is applied (mean close to 0, std close to 1) + dims = tuple(range(1, x.ndim)) + mean = output.mean(dim=dims) + std = output.std(dim=dims) + + assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-6), 'Mean not close to 0' + assert torch.allclose(std, torch.ones_like(std), atol=1e-5), 'Std not close to 1' + + +def test_layernorm_conditioning_without_channels() -> None: + """Test LayerNorm with conditioning but no channels (should raise error).""" + with pytest.raises(ValueError, match='channels must be provided if cond_dim > 0'): + LayerNorm(n_channels=None, cond_dim=16) + + +def test_layernorm_invalid_cond_dim() -> None: + """Test LayerNorm with invalid cond_dim.""" + with pytest.raises(RuntimeError, match='Trying to create tensor with negative dimension'): + LayerNorm(n_channels=32, cond_dim=-1) + + +def test_layernorm_3d_input() -> None: + """Test LayerNorm with 3D input.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((2, 64, 128)).requires_grad_(True) + norm = LayerNorm(n_channels=64) + + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + + +def test_layernorm_5d_input() -> None: + """Test LayerNorm with 5D input.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 16, 16, 16)).requires_grad_(True) + norm = LayerNorm(n_channels=32) + + output = norm(x) + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + + +def test_layernorm_conditioning_features_last() -> None: + """Test LayerNorm with conditioning and features_last=True.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + cond = rng.float32_tensor((1, 8)).requires_grad_(True) + + norm = LayerNorm(n_channels=3, features_last=True, cond_dim=8) + output = norm(x.moveaxis(1, -1), cond=cond) + + assert output.shape == x.moveaxis(1, -1).shape, f'Output shape {output.shape} != expected shape' + + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert cond.grad is not None, 'No gradient computed for conditioning' + + +def test_layernorm_gradient_flow() -> None: + """Test that gradients flow properly through LayerNorm.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 32, 32, 32)).requires_grad_(True) + norm = LayerNorm(n_channels=32) + + output = norm(x) + loss = output.sum() + loss.backward() + + # Check that gradients are computed for all learnable parameters + assert x.grad is not None, 'Input gradients not computed' + assert norm.weight is not None, 'Weight should not be None when n_channels is provided' + assert norm.bias is not None, 'Bias should not be None when n_channels is provided' + assert norm.weight.grad is not None, 'Weight gradients not computed' + assert norm.bias.grad is not None, 'Bias gradients not computed' + + # Check that gradients are finite + assert torch.isfinite(x.grad).all(), 'Input gradients contain non-finite values' + assert torch.isfinite(norm.weight.grad).all(), 'Weight gradients contain non-finite values' + assert torch.isfinite(norm.bias.grad).all(), 'Bias gradients contain non-finite values' diff --git a/tests/nn/test_ndmodules.py b/tests/nn/test_ndmodules.py new file mode 100644 index 000000000..a0a77a98d --- /dev/null +++ b/tests/nn/test_ndmodules.py @@ -0,0 +1,76 @@ +"""Tests for the ndmodules module.""" + +import pytest +import torch +from mrpro.nn.ndmodules import ( + adaptiveAvgPoolND, + avgPoolND, + batchNormND, + convND, + convTransposeND, + instanceNormND, + maxPoolND, +) + + +def test_convnd() -> None: + """Test ConvND.""" + assert convND(1) is torch.nn.Conv1d + assert convND(2) is torch.nn.Conv2d + assert convND(3) is torch.nn.Conv3d + with pytest.raises(NotImplementedError): + convND(4) + + +def test_convtransposend() -> None: + """Test ConvTransposeND.""" + assert convTransposeND(1) is torch.nn.ConvTranspose1d + assert convTransposeND(2) is torch.nn.ConvTranspose2d + assert convTransposeND(3) is torch.nn.ConvTranspose3d + with pytest.raises(NotImplementedError): + convTransposeND(4) + + +def test_maxpoolnd() -> None: + """Test MaxPoolND.""" + assert maxPoolND(1) is torch.nn.MaxPool1d + assert maxPoolND(2) is torch.nn.MaxPool2d + assert maxPoolND(3) is torch.nn.MaxPool3d + with pytest.raises(NotImplementedError): + maxPoolND(4) + + +def test_avgpoolnd() -> None: + """Test AvgPoolND.""" + assert avgPoolND(1) is torch.nn.AvgPool1d + assert avgPoolND(2) is torch.nn.AvgPool2d + assert avgPoolND(3) is torch.nn.AvgPool3d + with pytest.raises(NotImplementedError): + avgPoolND(4) + + +def test_adaptiveavgpoolnd() -> None: + """Test AdaptiveAvgPoolND.""" + assert adaptiveAvgPoolND(1) is torch.nn.AdaptiveAvgPool1d + assert adaptiveAvgPoolND(2) is torch.nn.AdaptiveAvgPool2d + assert adaptiveAvgPoolND(3) is torch.nn.AdaptiveAvgPool3d + with pytest.raises(NotImplementedError): + adaptiveAvgPoolND(4) + + +def test_instancenormnd() -> None: + """Test InstanceNormND.""" + assert instanceNormND(1) is torch.nn.InstanceNorm1d + assert instanceNormND(2) is torch.nn.InstanceNorm2d + assert instanceNormND(3) is torch.nn.InstanceNorm3d + with pytest.raises(NotImplementedError): + instanceNormND(4) + + +def test_batchnormnd() -> None: + """Test BatchNormND.""" + assert batchNormND(1) is torch.nn.BatchNorm1d + assert batchNormND(2) is torch.nn.BatchNorm2d + assert batchNormND(3) is torch.nn.BatchNorm3d + with pytest.raises(NotImplementedError): + batchNormND(4) diff --git a/tests/nn/test_pixelshuffle.py b/tests/nn/test_pixelshuffle.py new file mode 100644 index 000000000..8d5917a83 --- /dev/null +++ b/tests/nn/test_pixelshuffle.py @@ -0,0 +1,92 @@ +"""Test PixelShuffle and PixelUnshuffle.""" + +from typing import cast + +import torch +from mrpro.nn.PixelShuffle import PixelShuffle, PixelShuffleUpsample, PixelUnshuffle, PixelUnshuffleDownsample +from mrpro.utils import RandomGenerator + + +def test_pixel_shuffle_2d() -> None: + """Test PixelUnshuffle's fast path for 2D images.""" + x = torch.arange(3 * 4 * 8).reshape(1, 3, 4, 8) + pixel_unshuffle = PixelUnshuffle(2) + y = pixel_unshuffle(x) + assert y.shape == (1, 3 * 4, 4 // 2, 8 // 2) + + pixel_shuffle = PixelShuffle(2) + z = pixel_shuffle(y) + assert z.shape == (1, 3, 4, 8) + assert (x == z).all() + + +def test_pixel_unshuffle_4d() -> None: + """Test PixelUnshuffle's general case.""" + x = torch.arange(3 * 4 * 8 * 10 * 12).reshape(1, 3, 4, 8, 10, 12) + pixel_unshuffle = PixelUnshuffle(2) + y = pixel_unshuffle(x) + assert y.shape == (1, 3 * 16, 4 // 2, 8 // 2, 10 // 2, 12 // 2) + + pixel_shuffle = PixelShuffle(2) + z = pixel_shuffle(y) + assert z.shape == (1, 3, 4, 8, 10, 12) + assert (x == z).all() + + +def test_pixelunshuffle_features_last() -> None: + """Test PixelUnshuffle with features_last.""" + x = torch.arange(3 * 4 * 8 * 10 * 12).reshape(1, 3, 4, 8, 10, 12) + pixel_unshuffle_last = PixelUnshuffle(2, features_last=True) + pixel_unshuffle = PixelUnshuffle(2, features_last=False) + y_last = pixel_unshuffle_last(x.moveaxis(1, -1)).moveaxis(-1, 1) + y_normal = pixel_unshuffle(x) + assert (y_last == y_normal).all() + + +def test_pixelshuffle_features_last() -> None: + """Test PixelShuffle with features_last.""" + x = torch.arange(3 * 4 * 8 * 10 * 12).reshape(1, -1, 2, 4, 5, 6) + pixel_shuffle_last = PixelShuffle(2, features_last=True) + pixel_shuffle = PixelShuffle(2, features_last=False) + y_last = pixel_shuffle_last(x.moveaxis(1, -1)).moveaxis(-1, 1) + y_normal = pixel_shuffle(x) + assert (y_last == y_normal).all() + + +def test_unpixelshuffledownsample_residual() -> None: + """Test PixelUnshuffleDownsample with residual.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 2, 9, 12, 15)) + downsample = PixelUnshuffleDownsample(3, 2, 27, downscale_factor=3, residual=True) + y = downsample(x) + assert y.shape == (1, 27, 3, 4, 5) + + +def test_pixelshuffleupsample_residual() -> None: + """Test PixelShuffleUpsample with residual.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 2, 3, 4, 5)) + upsample = PixelShuffleUpsample(3, 2, 1, upscale_factor=3, residual=True) + y = upsample(x) + assert y.shape == (1, 1, 9, 12, 15) + + +def test_pixelshuffleupsample_pixelunshuffledownsample() -> None: + """Test if PixelUnshuffleDownsample is the inverse of PixelShuffleUpsample.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3**3, 3, 4, 5)) + # Only without residual, the upsample and downsample are inverses. + downsample = PixelUnshuffleDownsample(3, 1, 3**3, downscale_factor=3, residual=False) + upsample = PixelShuffleUpsample(3, 3**3, 1, upscale_factor=3, residual=False) + # Only if the convs are Identity, the upsample and downsample are inverses. + torch.nn.init.dirac_(cast(torch.Tensor, downsample.projection.weight)) + torch.nn.init.dirac_(cast(torch.Tensor, upsample.projection.weight)) + downsample_bias = cast(torch.Tensor | None, downsample.projection.bias) + upsample_bias = cast(torch.Tensor | None, upsample.projection.bias) + if downsample_bias is not None: + torch.nn.init.zeros_(downsample_bias) + if upsample_bias is not None: + torch.nn.init.zeros_(upsample_bias) + y = downsample(upsample(x)) + assert y.shape == (1, 3**3, 3, 4, 5) + torch.testing.assert_close(y, x, msg='Upsample and downsample are not inverses.') diff --git a/tests/nn/test_rmsnorm.py b/tests/nn/test_rmsnorm.py new file mode 100644 index 000000000..c8ddc0b69 --- /dev/null +++ b/tests/nn/test_rmsnorm.py @@ -0,0 +1,58 @@ +"""Tests for RMSNorm module.""" + +from collections.abc import Sequence + +import pytest +import torch +from mrpro.nn import RMSNorm +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('n_channels', 'features_last', 'input_shape'), + [ + (32, False, (1, 32, 32, 32)), + (64, True, (2, 16, 16, 64)), + (None, False, (1, 32, 32, 32)), + (None, True, (2, 16, 16, 64)), + ], +) +def test_rmsnorm_basic(n_channels: int | None, features_last: bool, input_shape: Sequence[int], device: str) -> None: + """Test RMSNorm basic functionality.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + norm = RMSNorm(n_channels=n_channels, features_last=features_last).to(device) + output = norm(x) + + assert output.shape == x.shape, f'Output shape {output.shape} != input shape {x.shape}' + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + + if n_channels is not None: + assert norm.weight is not None, 'Weight should not be None when n_channels is provided' + assert norm.bias is not None, 'Bias should not be None when n_channels is provided' + assert norm.weight.grad is not None, 'No gradient computed for weight' + assert norm.bias.grad is not None, 'No gradient computed for bias' + + +def test_rmsnorm_features_last() -> None: + """Test RMSNorm with features_last=True vs features_last=False.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor((1, 3, 4, 5)).requires_grad_(True) + + norm_standard = RMSNorm(n_channels=3, features_last=False) + y_standard = norm_standard(x) + + norm_last = RMSNorm(n_channels=3, features_last=True) + y_last = norm_last(x.moveaxis(1, -1)) + + torch.testing.assert_close(y_standard, y_last.moveaxis(-1, 1)) diff --git a/tests/nn/test_sequential.py b/tests/nn/test_sequential.py new file mode 100644 index 000000000..bdf81bf8d --- /dev/null +++ b/tests/nn/test_sequential.py @@ -0,0 +1,50 @@ +"""Tests for Sequential module.""" + +from collections.abc import Sequence + +import pytest +from mrpro.nn import FiLM, Sequential +from mrpro.operators import FastFourierOp, MagnitudeOp +from mrpro.utils import RandomGenerator +from torch.nn import Linear + + +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('input_shape', 'cond_dim'), + [ + ((1, 32), (1, 16)), + ((2, 32), None), + ], +) +def test_sequential( + input_shape: Sequence[int], + cond_dim: Sequence[int] | None, + device: str, +) -> None: + """Test Sequential output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_dim).to(device).requires_grad_(True) if cond_dim else None + seq = Sequential( + Linear(input_shape[1], 64), + FastFourierOp(dim=(-1,)), + FiLM(channels=64, cond_dim=16), + MagnitudeOp(), + ).to(device) + output = seq(x, cond=cond) + assert output.shape == (input_shape[0], 64) + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + if cond is not None: + assert cond.grad is not None, 'No gradient computed for cond' + assert not cond.grad.isnan().any(), 'NaN values in cond gradients' + assert seq[0].weight.grad is not None, 'No gradient computed for Linear' From add38112b782b0aabcaf963c03e4ba9bd982a52e Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 11 Feb 2026 10:19:19 +0100 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- src/mrpro/nn/GroupNorm.py | 2 +- src/mrpro/nn/PermutedBlock.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrpro/nn/GroupNorm.py b/src/mrpro/nn/GroupNorm.py index e0bbb0a4b..a660e3beb 100644 --- a/src/mrpro/nn/GroupNorm.py +++ b/src/mrpro/nn/GroupNorm.py @@ -65,7 +65,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply GroupNorm.""" if self.features_last: x = x.moveaxis(-1, 1) - result = super().forward(x.float()).type(x.dtype) + result = super().__call__(x.float()).type(x.dtype) if self.features_last: result = result.moveaxis(1, -1) return result diff --git a/src/mrpro/nn/PermutedBlock.py b/src/mrpro/nn/PermutedBlock.py index 935d114dc..f01d761ac 100644 --- a/src/mrpro/nn/PermutedBlock.py +++ b/src/mrpro/nn/PermutedBlock.py @@ -47,7 +47,7 @@ def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torc ------- Output tensor. """ - return self.forward(x, cond=cond) + return super().__call__(x, cond=cond) def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: """Apply the module along the selected dimensions.""" From 2b0bab3d9397a4c23135d9a2c7f7e0d7ce47c217 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 11 Feb 2026 13:09:49 +0100 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- src/mrpro/nn/GroupNorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/nn/GroupNorm.py b/src/mrpro/nn/GroupNorm.py index a660e3beb..e0bbb0a4b 100644 --- a/src/mrpro/nn/GroupNorm.py +++ b/src/mrpro/nn/GroupNorm.py @@ -65,7 +65,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply GroupNorm.""" if self.features_last: x = x.moveaxis(-1, 1) - result = super().__call__(x.float()).type(x.dtype) + result = super().forward(x.float()).type(x.dtype) if self.features_last: result = result.moveaxis(1, -1) return result