Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/mrpro/nn/ResBlock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Residual convolution block with two convolutions."""

import torch
from torch.nn import Identity, Module, SiLU

from mrpro.nn.CondMixin import CondMixin
from mrpro.nn.FiLM import FiLM
from mrpro.nn.GroupNorm import GroupNorm
from mrpro.nn.ndmodules import convND
from mrpro.nn.Sequential import Sequential


class ResBlock(CondMixin, Module):
"""Residual convolution block with two convolutions."""

def __init__(self, n_dim: int, n_channels_in: int, n_channels_out: int, cond_dim: int) -> None:
"""Initialize the ResBlock.

Parameters
----------
n_dim
The number of dimensions, i.e. 1, 2 or 3.
n_channels_in
The number of channels in the input tensor.
n_channels_out
The number of channels in the output tensor.
cond_dim
The number of features in the conditioning tensor used in a FiLM.
If set to 0 no FiLM is used.

"""
super().__init__()
self.rezero = torch.nn.Parameter(torch.tensor(0.1))
self.block = Sequential(
GroupNorm(n_channels_in),
SiLU(),
convND(n_dim)(n_channels_in, n_channels_out, kernel_size=3, padding=1),
GroupNorm(n_channels_out),
SiLU(),
convND(n_dim)(n_channels_out, n_channels_out, kernel_size=3, padding=1),
)
if cond_dim > 0:
self.block.insert(1, FiLM(n_channels_in, cond_dim))
self.block.insert(-2, FiLM(n_channels_out, cond_dim))

if n_channels_out == n_channels_in:
self.skip_connection: Module = Identity()
else:
self.skip_connection = convND(n_dim)(n_channels_in, n_channels_out, kernel_size=1)

def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor:
"""Apply the ResBlock.

Parameters
----------
x
The input tensor.
cond
A conditioning tensor to be used for FiLM.

Returns
-------
The output tensor.
"""
return super().__call__(x, cond=cond)

def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor:
"""Apply the ResBlock."""
h = self.block(x, cond=cond)
x = self.skip_connection(x) + self.rezero * h
return x
89 changes: 89 additions & 0 deletions src/mrpro/nn/SeparableResBlock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Residual block with separable convolutions."""

from collections.abc import Sequence

import torch
from torch.nn import Module, SiLU

from mrpro.nn.FiLM import FiLM
from mrpro.nn.GroupNorm import GroupNorm
from mrpro.nn.ndmodules import convND
from mrpro.nn.PermutedBlock import PermutedBlock
from mrpro.nn.Sequential import Sequential


class SeparableResBlock(Module):
"""Residual block with separable convolutions."""

def __init__(
self,
dim_groups: Sequence[Sequence[int]],
n_channels_in: int,
n_channels_out: int,
cond_dim: int,
) -> None:
"""Initialize the SeparableResBlock.

Applies convolutions as separable convolutions with SilU activation and group normalization.
For example, if ``dim_groups = ((-1,-2), (-3))`` then one 2D convolution is applied to the last two dimensions,
and one 1D convolution is applied to the last dimension.
The order within the block is Norm->Activation->Conv.
The whole sequence for all dimension groups is performed twice, with optional FiLM conditioning in between.
So for two `dim_groups`, a total of 4 convolutions are applied.

Parameters
----------
dim_groups
Sequence of dimension groups to use in the convolutions.
n_channels_in
Number of input channels.
n_channels_out
Number of output channels.
cond_dim
Number of channels in the conditioning tensor. If 0, no conditioning is applied.
"""
super().__init__()
self.rezero = torch.nn.Parameter(torch.tensor(0.1))

def block(dims: Sequence[int], channels_in: int) -> Module:
return Sequential(
GroupNorm(channels_in),
SiLU(),
PermutedBlock(dims, convND(len(dims))(channels_in, n_channels_out, 3, padding=1)),
)

blocks = Sequential(*(block(d, n_channels_in if i == 0 else n_channels_out) for i, d in enumerate(dim_groups)))
if cond_dim > 0:
blocks.append(FiLM(n_channels_out, cond_dim))
blocks.extend(block(d, n_channels_out) for d in dim_groups)
self.block = blocks
self.skip_connection = None
if n_channels_in != n_channels_out:
self.skip_connection = torch.nn.Linear(n_channels_in, n_channels_out)

def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor:
"""Apply the SeparableResBlock.

Parameters
----------
x
Input tensor.
cond
Conditioning tensor.

Returns
-------
Output tensor with the same number and order of dimensions as the input.
"""
return super().__call__(x, cond=cond)

