Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
231 commits
Select commit Hold shift + click to select a range
2be3c4f
cg matrix op
fzimmermann89 Apr 7, 2025
1c31fed
first draft
fzimmermann89 Apr 6, 2025
aaff916
types
fzimmermann89 Apr 7, 2025
a336ded
test
fzimmermann89 Apr 16, 2025
5d6fe98
update
fzimmermann89 Apr 23, 2025
dc6429d
update
fzimmermann89 Apr 23, 2025
c89a25a
fix merge
fzimmermann89 Apr 23, 2025
391979e
cleanup
fzimmermann89 Apr 27, 2025
c3ffa9c
fix miniml
fzimmermann89 Apr 27, 2025
d9c024a
fix
fzimmermann89 Apr 28, 2025
97f102b
py310
fzimmermann89 Apr 28, 2025
19087c9
py310
fzimmermann89 Apr 28, 2025
41b5aec
py310
fzimmermann89 Apr 29, 2025
3118792
py310
fzimmermann89 Apr 29, 2025
6454662
pyr310
fzimmermann89 Apr 29, 2025
c9e7160
Merge branch 'main' into diff
fzimmermann89 Apr 29, 2025
c8f51ec
review
fzimmermann89 Apr 29, 2025
227646a
norm
fzimmermann89 Apr 29, 2025
78b570d
fix rhs norm zero
fzimmermann89 Apr 29, 2025
637cdf0
Merge branch 'main' into diff
fzimmermann89 Apr 29, 2025
904f3c9
fix doc
fzimmermann89 May 10, 2025
a458855
start
fzimmermann89 May 10, 2025
7e83be7
update
fzimmermann89 May 12, 2025
26467bf
update
fzimmermann89 May 13, 2025
c39a9af
update
fzimmermann89 May 13, 2025
633682b
update
fzimmermann89 May 13, 2025
420cdc1
update
fzimmermann89 May 13, 2025
9cfae55
update
fzimmermann89 May 14, 2025
cf4be7f
uformer
fzimmermann89 May 14, 2025
54a66b6
fix
fzimmermann89 May 14, 2025
3ae37d1
update
fzimmermann89 May 15, 2025
912d7c8
update
fzimmermann89 May 15, 2025
b6a1db3
doc
fzimmermann89 May 16, 2025
3d01220
update
fzimmermann89 May 18, 2025
7f37fa9
update
fzimmermann89 May 18, 2025
97115e7
update
fzimmermann89 May 19, 2025
33d9557
update
fzimmermann89 May 19, 2025
62e04a1
update
fzimmermann89 May 19, 2025
7e9d121
update
fzimmermann89 May 19, 2025
3d259bb
update
fzimmermann89 May 20, 2025
52c8630
update
fzimmermann89 May 20, 2025
d626bbb
update
fzimmermann89 May 20, 2025
01881fe
Refactor imports to use lowercase 'ndmodules' and update method signa…
fzimmermann89 May 21, 2025
4f6a603
Enhance conversion functions between Linear and Conv layers by refini…
fzimmermann89 May 21, 2025
7d608b5
Refactor method signatures in neural network modules to use keyword-o…
fzimmermann89 May 21, 2025
3757b1f
Add EMADict class for Exponential Moving Average functionality and up…
fzimmermann89 May 22, 2025
f3aaa6a
Refactor EfficientViTBlock and Encoder/Decoder stages to use dynamic …
fzimmermann89 May 22, 2025
4c49734
Refactor Restormer and Uformer networks to utilize UNetEncoder and UN…
fzimmermann89 May 22, 2025
eafbfc6
Add SpatialTransformerBlock and integrate into UNet architecture
fzimmermann89 May 22, 2025
afd7a45
Refactor FiLM and Uformer modules for improved clarity and functionality
fzimmermann89 May 22, 2025
24bfbc9
Refactor ZeroPadOp and pad_or_crop utility for improved functionality…
fzimmermann89 May 22, 2025
376be37
Add Upsample module for tensor resizing functionality
fzimmermann89 Jun 2, 2025
b1ff7f8
Refactor neural network modules to standardize feature dimension hand…
fzimmermann89 Jun 2, 2025
dca726b
Enhance UNet architecture with improved attention handling and modula…
fzimmermann89 Jun 2, 2025
59569dc
- Updated AttentionGate to include a new 'concatenate' parameter, all…
fzimmermann89 Jun 2, 2025
6b942ca
wip
fzimmermann89 Jun 2, 2025
ea3109e
update
fzimmermann89 Jun 2, 2025
4c1ec0f
Refactor parameter documentation in encoding and normalization modules
fzimmermann89 Jun 3, 2025
e20d6f7
wip
fzimmermann89 Jun 3, 2025
c7e588e
Refactor AttentionGate and DropPath modules for improved functionality
fzimmermann89 Jun 4, 2025
10f1994
Refactor import statements and enhance EMA documentation
fzimmermann89 Jun 4, 2025
07f9f7a
Update src/mrpro/operators/ConjugateGradientOp.py
fzimmermann89 Jun 10, 2025
44c8ac0
Apply suggestions from code review
fzimmermann89 Jun 10, 2025
edb3b0f
review
fzimmermann89 Jun 10, 2025
ca88e3a
review
fzimmermann89 Jun 10, 2025
630f146
Merge branch 'main' into diff
fzimmermann89 Jun 10, 2025
8d8667a
Merge branch 'main' into diff
fzimmermann89 Jun 10, 2025
8c9943f
docstring
fzimmermann89 Jun 13, 2025
fb6eb41
update
fzimmermann89 Jun 22, 2025
aaa68e9
separable
fzimmermann89 Jun 23, 2025
8209d11
fix
fzimmermann89 Jun 27, 2025
71850ea
Merge branch 'main' into pinqi
fzimmermann89 Jun 27, 2025
6d85954
Merge branch 'diff' into pinqi
fzimmermann89 Jun 27, 2025
41e4216
wip
fzimmermann89 Jun 28, 2025
a242e55
add
fzimmermann89 Jun 28, 2025
6861ca9
change tol
fzimmermann89 Jul 1, 2025
eadd3a6
fix brainweb
fzimmermann89 Jul 1, 2025
7d9dd14
fix sat rec
fzimmermann89 Jul 1, 2025
1322b51
fix csmdata init typing
fzimmermann89 Jul 1, 2025
d0d333e
fix nn
fzimmermann89 Jul 1, 2025
c834065
train_pinqi
fzimmermann89 Jul 1, 2025
096038e
update nn
fzimmermann89 Jul 1, 2025
a66f0db
update dataclass
fzimmermann89 Jul 1, 2025
4b9508d
train pinqi
fzimmermann89 Jul 1, 2025
da8aef1
update pinqi
fzimmermann89 Jul 1, 2025
dfa7282
modl
fzimmermann89 Jul 2, 2025
78a6322
oberator subtractino
fzimmermann89 Jul 3, 2025
5b40ad1
fastmri: fix padding undo
fzimmermann89 Jul 3, 2025
cc34beb
fix dataclass error
fzimmermann89 Jul 3, 2025
777443f
nn
fzimmermann89 Jul 3, 2025
09b70c5
inati: no nans
fzimmermann89 Jul 3, 2025
b0a8517
modl
fzimmermann89 Jul 3, 2025
816f3a3
pinqi
fzimmermann89 Jul 4, 2025
d8bb305
fix ssim
fzimmermann89 Jul 4, 2025
6b46b3a
fix cg
fzimmermann89 Jul 3, 2025
7cd0d7f
modl
fzimmermann89 Jul 4, 2025
3a91b5d
fix test
fzimmermann89 Jul 8, 2025
3b17a71
apply pinqi
fzimmermann89 Jul 8, 2025
08cbeca
Merge branch 'main' into pinqi
fzimmermann89 Jul 8, 2025
1be7d32
update
fzimmermann89 Jul 10, 2025
c15cb10
Merge branch 'main' into pinqi
fzimmermann89 Jul 11, 2025
db742b0
Merge branch 'main' into nn
fzimmermann89 Jul 11, 2025
3386923
Squashed commit of the following:
fzimmermann89 Jul 11, 2025
d6cb116
Squashed commit of the following:
fzimmermann89 Jul 11, 2025
2270236
pull changes form pinqi
fzimmermann89 Jul 11, 2025
2311625
Merge branch 'nn' into pinqi_new
fzimmermann89 Jul 11, 2025
9310fb5
simpliy unet
fzimmermann89 Jul 12, 2025
1dc3d9b
Squashed commit of the following:
fzimmermann89 Jul 12, 2025
25914d1
Merge remote-tracking branch 'origin/nn' into pinqi_new
fzimmermann89 Jul 12, 2025
12f41b8
simplidy unet
fzimmermann89 Jul 12, 2025
a17e507
fewer param in pinqi
fzimmermann89 Jul 12, 2025
132491e
simplidy unet
fzimmermann89 Jul 12, 2025
a4e46f6
update
fzimmermann89 Jul 13, 2025
c31a2ee
update
fzimmermann89 Jul 14, 2025
da5baff
Refactor variable names in GluMBConvResBlock and PixelShuffle classes…
fzimmermann89 Jul 14, 2025
79124d1
tests
fzimmermann89 Jul 14, 2025
b3c8d40
fix swin attention
fzimmermann89 Jul 14, 2025
a748ced
allow reflection padding
fzimmermann89 Jul 14, 2025
3e55d01
Add RMSNorm module, update NeighborhoodSelfAttention to accept device…
fzimmermann89 Jul 14, 2025
019fe5c
Merge branch 'main' into nn
fzimmermann89 Jul 14, 2025
f734e3f
Add LinearSelfAttention module, update join.py to accept string mode,…
fzimmermann89 Jul 14, 2025
e91b6dc
bump torch version
fzimmermann89 Jul 14, 2025
b387ca5
fix pad
fzimmermann89 Jul 14, 2025
233ecf5
include build essentials
fzimmermann89 Jul 15, 2025
1c8b52e
Merge branch 'build_essentials' into nn
fzimmermann89 Jul 15, 2025
0edd9e7
dev and full
fzimmermann89 Jul 15, 2025
a8851a7
Merge branch 'build_essentials' into nn
fzimmermann89 Jul 15, 2025
be7c9a1
fix
fzimmermann89 Jul 15, 2025
a208315
Merge branch 'reflect_pad' into nn
fzimmermann89 Jul 15, 2025
ef42f4f
update dependencies
fzimmermann89 Jul 16, 2025
3bbd66e
Add tests for NeighborhoodSelfAttention module
fzimmermann89 Jul 16, 2025
a2f7eb6
move attention
fzimmermann89 Jul 17, 2025
2bfba60
dc
fzimmermann89 Jul 17, 2025
396d4b6
fix
fzimmermann89 Jul 17, 2025
8311bc5
dc
fzimmermann89 Jul 17, 2025
5c31952
fix
fzimmermann89 Jul 17, 2025
6495eb0
Add fully sampled Cartesian trajectory generation and improve error h…
fzimmermann89 Jul 17, 2025
089d312
docstring
fzimmermann89 Jul 17, 2025
1959fda
fix test
fzimmermann89 Jul 17, 2025
1a86c3a
Merge branch 'phantom' into nn
fzimmermann89 Jul 17, 2025
956fc2f
Refactor data consistency modules and enhance attention imports
fzimmermann89 Jul 17, 2025
fd67024
dc
fzimmermann89 Jul 17, 2025
e629a55
update
fzimmermann89 Jul 17, 2025
3b92c67
rope
fzimmermann89 Jul 21, 2025
cb08caf
rope
fzimmermann89 Jul 21, 2025
bf63a4b
hourglass v1
fzimmermann89 Jul 21, 2025
6257723
Add AxialRoPE to nn module and introduce Interpolate class
fzimmermann89 Jul 21, 2025
93666bb
docstrings
fzimmermann89 Jul 21, 2025
08e20b7
fix unet
fzimmermann89 Jul 21, 2025
88fe7a2
cahnge rope shape
fzimmermann89 Jul 21, 2025
d5895ce
fix restormer
fzimmermann89 Jul 21, 2025
5526bac
tests
fzimmermann89 Jul 21, 2025
772529c
Merge branch 'main' into nn
fzimmermann89 Jul 21, 2025
f63e059
update
fzimmermann89 Jul 21, 2025
b9a3d4e
fixes
fzimmermann89 Jul 21, 2025
da5cc63
update test
fzimmermann89 Jul 22, 2025
fb0e0a4
fix
fzimmermann89 Jul 22, 2025
6882b6f
encodings
fzimmermann89 Jul 22, 2025
dc911f5
nocover
fzimmermann89 Jul 22, 2025
de43dce
hourglass
fzimmermann89 Jul 22, 2025
aa39110
ignore tensorcode warning in tests
fzimmermann89 Jul 22, 2025
a3187ae
filter warning
fzimmermann89 Jul 22, 2025
a228291
docstring
fzimmermann89 Jul 22, 2025
74b240c
formatting
fzimmermann89 Jul 22, 2025
3ee4c76
formatting
fzimmermann89 Jul 22, 2025
f624cb7
typo
fzimmermann89 Jul 22, 2025
4a60856
fix NA
fzimmermann89 Jul 23, 2025
7582506
Merge branch 'main' into nn
fzimmermann89 Jul 23, 2025
09606f5
python 2.3
fzimmermann89 Jul 24, 2025
3f7f36a
torch filter
fzimmermann89 Jul 24, 2025
ba2588f
Merge branch 'main' into nn
fzimmermann89 Jul 24, 2025
6bc0bf6
fix
fzimmermann89 Jul 24, 2025
c466ce3
fix tocvhvesion version
fzimmermann89 Jul 24, 2025
2480a08
version filter
fzimmermann89 Jul 24, 2025
f6ca670
fix
fzimmermann89 Jul 24, 2025
9718c3c
fix
fzimmermann89 Jul 24, 2025
c130c9e
fix
fzimmermann89 Jul 24, 2025
f88b546
fix
fzimmermann89 Jul 24, 2025
920ed5a
fix
fzimmermann89 Jul 24, 2025
f9465bc
Merge branch 'main' into nn
fzimmermann89 Jul 28, 2025
ae22196
fix?
fzimmermann89 Jul 28, 2025
d927ffa
fix?
fzimmermann89 Jul 28, 2025
8365752
fix
fzimmermann89 Jul 28, 2025
3ca441d
fix??
fzimmermann89 Jul 28, 2025
263fec7
fix
fzimmermann89 Jul 28, 2025
5cb8f72
Add SeparableResBlock implementation and corresponding tests
fzimmermann89 Jul 28, 2025
180048b
fix cuda
fzimmermann89 Jul 28, 2025
5444f4f
mypy
fzimmermann89 Jul 28, 2025
a0d0fe8
fix?
fzimmermann89 Jul 28, 2025
a23d5da
test
fzimmermann89 Jul 28, 2025
c8c91e3
test
fzimmermann89 Jul 28, 2025
e8bd9a3
fix?
fzimmermann89 Jul 28, 2025
d52e688
try
fzimmermann89 Jul 29, 2025
cd9541c
ignore warning
fzimmermann89 Jul 29, 2025
44a6032
Merge branch 'main' into nn
fzimmermann89 Jul 29, 2025
de08c3c
cleanup
fzimmermann89 Jul 29, 2025
51a8960
rename
fzimmermann89 Jul 29, 2025
d9928bf
text
fzimmermann89 Jul 29, 2025
e8c2106
Merge branch 'nn' into pinqi
fzimmermann89 Jul 29, 2025
3b65a69
Merge branch 'pinqi' into pinqi_new
fzimmermann89 Jul 29, 2025
4b8a52b
add core nn foundations, layers, and resize blocks
fzimmermann89 Feb 8, 2026
1e186e5
add data consistency modules and tests
fzimmermann89 Feb 8, 2026
9376390
add positional encodings and attention modules
fzimmermann89 Feb 8, 2026
1513fb3
add unet, basic cnn, and residual blocks
fzimmermann89 Feb 8, 2026
100057b
add restormer architecture and tests
fzimmermann89 Feb 8, 2026
497e58f
add swinir architecture and tests
fzimmermann89 Feb 8, 2026
920d1f6
add uformer architecture and tests
fzimmermann89 Feb 8, 2026
f9baf2a
add hourglass transformer architecture and tests
fzimmermann89 Feb 8, 2026
5aedd68
add vae and dcvae architectures with mbconv block
fzimmermann89 Feb 8, 2026
ffe4850
Merge remote-tracking branch 'origin/nn-stacked-v2' into pinqi_new
fzimmermann89 Feb 9, 2026
515e1ec
fix precommit
fzimmermann89 Feb 9, 2026
102dc44
Brainweb: Add ULF values
fzimmermann89 Feb 9, 2026
863edae
pinqi
fzimmermann89 Feb 9, 2026
3e85c27
Add fast path to PatchOp adjoint
fzimmermann89 Feb 10, 2026
c0be007
add core nn foundations, layers, and resize blocks
fzimmermann89 Feb 10, 2026
05db778
add data consistency modules and tests
fzimmermann89 Feb 10, 2026
d8eb662
add positional encodings and attention modules
fzimmermann89 Feb 10, 2026
88e72b7
add unet, basic cnn, and residual blocks
fzimmermann89 Feb 10, 2026
774db72
Add MLP network and tests
fzimmermann89 Feb 10, 2026
7ea4450
add restormer architecture and tests
fzimmermann89 Feb 10, 2026
650bfe5
add swinir architecture and tests
fzimmermann89 Feb 10, 2026
9ce3038
add uformer architecture and tests
fzimmermann89 Feb 10, 2026
6b51894
add hourglass transformer architecture and tests
fzimmermann89 Feb 10, 2026
fccecc4
add vae
fzimmermann89 Feb 10, 2026
512c4c9
add dit
fzimmermann89 Feb 10, 2026
495d916
Squashed commit of the following:
fzimmermann89 Feb 10, 2026
8636290
Merge remote-tracking branch 'origin/nn-stacked-v2' into pinqi_new
fzimmermann89 Feb 10, 2026
3ed29c9
update
fzimmermann89 Feb 10, 2026
6431fc4
cleanup
fzimmermann89 Feb 10, 2026
e55d062
[pre-commit] auto fixes from pre-commit hooks
pre-commit-ci[bot] Feb 10, 2026
ae7d128
Merge branch 'nn-stacked-v2' into pinqi_new
fzimmermann89 Feb 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
609 changes: 609 additions & 0 deletions examples/notebooks/apply_pinqi.ipynb

