diff --git a/src/mrpro/nn/nets/Restormer.py b/src/mrpro/nn/nets/Restormer.py new file mode 100644 index 000000000..357bebf51 --- /dev/null +++ b/src/mrpro/nn/nets/Restormer.py @@ -0,0 +1,223 @@ +"""Restormer implementation.""" + +from collections.abc import Sequence +from itertools import pairwise + +import torch +from torch.nn import Module + +from mrpro.nn.attention.TransposedAttention import TransposedAttention +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.FiLM import FiLM +from mrpro.nn.join import Concat +from mrpro.nn.ndmodules import convND, instanceNormND +from mrpro.nn.nets.UNet import UNetBase, UNetDecoder, UNetEncoder +from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample +from mrpro.nn.Sequential import Sequential + + +class GDFN(Module): + """Gated depthwise feed forward network. + + Feed-forward block used in Restormer [ZAM22]_. It first expands channels, + applies a depthwise convolution, then uses a gated interaction between two + channel splits before projecting back to the input width. + + References + ---------- + .. [ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for + high-resolution image restoration." CVPR 2022. + """ + + def __init__(self, n_dim: int, n_channels: int, mlp_ratio: float): + """Initialize GDFN. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + mlp_ratio + Ratio for hidden dimension expansion + """ + super().__init__() + + hidden_features = int(n_channels * mlp_ratio) + self.project_in = convND(n_dim)(n_channels, hidden_features * 2, kernel_size=1) + self.depthwise_conv = convND(n_dim)( + hidden_features * 2, + hidden_features * 2, + kernel_size=3, + stride=1, + padding=1, + groups=hidden_features * 2, + ) + self.project_out = convND(n_dim)(hidden_features, n_channels, kernel_size=1) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the gated depthwise feed forward network. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Output tensor + """ + x = self.project_in(x) + x1, x2 = self.depthwise_conv(x).chunk(2, dim=1) + x = x1 * torch.sigmoid(x2) + x = self.project_out(x) + return x + + +class RestormerBlock(CondMixin, Module): + """Transformer block with transposed attention and gated depthwise feed forward network.""" + + def __init__(self, n_dim: int, n_channels: int, n_heads: int, mlp_ratio: float, cond_dim: int = 0): + """Initialize RestormerBlock. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + n_heads + Number of attention heads + mlp_ratio + Ratio for hidden dimension expansion + cond_dim + Dimension of conditioning input. If 0, no conditioning is applied. + """ + super().__init__() + self.norm1 = Sequential(instanceNormND(n_dim)(n_channels)) + self.attn = TransposedAttention(n_dim, n_channels, n_channels, n_heads) + self.norm2 = Sequential(instanceNormND(n_dim)(n_channels)) + self.ffn = GDFN(n_dim, n_channels, mlp_ratio) + if cond_dim > 0: + self.norm2.append(FiLM(channels=n_channels, cond_dim=cond_dim)) + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply Restormer block. + + Parameters + ---------- + x + Input tensor + cond + Conditioning tensor. If None, no conditioning is applied. + + Returns + ------- + Output tensor + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Forward pass for RestormerBlock.""" + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x, cond=cond)) + return x + + +class Restormer(UNetBase): + """Restormer architecture. + + Implements the Restormer [ZAM22]_ network, which is a U-shaped transformer + with channel wise attention and depthwise convolutions in the feed forward network. + + References + ---------- + .. [ZAM22] Zamir, Syed Waqas, et al. "Restormer: Efficient transformer for high-resolution image restoration." + CVPR 2022, https://arxiv.org/pdf/2111.09881.pdf + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + n_blocks: Sequence[int] = (4, 6, 6, 8), + n_refinement_blocks: int = 4, + n_heads: Sequence[int] = (1, 2, 4, 8), + n_channels_per_head: int = 48, + mlp_ratio: float = 2.66, + cond_dim: int = 0, + ): + """Initialize Restormer. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of input channels. + n_channels_out + The number of output channels. + n_blocks + Number of blocks in each stage + n_refinement_blocks + Number of refinement blocks + n_heads + Number of attention heads in each stage + n_channels_per_head + Number of channels per attention head + mlp_ratio + Ratio for hidden dimension expansion + cond_dim + Dimension of conditioning input. If 0, no conditioning is applied. + """ + if len(n_blocks) != len(n_heads): + raise ValueError('n_blocks and n_heads must have the same length.') + + def blocks(n_heads: int, n_blocks: int): + layers = Sequential( + *(RestormerBlock(n_dim, n_channels_per_head * n_heads, n_heads, mlp_ratio) for _ in range(n_blocks)) + ) + + if cond_dim > 0 and n_blocks > 1: + layers.insert(1, FiLM(channels=n_channels_per_head * n_heads, cond_dim=cond_dim)) + return layers + + first_block = convND(n_dim)(n_channels_in, n_channels_per_head, kernel_size=3, stride=1, padding=1, bias=False) + encoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)] + down_blocks = [ + PixelUnshuffleDownsample(n_dim, n_channels_per_head * head_current, n_channels_per_head * head_next) + for head_current, head_next in pairwise(n_heads) + ] + middle_block = blocks(n_heads[-1], n_blocks[-1]) + encoder = UNetEncoder( + first_block=first_block, + blocks=encoder_blocks, + down_blocks=down_blocks, + middle_block=middle_block, + ) + + up_blocks = [ + PixelShuffleUpsample(n_dim, n_channels_per_head * head_next, n_channels_per_head * head_current) + for head_current, head_next in pairwise(n_heads) + ][::-1] + concat_blocks = [ + Sequential( + Concat(), + convND(n_dim)(2 * n_channels_per_head * head, n_channels_per_head * head, kernel_size=1), + ) + for head in n_heads[-2::-1] + ] + decoder_blocks = [blocks(head, block) for head, block in zip(n_heads[:-1], n_blocks[:-1], strict=True)][::-1] + last_block = Sequential( + *(RestormerBlock(n_dim, n_channels_per_head, n_heads[0], mlp_ratio) for _ in range(n_refinement_blocks)), + convND(n_dim)(n_channels_per_head, n_channels_out, kernel_size=3, stride=1, padding=1), + ) + decoder = UNetDecoder( + blocks=decoder_blocks, + up_blocks=up_blocks, + concat_blocks=concat_blocks, + last_block=last_block, + ) + + super().__init__(encoder=encoder, decoder=decoder) diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index af7dbdc5e..a0a9a6ad4 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -1,4 +1,5 @@ from mrpro.nn.nets.BasicCNN import BasicCNN +from mrpro.nn.nets.Restormer import Restormer from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet from mrpro.nn.nets.MLP import MLP @@ -6,5 +7,6 @@ "AttentionGatedUNet", "BasicCNN", "MLP", + "Restormer", "UNet", ] diff --git a/tests/nn/nets/test_restormer.py b/tests/nn/nets/test_restormer.py new file mode 100644 index 000000000..68c84a689 --- /dev/null +++ b/tests/nn/nets/test_restormer.py @@ -0,0 +1,62 @@ +"""Tests for Restormer network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import Restormer + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_restormer_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the restormer.""" + restormer = Restormer( + n_dim=2, + n_channels_in=1, + n_channels_out=1, + n_heads=(1, 2, 4), + n_blocks=(2, 1, 1), + cond_dim=32, + n_channels_per_head=2, + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + cond = torch.zeros(1, 32, device=device) + restormer = restormer.to(device) + x = x.to(device) + cond = cond.to(device) + if torch_compile: + restormer = cast(Restormer, torch.compile(restormer)) + y = restormer(x, cond=cond) + assert y.shape == (1, 1, 16, 16) + + +def test_restormer_backward() -> None: + restormer = Restormer( + n_dim=1, + n_channels_in=1, + n_channels_out=1, + n_heads=(1, 2), + n_blocks=(2, 2), + cond_dim=32, + n_channels_per_head=4, + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + cond = torch.zeros(1, 32, requires_grad=True) + y = restormer(x, cond=cond) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + assert cond.grad is not None, 'cond.grad is None' + assert not cond.grad.isnan().any(), 'cond.grad is NaN' + for name, parameter in restormer.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN'