From bde6437a277466b2087245cb42fbedb45f14836d Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sat, 7 Feb 2026 23:26:24 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- src/mrpro/nn/GluMBConvResBlock.py | 118 ++++++++++++ src/mrpro/nn/nets/DCVAE.py | 302 ++++++++++++++++++++++++++++++ src/mrpro/nn/nets/VAE.py | 64 +++++++ src/mrpro/nn/nets/__init__.py | 4 + tests/nn/nets/test_dcvae.py | 82 ++++++++ 5 files changed, 570 insertions(+) create mode 100644 src/mrpro/nn/GluMBConvResBlock.py create mode 100644 src/mrpro/nn/nets/DCVAE.py create mode 100644 src/mrpro/nn/nets/VAE.py create mode 100644 tests/nn/nets/test_dcvae.py diff --git a/src/mrpro/nn/GluMBConvResBlock.py b/src/mrpro/nn/GluMBConvResBlock.py new file mode 100644 index 000000000..c17bc019f --- /dev/null +++ b/src/mrpro/nn/GluMBConvResBlock.py @@ -0,0 +1,118 @@ +"""Gateded MBConv Residual Block.""" + +import torch +from torch.nn import Identity, Module, Sequential, SiLU + +from mrpro.nn.CondMixin import CondMixin +from mrpro.nn.FiLM import FiLM +from mrpro.nn.ndmodules import convND +from mrpro.nn.RMSNorm import RMSNorm + + +class GluMBConvResBlock(CondMixin, Module): + """Gated MBConv residual block. + + Gated variant [DCAE]_ of the MBConv block [EffNet]_ with a residual connection and (optional) conditioning. + + References + ---------- + .. [DCAE] Chen et al. Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models. ICLR 2025 + https://arxiv.org/abs/2410.10733 + .. [EffNet] Tan et al. EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. ICML 2019 + https://arxiv.org/abs/1905.11946 + """ + + def __init__( + self, + n_dim: int, + n_channels_in: int, + n_channels_out: int, + expand_ratio: int = 6, + stride: int = 1, + kernel_size: int = 3, + cond_dim: int = 0, + ): + """Initialize MBConv block. + + Parameters + ---------- + n_dim + Number of spatial dimensions. + n_channels_in + Number of input channels. + n_channels_out + Number of output channels. + expand_ratio + Expansion ratio inside the block. + stride + Stride of the depthwise convolution. + kernel_size + Kernel size of the depthwise convolution. + cond_dim + Dimension of the conditioning tensor used in a FiLM. If 0, no FiLM is used. + """ + super().__init__() + channels_mid = n_channels_in * expand_ratio + if stride == 1 and n_channels_in == n_channels_out: + self.skip: Module = Identity() + else: + self.skip = convND(n_dim)(n_channels_in, n_channels_out, kernel_size=1, stride=stride) + self.inverted_conv = Sequential( + convND(n_dim)( + n_channels_in, + channels_mid * 2, + kernel_size=1, + ), + SiLU(), + ) + self.depth_conv = Sequential( + convND(n_dim)( + channels_mid * 2, + channels_mid * 2, + kernel_size=kernel_size, + stride=stride, + padding='same', + groups=channels_mid * 2, + ), + SiLU(), + ) + self.point_conv = Sequential( + convND(n_dim)( + channels_mid, + n_channels_out, + kernel_size=1, + ), + RMSNorm(n_channels_out), + SiLU(), + ) + if cond_dim > 0: + self.film: FiLM | None = FiLM(channels_mid, cond_dim) + else: + self.film = None + + def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply MBConv block. + + Parameters + ---------- + x + Input tensor. + cond + Conditioning tensor. If `None`, no conditioning is applied. + + Returns + ------- + Output tensor. + """ + return super().__call__(x, cond=cond) + + def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: + """Apply MBConv block.""" + h = self.inverted_conv(x) + h = self.depth_conv(h) + h, gate = torch.chunk(h, 2, dim=1) + h = h * torch.nn.functional.silu(gate) + if self.film is not None: + h = self.film(h, cond=cond) + h = self.point_conv(h) + return self.skip(x) + h diff --git a/src/mrpro/nn/nets/DCVAE.py b/src/mrpro/nn/nets/DCVAE.py new file mode 100644 index 000000000..5faba554b --- /dev/null +++ b/src/mrpro/nn/nets/DCVAE.py @@ -0,0 +1,302 @@ +"""Deep Compression Autoencoder.""" + +from collections.abc import Sequence +from typing import Literal + +import torch +from torch.nn import Module, ReLU, SiLU + +from mrpro.nn.attention.LinearSelfAttention import LinearSelfAttention +from mrpro.nn.attention.MultiHeadAttention import MultiHeadAttention +from mrpro.nn.GluMBConvResBlock import GluMBConvResBlock +from mrpro.nn.ndmodules import convND +from mrpro.nn.nets.VAE import VAE +from mrpro.nn.PixelShuffle import PixelShuffleUpsample, PixelUnshuffleDownsample +from mrpro.nn.Residual import Residual +from mrpro.nn.RMSNorm import RMSNorm +from mrpro.nn.Sequential import Sequential + + +class CNNBlock(Residual): + """Block with two convolutions and normalization. + + As used in the DCAE [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int, + n_channels: int, + ): + """Initialize the CNNBlock. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + """ + super().__init__( + Sequential( + convND(n_dim)(n_channels, n_channels, kernel_size=3, padding=1), + SiLU(True), + convND(n_dim)(n_channels, n_channels, kernel_size=3, padding=1, bias=False), + RMSNorm(n_channels), + ) + ) + + +class EfficientViTBlock(Module): + """Efficient Vision Transformer block with optional linear attention. + + As used in the DCAE [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int, + n_channels: int, + n_heads: int, + expand_ratio: int = 4, + linear_attn: bool = False, + ): + """Initialize the EfficientViTBlock. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + n_heads + The number of attention heads. + expand_ratio + The expansion ratio of the GluMBConvResBlock. + linear_attn + Whether to use linear attention instead of softmax attention with quadratic complexity. + """ + super().__init__() + if linear_attn: + attention: Module = LinearSelfAttention(n_channels, n_channels, n_heads) + else: + attention = MultiHeadAttention(n_channels, n_channels, n_heads, features_last=False) + self.context_module = Residual(Sequential(attention, RMSNorm(n_channels))) + self.local_module = GluMBConvResBlock( + n_dim=n_dim, + n_channels_in=n_channels, + n_channels_out=n_channels, + expand_ratio=expand_ratio, + ) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the EfficientViTBlock. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + Output tensor + """ + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for EfficientViTBlock.""" + x = self.context_module(x) + x = self.local_module(x) + return x + + +class Encoder(Sequential): + """Encoder for DCAE. + + As used in the DC-Autoencoder [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int = 2, + n_channels_in: int = 3, + n_channels_out: int = 32, + block_types: Sequence[Literal['CNN', 'LinearViT', 'ViT']] = ('CNN', 'CNN', 'LinearViT', 'LinearViT', 'ViT'), + widths: Sequence[int] = (256, 512, 512, 1024, 1024), + depths: Sequence[int] = (4, 6, 2, 2, 2), + ): + """Initialize the Encoder. + + The length of the `block_types`, `widths`, and `depths` must be the same and determine + the number of stages in the encoder. Between the stages, downsampling is performed. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of channels in the input tensor, i.e. the latent space + n_channels_out + The number of channels in the output tensor, i.e. the original space + block_types + The types of blocks to use in the decoder. + widths + The widths of the blocks in the decoder, i.e. the number of channels in the blocks + depths + The depths of the blocks in the decoder, i.e. the number blocks in the stage + """ + super().__init__() + self.append(PixelUnshuffleDownsample(n_dim, n_channels_in, widths[0], downscale_factor=2, residual=False)) + if len(block_types) != len(widths) or len(block_types) != len(depths): + raise ValueError('block_types, widths, and depths must have the same length') + for block_type, width, next_width, depth in zip(block_types, widths, (*widths[1:], None), depths, strict=False): + match block_type: + case 'CNN': + stage: list[Module] = [CNNBlock(n_dim, width) for _ in range(depth)] + case 'LinearViT': + stage = [ + EfficientViTBlock(n_dim, width, max(1, width // 32), linear_attn=True) for _ in range(depth) + ] + case 'ViT': + stage = [EfficientViTBlock(n_dim, width, max(1, width // 32)) for _ in range(depth)] + case _: + raise ValueError(f'Block type {block_type} not supported') + self.append(Sequential(*stage)) + if next_width: + self.append(PixelUnshuffleDownsample(n_dim, width, next_width, downscale_factor=2, residual=True)) + self.append( + Sequential( + RMSNorm(widths[-1]), + ReLU(), + PixelUnshuffleDownsample(n_dim, widths[-1], n_channels_out, downscale_factor=1, residual=True), + ) + ) + + +class Decoder(Sequential): + """Decoder for DCAE. + + As used in the DC-Autoencoder [DCAE]_. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int = 2, + n_channels_in: int = 32, + n_channels_out: int = 3, + block_types: Sequence[Literal['ViT', 'LinearViT', 'CNN']] = ('ViT', 'LinearViT', 'LinearViT', 'CNN', 'CNN'), + widths: Sequence[int] = (1024, 1024, 512, 512, 256), + depths: Sequence[int] = (2, 2, 2, 6, 4), + ): + """Initialize the Decoder. + + The length of the `block_types`, `widths`, and `depths` must be the same and determine + the number of stages in the decoder. Between the stages, upsampling is performed. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels_in + The number of channels in the input tensor, i.e. the latent space + n_channels_out + The number of channels in the output tensor, i.e. the original space + block_types + The types of blocks to use in the decoder. + widths + The widths of the blocks in the decoder, i.e. the number of channels in the blocks + depths + The depths of the blocks in the decoder, i.e. the number blocks in the stage + """ + super().__init__() + if not (len(block_types) == len(widths) == len(depths)): + raise ValueError('block_types, widths, and depths must have the same length') + self.append(PixelShuffleUpsample(n_dim, n_channels_in, widths[0], upscale_factor=1, residual=True)) + + for block_type, width, next_width, depth in zip(block_types, widths, (*widths[1:], None), depths, strict=False): + match block_type: + case 'CNN': + stage: list[Module] = [CNNBlock(n_dim, width) for _ in range(depth)] + case 'LinearViT': + stage = [ + EfficientViTBlock(n_dim, width, n_heads=max(1, width // 32), linear_attn=True) + for _ in range(depth) + ] + case 'ViT': + stage = [ + EfficientViTBlock(n_dim, width, n_heads=max(1, width // 32), linear_attn=False) + for _ in range(depth) + ] + case _: + raise ValueError(f'Block type {block_type} not supported') + self.append(Sequential(*stage)) + if next_width: + self.append(PixelShuffleUpsample(n_dim, width, next_width, upscale_factor=2, residual=True)) + + self.append( + Sequential( + RMSNorm(widths[-1]), + ReLU(), + PixelShuffleUpsample(n_dim, widths[-1], n_channels_out, upscale_factor=2), + ) + ) + + +class DCVAE(VAE): + """Variational Autoencoder based on DCAE. + + References + ---------- + .. [DCAE] Chen, J., Cai, H., Chen, J., Xie, E., Yang, S., Tang, H., ... & Han, S. Deep compression autoencoder + for efficient high-resolution diffusion models. ICLR 2025. https://arxiv.org/abs/2410.10733 + """ + + def __init__( + self, + n_dim: int, + n_channels: int, + latent_dim: int = 32, + block_types: Sequence[Literal['CNN', 'LinearViT', 'ViT']] = ('CNN', 'CNN', 'LinearViT', 'LinearViT', 'ViT'), + widths: Sequence[int] = (256, 512, 512, 1024, 1024), + depths: Sequence[int] = (4, 6, 2, 2, 2), + ): + """Initialize the DCVAE. + + Parameters + ---------- + n_dim + The number of spatial dimensions of the input tensor. + n_channels + The number of channels in the input tensor. + latent_dim + The number of channels in the latent space. + block_types + The types of blocks to use in the encoder and decoder. + widths + The widths of the blocks in the encoder and decoder. + depths + The depths of the blocks in the encoder and decoder. + """ + encoder = Encoder(n_dim, n_channels, latent_dim * 2, block_types, widths, depths) + decoder = Decoder(n_dim, latent_dim, n_channels, block_types[::-1], widths[::-1], depths[::-1]) + super().__init__(encoder, decoder) diff --git a/src/mrpro/nn/nets/VAE.py b/src/mrpro/nn/nets/VAE.py new file mode 100644 index 000000000..e0b1bfc58 --- /dev/null +++ b/src/mrpro/nn/nets/VAE.py @@ -0,0 +1,64 @@ +"""Variational Autoencoder with a Gaussian latent space.""" + +import torch +from torch.nn import Module + + +class VAE(Module): + """Basic Variational Autoencoder. + + Consists of an encoder to transform the input into a latent space and a decoder to transform the latent space back + into the original space. The encoder should return twice the number of channels as the decoder needs to reconstruct + the input: half of the channels are the mean and the other half the log variance of the latent space. + The reparameterization trick is used to sample from the latent space. + The forward pass returns the reconstructed image and the KL divergence between the latent space and the standard + normal distribution. + """ + + def __init__(self, encoder: Module, decoder: Module): + """Initialize the VAE. + + Parameters + ---------- + encoder + Encoder module. Should return double the number of channels of the latent space. + decoder + Decoder module + """ + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass of the VAE. + + Calculates the reconstruction as well as the KL divergence between the latent space and the + standard normal distribution. + + Parameters + ---------- + x + Input tensor + + Returns + ------- + tuple of the reconstructed image and + the KL divergence between the latent space and the standard normal distribution. + """ + return self.forward(x) + + def mode(self, x: torch.Tensor) -> torch.Tensor: + """Mode of the VAE.""" + z = self.encoder(x) + mean, _ = z.chunk(2, dim=1) + return self.decoder(mean) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass of the VAE.""" + z = self.encoder(x) + mean, logvar = z.chunk(2, dim=1) + std = torch.exp(0.5 * logvar) + sample = mean + torch.randn_like(std) * std + reconstruction = self.decoder(sample) + kl = -0.5 * torch.sum(1 + logvar - mean.square() - std.square()) + return reconstruction, kl diff --git a/src/mrpro/nn/nets/__init__.py b/src/mrpro/nn/nets/__init__.py index b5a6a76af..87d9075f7 100644 --- a/src/mrpro/nn/nets/__init__.py +++ b/src/mrpro/nn/nets/__init__.py @@ -1,16 +1,20 @@ from mrpro.nn.nets.BasicCNN import BasicCNN +from mrpro.nn.nets.DCVAE import DCVAE from mrpro.nn.nets.HourglassTransformer import HourglassTransformer from mrpro.nn.nets.Restormer import Restormer from mrpro.nn.nets.SwinIR import SwinIR from mrpro.nn.nets.UNet import AttentionGatedUNet, UNet from mrpro.nn.nets.Uformer import Uformer +from mrpro.nn.nets.VAE import VAE __all__ = [ 'AttentionGatedUNet', 'BasicCNN', + 'DCVAE', 'HourglassTransformer', 'Restormer', 'SwinIR', 'UNet', 'Uformer', + 'VAE', ] diff --git a/tests/nn/nets/test_dcvae.py b/tests/nn/nets/test_dcvae.py new file mode 100644 index 000000000..ff5371b7b --- /dev/null +++ b/tests/nn/nets/test_dcvae.py @@ -0,0 +1,82 @@ +"""Tests for DCVAE network.""" + +from typing import cast + +import pytest +import torch +from mrpro.nn.nets import DCVAE + + +@pytest.mark.parametrize('torch_compile', [True, False], ids=['compiled', 'uncompiled']) +@pytest.mark.parametrize( + 'device', + [ + pytest.param('cpu', id='cpu'), + pytest.param('cuda', marks=pytest.mark.cuda, id='cuda'), + ], +) +def test_dcvae_forward(torch_compile: bool, device: str) -> None: + """Test the forward pass of the DCVAE.""" + dcvae = DCVAE( + n_dim=2, + n_channels=1, + latent_dim=4, + block_types=('CNN', 'LinearViT', 'ViT'), + widths=(32, 64, 32), + depths=(1, 2, 2), + ) + + x = torch.zeros(1, 1, 16, 16, device=device) + dcvae = dcvae.to(device) + x = x.to(device) + if torch_compile: + dcvae = cast(DCVAE, torch.compile(dcvae)) + y, kl = dcvae(x) + assert y.shape == (1, 1, 16, 16) + assert kl.shape == () + latent = dcvae.encoder(x) + assert latent.shape == (1, 2 * 4, 2, 2) # 2 because of mean and logvar + + +def test_dcvae_backward_kl() -> None: + """Test the backward pass of the DCVAE wrt kl.""" + dcvae = DCVAE( + n_dim=1, + n_channels=1, + latent_dim=4, + block_types=('CNN', 'LinearViT', 'ViT'), + widths=(8, 12, 16), + depths=(2, 2, 3), + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + + _, kl = dcvae(x) + kl.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + for name, parameter in dcvae.encoder.named_parameters(): # only the encoder parameters can influence kl + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN' + + +def test_dcvae_backward_y() -> None: + """Test the backward pass of the DCVAE wrt y.""" + dcvae = DCVAE( + n_dim=1, + n_channels=1, + latent_dim=4, + block_types=('CNN', 'LinearViT', 'ViT'), + widths=(8, 12, 16), + depths=(2, 2, 3), + ) + + x = torch.zeros(1, 1, 16, requires_grad=True) + + y, _ = dcvae(x) + y.sum().backward() + assert x.grad is not None, 'x.grad is None' + assert not x.grad.isnan().any(), 'x.grad is NaN' + for name, parameter in dcvae.named_parameters(): + assert parameter.grad is not None, f'{name}.grad is None' + assert not parameter.grad.isnan().any(), f'{name}.grad is NaN'