Large diffs are not rendered by default.

266 changes: 266 additions & 0 deletions examples/notebooks/modl.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PTB-MR/mrpro/blob/main/examples/notebooks/modl.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"import importlib\n",
"\n",
"if not importlib.util.find_spec('mrpro'):\n",
" %pip install mrpro[notebooks]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"from collections.abc import Sequence\n",
"from pathlib import Path\n",
"from typing import TypedDict\n",
"\n",
"import matplotlib.axes\n",
"import matplotlib.pyplot as plt\n",
"import mrpro\n",
"import torch\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"class BatchType(TypedDict):\n",
" \"\"\"A single Batch.\"\"\"\n",
"\n",
" data: mrpro.data.KData\n",
" target: mrpro.data.IData\n",
" csm: mrpro.data.CsmData\n",
"\n",
"\n",
"class AcceleratedFastMRI(torch.utils.data.Dataset):\n",
" \"\"\"An undersampled FastMRI Dataset.\"\"\"\n",
"\n",
" def __init__(self, path: Path, acceleration: float = 12, noise_level: float = 0.1):\n",
" \"\"\"Create an undersampled FastMRI Dataset.\n",
"\n",
" Parameters\n",
" ----------\n",
" path\n",
" Path to the FastMRI dataset.\n",
" acceleration\n",
" Undersampling factor; higher values mean more acceleration. Default is 12.\n",
" noise_level\n",
" Level of additive Gaussian noise applied to the FastMRI dataset. Default is 0.1.\n",
" \"\"\"\n",
" self.acceleration = acceleration\n",
" files = list(path.glob('*AXT1*'))\n",
" self.dataset = mrpro.phantoms.FastMRIKDataDataset(files)\n",
" self.noise_level = noise_level\n",
"\n",
" def __len__(self):\n",
" \"\"\"Get length of the dataset.\"\"\"\n",
" return len(self.dataset)\n",
"\n",
" def __getitem__(self, index: int) -> BatchType:\n",
" \"\"\"Get a single batch of data.\n",
"\n",
" Parameters\n",
" ----------\n",
" index\n",
" Index of the batch.\n",
"\n",
" Returns\n",
" -------\n",
" A single batch of data with keys 'data', 'target', and 'csm'.and\n",
" \"\"\"\n",
" data = self.dataset[index]\n",
" data = data.remove_readout_os()\n",
" data.data /= data.data.std()\n",
" reconstruction = mrpro.algorithms.reconstruction.DirectReconstruction(\n",
" data, csm=lambda data: mrpro.data.CsmData.from_idata_inati(data, downsampled_size=64)\n",
" )\n",
" csm = reconstruction.csm\n",
" target = reconstruction(data)\n",
"\n",
" n = max(data.data.shape[-2:])\n",
" distance = (torch.linspace(-1, 1, n)[:, None] ** 2 + torch.linspace(-1, 1, n) ** 2).sqrt()\n",
" random = 0.1 / (distance + 0.1) + torch.rand_like(distance)\n",
" threshold = torch.kthvalue(random.ravel(), int(n**2 * (1 - 1 / self.acceleration))).values\n",
" undersampling_mask = mrpro.utils.pad_or_crop(random > threshold, data.data.shape[-2:])\n",
" data_undersampled = data[..., undersampling_mask].rearrange('k ... 1 -> ... k')\n",
"\n",
" noise = mrpro.utils.RandomGenerator(seed=index).randn_like(data_undersampled.data)\n",
" data_undersampled.data += self.noise_level * noise\n",
"\n",
" assert csm is not None # for mypy\n",
" return {'data': data_undersampled, 'target': target, 'csm': csm}\n",
"\n",
"\n",
"class MODL(torch.nn.Module):\n",
" \"\"\"MODL network.\"\"\"\n",
"\n",
" def __init__(self, iterations: int = 8, n_features: Sequence[int] = (64, 64, 64, 64)):\n",
" \"\"\"Initialize MODL network.\n",
"\n",
" Parameters\n",
" ----------\n",
" iterations\n",
" Number of iterations.\n",
" n_features\n",
" Number of features in the network.\n",
" \"\"\"\n",
" super().__init__()\n",
" cnn = mrpro.nn.nets.BasicCNN(\n",
" dim=2,\n",
" channels_in=2,\n",
" channels_out=2,\n",
" n_features=n_features,\n",
" batch_norm=True,\n",
" )\n",
" self.network = mrpro.nn.Residual(mrpro.nn.ComplexAsChannel(mrpro.nn.PermutedBlock((-1, -2), cnn)))\n",
" self.network = torch.compile(self.network, dynamic=True, fullgraph=True)\n",
" self.iterations = iterations\n",
" self.regularization_weights = torch.nn.Parameter(0.2 * torch.ones(iterations))\n",
"\n",
" def __call__(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData:\n",
" \"\"\"Apply MODL network.\n",
"\n",
" Parameters\n",
" ----------\n",
" kdata\n",
" The k-space data.\n",
" csm\n",
" The coil sensitivity maps.\n",
"\n",
" Returns\n",
" -------\n",
" The reconstructed image.\n",
" \"\"\"\n",
" return super().__call__(kdata, csm)\n",
"\n",
" def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData:\n",
" \"\"\"Apply the MODL network.\"\"\"\n",
" fourier_op = mrpro.operators.FourierOp.from_kdata(kdata)\n",
" acquisition_op = fourier_op @ csm.as_operator()\n",
" (zero_filled_image,) = acquisition_op.H(kdata.data)\n",
" gram = acquisition_op.gram\n",
" data_consistency_op = mrpro.operators.ConjugateGradientOp(\n",
" operator_factory=lambda _image, weight: gram + weight,\n",
" rhs_factory=lambda image, weight: zero_filled_image + weight * image,\n",
" )\n",
"\n",
" (image,) = mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=5)\n",
" for iteration in range(self.iterations):\n",
" regularization = self.network(image)\n",
" (image,) = data_consistency_op(regularization, self.regularization_weights[iteration])\n",
"\n",
" return mrpro.data.IData(image, header=mrpro.data.IHeader.from_kheader(kdata.header))\n",
"\n",
"\n",
"def plot(batch: BatchType, prediction: mrpro.data.IData, step: int) -> None:\n",
" \"\"\"Plot the direct, sense, and modl reconstructions.\"\"\"\n",
" target = batch['target'].rss().cpu().squeeze()\n",
" direct = mrpro.algorithms.reconstruction.DirectReconstruction(batch['data'], csm=batch['csm'])(batch['data'])\n",
" direct = direct.rss().cpu().squeeze()\n",
" direct *= target.std() / direct.std()\n",
" sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(batch['data'], csm=batch['csm'])(batch['data'])\n",
" sense = sense.rss().cpu().squeeze()\n",
" prediction_ = prediction.rss().cpu().squeeze().detach()\n",
"\n",
" ssim = mrpro.operators.functionals.SSIM(mrpro.utils.pad_or_crop(target[None], (320, 320)))\n",
"\n",
" def show(ax: matplotlib.axes.Axes, data: torch.Tensor, label: str):\n",
" data = mrpro.utils.pad_or_crop(data, (320, 320))\n",
" ax.imshow(data, vmin=0, vmax=target.max().item(), cmap='gray')\n",
" if label != 'Ground Truth':\n",
" (ssim_value,) = ssim(data[None])\n",
" ax.text(\n",
" 0.98,\n",
" 0.1,\n",
" f'SSIM: {ssim_value.item():.2f}',\n",
" color='white',\n",
" horizontalalignment='right',\n",
" verticalalignment='top',\n",
" transform=ax.transAxes,\n",
" )\n",
" ax.set_title(label)\n",
" ax.set_axis_off()\n",
"\n",
" fig, ax = plt.subplots(1, 4)\n",
" show(ax[0], direct, 'Direct')\n",
" show(ax[1], sense, 'CG-SENSE')\n",
" show(ax[2], prediction_, 'MODL')\n",
" show(ax[3], target, 'Ground Truth')\n",
" fig.tight_layout()\n",
" fig.savefig(f'modl_{step}.pdf', bbox_inches='tight', pad_inches=0)\n",
"\n",
"\n",
"# %%.\n",
"path = Path('/echo/allgemein/resources/publicTrainingData/fastmri/brain_multicoil_train/')\n",
"dataset = AcceleratedFastMRI(path)\n",
"dataloader = torch.utils.data.DataLoader(dataset, num_workers=16, shuffle=True, collate_fn=lambda batch: batch[0])\n",
"modl = MODL().cuda()\n",
"optimizer = torch.optim.Adam(modl.parameters(), lr=1e-3)\n",
"pbar = tqdm(dataloader)\n",
"for i, batch in enumerate(pbar):\n",
" optimizer.zero_grad()\n",
" kdata, csm, target = (batch['data'].cuda(), batch['csm'].cuda(), batch['target'].cuda())\n",
" prediction = modl(kdata, csm)\n",
" objective = 0.5 * mrpro.operators.functionals.MSE(target.data) - mrpro.operators.functionals.SSIM(target.data)\n",
" (loss,) = objective(prediction.data)\n",
" loss.backward()\n",
" torch.nn.utils.clip_grad_norm_(modl.parameters(), 5.0)\n",
" optimizer.step()\n",
"\n",
" pbar.set_postfix(loss=loss.item())\n",
" if i % 200 == 0:\n",
" plot(batch, prediction, i)\n",
" print(modl.regularization_weights)\n",
" state = {'modl': modl.state_dict(), 'optimizer': optimizer.state_dict()}\n",
" torch.save(state, f'modl_{i}.pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"jupytext": {
"cell_metadata_filter": "mystnb,tags,-all"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Loading