From b7ef62d784dc47dd2656c972bcfb68c41c48eac7 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sat, 7 Feb 2026 23:25:40 +0100 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- src/mrpro/nn/ResBlock.py | 70 ++++++ src/mrpro/nn/SeparableResBlock.py | 89 +++++++ src/mrpro/nn/__init__.py | 6 + src/mrpro/nn/nets/BasicCNN.py | 105 +++++++++ src/mrpro/nn/nets/UNet.py | 364 +++++++++++++++++++++++++++++ src/mrpro/nn/nets/__init__.py | 8 + tests/nn/nets/test_cnn.py | 58 +++++ tests/nn/nets/test_unet.py | 116 +++++++++ tests/nn/test_resblock.py | 56 +++++ tests/nn/test_separableresblock.py | 58 +++++ 10 files changed, 930 insertions(+) create mode 100644 src/mrpro/nn/ResBlock.py create mode 100644 src/mrpro/nn/SeparableResBlock.py create mode 100644 src/mrpro/nn/nets/BasicCNN.py create mode 100644 src/mrpro/nn/nets/UNet.py create mode 100644 src/mrpro/nn/nets/__init__.py create mode 100644 tests/nn/nets/test_cnn.py create mode 100644 tests/nn/nets/test_unet.py create mode 100644 tests/nn/test_resblock.py create mode 100644 tests/nn/test_separableresblock.py diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py new file mode 100644 index 000000000..32870979f --- /dev/null +++ b/src/mrpro/nn/ResBlock.py @@ -0,0 +1,70 @@ +"""Residual convolution block with two convolutions.""" + +import torch +from torch.nn import Identity, Module, SiLU + +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import convND +from mrpro.nn.Sequential import Sequential + + +class ResBlock(CondMixin, Module): + """Residual convolution block with two convolutions.""" + + def __init__(self, n_dim: int, n_channels_in: int, n_channels_out: int, cond_dim: int) -> None: + """Initialize the ResBlock. + + Parameters + ---------- + n_dim + The number of dimensions, i.e. 1, 2 or 3. + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. + cond_dim + The number of features in the conditioning tensor used in a FiLM. + If set to 0 no FiLM is used. + + """ + super().__init__() + self.rezero = torch.nn.Parameter(torch.tensor(0.1)) + self.block = Sequential( + GroupNorm(n_channels_in), + SiLU(), + convND(n_dim)(n_channels_in, n_channels_out, kernel_size=3, padding=1), + GroupNorm(n_channels_out), + SiLU(), + convND(n_dim)(n_channels_out, n_channels_out, kernel_size=3, padding=1), + ) + if cond_dim > 0: + self.block.insert(-3, FiLM(n_channels_out, cond_dim)) + + if n_channels_out == n_channels_in: + self.skip_connection: Module = Identity() + else: + self.skip_connection = convND(n_dim)(n_channels_in, n_channels_out, kernel_size=1) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the ResBlock. + + Parameters + ---------- + x + The input tensor. + cond + A conditioning tensor to be used for FiLM. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the ResBlock.""" + h = self.block(x, cond=cond) + x = self.skip_connection(x) + self.rezero * h + return x diff --git a/src/mrpro/nn/SeparableResBlock.py b/src/mrpro/nn/SeparableResBlock.py new file mode 100644 index 000000000..a12fda16f --- /dev/null +++ b/src/mrpro/nn/SeparableResBlock.py @@ -0,0 +1,89 @@ +"""Residual block with separable convolutions.""" + +from collections.abc import Sequence + +import torch +from torch.nn import Module, SiLU + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import convND +from mrpro.nn.PermutedBlock import PermutedBlock +from mrpro.nn.Sequential import Sequential + + +class SeparableResBlock(Module): + """Residual block with separable convolutions.""" + + def __init__( + self, + dim_groups: Sequence[Sequence[int]], + n_channels_in: int, + n_channels_out: int, + cond_dim: int, + ) -> None: + """Initialize the SeparableResBlock. + + Applies convolutions as separable convolutions with SilU activation and group normalization. + For example, if ``dim_groups = ((-1,-2), (-3))`` then one 2D convolution is applied to the last two dimensions, + and one 1D convolution is applied to the last dimension. + The order within the block is Norm->Activation->Conv. + The whole sequence for all dimension groups is performed twice, with optional FiLM conditioning in between. + So for two `dim_groups`, a total of 4 convolutions are applied. + + Parameters + ---------- + dim_groups + Sequence of dimension groups to use in the convolutions. + n_channels_in + Number of input channels. + n_channels_out + Number of output channels. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + """ + super().__init__() + self.rezero = torch.nn.Parameter(torch.tensor(0.1)) + + def block(dims: Sequence[int], channels_in: int) -> Module: + return Sequential( + GroupNorm(channels_in), + SiLU(), + PermutedBlock(dims, convND(len(dims))(channels_in, n_channels_out, 3, padding=1)), + ) + + blocks = Sequential(*(block(d, n_channels_in if i == 0 else n_channels_out) for i, d in enumerate(dim_groups))) + if cond_dim > 0: + blocks.append(FiLM(n_channels_out, cond_dim)) + blocks.extend(block(d, n_channels_out) for d in dim_groups) + self.block = blocks + self.skip_connection = None + if n_channels_in != n_channels_out: + self.skip_connection = torch.nn.Linear(n_channels_in, n_channels_out) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. + + Returns + ------- + Output tensor with the same number and order of dimensions as the input. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply the SeparableResBlock.""" + h = self.block(x, cond=cond) + if self.skip_connection is None: + skip = x + else: + skip = torch.moveaxis(x, 1, -1) + skip = self.skip_connection(skip) + skip = torch.moveaxis(skip, -1, 1) + return skip + self.rezero * h diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index ffb98843d..94f28625e 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -12,10 +12,13 @@ from mrpro.nn.LayerNorm import LayerNorm from mrpro.nn.PermutedBlock import PermutedBlock from mrpro.nn.RMSNorm import RMSNorm +from mrpro.nn.ResBlock import ResBlock from mrpro.nn.Residual import Residual +from mrpro.nn.SeparableResBlock import SeparableResBlock from mrpro.nn.Sequential import Sequential from mrpro.nn import attention from mrpro.nn import data_consistency +from mrpro.nn import nets from mrpro.nn.ndmodules import ( adaptiveAvgPoolND, avgPoolND, @@ -39,7 +42,9 @@ 'LayerNorm', 'PermutedBlock', 'RMSNorm', + 'ResBlock', 'Residual', + 'SeparableResBlock', 'Sequential', 'adaptiveAvgPoolND', 'attention', @@ -50,4 +55,5 @@ 'data_consistency', 'instanceNormND', 'maxPoolND', + 'nets', ] diff --git a/src/mrpro/nn/nets/BasicCNN.py b/src/mrpro/nn/nets/BasicCNN.py new file mode 100644 index 000000000..5bf911eef --- /dev/null +++ b/src/mrpro/nn/nets/BasicCNN.py @@ -0,0 +1,105 @@ +"""Basic CNN.""" + +from collections.abc import Sequence +from itertools import pairwise +from typing import Literal + +import torch +from torch.nn import LeakyReLU, ReLU, SiLU + +from mrpro.nn.FiLM import FiLM +from mrpro.nn.GroupNorm import GroupNorm +from mrpro.nn.ndmodules import batchNormND, convND +from mrpro.nn.Sequential import Sequential + + +class BasicCNN(Sequential): + """Basic CNN. + + A series of convolutions (window 3, stride 1, padding 1), normalization and activation. + Allows to use FiLM conditioning. + Order is Conv -> Norm (optional) -> FiLM (optional) -> Activation. + + If you need more flexibility, use `~mrpro.nn.Sequential` directly. + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + norm: Literal['batch', 'group', 'instance', 'none', 'layer'] = 'none', + activation: Literal['relu', 'silu', 'leaky_relu'] = 'relu', + n_features: Sequence[int] = (64, 64, 64), + cond_dim: int = 0, + ): + """Initialize a basic CNN. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of input channels. + n_channels_out + The number of output channels. + norm + The type of normalization to use. If 'batch', use batch normalization. If 'group', use group normalization, + if 'instance', use instance normalization, and if `layer`, use layer normalization. + If 'none', use no normalization. + activation + The type of activation to use. If 'relu', use ReLU. If 'silu', use SiLU. If 'leaky_relu', use LeakyReLU. + n_features + The number of features in the hidden layers. The length of this sequence determines the number of hidden + layers. The total number of convolutions is `len(n_features) + 1`. + cond_dim + The dimension of the condition tensor. If 0, no FiLM conditioning is applied. + Otherwise, between convolutions, after normalization, FiLM conditioning is applied. + """ + super().__init__() + use_film = cond_dim > 0 + + self.append(convND(n_dim)(n_channels_in, n_features[0], kernel_size=3, padding='same')) + + for c_in, c_out in pairwise((*n_features, n_channels_out)): + if norm.lower() == 'batch': + self.append(batchNormND(n_dim)(c_in, affine=not use_film)) + elif norm.lower() == 'group': + self.append(GroupNorm(c_in, affine=not use_film)) + elif norm.lower() == 'instance': + self.append(GroupNorm(c_in, n_groups=c_in, affine=not use_film)) # is instance norm + elif norm.lower() == 'layer': + self.append(GroupNorm(c_in, n_groups=1, affine=not use_film)) # is layer norm + elif norm.lower() != 'none': + raise ValueError(f'Invalid normalization type: {norm}') + + if use_film: + self.append(FiLM(c_in, cond_dim)) + + if activation.lower() == 'relu': + self.append(ReLU(True)) + elif activation.lower() == 'silu': + self.append(SiLU(inplace=True)) + elif activation.lower() == 'leaky_relu': + self.append(LeakyReLU(inplace=True)) + else: + raise ValueError(f'Invalid activation type: {activation}') + + self.append(convND(n_dim)(c_in, c_out, kernel_size=3, padding='same')) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: # type: ignore[override] + """Apply the basic CNN to the input tensor. + + Parameters + ---------- + x + The input tensor. Should be of shape `(batch_size, channels_in, *spatial dimensions)` + with `spatial dimensions` being of length `dim`. + cond + The condition tensor. If None, no FiLM conditioning is applied. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) diff --git a/src/mrpro/nn/nets/UNet.py b/src/mrpro/nn/nets/UNet.py new file mode 100644 index 000000000..7b92b1c73 --- /dev/null +++ b/src/mrpro/nn/nets/UNet.py @@ -0,0 +1,364 @@ +"""UNet variants.""" + +from collections.abc import Sequence +from functools import partial +from itertools import pairwise + +import torch +from torch.nn import Identity, Module, ModuleList, ReLU, SiLU + +from mrpro.nn.attention.AttentionGate import AttentionGate +from mrpro.nn.attention.SpatialTransformerBlock import SpatialTransformerBlock +from mrpro.nn.CondMixin import call_with_cond +from mrpro.nn.FiLM import FiLM +from mrpro.nn.join import Concat +from mrpro.nn.ndmodules import convND, maxPoolND +from mrpro.nn.ResBlock import ResBlock +from mrpro.nn.Sequential import Sequential +from mrpro.nn.Upsample import Upsample + + +class UNetEncoder(Module): + """Encoder.""" + + def __init__( + self, + first_block: Module, + blocks: Sequence[Module], + down_blocks: Sequence[Module], + middle_block: Module, + ) -> None: + """Initialize the UNetEncoder.""" + super().__init__() + self.first = first_block + """The first block. Should expand from the number of input channels.""" + + self.blocks = ModuleList(blocks) + """The encoder blocks. Order is highest resolution to lowest resolution.""" + + self.down_blocks = ModuleList(down_blocks) + """The downsampling blocks""" + + self.middle_block = middle_block + """Also called bottleneck block""" + + def __len__(self): + """Get the number of resolutions levels.""" + return len(self.down_blocks) + 1 + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple[torch.Tensor, ...]: + """Apply to Network.""" + call = partial(call_with_cond, cond=cond) + + x = call(self.first, x) + + xs = [] + for block, down in zip(self.blocks, self.down_blocks, strict=True): + x = call(block, x) + xs.append(x) + x = call(down, x) + + x = call(self.middle_block, x) + + return (*xs, x) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> tuple[torch.Tensor, ...]: + """Apply to Network. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor. + + Returns + ------- + The tensors at the different resolutions, highest resolution first. + """ + return super().__call__(x, cond=cond) + + +class UNetDecoder(Module): + """Decoder.""" + + def __init__( + self, + blocks: Sequence[Module], + up_blocks: Sequence[Module], + concat_blocks: Sequence[Module], + last_block: Module, + ) -> None: + """Initialize the UNetDecoder.""" + super().__init__() + self.blocks = ModuleList(blocks) + """The decoder blocks. Order is lowest resolution to highest resolution.""" + + self.up_blocks = ModuleList(up_blocks) + """The upsampling blocks""" + + self.concat_blocks = ModuleList(concat_blocks) + """Joins the skip connections with the upsampled features from a lower resolution level""" + + self.last_block = last_block + """The last block. Should reduce to the number of output channels.""" + + def __len__(self): + """Get the number of resolutions levels.""" + return len(self.up_blocks) + 1 + + def forward(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network.""" + call = partial(call_with_cond, cond=cond) + + x = hs[-1] # lowest resolution, from middle block + for block, up, concat, h in zip(self.blocks, self.up_blocks, self.concat_blocks, hs[-2::-1], strict=True): + x = call(up, x) + x = concat(h, x) + x = call(block, x) + x = call(self.last_block, x) + return x + + def __call__(self, hs: tuple[torch.Tensor, ...], *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network. + + Parameters + ---------- + hs + The tensors at the different resolutions, highest resolution first. + cond + The conditioning tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(hs, cond=cond) + + +class UNetBase(Module): + """Base class for U-shaped networks.""" + + def __init__(self, encoder: UNetEncoder, decoder: UNetDecoder, skip_blocks: Sequence[Module] | None = None) -> None: + """Initialize the UNetBase.""" + super().__init__() + self.encoder = encoder + """The encoder.""" + + self.decoder = decoder + """The decoder.""" + + self.skip_blocks = ModuleList() + """Modifications of the skip connections.""" + + if len(decoder) != len(encoder): + raise ValueError( + 'The number of resolutions in the encoder and decoder must be the same, ' + f'got {len(decoder)} and {len(encoder)}' + ) + + if skip_blocks is None: + self.skip_blocks.extend(Identity() for _ in range(len(decoder))) + elif len(skip_blocks) != len(decoder): + raise ValueError( + f'The number of skip blocks must be the same as the number of resolutions, ' + f'got {len(skip_blocks)} and {len(encoder)}' + ) + else: + self.skip_blocks.extend(skip_blocks) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network.""" + xs = self.encoder(x, cond=cond) + xs = tuple( + call_with_cond(self.skip_blocks[i], x, cond=cond) if i < len(self.skip_blocks) else x + for i, x in enumerate(xs) + ) + x = self.decoder(xs, cond=cond) + return x + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply to Network. + + Parameters + ---------- + x + The input tensor. + cond + The conditioning tensor. + + Returns + ------- + The output tensor. + """ + return super().__call__(x, cond=cond) + + +class UNet(UNetBase): + """UNet. + + U-shaped convolutional network with optional patch attention. + Inspired by [NOSENSE_] and the OpenAi DDPM UNet/Latent Diffusion UNet [LDM]_. + significant differences to the vanilla UNet [UNET]_ include: + - Spatial transformer blocks + - Convolutional downsampling, nearest neighbor upsampling + - Residual convolution blocks with pre-act group normalization and SiLU activation + + + References + ---------- + .. [UNET] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image + segmentation MICCAI 2015. https://arxiv.org/abs/1505.04597 + .. [LDM] https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py + .. [NOSENSE] Zimmermann, FF, and Kofler, Andreas. "NoSENSE: Learned unrolled cardiac MRI reconstruction without + explicit sensitivity maps." STACOM 2023. https://github.com/fzimmermann89/CMRxRecon/blob/master/src/cmrxrecon/nets/unet.py + + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + attention_depths: Sequence[int] = (-1,), + n_features: Sequence[int] = (64, 128, 192, 256), + n_heads: int = 8, + cond_dim: int = 0, + encoder_blocks_per_scale: int = 2, + ) -> None: + """Initialize the UNet. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. + attention_depths + The depths at which to apply attention. + n_features + Number of features at each resolution level. The length determines the number of resolution levels. + n_heads + Number of attention heads. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + encoder_blocks_per_scale + Number of encoder blocks per resolution level. The number of decoder blocks is one more. + """ + depth = len(n_features) + if not all(-depth <= d < depth for d in attention_depths): + raise ValueError( + f'attention_depths must be in the range [-depth, depth], got {attention_depths=} for {depth=}' + ) + attention_depths = tuple(d % depth for d in attention_depths) + if len(attention_depths) != len(set(attention_depths)): + raise ValueError(f'attention_depths must be unique, got {attention_depths=}') + + def attention_block(channels: int) -> Module: + dim_groups = (tuple(range(-n_dim, 0)),) + return SpatialTransformerBlock(dim_groups, channels, n_heads, cond_dim=cond_dim) + + def blocks(channels_in: int, channels_out: int, attention: bool) -> Module: + blocks = Sequential() + for _ in range(encoder_blocks_per_scale): + blocks.append(ResBlock(n_dim, channels_in, channels_out, cond_dim)) + if attention: + blocks.append(attention_block(channels_out)) + channels_in = channels_out + return blocks + + encoder_blocks: list[Module] = [] + down_blocks: list[Module] = [] + decoder_blocks: list[Module] = [] + up_blocks: list[Module] = [] + + for i_level, (n_feat, n_feat_next) in enumerate(pairwise(n_features)): + encoder_blocks.append(blocks(n_feat, n_feat, i_level in attention_depths)) + down_blocks.append(convND(n_dim)(n_feat, n_feat_next, 3, stride=2, padding=1)) + decoder_blocks.append(blocks(n_feat_next + n_feat, n_feat, i_level in attention_depths)) + up_blocks.append(Upsample(tuple(range(-n_dim, 0)), scale_factor=2)) + + middle_block = Sequential( + ResBlock(n_dim, n_feat_next, n_feat_next, cond_dim), + ResBlock(n_dim, n_feat_next, n_feat_next, cond_dim), + ) + if depth - 1 in attention_depths: + middle_block.insert(1, attention_block(n_feat_next)) + first_block = convND(n_dim)(n_channels_in, n_features[0], 3, padding=1) + encoder = UNetEncoder(first_block, encoder_blocks, down_blocks, middle_block) + + decoder_blocks, up_blocks = decoder_blocks[::-1], up_blocks[::-1] + last_block = Sequential( + SiLU(), + convND(n_dim)(n_features[0], n_channels_out, 3, padding=1), + ) + concat_blocks = [Concat() for _ in range(len(decoder_blocks))] + decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + + super().__init__(encoder, decoder) + + +class AttentionGatedUNet(UNetBase): + """UNet with attention gates. + + Basic UNet with attention gating of the skip signals by the lower resolution features [OKT18]_. + + References + ---------- + .. [OKT18] Oktay, Ozan, et al. "Attention U-net: Learning where to look for the pancreas." MIDL (2018). + https://arxiv.org/abs/1804.03999 + """ + + def __init__( + self, n_dim: int, n_channels_in: int, n_channels_out: int, n_features: Sequence[int], cond_dim: int = 0 + ): + """Initialize the AttentionGatedUNet. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of channels in the input tensor. + n_channels_out + The number of channels in the output tensor. + n_features + Number of features at each resolution level. The length determines the number of resolution levels. + cond_dim + Number of channels in the conditioning tensor. If 0, no conditioning is applied. + """ + + def block(channels_in: int, channels_out: int) -> Module: + block = Sequential( + convND(n_dim)(channels_in, channels_out, 3, padding=1), + ReLU(True), + convND(n_dim)(channels_out, channels_out, 3, padding=1), + ReLU(True), + ) + if cond_dim > 0: + block.insert(2, FiLM(channels_out, cond_dim)) + return block + + encoder_blocks: list[Module] = [] + down_blocks: list[Module] = [] + n_feat_old = n_channels_in + for n_feat in n_features[:-1]: + encoder_blocks.append(block(n_feat_old, n_feat)) + down_blocks.append(maxPoolND(n_dim)(2)) + n_feat_old = n_feat + middle_block = block(n_features[-2], n_features[-1]) + encoder = UNetEncoder(Identity(), encoder_blocks, down_blocks, middle_block) + + concat_blocks = [] + decoder_blocks: list[Module] = [] + up_blocks: list[Module] = [] + for n_feat, n_feat_skip in pairwise(n_features[::-1]): + concat_blocks.append(AttentionGate(n_dim, n_feat, n_feat_skip, n_feat_skip, concatenate=True)) + decoder_blocks.append(block(n_feat + n_feat_skip, n_feat_skip)) + up_blocks.append(Upsample(range(-n_dim, 0), scale_factor=2)) + last_block = convND(n_dim)(n_features[0], n_channels_out, 1) + decoder = UNetDecoder(decoder_blocks, up_blocks, concat_blocks, last_block) + + super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py new file mode 100644 index 000000000..06271d970 --- /dev/null +++ b/src/mrpro/nn/nets/__init__.py @@ -0,0 +1,8 @@ +from mrpro.nn.nets.BasicCNN import BasicCNN +from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet + +__all__ = [ + 'AttentionGatedUNet', + 'BasicCNN', + 'UNet', +] diff --git a/tests/nn/nets/test_cnn.py b/tests/nn/nets/test_cnn.py new file mode 100644 index 000000000..5dbbd5ad6 --- /dev/null +++ b/tests/nn/nets/test_cnn.py @@ -0,0 +1,58 @@ +"""Tests for BasicCNN network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import BasicCNN + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_cnn_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the cnn.""" + cnn = BasicCNN( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + norm='layer', + n_features=(8, 8), + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cnn = cnn.to(device) + x = x.to(device) + if torch_compile: + cnn = cast(BasicCNN, torch.compile(cnn)) + y = cnn(x) + assert y.shape == (1, 1, 16, 16) + + +def test_cnn_backward() -> None: + cnn = BasicCNN( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + norm='instance', + activation='silu', + n_features=(8, 8), + cond_dim=32, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = cnn(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in cnn.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' diff --git a/tests/nn/nets/test_unet.py b/tests/nn/nets/test_unet.py new file mode 100644 index 000000000..fdf2f5250 --- /dev/null +++ b/tests/nn/nets/test_unet.py @@ -0,0 +1,116 @@ +"""Tests for UNet and AttentionGatedUNet networks.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import AttentionGatedUNet, UNet + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_unet_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the UNet.""" + unet = UNet( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + attention_depths=(-1,), + n_features=(4, 6, 8), + n_heads=2, + cond_dim=32, + encoder_blocks_per_scale=1, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + unet = unet.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + unet = cast(UNet, torch.compile(unet)) + y = unet(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + + +def test_unet_backward() -> None: + unet = UNet( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + attention_depths=(-1,), + n_features=(4, 6, 8), + n_heads=2, + cond_dim=32, + encoder_blocks_per_scale=1, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = unet(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in unet.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_gated_unet_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the AttentionGatedUNet.""" + unet = AttentionGatedUNet( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_features=(4, 6, 8), + cond_dim=32, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + unet = unet.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + unet = cast(AttentionGatedUNet, torch.compile(unet)) + y = unet(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + + +def test_gated_unet_backward() -> None: + """Test the backward pass of the AttentionGatedUNet.""" + unet = AttentionGatedUNet( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_features=(4, 6, 8), + cond_dim=32, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = unet(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in unet.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' diff --git a/tests/nn/test_resblock.py b/tests/nn/test_resblock.py new file mode 100644 index 000000000..ea4356173 --- /dev/null +++ b/tests/nn/test_resblock.py @@ -0,0 +1,56 @@ +"""Tests for ResBlock module.""" + +from collections.abc import Sequence +from typing import cast + +import pytest +import torch +from mrpro.nn import ResBlock +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'eager']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim', 'channels_in', 'channels_out', 'cond_dim', 'input_shape', 'cond_shape'), + [ + (2, 32, 32, 16, (1, 32, 32, 32), (1, 16)), + (3, 64, 32, 0, (2, 64, 16, 16, 16), None), + ], +) +def test_resblock( + dim: int, + channels_in: int, + channels_out: int, + cond_dim: int, + input_shape: Sequence[int], + cond_shape: Sequence[int] | None, + device: str, + torch_compile: bool, +) -> None: + """Test ResBlock output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) if cond_shape else None + block = ResBlock(n_dim=dim, n_channels_in=channels_in, n_channels_out=channels_out, cond_dim=cond_dim).to(device) + if torch_compile: + block = cast(ResBlock, torch.compile(block, dynamic=False)) + output = block(x, cond=cond) + assert output.shape == (input_shape[0], channels_out, *input_shape[2:]), ( + f'Output shape {output.shape} != expected {(input_shape[0], channels_out, *input_shape[2:])}' + ) + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert block.block[2].weight.grad is not None, 'No gradient computed for first Conv' + if cond is not None: + assert cond.grad is not None, 'No gradient computed for conditioning' + assert not cond.isnan().any(), 'NaN values in conditioning' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' diff --git a/tests/nn/test_separableresblock.py b/tests/nn/test_separableresblock.py new file mode 100644 index 000000000..25b6c65e0 --- /dev/null +++ b/tests/nn/test_separableresblock.py @@ -0,0 +1,58 @@ +"""Tests for SeparableResBlock module.""" + +from collections.abc import Sequence +from typing import cast + +import pytest +import torch +from mrpro.nn import SeparableResBlock +from mrpro.utils import RandomGenerator + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'eager']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', id='cuda', marks=pytest.mark.cuda), + ], +) +@pytest.mark.parametrize( + ('dim_groups', 'channels_in', 'channels_out', 'cond_dim', 'input_shape', 'cond_shape'), + [ + (((-1, -2),), 32, 32, 16, (1, 32, 32, 32), (1, 16)), + (((-1, -2), (-3,)), 64, 32, 0, (2, 64, 16, 16, 16), None), # 2D + 1D + ], +) +def test_separable_resblock( + dim_groups: Sequence[Sequence[int]], + channels_in: int, + channels_out: int, + cond_dim: int, + input_shape: Sequence[int], + cond_shape: Sequence[int] | None, + device: str, + torch_compile: bool, +) -> None: + """Test SeparableResBlock output shape and backpropagation.""" + rng = RandomGenerator(seed=42) + x = rng.float32_tensor(input_shape).to(device).requires_grad_(True) + cond = rng.float32_tensor(cond_shape).to(device).requires_grad_(True) if cond_shape else None + block = SeparableResBlock( + dim_groups=dim_groups, n_channels_in=channels_in, n_channels_out=channels_out, cond_dim=cond_dim + ).to(device) + if torch_compile: + block = cast(SeparableResBlock, torch.compile(block, dynamic=False)) + output = block(x, cond=cond) + assert output.shape == (input_shape[0], channels_out, *input_shape[2:]), ( + f'Output shape {output.shape} != expected {(input_shape[0], channels_out, *input_shape[2:])}' + ) + output.sum().backward() + assert x.grad is not None, 'No gradient computed for input' + assert not output.isnan().any(), 'NaN values in output' + assert not x.grad.isnan().any(), 'NaN values in input gradients' + assert block.block[0][2].module.weight.grad is not None, 'No gradient computed for first Conv' # type: ignore[union-attr] + if cond is not None: + assert cond.grad is not None, 'No gradient computed for conditioning' + assert not cond.isnan().any(), 'NaN values in conditioning' + assert not cond.grad.isnan().any(), 'NaN values in conditioning gradients' From 0a6850056ea6a66cd562a6650ae8f56ed2436208 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Wed, 11 Feb 2026 14:31:48 +0100 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- src/mrpro/nn/ResBlock.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mrpro/nn/ResBlock.py b/src/mrpro/nn/ResBlock.py index bb6a3e81b..b84443f39 100644 --- a/src/mrpro/nn/ResBlock.py +++ b/src/mrpro/nn/ResBlock.py @@ -40,6 +40,7 @@ def __init__(self, n_dim: int, n_channels_in: int, n_channels_out: int, cond_dim convND(n_dim)(n_channels_out, n_channels_out, kernel_size=3, padding=1), ) if cond_dim > 0: + self.block.insert(1, FiLM(n_channels_in, cond_dim)) self.block.insert(-2, FiLM(n_channels_out, cond_dim)) if n_channels_out == n_channels_in: