Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ docs = [
"sphinx-autodoc-typehints>=3, <3.1",
"sphinx-copybutton>=0.5, <0.6",
"sphinx-last-updated-by-git>=0.3, <0.4",
"snowballstemmer>=2.2, <3.0",
]
notebooks = [
"zenodo_get>=2.0",
Expand All @@ -119,10 +120,12 @@ testpaths = ["tests"]
filterwarnings = [
"error",
"ignore:'write_like_original':DeprecationWarning:pydicom:",
"ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd
"ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5
"ignore:Anomaly Detection has been enabled:UserWarning", # torch.autograd
"ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning", # einops and dynamo<2.5
"ignore:TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled:UserWarning", # torch cuda
"ignore:.*In the future, this object will be coerced as if it was first converted using.*:FutureWarning", # numpy 1.2
"ignore:`torch.jit.script` is deprecated:DeprecationWarning", # torch 2.10
"ignore:.*load_module.*:DeprecationWarning", # torch compile in torch<2.6
"ignore:`torch.jit.script` is deprecated:DeprecationWarning", # torch 2.10
]
addopts = "-n auto --dist loadfile --maxprocesses=8"
markers = ["cuda : Tests only to be run when cuda device is available"]
Expand Down Expand Up @@ -232,6 +235,7 @@ arange = "arange" # torch.arange
Ba = "Ba"
wht = "wht" # Brainweb tissue class
nd = "nd" # pad_nd function from torchnd
ND = "ND" # Short for N-dimensional

[tool.typos.files]
extend-exclude = [
Expand Down
6 changes: 4 additions & 2 deletions src/mrpro/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from mrpro._version import __version__
from mrpro import algorithms, operators, data, phantoms, utils
from mrpro import algorithms, operators, data, phantoms, utils, nn

__all__ = [
"__version__",
"algorithms",
"data",
"nn",
"operators",
"phantoms",
"utils"
]
]
59 changes: 59 additions & 0 deletions src/mrpro/nn/ComplexAsChannel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""ComplexAsChannel: handling complex-valued tensors as channels."""

import torch
from einops import rearrange
from torch.nn import Module

from mrpro.nn.CondMixin import CondMixin, call_with_cond


class ComplexAsChannel(CondMixin, Module):
"""Wrap module to treat complex numbers as a channel dimension."""

def __init__(self, module: Module, convert_back: bool = True):
"""Initialize the ComplexAsChannel module.

Wraps a module to treat complex numbers as a channel dimension.
For each complex tensor in the input, real and imaginary parts are concatenated along the channel dimension
before being passed to the wrapped module.


Parameters
----------
module
The module to wrap. Should output a single real tensor.
convert_back
If True, the output is converted back to a complex tensor.
The output should have a number of channels that is a multiple of 2.
"""
super().__init__()
self.module = module
self.convert_back = convert_back

def __call__(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor:
"""Apply the module.

Parameters
----------
x
The input tensor.
cond
The conditioning tensor (if used by the wrapped module)
"""
return super().__call__(*x, cond=cond)

def forward(self, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor:
"""Apply the module."""
x_real = [
rearrange(torch.view_as_real(c), 'batch channel ... complex -> batch (channel complex) ...')
if c.is_complex()
else c
for c in x
]

y = call_with_cond(self.module, *x_real, cond=cond)

if self.convert_back:
y = rearrange(y, 'b (channel complex) ... -> b channel ... complex', complex=2).contiguous()
y = torch.view_as_complex(y)
return y
22 changes: 22 additions & 0 deletions src/mrpro/nn/CondMixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Base class for modules using a conditioning."""

import torch
from torch.nn import Module


def call_with_cond(module: Module, *x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor:
"""Call a module with conditioning if it is a CondMixin."""
if isinstance(module, CondMixin):
return module(*x, cond=cond)
return module(*x)


class CondMixin(Module):
"""Mixin for modules using a conditioning.

Used to determine if a module uses a conditioning within a Sequential container.
"""

def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor:
"""Apply the module to the input."""
return super().__call__(x, cond=cond)
55 changes: 55 additions & 0 deletions src/mrpro/nn/DropPath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""DropPath (stochastic depth)."""

import torch
from torch.nn import Module


class DropPath(Module):
"""Drop path or stochastic depth.

Drops full samples from batch with probability `droprate`.
Should be used in the main path of a Resblock.

References
----------
.. [HUANG16] Huang, G., Sun, Y., Liu, Z., Sedra, D., & Weinberger, K. Q. Deep networks with stochastic depth.
ECCV 2016. https://link.springer.com/chapter/10.1007/978-3-319-46493-0_39
"""

def __init__(self, droprate: float = 0.0, scale_by_keep: bool = False):
"""Initialize the DropPath module.

Parameters
----------
droprate
Drop probability
scale_by_keep
If True, the kept samples are scaled by :math:`1/(1-droprate)`
"""
super().__init__()
self.droprate = droprate
self.scale_by_keep = scale_by_keep

def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""Apply DropPath.

Parameters
----------
x
Input tensor

Returns
-------
Tensor with batch samples randomly dropped
"""
return super().__call__(x)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply DropPath."""
if self.droprate == 0 or not self.training:
return x
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
mask = ((1 - self.droprate) + torch.rand(shape, dtype=x.dtype, device=x.device)).floor_()
if self.scale_by_keep:
mask = mask.div_(1 - self.droprate)
return x * mask
68 changes: 68 additions & 0 deletions src/mrpro/nn/FiLM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Feature-wise Linear Modulation."""

import torch
from torch.nn import Linear, Module

from mrpro.nn.CondMixin import CondMixin
from mrpro.utils.reshape import unsqueeze_tensors_right


class FiLM(CondMixin, Module):
"""Feature-wise Linear Modulation.

Feature-wise Linear Modulation from [FiLM]_ to condition a network on a conditioning tensor.


References
----------
..[FiLM] Perez, L., Strub, F., de Vries, H., Dumoulin, V., & Courville, A. "FiLM : Visual reasoning with a general
conditioning layer." AAAI (2018). https://arxiv.org/abs/1709.07871
"""

features_last: bool

def __init__(self, channels: int, cond_dim: int, features_last: bool = False) -> None:
"""Initialize FiLM.

Parameters
----------
channels
The number of channels in the input tensor.
cond_dim
The dimension of the conditioning tensor.
features_last
Whether the features are in the last dimension of the input tensor (e.g. transformer tokens)
or in the second dimension (e.g. image tensors).
"""
super().__init__()
self.project = Linear(cond_dim, 2 * channels) if cond_dim > 0 else None
self.features_last = features_last

def __call__(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor:
"""Apply FiLM.

Parameters
----------
x
The input tensor.
cond
The conditioning tensor.
"""
return super().__call__(x, cond=cond)

def forward(self, x: torch.Tensor, *, cond: torch.Tensor | None = None) -> torch.Tensor:
"""Apply FiLM."""
if cond is None or self.project is None:
return x

if self.features_last:
x = x.moveaxis(-1, 1)

scale, shift = self.project(cond).chunk(2, dim=1)
scale, shift = unsqueeze_tensors_right(scale, shift, ndim=x.ndim)
x = x * (1 + scale) + shift

if self.features_last:
x = x.moveaxis(1, -1)

return x
50 changes: 50 additions & 0 deletions src/mrpro/nn/FourierFeatures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Random Fourier feature embedding."""

import torch
from torch.nn import Module


class FourierFeatures(Module):
"""Fourier feature encoding layer.

Projects input features into a higher dimensional space using random Fourier features.
Used in INRs and to embed the time or other continuous variables.
"""

weight: torch.Tensor

def __init__(self, n_features_in: int, n_features_out: int, std: float = 1.0):
"""Initialize Fourier feature encoding layer.

Parameters
----------
n_features_in
Number of input features
n_features_out
Number of output features (must be even)
std
Standard deviation for random initialization
"""
if n_features_out % 2 != 0:
raise ValueError('n_features_out must be even.')
super().__init__()
self.register_buffer('weight', torch.randn([n_features_out // 2, n_features_in]) * std)

def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Fourier feature encoding.

Parameters
----------
x
Input tensor of shape (..., in_features)

Returns
-------
Encoded features of shape (..., out_features)
"""
return super().__call__(x)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Fourier feature encoding."""
f = 2 * torch.pi * x @ self.weight.T
return torch.cat([f.cos(), f.sin()], dim=-1)
56 changes: 56 additions & 0 deletions src/mrpro/nn/GEGLU.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Gated linear unit activation function."""

import torch
from torch.nn import Linear, Module


class GEGLU(Module):
r"""Gated linear unit activation function.

References
----------
..[GLU] Shazeer, N. (2020). GLU variants improve transformer. https://arxiv.org/abs/2002.05202
"""

def __init__(self, n_channels_in: int, n_channels_out: int | None = None, features_last: bool = False):
"""Initialize the GEGLU activation function.

Parameters
----------
n_channels_in
The number of input features/channels.
n_channels_out
The number of output features/channels. If None, the number of
output features is the same as the number of input features.
features_last
If True, the channel dimension is the last dimension, else in the second dimension.
"""
super().__init__()
out_channels_ = n_channels_in if n_channels_out is None else n_channels_out
self.proj = Linear(n_channels_in, out_channels_ * 2) # gate and output stacked
self.features_last = features_last

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the GEGLU activation."""
if not self.features_last:
x = x.moveaxis(1, -1)
h, gate = self.proj(x).chunk(2, dim=-1)
gate = torch.nn.functional.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
out = h * gate
if not self.features_last:
out = out.moveaxis(-1, 1)
return out

def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the GEGLU activation.

Parameters
----------
x
Input tensor

Returns
-------
Activated tensor
"""
return super().__call__(x)
Loading
Loading