def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor:
"""Apply the SeparableResBlock."""
h = self.block(x, cond=cond)
if self.skip_connection is None:
skip = x
else:
skip = torch.moveaxis(x, 1, -1)
skip = self.skip_connection(skip)
skip = torch.moveaxis(skip, -1, 1)
return skip + self.rezero * h
6 changes: 6 additions & 0 deletions src/mrpro/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
from mrpro.nn.LayerNorm import LayerNorm
from mrpro.nn.PermutedBlock import PermutedBlock
from mrpro.nn.RMSNorm import RMSNorm
from mrpro.nn.ResBlock import ResBlock
from mrpro.nn.Residual import Residual
from mrpro.nn.SeparableResBlock import SeparableResBlock
from mrpro.nn.Sequential import Sequential
from mrpro.nn import attention
from mrpro.nn import data_consistency
from mrpro.nn import nets
from mrpro.nn.ndmodules import (
adaptiveAvgPoolND,
avgPoolND,
Expand All @@ -39,7 +42,9 @@
'LayerNorm',
'PermutedBlock',
'RMSNorm',
'ResBlock',
'Residual',
'SeparableResBlock',
'Sequential',
'adaptiveAvgPoolND',
'attention',
Expand All @@ -50,4 +55,5 @@
'data_consistency',
'instanceNormND',
'maxPoolND',
'nets',
]
105 changes: 105 additions & 0 deletions src/mrpro/nn/nets/BasicCNN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Basic CNN."""

from collections.abc import Sequence
from itertools import pairwise
from typing import Literal

import torch
from torch.nn import LeakyReLU, ReLU, SiLU

from mrpro.nn.FiLM import FiLM
from mrpro.nn.GroupNorm import GroupNorm
from mrpro.nn.ndmodules import batchNormND, convND
from mrpro.nn.Sequential import Sequential


class BasicCNN(Sequential):
"""Basic CNN.

A series of convolutions (window 3, stride 1, padding 1), normalization and activation.
Allows to use FiLM conditioning.
Order is Conv -> Norm (optional) -> FiLM (optional) -> Activation.

If you need more flexibility, use `~mrpro.nn.Sequential` directly.
"""

def __init__(
self,
n_dim: int,
n_channels_in: int,
n_channels_out: int,
norm: Literal['batch', 'group', 'instance', 'none', 'layer'] = 'none',
activation: Literal['relu', 'silu', 'leaky_relu'] = 'relu',
n_features: Sequence[int] = (64, 64, 64),
cond_dim: int = 0,
):
"""Initialize a basic CNN.

Parameters
----------
n_dim
The number of spatial dimensions of the input tensor.
n_channels_in
The number of input channels.
n_channels_out
The number of output channels.
norm
The type of normalization to use. If 'batch', use batch normalization. If 'group', use group normalization,
if 'instance', use instance normalization, and if `layer`, use layer normalization.
If 'none', use no normalization.
activation
The type of activation to use. If 'relu', use ReLU. If 'silu', use SiLU. If 'leaky_relu', use LeakyReLU.
n_features
The number of features in the hidden layers. The length of this sequence determines the number of hidden
layers. The total number of convolutions is `len(n_features) + 1`.
cond_dim
The dimension of the condition tensor. If 0, no FiLM conditioning is applied.
Otherwise, between convolutions, after normalization, FiLM conditioning is applied.
"""
super().__init__()
use_film = cond_dim > 0

self.append(convND(n_dim)(n_channels_in, n_features[0], kernel_size=3, padding='same'))

for c_in, c_out in pairwise((*n_features, n_channels_out)):
if norm.lower() == 'batch':
self.append(batchNormND(n_dim)(c_in, affine=not use_film))
elif norm.lower() == 'group':
self.append(GroupNorm(c_in, affine=not use_film))
elif norm.lower() == 'instance':
self.append(GroupNorm(c_in, n_groups=c_in, affine=not use_film)) # is instance norm
elif norm.lower() == 'layer':
self.append(GroupNorm(c_in, n_groups=1, affine=not use_film)) # is layer norm
elif norm.lower() != 'none':
raise ValueError(f'Invalid normalization type: {norm}')

if use_film:
self.append(FiLM(c_in, cond_dim))

if activation.lower() == 'relu':
self.append(ReLU(True))
elif activation.lower() == 'silu':
self.append(SiLU(inplace=True))
elif activation.lower() == 'leaky_relu':
self.append(LeakyReLU(inplace=True))
else:
raise ValueError(f'Invalid activation type: {activation}')

self.append(convND(n_dim)(c_in, c_out, kernel_size=3, padding='same'))

def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor: # type: ignore[override]
"""Apply the basic CNN to the input tensor.

Parameters
----------
x
The input tensor. Should be of shape `(batch_size, channels_in, *spatial dimensions)`
with `spatial dimensions` being of length `dim`.
cond
The condition tensor. If None, no FiLM conditioning is applied.

Returns
-------
The output tensor.
"""
return super().__call__(x, cond=cond)
Loading
Loading