diff --git a/examples/notebooks/apply_pinqi.ipynb b/examples/notebooks/apply_pinqi.ipynb new file mode 100644 index 000000000..9d1afb4d5 --- /dev/null +++ b/examples/notebooks/apply_pinqi.ipynb @@ -0,0 +1,609 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PTB-MR/mrpro/blob/main/examples/notebooks/apply_pinqi.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "import importlib\n", + "\n", + "if not importlib.util.find_spec('mrpro'):\n", + " %pip install mrpro[notebooks]" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "# End-to-end physics informed network for quantitative MRI (PINQI)\n", + "A recent DL approach, PINQI, approaches learned quantitative MRI by half quadratic splitting to alternate between two\n", + "subproblems. The first is a linear image reconstruction task\n", + "$$\n", + "\\underset{\\mathbf{x}}{\\min} \\frac{1}{2} \\| \\mathbf{A} \\mathbf{x} - \\mathbf{y} \\|_2^2\n", + "+ \\frac{\\lambda_\\mathbf{x}}{2} \\left\\| \\mathbf{x} - \\mathbf{x}_{\\text{reg}} \\right\\|_2^2\n", + "+ \\frac{\\lambda_{\\mathbf{q}}}{2} \\left\\| \\mathbf{q}(\\mathbf{p}) - \\mathbf{x} \\right\\|_2^2\n", + "$$\n", + "with $\\mathbf{x}$ being intermediary qualitative images, $\\lambda_{\\mathbf{x}}$ and $\\lambda_{\\mathbf{q}}$ being\n", + "regularization strengths and $\\mathbf{x}_{\\text{reg}}$ denoting an image prior for regularization.\n", + "The second, non-linear, subproblem is finding the quantitative parameters by solving\n", + "$$\n", + "\\underset{\\mathbf{p}}{\\min} \\frac{\\lambda_{\\mathbf{q}}}{2}\\left \\| \\mathbf{q}(\\vec{p}) - \\mathbf{x} \\right\\|_2^2\n", + "+ \\frac{\\lambda_{\\mathbf{p}}}{2} \\left\\| \\mathbf{p} - \\mathbf{p}_{\\text{reg}} \\right\\|_2^2.\n", + "$$\n", + "Here, $\\mathbf{p}_{\\text{reg}}$ is a prior on the parameter maps and $\\lambda_{\\mathbf{p}}$ the associated weight for\n", + "regularization.\n", + "In PINQI, a solution is found by iterating between both subproblems. In each iteration $k=1,\\ldots,T$,\n", + "the image and parameter priors are updated by U-Nets. The network parameters and the regularization strengths\n", + "are trained end-to-end.\n", + "Here, we apply a trained PINQI model to a validation set. We first define the dataset, then define the PINQI model,\n", + "before loading the model weights and applying it to the dataset." + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## Dataset\n", + "We base the dataset on the BrainWeb phantom (`mrpro.phantoms.brainweb.BrainwebSlices`) and simulate Cartesian random\n", + "undersampling in phase encode direction." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "from collections.abc import Sequence\n", + "from copy import deepcopy\n", + "from pathlib import Path\n", + "from typing import Literal, TypedDict\n", + "\n", + "import einops\n", + "import mrpro\n", + "import torch\n", + "\n", + "# mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "class BatchType(TypedDict):\n", + " \"\"\"Typehint for a batch of data.\"\"\"\n", + "\n", + " kdata: mrpro.data.KData\n", + " csm: mrpro.data.CsmData\n", + " m0: torch.Tensor\n", + " t1: torch.Tensor\n", + " mask: torch.Tensor\n", + "\n", + "\n", + "class Dataset(torch.utils.data.Dataset[BatchType]):\n", + " \"\"\"A brainweb based cartesian qMRI dataset.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " folder: Path,\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " n_images: int,\n", + " size: int,\n", + " acceleration: int,\n", + " n_coils: int,\n", + " max_noise: float,\n", + " orientation: Sequence[Literal['axial', 'coronal', 'sagittal']],\n", + " random: bool = True,\n", + " ):\n", + " \"\"\"Initialize the dataset.\"\"\"\n", + " if random:\n", + " augment = mrpro.phantoms.brainweb.augment(size=size)\n", + " else:\n", + " augment = mrpro.phantoms.brainweb.augment(\n", + " size=size,\n", + " max_random_shear=0,\n", + " max_random_rotation=0,\n", + " max_random_scaling_factor=0,\n", + " p_horizontal_flip=0,\n", + " p_vertical_flip=1.0,\n", + " )\n", + " self.phantom = mrpro.phantoms.brainweb.BrainwebSlices(\n", + " folder=folder,\n", + " what=('m0', 't1', 'mask'),\n", + " seed='index' if not random else 'random',\n", + " slice_preparation=augment,\n", + " orientation=orientation,\n", + " )\n", + " self.signalmodel = deepcopy(signalmodel)\n", + " self.encoding_matrix = mrpro.data.SpatialDimension(1, size, size)\n", + " self.fov = mrpro.data.SpatialDimension(0.01, 0.25, 0.25)\n", + " self.acceleration = acceleration\n", + " self.n_coils = n_coils\n", + " self._random = random\n", + " self.max_noise = max_noise\n", + " self._n_images = n_images\n", + "\n", + " def __len__(self) -> int:\n", + " \"\"\"Get the length of the dataset.\"\"\"\n", + " return len(self.phantom)\n", + "\n", + " def __getitem__(self, index: int):\n", + " \"\"\"Get an item from the dataset.\"\"\"\n", + " phantom = self.phantom[index]\n", + " (images,) = self.signalmodel(phantom['m0'], phantom['t1'])\n", + " seed = int(torch.randint(0, 1000000, (1,))) if self._random else index\n", + "\n", + " traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density(\n", + " encoding_matrix=self.encoding_matrix,\n", + " seed=seed,\n", + " acceleration=self.acceleration,\n", + " fwhm_ratio=1.5,\n", + " n_center=10,\n", + " n_other=(self._n_images,),\n", + " )\n", + " header = mrpro.data.KHeader(\n", + " encoding_matrix=self.encoding_matrix,\n", + " recon_matrix=self.encoding_matrix,\n", + " recon_fov=self.fov,\n", + " encoding_fov=self.fov,\n", + " )\n", + "\n", + " if isinstance(self.signalmodel, mrpro.operators.models.SaturationRecovery):\n", + " header.ti = self.signalmodel.saturation_time.tolist()\n", + " elif isinstance(self.signalmodel, mrpro.operators.models.InversionRecovery):\n", + " header.ti = self.signalmodel.ti.tolist()\n", + "\n", + " fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj)\n", + " csm = mrpro.data.CsmData(\n", + " mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix),\n", + " header,\n", + " )\n", + " images = einops.rearrange(images, 't y x -> t 1 1 y x')\n", + " (data,) = (fourier_op @ csm.as_operator())(images)\n", + " data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std()\n", + " kdata = mrpro.data.KData(header, data, traj)\n", + " return {'kdata': kdata, 'csm': csm, **phantom}" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "## PINQI\n", + "Next, We define the PINQI model. Here we can make use of the diffferntiable optimization operators in MRpro." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "class PINQI(torch.nn.Module):\n", + " \"\"\"PINQI model.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp,\n", + " parameter_is_complex: Sequence[bool],\n", + " n_images: int,\n", + " n_iterations: int,\n", + " n_features_parameter_net: Sequence[int],\n", + " n_features_image_net: Sequence[int],\n", + " ):\n", + " \"\"\"Initialize the PINQI model.\"\"\"\n", + " super().__init__()\n", + " self.signalmodel = mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ signalmodel @ constraints_op\n", + " self.constraints_op = constraints_op\n", + " self._n_images = n_images\n", + " self._parameter_is_complex = parameter_is_complex\n", + " real_parameters = sum(1 for c in parameter_is_complex if c) + len(parameter_is_complex)\n", + " self.parameter_net = mrpro.nn.nets.UNet(\n", + " dim=2,\n", + " channels_in=n_images * 2,\n", + " channels_out=real_parameters,\n", + " attention_depths=(-1, -2),\n", + " n_features=n_features_parameter_net,\n", + " cond_dim=128,\n", + " )\n", + "\n", + " self.image_net = mrpro.nn.nets.UNet(\n", + " 2,\n", + " channels_in=2,\n", + " channels_out=2,\n", + " attention_depths=(),\n", + " n_features=n_features_image_net,\n", + " cond_dim=128,\n", + " )\n", + " self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3))\n", + " self.softplus = torch.nn.Softplus(beta=5)\n", + " self.iteration_embedding = torch.nn.Embedding(n_iterations + 1, 128)\n", + "\n", + " def objective_factory(\n", + " lambda_parameters: torch.Tensor,\n", + " image: torch.Tensor,\n", + " *parameter_reg: torch.Tensor,\n", + " ) -> torch.operators.Operator:\n", + " dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel\n", + " reg = mrpro.operators.ProximableFunctionalSeparableSum(\n", + " *[mrpro.operators.functionals.L2NormSquared(r) for r in parameter_reg]\n", + " )\n", + " return dc + lambda_parameters * reg\n", + "\n", + " self.nonlinear_solver = mrpro.operators.OptimizerOp(\n", + " objective_factory,\n", + " lambda _l, _i, *parameter_reg: parameter_reg,\n", + " )\n", + " # This can be done once, as the signal model is the same for all samples.\n", + "\n", + " def get_linear_solver(self, gram: mrpro.operators.LinearOperator) -> mrpro.operators.ConjugateGradientOp:\n", + " \"\"\"Set up the linear solver.\"\"\"\n", + " # This needs to be done for each sample, as the undersampling pattern and csm are different for each sample,\n", + " # thus the gram operator of the acquisition operator is different for each sample.\n", + "\n", + " def operator_factory(\n", + " lambda_image: torch.Tensor,\n", + " lambda_q: torch.Tensor,\n", + " *_,\n", + " ):\n", + " return gram + lambda_image + lambda_q\n", + "\n", + " def rhs_factory(\n", + " lambda_image: torch.Tensor,\n", + " lambda_q: torch.Tensor,\n", + " image_reg: torch.Tensor,\n", + " signal: torch.Tensor,\n", + " zero_filled_image: torch.Tensor,\n", + " ):\n", + " return (zero_filled_image + lambda_image * image_reg + lambda_q * signal,)\n", + "\n", + " return mrpro.operators.ConjugateGradientOp(\n", + " operator_factory=operator_factory,\n", + " rhs_factory=rhs_factory,\n", + " )\n", + "\n", + " def get_parameter_reg(self, image: torch.Tensor, iteration: int = 0) -> tuple[torch.Tensor, ...]:\n", + " \"\"\"Get the parameter regularization.\"\"\"\n", + " image = einops.rearrange(\n", + " torch.view_as_real(image),\n", + " 'batch t 1 1 y x complex-> batch (t complex) y x',\n", + " )\n", + " cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None]\n", + " parameters = self.parameter_net(image.contiguous(), cond=cond)\n", + " parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x')\n", + " i = 0\n", + " result = []\n", + " for is_complex in self._parameter_is_complex:\n", + " if is_complex:\n", + " result.append(torch.complex(parameters[i], parameters[i + 1]))\n", + " i += 2\n", + " else:\n", + " result.append(parameters[i])\n", + " i += 1\n", + " return tuple(result)\n", + "\n", + " def get_image_reg(self, image: torch.Tensor, iteration: int = 0) -> torch.Tensor:\n", + " \"\"\"Get the image regularization.\"\"\"\n", + " batch = image.shape[0]\n", + " image = einops.rearrange(\n", + " torch.view_as_real(image),\n", + " 'batch t 1 1 y x complex-> (batch t) complex y x',\n", + " )\n", + " cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None]\n", + " image = image + self.image_net(image.contiguous(), cond=cond)\n", + " image = einops.rearrange(image, '(batch t) complex y x-> batch t 1 1 y x complex', batch=batch)\n", + " return torch.view_as_complex(image.contiguous())\n", + "\n", + " def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> tuple[torch.Tensor, ...]:\n", + " \"\"\"Estimate the quantitative parameters.\n", + "\n", + " Parameters\n", + " ----------\n", + " kdata\n", + " The k-space data.\n", + " csm\n", + " The coil sensitivity maps.\n", + "\n", + " Returns\n", + " -------\n", + " images\n", + " The qualitative images.\n", + " parameters\n", + " The quantitative parameters.\n", + " \"\"\"\n", + " csm_op = csm.as_operator()\n", + " fourier_op = mrpro.operators.FourierOp.from_kdata(kdata)\n", + " acquisition_op = fourier_op @ csm_op\n", + " gram = acquisition_op.gram\n", + " (zero_filled_image,) = acquisition_op.H(kdata.data)\n", + " images = mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2)\n", + " parameters = self.get_parameter_reg(images, 0)\n", + " linear_solver = self.get_linear_solver(gram)\n", + "\n", + " for i, (lambda_image, lambda_q, lambda_parameter) in enumerate(self.softplus(self.lambdas_raw)):\n", + " # linear subproblem 1\n", + " image_reg = self.get_image_reg(images, i)\n", + " (signal,) = self.signalmodel(*parameters)\n", + " images = linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image)\n", + " # nonlinear subproblem 2\n", + " parameters_reg = self.get_parameter_reg(images, i + 1)\n", + " parameters = self.nonlinear_solver(lambda_parameter, images, *parameters_reg)\n", + " if self.constraints_op is not None:\n", + " # map the parameters into the constrained space\n", + " parameters = self.constraints_op(*parameters)\n", + " return parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# As a baseline methods for comparison, we use a simple non-learned approach. We reconstruct the qualitative images at\n", + "# different saturation times using iterative SENSE. We then perform a constrained non-linear least squares regression\n", + "# using L-BFGS to obtain the parameter maps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "def baseline_solution(\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp,\n", + " parameter_is_complex: Sequence[bool],\n", + " kdata: mrpro.data.KData,\n", + " csm: mrpro.data.CsmData,\n", + ") -> tuple[torch.Tensor, ...]:\n", + " \"\"\"Compute a baseline solution using SENSE + Regression.\"\"\"\n", + " sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(kdata, csm=csm)\n", + " images = sense(kdata)\n", + " objective = mrpro.operators.functionals.L2NormSquared(images.data) @ signalmodel @ constraints_op\n", + " initial_values = tuple(\n", + " torch.zeros(\n", + " images.shape[1:],\n", + " device=images.device,\n", + " dtype=torch.complex64 if is_complex else torch.float32,\n", + " )\n", + " for is_complex in parameter_is_complex\n", + " )\n", + " solution = constraints_op(*mrpro.algorithms.optimizers.lbfgs(objective, initial_values))\n", + " return solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "data_folder = Path('/home/zimmer08/.cache/mrpro/brainweb')\n", + "\n", + "signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 8.0))\n", + "constraints_op = mrpro.operators.ConstraintsOp(\n", + " bounds=(\n", + " (-2, 2), # M0 in [-2, 2]\n", + " (0.01, 6.0), # T1 is constrained between 10 ms and 6 s\n", + " )\n", + ")\n", + "n_images = len(signalmodel.saturation_time)\n", + "parameter_is_complex = [True, False]\n", + "\n", + "\n", + "dataset = torch.utils.data.Subset(\n", + " Dataset(\n", + " folder=data_folder,\n", + " signalmodel=signalmodel,\n", + " n_images=n_images,\n", + " size=192,\n", + " acceleration=8,\n", + " n_coils=8,\n", + " max_noise=0.05,\n", + " orientation=('axial',),\n", + " random=False,\n", + " ),\n", + " list(range(500)),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "checkpoint = torch.load('./examples/scripts/last.ckpt', map_location='cpu')\n", + "hyper_parameters = checkpoint['hyper_parameters']\n", + "\n", + "\n", + "pinqi = PINQI(\n", + " signalmodel=signalmodel,\n", + " constraints_op=constraints_op,\n", + " parameter_is_complex=parameter_is_complex,\n", + " n_images=n_images,\n", + " n_iterations=hyper_parameters['n_iterations'],\n", + " n_features_parameter_net=hyper_parameters['n_features_parameter_net'],\n", + " n_features_image_net=hyper_parameters['n_features_image_net'],\n", + ")\n", + "state_dict = {\n", + " k.replace('pinqi.', '').replace('_orig_mod.', ''): v\n", + " for k, v in checkpoint['state_dict'].items()\n", + " if 'baseline' not in k\n", + "}\n", + "pinqi.load_state_dict(state_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "batch = dataset[40]\n", + "csm, kdata = batch['csm'], batch['kdata']\n", + "\n", + "if torch.cuda.is_available():\n", + " pinqi, csm, kdata = pinqi.cuda(), csm.cuda(), kdata.cuda()\n", + "images, parameters = pinqi(kdata[None], csm[None])\n", + "with torch.no_grad():\n", + " predicted_m0, predicted_t1 = (p.cpu().detach().squeeze() for p in parameters[-1])\n", + "baseline_m0, baseline_t1 = baseline_solution(signalmodel, constraints_op, parameter_is_complex, kdata, csm)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "(ssim_t1,) = mrpro.operators.functionals.SSIM(batch['t1'][None], batch['mask'][None])(predicted_t1[None])\n", + "(mse_t1,) = mrpro.operators.functionals.MSE(batch['t1'], batch['mask'])(predicted_t1)\n", + "\n", + "(mse_baseline,) = mrpro.operators.functionals.MSE(batch['t1'], batch['mask'])(baseline_t1)\n", + "nrmse_t1 = torch.sqrt(mse_t1) / batch['t1'][batch['mask']].max()\n", + "(ssim_baseline,) = mrpro.operators.functionals.SSIM(batch['t1'][None], batch['mask'][None])(baseline_t1[None])\n", + "nrmse_baseline = torch.sqrt(mse_baseline) / batch['t1'][batch['mask']].max()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "from cmap import Colormap\n", + "\n", + "cmap = Colormap('lipari').to_matplotlib()\n", + "\n", + "print(f'SSIM: {ssim_baseline.item():.4f}, NRMSE: {nrmse_baseline.item():.4f}')\n", + "print(f'SSIM: {ssim_t1.item():.4f}, NRMSE: {nrmse_t1.item():.4f}')\n", + "\n", + "\n", + "fig, ax = plt.subplots(\n", + " 1,\n", + " 5,\n", + " gridspec_kw={\n", + " 'width_ratios': [1, 1, 1, 0.28, 0.075],\n", + " 'wspace': -0.25,\n", + " },\n", + " figsize=(6.5, 2.5),\n", + ")\n", + "baseline_t1 = baseline_t1.squeeze()\n", + "baseline_t1[~batch['mask']] = torch.nan\n", + "ax[0].imshow(baseline_t1, vmin=0, vmax=2, cmap=cmap)\n", + "ax[0].axis('off')\n", + "ax[0].set_title('SENSE + NLS')\n", + "ax[0].text(\n", + " 0.5,\n", + " -0.00,\n", + " f'SSIM: {ssim_baseline.item():.2f}',\n", + " color='black',\n", + " horizontalalignment='center',\n", + " verticalalignment='top',\n", + " transform=ax[0].transAxes,\n", + " size=11,\n", + ")\n", + "predicted_t1 = predicted_t1.squeeze()\n", + "predicted_t1[~batch['mask']] = torch.nan\n", + "ax[1].imshow(predicted_t1, vmin=0, vmax=2, cmap=cmap)\n", + "ax[1].axis('off')\n", + "ax[1].set_title('PINQI')\n", + "ax[1].text(\n", + " 0.5,\n", + " -0.0,\n", + " f'SSIM: {ssim_t1.item():.2f}',\n", + " color='black',\n", + " horizontalalignment='center',\n", + " verticalalignment='top',\n", + " transform=ax[1].transAxes,\n", + " size=11,\n", + ")\n", + "\n", + "target_t1 = batch['t1'].squeeze()\n", + "target_t1[~batch['mask']] = torch.nan\n", + "im = ax[2].imshow(target_t1, vmin=0, vmax=2, cmap=cmap)\n", + "ax[2].axis('off')\n", + "ax[2].set_title(\n", + " 'Ground Truth',\n", + ")\n", + "ax[-2].axis('off')\n", + "fig.tight_layout()\n", + "plt.colorbar(im, cax=ax[-1], label='$T_1$ (s)')\n", + "fig.savefig(\n", + " '/home/zimmer08/code/mrpro/examples/scripts/pinqi_t1_3.pdf',\n", + " bbox_inches='tight',\n", + " pad_inches=0,\n", + ")" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "jupytext": { + "cell_metadata_filter": "mystnb,tags,-all" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/modl.ipynb b/examples/notebooks/modl.ipynb new file mode 100644 index 000000000..54743423d --- /dev/null +++ b/examples/notebooks/modl.ipynb @@ -0,0 +1,266 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PTB-MR/mrpro/blob/main/examples/notebooks/modl.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "import importlib\n", + "\n", + "if not importlib.util.find_spec('mrpro'):\n", + " %pip install mrpro[notebooks]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "from collections.abc import Sequence\n", + "from pathlib import Path\n", + "from typing import TypedDict\n", + "\n", + "import matplotlib.axes\n", + "import matplotlib.pyplot as plt\n", + "import mrpro\n", + "import torch\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "class BatchType(TypedDict):\n", + " \"\"\"A single Batch.\"\"\"\n", + "\n", + " data: mrpro.data.KData\n", + " target: mrpro.data.IData\n", + " csm: mrpro.data.CsmData\n", + "\n", + "\n", + "class AcceleratedFastMRI(torch.utils.data.Dataset):\n", + " \"\"\"An undersampled FastMRI Dataset.\"\"\"\n", + "\n", + " def __init__(self, path: Path, acceleration: float = 12, noise_level: float = 0.1):\n", + " \"\"\"Create an undersampled FastMRI Dataset.\n", + "\n", + " Parameters\n", + " ----------\n", + " path\n", + " Path to the FastMRI dataset.\n", + " acceleration\n", + " Undersampling factor; higher values mean more acceleration. Default is 12.\n", + " noise_level\n", + " Level of additive Gaussian noise applied to the FastMRI dataset. Default is 0.1.\n", + " \"\"\"\n", + " self.acceleration = acceleration\n", + " files = list(path.glob('*AXT1*'))\n", + " self.dataset = mrpro.phantoms.FastMRIKDataDataset(files)\n", + " self.noise_level = noise_level\n", + "\n", + " def __len__(self):\n", + " \"\"\"Get length of the dataset.\"\"\"\n", + " return len(self.dataset)\n", + "\n", + " def __getitem__(self, index: int) -> BatchType:\n", + " \"\"\"Get a single batch of data.\n", + "\n", + " Parameters\n", + " ----------\n", + " index\n", + " Index of the batch.\n", + "\n", + " Returns\n", + " -------\n", + " A single batch of data with keys 'data', 'target', and 'csm'.and\n", + " \"\"\"\n", + " data = self.dataset[index]\n", + " data = data.remove_readout_os()\n", + " data.data /= data.data.std()\n", + " reconstruction = mrpro.algorithms.reconstruction.DirectReconstruction(\n", + " data, csm=lambda data: mrpro.data.CsmData.from_idata_inati(data, downsampled_size=64)\n", + " )\n", + " csm = reconstruction.csm\n", + " target = reconstruction(data)\n", + "\n", + " n = max(data.data.shape[-2:])\n", + " distance = (torch.linspace(-1, 1, n)[:, None] ** 2 + torch.linspace(-1, 1, n) ** 2).sqrt()\n", + " random = 0.1 / (distance + 0.1) + torch.rand_like(distance)\n", + " threshold = torch.kthvalue(random.ravel(), int(n**2 * (1 - 1 / self.acceleration))).values\n", + " undersampling_mask = mrpro.utils.pad_or_crop(random > threshold, data.data.shape[-2:])\n", + " data_undersampled = data[..., undersampling_mask].rearrange('k ... 1 -> ... k')\n", + "\n", + " noise = mrpro.utils.RandomGenerator(seed=index).randn_like(data_undersampled.data)\n", + " data_undersampled.data += self.noise_level * noise\n", + "\n", + " assert csm is not None # for mypy\n", + " return {'data': data_undersampled, 'target': target, 'csm': csm}\n", + "\n", + "\n", + "class MODL(torch.nn.Module):\n", + " \"\"\"MODL network.\"\"\"\n", + "\n", + " def __init__(self, iterations: int = 8, n_features: Sequence[int] = (64, 64, 64, 64)):\n", + " \"\"\"Initialize MODL network.\n", + "\n", + " Parameters\n", + " ----------\n", + " iterations\n", + " Number of iterations.\n", + " n_features\n", + " Number of features in the network.\n", + " \"\"\"\n", + " super().__init__()\n", + " cnn = mrpro.nn.nets.BasicCNN(\n", + " dim=2,\n", + " channels_in=2,\n", + " channels_out=2,\n", + " n_features=n_features,\n", + " batch_norm=True,\n", + " )\n", + " self.network = mrpro.nn.Residual(mrpro.nn.ComplexAsChannel(mrpro.nn.PermutedBlock((-1, -2), cnn)))\n", + " self.network = torch.compile(self.network, dynamic=True, fullgraph=True)\n", + " self.iterations = iterations\n", + " self.regularization_weights = torch.nn.Parameter(0.2 * torch.ones(iterations))\n", + "\n", + " def __call__(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData:\n", + " \"\"\"Apply MODL network.\n", + "\n", + " Parameters\n", + " ----------\n", + " kdata\n", + " The k-space data.\n", + " csm\n", + " The coil sensitivity maps.\n", + "\n", + " Returns\n", + " -------\n", + " The reconstructed image.\n", + " \"\"\"\n", + " return super().__call__(kdata, csm)\n", + "\n", + " def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData:\n", + " \"\"\"Apply the MODL network.\"\"\"\n", + " fourier_op = mrpro.operators.FourierOp.from_kdata(kdata)\n", + " acquisition_op = fourier_op @ csm.as_operator()\n", + " (zero_filled_image,) = acquisition_op.H(kdata.data)\n", + " gram = acquisition_op.gram\n", + " data_consistency_op = mrpro.operators.ConjugateGradientOp(\n", + " operator_factory=lambda _image, weight: gram + weight,\n", + " rhs_factory=lambda image, weight: zero_filled_image + weight * image,\n", + " )\n", + "\n", + " (image,) = mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=5)\n", + " for iteration in range(self.iterations):\n", + " regularization = self.network(image)\n", + " (image,) = data_consistency_op(regularization, self.regularization_weights[iteration])\n", + "\n", + " return mrpro.data.IData(image, header=mrpro.data.IHeader.from_kheader(kdata.header))\n", + "\n", + "\n", + "def plot(batch: BatchType, prediction: mrpro.data.IData, step: int) -> None:\n", + " \"\"\"Plot the direct, sense, and modl reconstructions.\"\"\"\n", + " target = batch['target'].rss().cpu().squeeze()\n", + " direct = mrpro.algorithms.reconstruction.DirectReconstruction(batch['data'], csm=batch['csm'])(batch['data'])\n", + " direct = direct.rss().cpu().squeeze()\n", + " direct *= target.std() / direct.std()\n", + " sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(batch['data'], csm=batch['csm'])(batch['data'])\n", + " sense = sense.rss().cpu().squeeze()\n", + " prediction_ = prediction.rss().cpu().squeeze().detach()\n", + "\n", + " ssim = mrpro.operators.functionals.SSIM(mrpro.utils.pad_or_crop(target[None], (320, 320)))\n", + "\n", + " def show(ax: matplotlib.axes.Axes, data: torch.Tensor, label: str):\n", + " data = mrpro.utils.pad_or_crop(data, (320, 320))\n", + " ax.imshow(data, vmin=0, vmax=target.max().item(), cmap='gray')\n", + " if label != 'Ground Truth':\n", + " (ssim_value,) = ssim(data[None])\n", + " ax.text(\n", + " 0.98,\n", + " 0.1,\n", + " f'SSIM: {ssim_value.item():.2f}',\n", + " color='white',\n", + " horizontalalignment='right',\n", + " verticalalignment='top',\n", + " transform=ax.transAxes,\n", + " )\n", + " ax.set_title(label)\n", + " ax.set_axis_off()\n", + "\n", + " fig, ax = plt.subplots(1, 4)\n", + " show(ax[0], direct, 'Direct')\n", + " show(ax[1], sense, 'CG-SENSE')\n", + " show(ax[2], prediction_, 'MODL')\n", + " show(ax[3], target, 'Ground Truth')\n", + " fig.tight_layout()\n", + " fig.savefig(f'modl_{step}.pdf', bbox_inches='tight', pad_inches=0)\n", + "\n", + "\n", + "# %%.\n", + "path = Path('/echo/allgemein/resources/publicTrainingData/fastmri/brain_multicoil_train/')\n", + "dataset = AcceleratedFastMRI(path)\n", + "dataloader = torch.utils.data.DataLoader(dataset, num_workers=16, shuffle=True, collate_fn=lambda batch: batch[0])\n", + "modl = MODL().cuda()\n", + "optimizer = torch.optim.Adam(modl.parameters(), lr=1e-3)\n", + "pbar = tqdm(dataloader)\n", + "for i, batch in enumerate(pbar):\n", + " optimizer.zero_grad()\n", + " kdata, csm, target = (batch['data'].cuda(), batch['csm'].cuda(), batch['target'].cuda())\n", + " prediction = modl(kdata, csm)\n", + " objective = 0.5 * mrpro.operators.functionals.MSE(target.data) - mrpro.operators.functionals.SSIM(target.data)\n", + " (loss,) = objective(prediction.data)\n", + " loss.backward()\n", + " torch.nn.utils.clip_grad_norm_(modl.parameters(), 5.0)\n", + " optimizer.step()\n", + "\n", + " pbar.set_postfix(loss=loss.item())\n", + " if i % 200 == 0:\n", + " plot(batch, prediction, i)\n", + " print(modl.regularization_weights)\n", + " state = {'modl': modl.state_dict(), 'optimizer': optimizer.state_dict()}\n", + " torch.save(state, f'modl_{i}.pt')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "jupytext": { + "cell_metadata_filter": "mystnb,tags,-all" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/train_pinqi.ipynb b/examples/notebooks/train_pinqi.ipynb new file mode 100644 index 000000000..c64bfa584 --- /dev/null +++ b/examples/notebooks/train_pinqi.ipynb @@ -0,0 +1,760 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PTB-MR/mrpro/blob/main/examples/notebooks/train_pinqi.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "import importlib\n", + "\n", + "if not importlib.util.find_spec('mrpro'):\n", + " %pip install mrpro[notebooks]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "# ruff: noqa: D102, ANN201\n", + "from collections.abc import Sequence\n", + "from copy import deepcopy\n", + "from pathlib import Path\n", + "from typing import Any, Literal, TypedDict\n", + "\n", + "import einops\n", + "import matplotlib.pyplot as plt\n", + "import mrpro\n", + "import numpy as np\n", + "import pytorch_lightning as pl # type:ignore[import-not-found]\n", + "import torch\n", + "import torch.utils.data._utils\n", + "from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint # type:ignore[import-not-found]\n", + "from pytorch_lightning.loggers import NeptuneLogger # type:ignore[import-not-found]\n", + "\n", + "\n", + "class BatchType(TypedDict):\n", + " \"\"\"Typehint for a batch of data.\"\"\"\n", + "\n", + " kdata: mrpro.data.KData\n", + " csm: mrpro.data.CsmData\n", + " m0: torch.Tensor\n", + " t1: torch.Tensor\n", + " mask: torch.Tensor\n", + "\n", + "\n", + "class Dataset(torch.utils.data.Dataset):\n", + " \"\"\"A brainweb based cartesian qMRI dataset.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " folder: Path,\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " n_images: int,\n", + " size: int,\n", + " acceleration: int,\n", + " n_coils: int,\n", + " max_noise: float,\n", + " orientation: Sequence[Literal['axial', 'coronal', 'sagittal']],\n", + " random: bool = True,\n", + " ):\n", + " \"\"\"Initialize the dataset.\"\"\"\n", + " if random:\n", + " augment = mrpro.phantoms.brainweb.augment(size=size)\n", + " else:\n", + " augment = mrpro.phantoms.brainweb.augment(\n", + " size=size,\n", + " max_random_shear=0,\n", + " max_random_rotation=0,\n", + " max_random_scaling_factor=0,\n", + " p_horizontal_flip=0,\n", + " p_vertical_flip=1.0,\n", + " )\n", + " self.phantom = mrpro.phantoms.brainweb.BrainwebSlices(\n", + " folder=folder,\n", + " what=('m0', 't1', 'mask'),\n", + " seed='index' if not random else 'random',\n", + " slice_preparation=augment,\n", + " orientation=orientation,\n", + " )\n", + " self.signalmodel = signalmodel\n", + " self.encoding_matrix = mrpro.data.SpatialDimension(1, size, size)\n", + " self.fov = mrpro.data.SpatialDimension(0.01, 0.25, 0.25)\n", + " self.acceleration = acceleration\n", + " self.n_coils = n_coils\n", + " self._random = random\n", + " self.max_noise = max_noise\n", + " self._n_images = n_images\n", + "\n", + " def __len__(self) -> int:\n", + " \"\"\"Get the length of the dataset.\"\"\"\n", + " return len(self.phantom)\n", + "\n", + " def __getitem__(self, index: int):\n", + " \"\"\"Get an item from the dataset.\"\"\"\n", + " phantom = self.phantom[index]\n", + " (images,) = self.signalmodel(phantom['m0'], phantom['t1'])\n", + " seed = int(torch.randint(0, 1000000, (1,))) if self._random else index\n", + "\n", + " traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density(\n", + " encoding_matrix=self.encoding_matrix,\n", + " seed=seed,\n", + " acceleration=self.acceleration,\n", + " fwhm_ratio=1.5,\n", + " n_center=12,\n", + " n_other=(self._n_images,),\n", + " )\n", + " header = mrpro.data.KHeader(\n", + " encoding_matrix=self.encoding_matrix,\n", + " recon_matrix=self.encoding_matrix,\n", + " recon_fov=self.fov,\n", + " encoding_fov=self.fov,\n", + " )\n", + "\n", + " if isinstance(self.signalmodel, mrpro.operators.models.SaturationRecovery):\n", + " header.ti = self.signalmodel.saturation_time.tolist()\n", + " elif isinstance(self.signalmodel, mrpro.operators.models.InversionRecovery):\n", + " header.ti = self.signalmodel.ti.tolist()\n", + "\n", + " fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj)\n", + " if self.n_coils > 1:\n", + " csm_tensor = mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix)\n", + " else:\n", + " csm_tensor = torch.ones(1, 1, *self.encoding_matrix.zyx)\n", + " csm = mrpro.data.CsmData(csm_tensor, header)\n", + " images = einops.rearrange(images, 't y x -> t 1 1 y x')\n", + " (data,) = (fourier_op @ csm.as_operator())(images)\n", + " data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std()\n", + " kdata = mrpro.data.KData(header, data, traj)\n", + " return {'kdata': kdata, 'csm': csm, **phantom}\n", + "\n", + "\n", + "def collate_fn(batch: Any): # noqa: ANN401\n", + " \"\"\"Join dataclasses to a batch.\"\"\"\n", + " return torch.utils.data._utils.collate.collate(\n", + " batch,\n", + " collate_fn_map={\n", + " mrpro.data.Dataclass: lambda batch, *, collate_fn_map: batch[0].stack(*batch[1:]), # noqa: ARG005\n", + " **torch.utils.data._utils.collate.default_collate_fn_map,\n", + " },\n", + " )\n", + "\n", + "\n", + "class PINQI(torch.nn.Module):\n", + " \"\"\"PINQI model.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp,\n", + " parameter_is_complex: Sequence[bool],\n", + " n_images: int,\n", + " n_iterations: int,\n", + " n_features_parameter_net: Sequence[int],\n", + " n_features_image_net: Sequence[int],\n", + " ):\n", + " \"\"\"Initialize the PINQI model.\"\"\"\n", + " super().__init__()\n", + " self.signalmodel = mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ signalmodel @ constraints_op\n", + " self.constraints_op = constraints_op\n", + " self._n_images = n_images\n", + " self._parameter_is_complex = parameter_is_complex\n", + " real_parameters = sum(1 for c in parameter_is_complex if c) + len(parameter_is_complex)\n", + " self.parameter_net = torch.compile(\n", + " mrpro.nn.nets.UNet(\n", + " n_dim=2,\n", + " n_channels_in=n_images * 2,\n", + " n_channels_out=real_parameters,\n", + " attention_depths=(-1, -2),\n", + " n_features=n_features_parameter_net,\n", + " cond_dim=128,\n", + " ),\n", + " dynamic=False,\n", + " fullgraph=True,\n", + " )\n", + " self.image_net = torch.compile(\n", + " mrpro.nn.nets.UNet(\n", + " n_dim=2,\n", + " n_channels_in=2,\n", + " n_channels_out=2,\n", + " attention_depths=(),\n", + " n_features=n_features_image_net,\n", + " cond_dim=128,\n", + " ),\n", + " dynamic=False,\n", + " fullgraph=True,\n", + " )\n", + " self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3))\n", + " self.softplus = torch.nn.Softplus(beta=5)\n", + " self.iteration_embedding = torch.nn.Embedding(n_iterations + 1, 128)\n", + "\n", + " def objective_factory(\n", + " lambda_parameters: torch.Tensor,\n", + " image: torch.Tensor,\n", + " *parameter_reg: torch.Tensor,\n", + " ):\n", + " dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel\n", + " reg = mrpro.operators.ProximableFunctionalSeparableSum(\n", + " *[mrpro.operators.functionals.L2NormSquared(r) for r in parameter_reg]\n", + " )\n", + " return dc + lambda_parameters * reg\n", + "\n", + " self.nonlinear_solver = mrpro.operators.OptimizerOp(\n", + " objective_factory,\n", + " lambda _l, _i, *parameter_reg: parameter_reg,\n", + " )\n", + "\n", + " def get_linear_solver(self, gram: mrpro.operators.LinearOperator):\n", + " def operator_factory(\n", + " lambda_image: torch.Tensor,\n", + " lambda_q: torch.Tensor,\n", + " *_,\n", + " ):\n", + " return gram + lambda_image + lambda_q\n", + "\n", + " def rhs_factory(\n", + " lambda_image: torch.Tensor,\n", + " lambda_q: torch.Tensor,\n", + " image_reg: torch.Tensor,\n", + " signal: torch.Tensor,\n", + " zero_filled_image: torch.Tensor,\n", + " ):\n", + " return (zero_filled_image + lambda_image * image_reg + lambda_q * signal,)\n", + "\n", + " return mrpro.operators.ConjugateGradientOp(\n", + " operator_factory=operator_factory,\n", + " rhs_factory=rhs_factory,\n", + " )\n", + "\n", + " def get_parameter_reg(self, image: torch.Tensor, iteration: int = 0) -> tuple[torch.Tensor, ...]:\n", + " image = einops.rearrange(\n", + " torch.view_as_real(image),\n", + " 'batch t 1 1 y x complex-> batch (t complex) y x',\n", + " )\n", + " cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None]\n", + " parameters = self.parameter_net(image.contiguous(), cond=cond)\n", + " parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x')\n", + " i = 0\n", + " result = []\n", + " for is_complex in self._parameter_is_complex:\n", + " if is_complex:\n", + " result.append(torch.complex(parameters[i], parameters[i + 1]))\n", + " i += 2\n", + " else:\n", + " result.append(parameters[i])\n", + " i += 1\n", + " return tuple(result)\n", + "\n", + " def get_image_reg(self, image: torch.Tensor, iteration: int = 0) -> torch.Tensor:\n", + " batch = image.shape[0]\n", + " image = einops.rearrange(\n", + " torch.view_as_real(image),\n", + " 'batch t 1 1 y x complex-> (batch t) complex y x',\n", + " )\n", + " cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None]\n", + " image = image + self.image_net(image.contiguous(), cond=cond)\n", + " image = einops.rearrange(image, '(batch t) complex y x-> batch t 1 1 y x complex', batch=batch)\n", + " return torch.view_as_complex(image.contiguous())\n", + "\n", + " def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData):\n", + " csm_op = csm.as_operator()\n", + " fourier_op = mrpro.operators.FourierOp.from_kdata(kdata)\n", + " acquisition_op = fourier_op @ csm_op\n", + " gram = acquisition_op.gram\n", + " (zero_filled_image,) = acquisition_op.H(kdata.data)\n", + " images = list(mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2))\n", + " parameters = [self.get_parameter_reg(images[-1], 0)]\n", + " linear_solver = self.get_linear_solver(gram)\n", + "\n", + " for i, (lambda_image, lambda_q, lambda_parameter) in enumerate(self.softplus(self.lambdas_raw)):\n", + " image_reg = self.get_image_reg(images[-1], i + 1)\n", + " (signal,) = self.signalmodel(*parameters[-1])\n", + " images.extend(linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image))\n", + " parameters_reg = self.get_parameter_reg(images[-1], i + 1)\n", + " parameters.append(self.nonlinear_solver(lambda_parameter, images[-1], *parameters_reg))\n", + " if self.constraints_op is not None:\n", + " parameters = [self.constraints_op(*p) for p in parameters]\n", + " return images, parameters\n", + "\n", + "\n", + "class DataModule(pl.LightningDataModule):\n", + " \"\"\"Data module for training the PINQI model.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " folder: Path,\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " n_images: int,\n", + " size: int = 192,\n", + " acceleration: int = 10,\n", + " n_coils: int = 8,\n", + " max_noise: float = 0.1,\n", + " orientation_train: Sequence[Literal['axial', 'coronal', 'sagittal']] = (\n", + " 'axial',\n", + " 'coronal',\n", + " 'sagittal',\n", + " ),\n", + " orientation_val: Sequence[Literal['axial', 'coronal', 'sagittal']] = ('axial',),\n", + " batch_size: int = 16,\n", + " num_workers: int = 4,\n", + " ):\n", + " \"\"\"Initialize the data module.\"\"\"\n", + " super().__init__()\n", + " self.save_hyperparameters(ignore=['signalmodel', 'folder', 'num_workers'])\n", + " self.batch_size = batch_size\n", + " self.num_workers = num_workers\n", + " self.train_dataset = Dataset(\n", + " folder=folder,\n", + " signalmodel=signalmodel,\n", + " n_images=n_images,\n", + " size=size,\n", + " acceleration=acceleration,\n", + " n_coils=n_coils,\n", + " max_noise=max_noise,\n", + " orientation=orientation_train,\n", + " random=True,\n", + " )\n", + " self.val_dataset = torch.utils.data.Subset(\n", + " Dataset(\n", + " folder=folder,\n", + " signalmodel=signalmodel,\n", + " n_images=n_images,\n", + " size=size,\n", + " acceleration=acceleration,\n", + " n_coils=n_coils,\n", + " max_noise=max_noise,\n", + " orientation=orientation_val,\n", + " random=False,\n", + " ),\n", + " list(range(30, 500, 20)),\n", + " )\n", + "\n", + " def train_dataloader(self):\n", + " return torch.utils.data.DataLoader(\n", + " self.train_dataset,\n", + " batch_size=self.batch_size,\n", + " shuffle=True,\n", + " num_workers=self.num_workers,\n", + " pin_memory=False,\n", + " persistent_workers=self.num_workers > 0,\n", + " collate_fn=collate_fn,\n", + " worker_init_fn=lambda *_: torch.set_num_threads(1),\n", + " )\n", + "\n", + " def val_dataloader(self):\n", + " return torch.utils.data.DataLoader(\n", + " self.val_dataset,\n", + " batch_size=4,\n", + " shuffle=False,\n", + " num_workers=self.num_workers,\n", + " pin_memory=False,\n", + " persistent_workers=self.num_workers > 0,\n", + " collate_fn=collate_fn,\n", + " )\n", + "\n", + "\n", + "class PinqiModule(pl.LightningModule):\n", + " \"\"\"Module for training the PINQI model.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " constraints_op: mrpro.operators.ConstraintsOp,\n", + " parameter_is_complex: Sequence[bool],\n", + " n_images: int,\n", + " n_iterations: int = 4,\n", + " n_features_parameter_net: Sequence[int] = (64, 128, 192, 256),\n", + " n_features_image_net: Sequence[int] = (32, 48, 64, 96),\n", + " lr: float = 3e-4, # noqa: ARG002\n", + " weight_decay: float = 1e-3, # noqa: ARG002\n", + " loss_weights: Sequence[float] = (0.2, 0.1, 0.1, 0.1, 0.8),\n", + " ):\n", + " \"\"\"Initialize the PINQI module.\"\"\"\n", + " super().__init__()\n", + " self.save_hyperparameters(ignore=['signalmodel', 'constraints_op'])\n", + " if len(loss_weights) != n_iterations + 1:\n", + " raise ValueError(f'loss_weights must be of length {n_iterations + 1} for {n_iterations} iterations')\n", + " signalmodel = deepcopy(signalmodel)\n", + " constraints_op = deepcopy(constraints_op)\n", + " self.pinqi = PINQI(\n", + " signalmodel=signalmodel,\n", + " constraints_op=constraints_op,\n", + " parameter_is_complex=parameter_is_complex,\n", + " n_images=n_images,\n", + " n_iterations=n_iterations,\n", + " n_features_parameter_net=n_features_parameter_net,\n", + " n_features_image_net=n_features_image_net,\n", + " )\n", + "\n", + " self.validation_step_outputs: dict[str, list] = {}\n", + " self.baseline = Baseline(signalmodel, constraints_op, parameter_is_complex)\n", + "\n", + " def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData):\n", + " \"\"\"Apply the PINQI model to the data.\"\"\"\n", + " return self.pinqi(kdata, csm)\n", + "\n", + " def loss(self, predictions: Sequence[torch.Tensor], batch: BatchType) -> torch.Tensor:\n", + " \"\"\"Compute the loss.\"\"\"\n", + " loss = torch.tensor(0.0, device=self.device)\n", + " target_m0, target_t1, mask = map(torch.squeeze, (batch['m0'], batch['t1'], batch['mask']))\n", + " for prediction, weight in zip(predictions, self.hparams.loss_weights, strict=False):\n", + " prediction_m0, prediction_t1 = map(torch.squeeze, prediction)\n", + " loss_t1 = torch.nn.functional.mse_loss(prediction_t1[mask], target_t1[mask])\n", + " loss_m0 = torch.nn.functional.mse_loss(\n", + " torch.view_as_real(prediction_m0[mask]),\n", + " torch.view_as_real(target_m0[mask]),\n", + " )\n", + " loss_outside = prediction_m0[~mask].abs().mean()\n", + " loss = loss + weight * (loss_t1 + 0.5 * loss_m0 + 0.1 * loss_outside)\n", + " return loss\n", + "\n", + " def training_step(self, batch: BatchType, _batch_idx: int) -> torch.Tensor:\n", + " \"\"\"Training step.\"\"\"\n", + " _images, parameters = self(batch['kdata'], batch['csm'])\n", + " loss = self.loss(parameters, batch)\n", + " self.log(\n", + " 'train/loss',\n", + " loss,\n", + " on_step=True,\n", + " on_epoch=True,\n", + " prog_bar=True,\n", + " sync_dist=True,\n", + " batch_size=len(batch['mask']),\n", + " )\n", + " return loss\n", + "\n", + " def validation_step(self, batch: BatchType, batch_idx: int) -> None:\n", + " \"\"\"Validate.\n", + "\n", + " Needs to be adapted for other signal models than Saturation Recovery.\n", + " \"\"\"\n", + " _images, parameters = self(batch['kdata'], batch['csm'])\n", + " loss = self.loss(parameters, batch)\n", + "\n", + " pred_m0, pred_t1 = parameters[-1]\n", + " target_t1, target_m0 = batch['t1'][:, None, None], batch['m0'][:, None, None]\n", + " mask = batch['mask']\n", + " batch_size = len(batch['mask'])\n", + " (ssim_t1,) = mrpro.operators.functionals.SSIM(target_t1, mask)(pred_t1)\n", + " (l1_t1,) = mrpro.operators.functionals.L1Norm(target_t1, mask)(pred_t1)\n", + " (l1_m0,) = mrpro.operators.functionals.L1Norm(target_m0, mask)(pred_m0)\n", + " self.log('val/ssim_t1', ssim_t1, on_epoch=True, sync_dist=True, batch_size=batch_size)\n", + " self.log('val/l1_t1', l1_t1, on_epoch=True, sync_dist=True, batch_size=batch_size)\n", + " self.log('val/l1_m0', l1_m0, on_epoch=True, sync_dist=True, batch_size=batch_size)\n", + " self.log('val/loss', loss, on_epoch=True, sync_dist=True, batch_size=batch_size)\n", + "\n", + " if batch_idx == 0 and self.trainer.is_global_zero:\n", + " self.validation_step_outputs['target_t1'] = batch['t1'].cpu()\n", + " self.validation_step_outputs['pred_t1'] = pred_t1.cpu()\n", + " self.validation_step_outputs['pred_m0'] = pred_m0.cpu()\n", + " self.validation_step_outputs['target_m0'] = target_m0.cpu()\n", + " self.validation_step_outputs['mask'] = batch['mask'].cpu()\n", + " baseline_m0, baseline_t1 = self.baseline(batch['kdata'], batch['csm'])\n", + " self.validation_step_outputs['baseline_t1'] = baseline_t1.cpu()\n", + " self.validation_step_outputs['baseline_m0'] = baseline_m0.cpu()\n", + "\n", + " def on_validation_epoch_end(self):\n", + " \"\"\"Validate.\n", + "\n", + " Needs to be adapted for other signal models than Saturation Recovery.\n", + " \"\"\"\n", + " if not self.trainer.is_global_zero:\n", + " return\n", + " outputs = self.validation_step_outputs\n", + "\n", + " samples = len(outputs['mask'])\n", + " fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 16), squeeze=False)\n", + "\n", + " for i in range(samples):\n", + " self.result_plot(\n", + " outputs['target_t1'][i],\n", + " outputs['pred_t1'][i],\n", + " outputs['mask'][i],\n", + " axes[:, i],\n", + " outputs['baseline_t1'][i],\n", + " '$T_1$ (s)',\n", + " )\n", + " fig.suptitle(f'$T_1$ Epoch {self.current_epoch}')\n", + " self.logger.run['val/images/t1'].log(fig)\n", + " plt.close(fig)\n", + "\n", + " fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 12))\n", + " for i in range(samples):\n", + " self.result_plot(\n", + " outputs['target_m0'][i].abs(),\n", + " outputs['pred_m0'][i].abs(),\n", + " outputs['mask'][i],\n", + " axes[:, i],\n", + " outputs['baseline_m0'][i].abs(),\n", + " '$|M_0|$ (a.u.)',\n", + " )\n", + " fig.suptitle(f'$|M_0|$ Epoch {self.current_epoch}')\n", + " self.logger.run['val/images/m0'].log(fig)\n", + " plt.close(fig)\n", + " self.validation_step_outputs.clear()\n", + "\n", + " def result_plot(\n", + " self,\n", + " target: torch.Tensor,\n", + " pred: torch.Tensor,\n", + " mask: torch.Tensor,\n", + " axes: Sequence[plt.Axes],\n", + " baseline: torch.Tensor,\n", + " label: str,\n", + " ) -> None:\n", + " \"\"\"Plot the results.\"\"\"\n", + " target = target.squeeze().cpu()\n", + " pred = pred.squeeze().detach().cpu()\n", + " mask = mask.squeeze().detach().bool().cpu()\n", + " baseline = baseline.squeeze().detach().cpu()\n", + " target[~mask] = torch.nan\n", + " pred[~mask] = torch.nan\n", + " baseline[~mask] = torch.nan\n", + " difference = (target - pred) / target * 100\n", + " vmax = np.nanmax(target.numpy())\n", + "\n", + " im0 = axes[0].imshow(target, vmin=0, vmax=vmax)\n", + " axes[0].set_title('Ground Truth')\n", + " axes[0].axis('off')\n", + " plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04, label=label)\n", + "\n", + " im1 = axes[1].imshow(baseline, vmin=0, vmax=vmax)\n", + " axes[1].set_title('SENSE + Regression')\n", + " axes[1].axis('off')\n", + " plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04, label=label)\n", + "\n", + " im2 = axes[2].imshow(pred, vmin=0, vmax=vmax)\n", + " axes[2].set_title('PINQI')\n", + " axes[2].axis('off')\n", + " plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04, label=label)\n", + "\n", + " diff_vmax = np.nanpercentile(difference.abs().numpy(), 90)\n", + " im3 = axes[3].imshow(difference, cmap='coolwarm', vmin=-diff_vmax, vmax=diff_vmax)\n", + " axes[3].set_title('rel. Error')\n", + " axes[3].axis('off')\n", + " plt.colorbar(im3, ax=axes[3], fraction=0.046, pad=0.04, label='%')\n", + "\n", + " def configure_optimizers(\n", + " self,\n", + " ) -> dict:\n", + " \"\"\"Configure the optimizer and the learning rate scheduler.\"\"\"\n", + " scalars = ('lambdas_raw', 'rezero')\n", + " params, scalar_params = [], []\n", + " for n, p in self.named_parameters():\n", + " if not p.requires_grad:\n", + " continue\n", + " if any(s in n for s in scalars):\n", + " scalar_params.append(p)\n", + " else:\n", + " params.append(p)\n", + " optimizer = torch.optim.AdamW(\n", + " [\n", + " {\n", + " 'params': params,\n", + " 'weight_decay': self.hparams.weight_decay,\n", + " 'lr': self.hparams.lr,\n", + " },\n", + " {\n", + " 'params': scalar_params,\n", + " 'weight_decay': 0.0,\n", + " 'lr': self.hparams.lr * 10,\n", + " },\n", + " ],\n", + " )\n", + " scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", + " optimizer,\n", + " max_lr=[self.hparams.lr, 10 * self.hparams.lr],\n", + " total_steps=self.trainer.estimated_stepping_batches,\n", + " pct_start=0.1,\n", + " div_factor=20,\n", + " final_div_factor=300,\n", + " )\n", + " return {\n", + " 'optimizer': optimizer,\n", + " 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'},\n", + " }\n", + "\n", + "\n", + "class Baseline(torch.nn.Module):\n", + " \"\"\"Baseline solution using SENSE + Regression.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " signalmodel: mrpro.operators.SignalModel,\n", + " constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp,\n", + " parameter_is_complex: Sequence[bool],\n", + " ):\n", + " \"\"\"Initialize the baseline.\"\"\"\n", + " super().__init__()\n", + " self.signalmodel = signalmodel\n", + " self.constraints_op = constraints_op\n", + " self.parameter_is_complex = parameter_is_complex\n", + "\n", + " def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> tuple[torch.Tensor, ...]:\n", + " \"\"\"Compute the baseline solution.\"\"\"\n", + " sense = mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction(\n", + " kdata, csm=csm, regularization_weight=0.01, n_iterations=3\n", + " )\n", + " images = sense(kdata).rearrange('batch time ...-> time batch ...')\n", + "\n", + " objective = mrpro.operators.functionals.L2NormSquared(images.data) @ self.signalmodel @ self.constraints_op\n", + " initial_values = tuple(\n", + " torch.zeros(\n", + " images.shape[1:],\n", + " device=images.device,\n", + " dtype=torch.complex64 if is_complex else torch.float32,\n", + " )\n", + " for is_complex in self.parameter_is_complex\n", + " )\n", + " solution = self.constraints_op(*mrpro.algorithms.optimizers.lbfgs(objective, initial_values))\n", + " return solution\n", + "\n", + "\n", + "class LogLambdasCallback(pl.Callback):\n", + " \"\"\"Log the lambdas.\"\"\"\n", + "\n", + " def on_train_batch_end(\n", + " self,\n", + " trainer: pl.Trainer,\n", + " pl_module: PinqiModule,\n", + " _outputs: dict,\n", + " _batch: BatchType,\n", + " _batch_idx: int,\n", + " ) -> None:\n", + " if trainer.global_step % 10 == 0:\n", + " lambdas = pl_module.pinqi.softplus(pl_module.pinqi.lambdas_raw).detach().cpu().numpy()\n", + " for iteration, (lambda_image, lambda_q, lambda_parameter) in enumerate(lambdas):\n", + " self.log_dict(\n", + " {\n", + " f'parameter/lambda_image_{iteration}': lambda_image,\n", + " f'parameter/lambda_q_{iteration}': lambda_q,\n", + " f'parameter/lambda_parameter_{iteration}': lambda_parameter,\n", + " },\n", + " on_step=True,\n", + " on_epoch=False,\n", + " )\n", + "\n", + "\n", + "if __name__ == '__main__':\n", + " torch.multiprocessing.set_sharing_strategy('file_system')\n", + " torch.set_float32_matmul_precision('high')\n", + " torch._inductor.config.compile_threads = 4\n", + " torch._inductor.config.worker_start_method = 'fork'\n", + " torch._dynamo.config.capture_scalar_outputs = True\n", + " torch._dynamo.config.cache_size_limit = 256\n", + " torch._functorch.config.activation_memory_budget = 0.5\n", + "\n", + " data_folder = Path(' /echo/zimmer08/brainweb')\n", + " if not data_folder.exists():\n", + " data_folder.mkdir(parents=True, exist_ok=True)\n", + " mrpro.phantoms.brainweb.download_brainweb(output_directory=data_folder, workers=2, progress=True)\n", + "\n", + " signalmodel = mrpro.operators.models.SaturationRecovery((0.2, 0.8, 4.0))\n", + " constraints_op = mrpro.operators.ConstraintsOp(\n", + " bounds=(\n", + " (-2, 2), # M0 in [-2, 2]\n", + " (0.01, 6.0), # T1 is constrained between 10 ms and 6 s\n", + " )\n", + " )\n", + " n_images = len(signalmodel.saturation_time)\n", + " parameter_is_complex = [True, False]\n", + "\n", + " dm = DataModule(\n", + " folder=data_folder,\n", + " signalmodel=signalmodel,\n", + " n_images=n_images,\n", + " batch_size=8,\n", + " num_workers=8,\n", + " size=192,\n", + " acceleration=6,\n", + " n_coils=1,\n", + " max_noise=0.3,\n", + " )\n", + "\n", + " model = PinqiModule(\n", + " signalmodel=signalmodel,\n", + " constraints_op=constraints_op,\n", + " parameter_is_complex=parameter_is_complex,\n", + " n_images=n_images,\n", + " )\n", + "\n", + " neptune_logger = NeptuneLogger(\n", + " log_model_checkpoints=False,\n", + " dependencies='infer',\n", + " )\n", + " neptune_logger.log_model_summary(model=model, max_depth=-1)\n", + "\n", + " checkpoint_callback = ModelCheckpoint(\n", + " monitor='val/loss',\n", + " mode='min',\n", + " save_top_k=2,\n", + " dirpath=Path('checkpoints') / str(neptune_logger.version),\n", + " filename='{epoch:02d}-{val/loss:.4f}',\n", + " save_last=True,\n", + " )\n", + "\n", + " strategy = 'auto' # DDPStrategy(find_unused_parameters=False)\n", + " trainer = pl.Trainer(\n", + " max_epochs=100,\n", + " accelerator='gpu',\n", + " devices=1,\n", + " strategy=strategy,\n", + " logger=neptune_logger,\n", + " callbacks=[\n", + " LearningRateMonitor(logging_interval='step'),\n", + " checkpoint_callback,\n", + " LogLambdasCallback(),\n", + " ],\n", + " log_every_n_steps=10,\n", + " gradient_clip_algorithm='norm',\n", + " gradient_clip_val=5.0,\n", + " )\n", + "\n", + " # trainer.fit(model, datamodule=dm)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "jupytext": { + "cell_metadata_filter": "mystnb,tags,-all" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/scripts/apply_pinqi.py b/examples/scripts/apply_pinqi.py new file mode 100644 index 000000000..9d7bde957 --- /dev/null +++ b/examples/scripts/apply_pinqi.py @@ -0,0 +1,461 @@ +# %% [markdown] +# # End-to-end physics informed network for quantitative MRI (PINQI) +# A recent DL approach, PINQI, approaches learned quantitative MRI by half quadratic splitting to alternate between two +# subproblems. The first is a linear image reconstruction task +# $$ +# \underset{\mathbf{x}}{\min} \frac{1}{2} \| \mathbf{A} \mathbf{x} - \mathbf{y} \|_2^2 +# + \frac{\lambda_\mathbf{x}}{2} \left\| \mathbf{x} - \mathbf{x}_{\text{reg}} \right\|_2^2 +# + \frac{\lambda_{\mathbf{q}}}{2} \left\| \mathbf{q}(\mathbf{p}) - \mathbf{x} \right\|_2^2 +# $$ +# with $\mathbf{x}$ being intermediary qualitative images, $\lambda_{\mathbf{x}}$ and $\lambda_{\mathbf{q}}$ being +# regularization strengths and $\mathbf{x}_{\text{reg}}$ denoting an image prior for regularization. +# The second, non-linear, subproblem is finding the quantitative parameters by solving +# $$ +# \underset{\mathbf{p}}{\min} \frac{\lambda_{\mathbf{q}}}{2}\left \| \mathbf{q}(\vec{p}) - \mathbf{x} \right\|_2^2 +# + \frac{\lambda_{\mathbf{p}}}{2} \left\| \mathbf{p} - \mathbf{p}_{\text{reg}} \right\|_2^2. +# $$ +# Here, $\mathbf{p}_{\text{reg}}$ is a prior on the parameter maps and $\lambda_{\mathbf{p}}$ the associated weight for +# regularization. +# In PINQI, a solution is found by iterating between both subproblems. In each iteration $k=1,\ldots,T$, +# the image and parameter priors are updated by U-Nets. The network parameters and the regularization strengths +# are trained end-to-end. +# Here, we apply a trained PINQI model to a validation set. We first define the dataset, then define the PINQI model, +# before loading the model weights and applying it to the dataset. + +# %% [markdown] +# ## Dataset +# We base the dataset on the BrainWeb phantom (`mrpro.phantoms.brainweb.BrainwebSlices`) and simulate Cartesian random +# undersampling in phase encode direction. + +# %% +from collections.abc import Sequence +from copy import deepcopy +from pathlib import Path +from typing import Literal, TypedDict + +import einops +import mrpro +import torch + +# mrpro.phantoms.brainweb.download_brainweb(workers=2, progress=True) + + +# %% +class BatchType(TypedDict): + """Typehint for a batch of data.""" + + kdata: mrpro.data.KData + csm: mrpro.data.CsmData + m0: torch.Tensor + t1: torch.Tensor + mask: torch.Tensor + + +class Dataset(torch.utils.data.Dataset[BatchType]): + """A brainweb based cartesian qMRI dataset.""" + + def __init__( + self, + folder: Path, + signalmodel: mrpro.operators.SignalModel, + n_images: int, + size: int, + acceleration: int, + n_coils: int, + max_noise: float, + orientation: Sequence[Literal['axial', 'coronal', 'sagittal']], + random: bool = True, + ): + """Initialize the dataset.""" + if random: + augment = mrpro.phantoms.brainweb.augment(size=size) + else: + augment = mrpro.phantoms.brainweb.augment( + size=size, + max_random_shear=0, + max_random_rotation=0, + max_random_scaling_factor=0, + p_horizontal_flip=0, + p_vertical_flip=1.0, + ) + self.phantom = mrpro.phantoms.brainweb.BrainwebSlices( + folder=folder, + what=('m0', 't1', 'mask'), + seed='index' if not random else 'random', + slice_preparation=augment, + orientation=orientation, + ) + self.signalmodel = deepcopy(signalmodel) + self.encoding_matrix = mrpro.data.SpatialDimension(1, size, size) + self.fov = mrpro.data.SpatialDimension(0.01, 0.25, 0.25) + self.acceleration = acceleration + self.n_coils = n_coils + self._random = random + self.max_noise = max_noise + self._n_images = n_images + + def __len__(self) -> int: + """Get the length of the dataset.""" + return len(self.phantom) + + def __getitem__(self, index: int): + """Get an item from the dataset.""" + phantom = self.phantom[index] + (images,) = self.signalmodel(phantom['m0'], phantom['t1']) + seed = int(torch.randint(0, 1000000, (1,))) if self._random else index + + traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density( + encoding_matrix=self.encoding_matrix, + seed=seed, + acceleration=self.acceleration, + fwhm_ratio=1.5, + n_center=10, + n_other=(self._n_images,), + ) + header = mrpro.data.KHeader( + encoding_matrix=self.encoding_matrix, + recon_matrix=self.encoding_matrix, + recon_fov=self.fov, + encoding_fov=self.fov, + ) + + if isinstance(self.signalmodel, mrpro.operators.models.SaturationRecovery): + header.ti = self.signalmodel.saturation_time.tolist() + elif isinstance(self.signalmodel, mrpro.operators.models.InversionRecovery): + header.ti = self.signalmodel.ti.tolist() + + fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj) + csm = mrpro.data.CsmData( + mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix), + header, + ) + images = einops.rearrange(images, 't y x -> t 1 1 y x') + (data,) = (fourier_op @ csm.as_operator())(images) + data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std() + kdata = mrpro.data.KData(header, data, traj) + return {'kdata': kdata, 'csm': csm, **phantom} + + +# %% [markdown] +# ## PINQI +# Next, We define the PINQI model. Here we can make use of the diffferntiable optimization operators in MRpro. + + +# %% +class PINQI(torch.nn.Module): + """PINQI model.""" + + def __init__( + self, + signalmodel: mrpro.operators.SignalModel, + constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp, + parameter_is_complex: Sequence[bool], + n_images: int, + n_iterations: int, + n_features_parameter_net: Sequence[int], + n_features_image_net: Sequence[int], + ): + """Initialize the PINQI model.""" + super().__init__() + self.signalmodel = mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ signalmodel @ constraints_op + self.constraints_op = constraints_op + self._n_images = n_images + self._parameter_is_complex = parameter_is_complex + real_parameters = sum(1 for c in parameter_is_complex if c) + len(parameter_is_complex) + self.parameter_net = mrpro.nn.nets.UNet( + dim=2, + channels_in=n_images * 2, + channels_out=real_parameters, + attention_depths=(-1, -2), + n_features=n_features_parameter_net, + cond_dim=128, + ) + + self.image_net = mrpro.nn.nets.UNet( + 2, + channels_in=2, + channels_out=2, + attention_depths=(), + n_features=n_features_image_net, + cond_dim=128, + ) + self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3)) + self.softplus = torch.nn.Softplus(beta=5) + self.iteration_embedding = torch.nn.Embedding(n_iterations + 1, 128) + + def objective_factory( + lambda_parameters: torch.Tensor, + image: torch.Tensor, + *parameter_reg: torch.Tensor, + ) -> torch.operators.Operator: + dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel + reg = mrpro.operators.ProximableFunctionalSeparableSum( + *[mrpro.operators.functionals.L2NormSquared(r) for r in parameter_reg] + ) + return dc + lambda_parameters * reg + + self.nonlinear_solver = mrpro.operators.OptimizerOp( + objective_factory, + lambda _l, _i, *parameter_reg: parameter_reg, + ) + # This can be done once, as the signal model is the same for all samples. + + def get_linear_solver(self, gram: mrpro.operators.LinearOperator) -> mrpro.operators.ConjugateGradientOp: + """Set up the linear solver.""" + # This needs to be done for each sample, as the undersampling pattern and csm are different for each sample, + # thus the gram operator of the acquisition operator is different for each sample. + + def operator_factory( + lambda_image: torch.Tensor, + lambda_q: torch.Tensor, + *_, + ): + return gram + lambda_image + lambda_q + + def rhs_factory( + lambda_image: torch.Tensor, + lambda_q: torch.Tensor, + image_reg: torch.Tensor, + signal: torch.Tensor, + zero_filled_image: torch.Tensor, + ): + return (zero_filled_image + lambda_image * image_reg + lambda_q * signal,) + + return mrpro.operators.ConjugateGradientOp( + operator_factory=operator_factory, + rhs_factory=rhs_factory, + ) + + def get_parameter_reg(self, image: torch.Tensor, iteration: int = 0) -> tuple[torch.Tensor, ...]: + """Get the parameter regularization.""" + image = einops.rearrange( + torch.view_as_real(image), + 'batch t 1 1 y x complex-> batch (t complex) y x', + ) + cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None] + parameters = self.parameter_net(image.contiguous(), cond=cond) + parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x') + i = 0 + result = [] + for is_complex in self._parameter_is_complex: + if is_complex: + result.append(torch.complex(parameters[i], parameters[i + 1])) + i += 2 + else: + result.append(parameters[i]) + i += 1 + return tuple(result) + + def get_image_reg(self, image: torch.Tensor, iteration: int = 0) -> torch.Tensor: + """Get the image regularization.""" + batch = image.shape[0] + image = einops.rearrange( + torch.view_as_real(image), + 'batch t 1 1 y x complex-> (batch t) complex y x', + ) + cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None] + image = image + self.image_net(image.contiguous(), cond=cond) + image = einops.rearrange(image, '(batch t) complex y x-> batch t 1 1 y x complex', batch=batch) + return torch.view_as_complex(image.contiguous()) + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> tuple[torch.Tensor, ...]: + """Estimate the quantitative parameters. + + Parameters + ---------- + kdata + The k-space data. + csm + The coil sensitivity maps. + + Returns + ------- + images + The qualitative images. + parameters + The quantitative parameters. + """ + csm_op = csm.as_operator() + fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) + acquisition_op = fourier_op @ csm_op + gram = acquisition_op.gram + (zero_filled_image,) = acquisition_op.H(kdata.data) + images = mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2) + parameters = self.get_parameter_reg(images, 0) + linear_solver = self.get_linear_solver(gram) + + for i, (lambda_image, lambda_q, lambda_parameter) in enumerate(self.softplus(self.lambdas_raw)): + # linear subproblem 1 + image_reg = self.get_image_reg(images, i) + (signal,) = self.signalmodel(*parameters) + images = linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image) + # nonlinear subproblem 2 + parameters_reg = self.get_parameter_reg(images, i + 1) + parameters = self.nonlinear_solver(lambda_parameter, images, *parameters_reg) + if self.constraints_op is not None: + # map the parameters into the constrained space + parameters = self.constraints_op(*parameters) + return parameters + + +# %% +# As a baseline methods for comparison, we use a simple non-learned approach. We reconstruct the qualitative images at +# different saturation times using iterative SENSE. We then perform a constrained non-linear least squares regression +# using L-BFGS to obtain the parameter maps. +# %% +def baseline_solution( + signalmodel: mrpro.operators.SignalModel, + constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp, + parameter_is_complex: Sequence[bool], + kdata: mrpro.data.KData, + csm: mrpro.data.CsmData, +) -> tuple[torch.Tensor, ...]: + """Compute a baseline solution using SENSE + Regression.""" + sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(kdata, csm=csm) + images = sense(kdata) + objective = mrpro.operators.functionals.L2NormSquared(images.data) @ signalmodel @ constraints_op + initial_values = tuple( + torch.zeros( + images.shape[1:], + device=images.device, + dtype=torch.complex64 if is_complex else torch.float32, + ) + for is_complex in parameter_is_complex + ) + solution = constraints_op(*mrpro.algorithms.optimizers.lbfgs(objective, initial_values)) + return solution + + +# %% +data_folder = Path('/home/zimmer08/.cache/mrpro/brainweb') + +signalmodel = mrpro.operators.models.SaturationRecovery((0.5, 1.0, 1.5, 2.0, 8.0)) +constraints_op = mrpro.operators.ConstraintsOp( + bounds=( + (-2, 2), # M0 in [-2, 2] + (0.01, 6.0), # T1 is constrained between 10 ms and 6 s + ) +) +n_images = len(signalmodel.saturation_time) +parameter_is_complex = [True, False] + + +dataset = torch.utils.data.Subset( + Dataset( + folder=data_folder, + signalmodel=signalmodel, + n_images=n_images, + size=192, + acceleration=8, + n_coils=8, + max_noise=0.05, + orientation=('axial',), + random=False, + ), + list(range(500)), +) +# %% +checkpoint = torch.load('./examples/scripts/last.ckpt', map_location='cpu') +hyper_parameters = checkpoint['hyper_parameters'] + + +pinqi = PINQI( + signalmodel=signalmodel, + constraints_op=constraints_op, + parameter_is_complex=parameter_is_complex, + n_images=n_images, + n_iterations=hyper_parameters['n_iterations'], + n_features_parameter_net=hyper_parameters['n_features_parameter_net'], + n_features_image_net=hyper_parameters['n_features_image_net'], +) +state_dict = { + k.replace('pinqi.', '').replace('_orig_mod.', ''): v + for k, v in checkpoint['state_dict'].items() + if 'baseline' not in k +} +pinqi.load_state_dict(state_dict) +# %% +batch = dataset[40] +csm, kdata = batch['csm'], batch['kdata'] + +if torch.cuda.is_available(): + pinqi, csm, kdata = pinqi.cuda(), csm.cuda(), kdata.cuda() +images, parameters = pinqi(kdata[None], csm[None]) +with torch.no_grad(): + predicted_m0, predicted_t1 = (p.cpu().detach().squeeze() for p in parameters[-1]) +baseline_m0, baseline_t1 = baseline_solution(signalmodel, constraints_op, parameter_is_complex, kdata, csm) +# %% +(ssim_t1,) = mrpro.operators.functionals.SSIM(batch['t1'][None], batch['mask'][None])(predicted_t1[None]) +(mse_t1,) = mrpro.operators.functionals.MSE(batch['t1'], batch['mask'])(predicted_t1) + +(mse_baseline,) = mrpro.operators.functionals.MSE(batch['t1'], batch['mask'])(baseline_t1) +nrmse_t1 = torch.sqrt(mse_t1) / batch['t1'][batch['mask']].max() +(ssim_baseline,) = mrpro.operators.functionals.SSIM(batch['t1'][None], batch['mask'][None])(baseline_t1[None]) +nrmse_baseline = torch.sqrt(mse_baseline) / batch['t1'][batch['mask']].max() + + +# %% +import matplotlib.pyplot as plt +from cmap import Colormap + +cmap = Colormap('lipari').to_matplotlib() + +print(f'SSIM: {ssim_baseline.item():.4f}, NRMSE: {nrmse_baseline.item():.4f}') +print(f'SSIM: {ssim_t1.item():.4f}, NRMSE: {nrmse_t1.item():.4f}') + + +fig, ax = plt.subplots( + 1, + 5, + gridspec_kw={ + 'width_ratios': [1, 1, 1, 0.28, 0.075], + 'wspace': -0.25, + }, + figsize=(6.5, 2.5), +) +baseline_t1 = baseline_t1.squeeze() +baseline_t1[~batch['mask']] = torch.nan +ax[0].imshow(baseline_t1, vmin=0, vmax=2, cmap=cmap) +ax[0].axis('off') +ax[0].set_title('SENSE + NLS') +ax[0].text( + 0.5, + -0.00, + f'SSIM: {ssim_baseline.item():.2f}', + color='black', + horizontalalignment='center', + verticalalignment='top', + transform=ax[0].transAxes, + size=11, +) +predicted_t1 = predicted_t1.squeeze() +predicted_t1[~batch['mask']] = torch.nan +ax[1].imshow(predicted_t1, vmin=0, vmax=2, cmap=cmap) +ax[1].axis('off') +ax[1].set_title('PINQI') +ax[1].text( + 0.5, + -0.0, + f'SSIM: {ssim_t1.item():.2f}', + color='black', + horizontalalignment='center', + verticalalignment='top', + transform=ax[1].transAxes, + size=11, +) + +target_t1 = batch['t1'].squeeze() +target_t1[~batch['mask']] = torch.nan +im = ax[2].imshow(target_t1, vmin=0, vmax=2, cmap=cmap) +ax[2].axis('off') +ax[2].set_title( + 'Ground Truth', +) +ax[-2].axis('off') +fig.tight_layout() +plt.colorbar(im, cax=ax[-1], label='$T_1$ (s)') +fig.savefig( + '/home/zimmer08/code/mrpro/examples/scripts/pinqi_t1_3.pdf', + bbox_inches='tight', + pad_inches=0, +) diff --git a/examples/scripts/modl.py b/examples/scripts/modl.py new file mode 100644 index 000000000..e2fb227e0 --- /dev/null +++ b/examples/scripts/modl.py @@ -0,0 +1,204 @@ +# %% +# %matplotlib inline +from collections.abc import Sequence +from pathlib import Path +from typing import TypedDict + +import matplotlib.axes +import matplotlib.pyplot as plt +import mrpro +import torch +from tqdm import tqdm + + +class BatchType(TypedDict): + """A single Batch.""" + + data: mrpro.data.KData + target: mrpro.data.IData + csm: mrpro.data.CsmData + + +class AcceleratedFastMRI(torch.utils.data.Dataset): + """An undersampled FastMRI Dataset.""" + + def __init__(self, path: Path, acceleration: float = 12, noise_level: float = 0.1): + """Create an undersampled FastMRI Dataset. + + Parameters + ---------- + path + Path to the FastMRI dataset. + acceleration + Undersampling factor; higher values mean more acceleration. Default is 12. + noise_level + Level of additive Gaussian noise applied to the FastMRI dataset. Default is 0.1. + """ + self.acceleration = acceleration + files = list(path.glob('*AXT1*')) + self.dataset = mrpro.phantoms.FastMRIKDataDataset(files) + self.noise_level = noise_level + + def __len__(self): + """Get length of the dataset.""" + return len(self.dataset) + + def __getitem__(self, index: int) -> BatchType: + """Get a single batch of data. + + Parameters + ---------- + index + Index of the batch. + + Returns + ------- + A single batch of data with keys 'data', 'target', and 'csm'.and + """ + data = self.dataset[index] + data = data.remove_readout_os() + data.data /= data.data.std() + reconstruction = mrpro.algorithms.reconstruction.DirectReconstruction( + data, csm=lambda data: mrpro.data.CsmData.from_idata_inati(data, downsampled_size=64) + ) + csm = reconstruction.csm + target = reconstruction(data) + + n = max(data.data.shape[-2:]) + distance = (torch.linspace(-1, 1, n)[:, None] ** 2 + torch.linspace(-1, 1, n) ** 2).sqrt() + random = 0.1 / (distance + 0.1) + torch.rand_like(distance) + threshold = torch.kthvalue(random.ravel(), int(n**2 * (1 - 1 / self.acceleration))).values + undersampling_mask = mrpro.utils.pad_or_crop(random > threshold, data.data.shape[-2:]) + data_undersampled = data[..., undersampling_mask].rearrange('k ... 1 -> ... k') + + noise = mrpro.utils.RandomGenerator(seed=index).randn_like(data_undersampled.data) + data_undersampled.data += self.noise_level * noise + + assert csm is not None # for mypy + return {'data': data_undersampled, 'target': target, 'csm': csm} + + +class MODL(torch.nn.Module): + """MODL network.""" + + def __init__(self, iterations: int = 8, n_features: Sequence[int] = (64, 64, 64, 64)): + """Initialize MODL network. + + Parameters + ---------- + iterations + Number of iterations. + n_features + Number of features in the network. + """ + super().__init__() + cnn = mrpro.nn.nets.BasicCNN( + dim=2, + channels_in=2, + channels_out=2, + n_features=n_features, + batch_norm=True, + ) + self.network = mrpro.nn.Residual(mrpro.nn.ComplexAsChannel(mrpro.nn.PermutedBlock((-1, -2), cnn))) + self.network = torch.compile(self.network, dynamic=True, fullgraph=True) + self.iterations = iterations + self.regularization_weights = torch.nn.Parameter(0.2 * torch.ones(iterations)) + + def __call__(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData: + """Apply MODL network. + + Parameters + ---------- + kdata + The k-space data. + csm + The coil sensitivity maps. + + Returns + ------- + The reconstructed image. + """ + return super().__call__(kdata, csm) + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> mrpro.data.IData: + """Apply the MODL network.""" + fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) + acquisition_op = fourier_op @ csm.as_operator() + (zero_filled_image,) = acquisition_op.H(kdata.data) + gram = acquisition_op.gram + data_consistency_op = mrpro.operators.ConjugateGradientOp( + operator_factory=lambda _image, weight: gram + weight, + rhs_factory=lambda image, weight: zero_filled_image + weight * image, + ) + + (image,) = mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=5) + for iteration in range(self.iterations): + regularization = self.network(image) + (image,) = data_consistency_op(regularization, self.regularization_weights[iteration]) + + return mrpro.data.IData(image, header=mrpro.data.IHeader.from_kheader(kdata.header)) + + +def plot(batch: BatchType, prediction: mrpro.data.IData, step: int) -> None: + """Plot the direct, sense, and modl reconstructions.""" + target = batch['target'].rss().cpu().squeeze() + direct = mrpro.algorithms.reconstruction.DirectReconstruction(batch['data'], csm=batch['csm'])(batch['data']) + direct = direct.rss().cpu().squeeze() + direct *= target.std() / direct.std() + sense = mrpro.algorithms.reconstruction.IterativeSENSEReconstruction(batch['data'], csm=batch['csm'])(batch['data']) + sense = sense.rss().cpu().squeeze() + prediction_ = prediction.rss().cpu().squeeze().detach() + + ssim = mrpro.operators.functionals.SSIM(mrpro.utils.pad_or_crop(target[None], (320, 320))) + + def show(ax: matplotlib.axes.Axes, data: torch.Tensor, label: str): + data = mrpro.utils.pad_or_crop(data, (320, 320)) + ax.imshow(data, vmin=0, vmax=target.max().item(), cmap='gray') + if label != 'Ground Truth': + (ssim_value,) = ssim(data[None]) + ax.text( + 0.98, + 0.1, + f'SSIM: {ssim_value.item():.2f}', + color='white', + horizontalalignment='right', + verticalalignment='top', + transform=ax.transAxes, + ) + ax.set_title(label) + ax.set_axis_off() + + fig, ax = plt.subplots(1, 4) + show(ax[0], direct, 'Direct') + show(ax[1], sense, 'CG-SENSE') + show(ax[2], prediction_, 'MODL') + show(ax[3], target, 'Ground Truth') + fig.tight_layout() + fig.savefig(f'modl_{step}.pdf', bbox_inches='tight', pad_inches=0) + + +# %%. +path = Path('/echo/allgemein/resources/publicTrainingData/fastmri/brain_multicoil_train/') +dataset = AcceleratedFastMRI(path) +dataloader = torch.utils.data.DataLoader(dataset, num_workers=16, shuffle=True, collate_fn=lambda batch: batch[0]) +modl = MODL().cuda() +optimizer = torch.optim.Adam(modl.parameters(), lr=1e-3) +pbar = tqdm(dataloader) +for i, batch in enumerate(pbar): + optimizer.zero_grad() + kdata, csm, target = (batch['data'].cuda(), batch['csm'].cuda(), batch['target'].cuda()) + prediction = modl(kdata, csm) + objective = 0.5 * mrpro.operators.functionals.MSE(target.data) - mrpro.operators.functionals.SSIM(target.data) + (loss,) = objective(prediction.data) + loss.backward() + torch.nn.utils.clip_grad_norm_(modl.parameters(), 5.0) + optimizer.step() + + pbar.set_postfix(loss=loss.item()) + if i % 200 == 0: + plot(batch, prediction, i) + print(modl.regularization_weights) + state = {'modl': modl.state_dict(), 'optimizer': optimizer.state_dict()} + torch.save(state, f'modl_{i}.pt') + +# %% diff --git a/examples/scripts/train_pinqi.py b/examples/scripts/train_pinqi.py new file mode 100644 index 000000000..d1c3dd5fd --- /dev/null +++ b/examples/scripts/train_pinqi.py @@ -0,0 +1,698 @@ +# %% +# ruff: noqa: D102, ANN201 +from collections.abc import Sequence +from copy import deepcopy +from pathlib import Path +from typing import Any, Literal, TypedDict + +import einops +import matplotlib.pyplot as plt +import mrpro +import numpy as np +import pytorch_lightning as pl # type:ignore[import-not-found] +import torch +import torch.utils.data._utils +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint # type:ignore[import-not-found] +from pytorch_lightning.loggers import NeptuneLogger # type:ignore[import-not-found] + + +class BatchType(TypedDict): + """Typehint for a batch of data.""" + + kdata: mrpro.data.KData + csm: mrpro.data.CsmData + m0: torch.Tensor + t1: torch.Tensor + mask: torch.Tensor + + +class Dataset(torch.utils.data.Dataset): + """A brainweb based cartesian qMRI dataset.""" + + def __init__( + self, + folder: Path, + signalmodel: mrpro.operators.SignalModel, + n_images: int, + size: int, + acceleration: int, + n_coils: int, + max_noise: float, + orientation: Sequence[Literal['axial', 'coronal', 'sagittal']], + random: bool = True, + ): + """Initialize the dataset.""" + if random: + augment = mrpro.phantoms.brainweb.augment(size=size) + else: + augment = mrpro.phantoms.brainweb.augment( + size=size, + max_random_shear=0, + max_random_rotation=0, + max_random_scaling_factor=0, + p_horizontal_flip=0, + p_vertical_flip=1.0, + ) + self.phantom = mrpro.phantoms.brainweb.BrainwebSlices( + folder=folder, + what=('m0', 't1', 'mask'), + seed='index' if not random else 'random', + slice_preparation=augment, + orientation=orientation, + ) + self.signalmodel = signalmodel + self.encoding_matrix = mrpro.data.SpatialDimension(1, size, size) + self.fov = mrpro.data.SpatialDimension(0.01, 0.25, 0.25) + self.acceleration = acceleration + self.n_coils = n_coils + self._random = random + self.max_noise = max_noise + self._n_images = n_images + + def __len__(self) -> int: + """Get the length of the dataset.""" + return len(self.phantom) + + def __getitem__(self, index: int): + """Get an item from the dataset.""" + phantom = self.phantom[index] + (images,) = self.signalmodel(phantom['m0'], phantom['t1']) + seed = int(torch.randint(0, 1000000, (1,))) if self._random else index + + traj = mrpro.data.traj_calculators.KTrajectoryCartesian.gaussian_variable_density( + encoding_matrix=self.encoding_matrix, + seed=seed, + acceleration=self.acceleration, + fwhm_ratio=1.5, + n_center=12, + n_other=(self._n_images,), + ) + header = mrpro.data.KHeader( + encoding_matrix=self.encoding_matrix, + recon_matrix=self.encoding_matrix, + recon_fov=self.fov, + encoding_fov=self.fov, + ) + + if isinstance(self.signalmodel, mrpro.operators.models.SaturationRecovery): + header.ti = self.signalmodel.saturation_time.tolist() + elif isinstance(self.signalmodel, mrpro.operators.models.InversionRecovery): + header.ti = self.signalmodel.ti.tolist() + + fourier_op = mrpro.operators.FourierOp(self.encoding_matrix, self.encoding_matrix, traj) + if self.n_coils > 1: + csm_tensor = mrpro.phantoms.coils.birdcage_2d(self.n_coils, self.encoding_matrix) + else: + csm_tensor = torch.ones(1, 1, *self.encoding_matrix.zyx) + csm = mrpro.data.CsmData(csm_tensor, header) + images = einops.rearrange(images, 't y x -> t 1 1 y x') + (data,) = (fourier_op @ csm.as_operator())(images) + data = data + torch.randn_like(data) * torch.rand(1) * self.max_noise * data.std() + kdata = mrpro.data.KData(header, data, traj) + return {'kdata': kdata, 'csm': csm, **phantom} + + +def collate_fn(batch: Any): # noqa: ANN401 + """Join dataclasses to a batch.""" + return torch.utils.data._utils.collate.collate( + batch, + collate_fn_map={ + mrpro.data.Dataclass: lambda batch, *, collate_fn_map: batch[0].stack(*batch[1:]), # noqa: ARG005 + **torch.utils.data._utils.collate.default_collate_fn_map, + }, + ) + + +class PINQI(torch.nn.Module): + """PINQI model.""" + + def __init__( + self, + signalmodel: mrpro.operators.SignalModel, + constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp, + parameter_is_complex: Sequence[bool], + n_images: int, + n_iterations: int, + n_features_parameter_net: Sequence[int], + n_features_image_net: Sequence[int], + ): + """Initialize the PINQI model.""" + super().__init__() + self.signalmodel = mrpro.operators.RearrangeOp('t batch ... -> batch t ...') @ signalmodel @ constraints_op + self.constraints_op = constraints_op + self._n_images = n_images + self._parameter_is_complex = parameter_is_complex + real_parameters = sum(1 for c in parameter_is_complex if c) + len(parameter_is_complex) + self.parameter_net = torch.compile( + mrpro.nn.nets.UNet( + n_dim=2, + n_channels_in=n_images * 2, + n_channels_out=real_parameters, + attention_depths=(-1, -2), + n_features=n_features_parameter_net, + cond_dim=128, + ), + dynamic=False, + fullgraph=True, + ) + self.image_net = torch.compile( + mrpro.nn.nets.UNet( + n_dim=2, + n_channels_in=2, + n_channels_out=2, + attention_depths=(), + n_features=n_features_image_net, + cond_dim=128, + ), + dynamic=False, + fullgraph=True, + ) + self.lambdas_raw = torch.nn.Parameter(torch.ones(n_iterations, 3)) + self.softplus = torch.nn.Softplus(beta=5) + self.iteration_embedding = torch.nn.Embedding(n_iterations + 1, 128) + + def objective_factory( + lambda_parameters: torch.Tensor, + image: torch.Tensor, + *parameter_reg: torch.Tensor, + ): + dc = mrpro.operators.functionals.L2NormSquared(image) @ self.signalmodel + reg = mrpro.operators.ProximableFunctionalSeparableSum( + *[mrpro.operators.functionals.L2NormSquared(r) for r in parameter_reg] + ) + return dc + lambda_parameters * reg + + self.nonlinear_solver = mrpro.operators.OptimizerOp( + objective_factory, + lambda _l, _i, *parameter_reg: parameter_reg, + ) + + def get_linear_solver(self, gram: mrpro.operators.LinearOperator): + def operator_factory( + lambda_image: torch.Tensor, + lambda_q: torch.Tensor, + *_, + ): + return gram + lambda_image + lambda_q + + def rhs_factory( + lambda_image: torch.Tensor, + lambda_q: torch.Tensor, + image_reg: torch.Tensor, + signal: torch.Tensor, + zero_filled_image: torch.Tensor, + ): + return (zero_filled_image + lambda_image * image_reg + lambda_q * signal,) + + return mrpro.operators.ConjugateGradientOp( + operator_factory=operator_factory, + rhs_factory=rhs_factory, + ) + + def get_parameter_reg(self, image: torch.Tensor, iteration: int = 0) -> tuple[torch.Tensor, ...]: + image = einops.rearrange( + torch.view_as_real(image), + 'batch t 1 1 y x complex-> batch (t complex) y x', + ) + cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None] + parameters = self.parameter_net(image.contiguous(), cond=cond) + parameters = einops.rearrange(parameters, 'batch parameters y x-> parameters batch 1 1 y x') + i = 0 + result = [] + for is_complex in self._parameter_is_complex: + if is_complex: + result.append(torch.complex(parameters[i], parameters[i + 1])) + i += 2 + else: + result.append(parameters[i]) + i += 1 + return tuple(result) + + def get_image_reg(self, image: torch.Tensor, iteration: int = 0) -> torch.Tensor: + batch = image.shape[0] + image = einops.rearrange( + torch.view_as_real(image), + 'batch t 1 1 y x complex-> (batch t) complex y x', + ) + cond = self.iteration_embedding(torch.tensor(iteration, device=image.device))[None] + image = image + self.image_net(image.contiguous(), cond=cond) + image = einops.rearrange(image, '(batch t) complex y x-> batch t 1 1 y x complex', batch=batch) + return torch.view_as_complex(image.contiguous()) + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): + csm_op = csm.as_operator() + fourier_op = mrpro.operators.FourierOp.from_kdata(kdata) + acquisition_op = fourier_op @ csm_op + gram = acquisition_op.gram + (zero_filled_image,) = acquisition_op.H(kdata.data) + images = list(mrpro.algorithms.optimizers.cg(gram, zero_filled_image, max_iterations=2)) + parameters = [self.get_parameter_reg(images[-1], 0)] + linear_solver = self.get_linear_solver(gram) + + for i, (lambda_image, lambda_q, lambda_parameter) in enumerate(self.softplus(self.lambdas_raw)): + image_reg = self.get_image_reg(images[-1], i + 1) + (signal,) = self.signalmodel(*parameters[-1]) + images.extend(linear_solver(lambda_image, lambda_q, image_reg, signal, zero_filled_image)) + parameters_reg = self.get_parameter_reg(images[-1], i + 1) + parameters.append(self.nonlinear_solver(lambda_parameter, images[-1], *parameters_reg)) + if self.constraints_op is not None: + parameters = [self.constraints_op(*p) for p in parameters] + return images, parameters + + +class DataModule(pl.LightningDataModule): + """Data module for training the PINQI model.""" + + def __init__( + self, + folder: Path, + signalmodel: mrpro.operators.SignalModel, + n_images: int, + size: int = 192, + acceleration: int = 10, + n_coils: int = 8, + max_noise: float = 0.1, + orientation_train: Sequence[Literal['axial', 'coronal', 'sagittal']] = ( + 'axial', + 'coronal', + 'sagittal', + ), + orientation_val: Sequence[Literal['axial', 'coronal', 'sagittal']] = ('axial',), + batch_size: int = 16, + num_workers: int = 4, + ): + """Initialize the data module.""" + super().__init__() + self.save_hyperparameters(ignore=['signalmodel', 'folder', 'num_workers']) + self.batch_size = batch_size + self.num_workers = num_workers + self.train_dataset = Dataset( + folder=folder, + signalmodel=signalmodel, + n_images=n_images, + size=size, + acceleration=acceleration, + n_coils=n_coils, + max_noise=max_noise, + orientation=orientation_train, + random=True, + ) + self.val_dataset = torch.utils.data.Subset( + Dataset( + folder=folder, + signalmodel=signalmodel, + n_images=n_images, + size=size, + acceleration=acceleration, + n_coils=n_coils, + max_noise=max_noise, + orientation=orientation_val, + random=False, + ), + list(range(30, 500, 20)), + ) + + def train_dataloader(self): + return torch.utils.data.DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=False, + persistent_workers=self.num_workers > 0, + collate_fn=collate_fn, + worker_init_fn=lambda *_: torch.set_num_threads(1), + ) + + def val_dataloader(self): + return torch.utils.data.DataLoader( + self.val_dataset, + batch_size=4, + shuffle=False, + num_workers=self.num_workers, + pin_memory=False, + persistent_workers=self.num_workers > 0, + collate_fn=collate_fn, + ) + + +class PinqiModule(pl.LightningModule): + """Module for training the PINQI model.""" + + def __init__( + self, + signalmodel: mrpro.operators.SignalModel, + constraints_op: mrpro.operators.ConstraintsOp, + parameter_is_complex: Sequence[bool], + n_images: int, + n_iterations: int = 4, + n_features_parameter_net: Sequence[int] = (64, 128, 192, 256), + n_features_image_net: Sequence[int] = (32, 48, 64, 96), + lr: float = 3e-4, # noqa: ARG002 + weight_decay: float = 1e-3, # noqa: ARG002 + loss_weights: Sequence[float] = (0.2, 0.1, 0.1, 0.1, 0.8), + ): + """Initialize the PINQI module.""" + super().__init__() + self.save_hyperparameters(ignore=['signalmodel', 'constraints_op']) + if len(loss_weights) != n_iterations + 1: + raise ValueError(f'loss_weights must be of length {n_iterations + 1} for {n_iterations} iterations') + signalmodel = deepcopy(signalmodel) + constraints_op = deepcopy(constraints_op) + self.pinqi = PINQI( + signalmodel=signalmodel, + constraints_op=constraints_op, + parameter_is_complex=parameter_is_complex, + n_images=n_images, + n_iterations=n_iterations, + n_features_parameter_net=n_features_parameter_net, + n_features_image_net=n_features_image_net, + ) + + self.validation_step_outputs: dict[str, list] = {} + self.baseline = Baseline(signalmodel, constraints_op, parameter_is_complex) + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData): + """Apply the PINQI model to the data.""" + return self.pinqi(kdata, csm) + + def loss(self, predictions: Sequence[torch.Tensor], batch: BatchType) -> torch.Tensor: + """Compute the loss.""" + loss = torch.tensor(0.0, device=self.device) + target_m0, target_t1, mask = map(torch.squeeze, (batch['m0'], batch['t1'], batch['mask'])) + for prediction, weight in zip(predictions, self.hparams.loss_weights, strict=False): + prediction_m0, prediction_t1 = map(torch.squeeze, prediction) + loss_t1 = torch.nn.functional.mse_loss(prediction_t1[mask], target_t1[mask]) + loss_m0 = torch.nn.functional.mse_loss( + torch.view_as_real(prediction_m0[mask]), + torch.view_as_real(target_m0[mask]), + ) + loss_outside = prediction_m0[~mask].abs().mean() + loss = loss + weight * (loss_t1 + 0.5 * loss_m0 + 0.1 * loss_outside) + return loss + + def training_step(self, batch: BatchType, _batch_idx: int) -> torch.Tensor: + """Training step.""" + _images, parameters = self(batch['kdata'], batch['csm']) + loss = self.loss(parameters, batch) + self.log( + 'train/loss', + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + batch_size=len(batch['mask']), + ) + return loss + + def validation_step(self, batch: BatchType, batch_idx: int) -> None: + """Validate. + + Needs to be adapted for other signal models than Saturation Recovery. + """ + _images, parameters = self(batch['kdata'], batch['csm']) + loss = self.loss(parameters, batch) + + pred_m0, pred_t1 = parameters[-1] + target_t1, target_m0 = batch['t1'][:, None, None], batch['m0'][:, None, None] + mask = batch['mask'] + batch_size = len(batch['mask']) + (ssim_t1,) = mrpro.operators.functionals.SSIM(target_t1, mask)(pred_t1) + (l1_t1,) = mrpro.operators.functionals.L1Norm(target_t1, mask)(pred_t1) + (l1_m0,) = mrpro.operators.functionals.L1Norm(target_m0, mask)(pred_m0) + self.log('val/ssim_t1', ssim_t1, on_epoch=True, sync_dist=True, batch_size=batch_size) + self.log('val/l1_t1', l1_t1, on_epoch=True, sync_dist=True, batch_size=batch_size) + self.log('val/l1_m0', l1_m0, on_epoch=True, sync_dist=True, batch_size=batch_size) + self.log('val/loss', loss, on_epoch=True, sync_dist=True, batch_size=batch_size) + + if batch_idx == 0 and self.trainer.is_global_zero: + self.validation_step_outputs['target_t1'] = batch['t1'].cpu() + self.validation_step_outputs['pred_t1'] = pred_t1.cpu() + self.validation_step_outputs['pred_m0'] = pred_m0.cpu() + self.validation_step_outputs['target_m0'] = target_m0.cpu() + self.validation_step_outputs['mask'] = batch['mask'].cpu() + baseline_m0, baseline_t1 = self.baseline(batch['kdata'], batch['csm']) + self.validation_step_outputs['baseline_t1'] = baseline_t1.cpu() + self.validation_step_outputs['baseline_m0'] = baseline_m0.cpu() + + def on_validation_epoch_end(self): + """Validate. + + Needs to be adapted for other signal models than Saturation Recovery. + """ + if not self.trainer.is_global_zero: + return + outputs = self.validation_step_outputs + + samples = len(outputs['mask']) + fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 16), squeeze=False) + + for i in range(samples): + self.result_plot( + outputs['target_t1'][i], + outputs['pred_t1'][i], + outputs['mask'][i], + axes[:, i], + outputs['baseline_t1'][i], + '$T_1$ (s)', + ) + fig.suptitle(f'$T_1$ Epoch {self.current_epoch}') + self.logger.run['val/images/t1'].log(fig) + plt.close(fig) + + fig, axes = plt.subplots(4, samples, figsize=(4 * samples, 12)) + for i in range(samples): + self.result_plot( + outputs['target_m0'][i].abs(), + outputs['pred_m0'][i].abs(), + outputs['mask'][i], + axes[:, i], + outputs['baseline_m0'][i].abs(), + '$|M_0|$ (a.u.)', + ) + fig.suptitle(f'$|M_0|$ Epoch {self.current_epoch}') + self.logger.run['val/images/m0'].log(fig) + plt.close(fig) + self.validation_step_outputs.clear() + + def result_plot( + self, + target: torch.Tensor, + pred: torch.Tensor, + mask: torch.Tensor, + axes: Sequence[plt.Axes], + baseline: torch.Tensor, + label: str, + ) -> None: + """Plot the results.""" + target = target.squeeze().cpu() + pred = pred.squeeze().detach().cpu() + mask = mask.squeeze().detach().bool().cpu() + baseline = baseline.squeeze().detach().cpu() + target[~mask] = torch.nan + pred[~mask] = torch.nan + baseline[~mask] = torch.nan + difference = (target - pred) / target * 100 + vmax = np.nanmax(target.numpy()) + + im0 = axes[0].imshow(target, vmin=0, vmax=vmax) + axes[0].set_title('Ground Truth') + axes[0].axis('off') + plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04, label=label) + + im1 = axes[1].imshow(baseline, vmin=0, vmax=vmax) + axes[1].set_title('SENSE + Regression') + axes[1].axis('off') + plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04, label=label) + + im2 = axes[2].imshow(pred, vmin=0, vmax=vmax) + axes[2].set_title('PINQI') + axes[2].axis('off') + plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04, label=label) + + diff_vmax = np.nanpercentile(difference.abs().numpy(), 90) + im3 = axes[3].imshow(difference, cmap='coolwarm', vmin=-diff_vmax, vmax=diff_vmax) + axes[3].set_title('rel. Error') + axes[3].axis('off') + plt.colorbar(im3, ax=axes[3], fraction=0.046, pad=0.04, label='%') + + def configure_optimizers( + self, + ) -> dict: + """Configure the optimizer and the learning rate scheduler.""" + scalars = ('lambdas_raw', 'rezero') + params, scalar_params = [], [] + for n, p in self.named_parameters(): + if not p.requires_grad: + continue + if any(s in n for s in scalars): + scalar_params.append(p) + else: + params.append(p) + optimizer = torch.optim.AdamW( + [ + { + 'params': params, + 'weight_decay': self.hparams.weight_decay, + 'lr': self.hparams.lr, + }, + { + 'params': scalar_params, + 'weight_decay': 0.0, + 'lr': self.hparams.lr * 10, + }, + ], + ) + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=[self.hparams.lr, 10 * self.hparams.lr], + total_steps=self.trainer.estimated_stepping_batches, + pct_start=0.1, + div_factor=20, + final_div_factor=300, + ) + return { + 'optimizer': optimizer, + 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}, + } + + +class Baseline(torch.nn.Module): + """Baseline solution using SENSE + Regression.""" + + def __init__( + self, + signalmodel: mrpro.operators.SignalModel, + constraints_op: mrpro.operators.ConstraintsOp | mrpro.operators.MultiIdentityOp, + parameter_is_complex: Sequence[bool], + ): + """Initialize the baseline.""" + super().__init__() + self.signalmodel = signalmodel + self.constraints_op = constraints_op + self.parameter_is_complex = parameter_is_complex + + def forward(self, kdata: mrpro.data.KData, csm: mrpro.data.CsmData) -> tuple[torch.Tensor, ...]: + """Compute the baseline solution.""" + sense = mrpro.algorithms.reconstruction.RegularizedIterativeSENSEReconstruction( + kdata, csm=csm, regularization_weight=0.01, n_iterations=3 + ) + images = sense(kdata).rearrange('batch time ...-> time batch ...') + + objective = mrpro.operators.functionals.L2NormSquared(images.data) @ self.signalmodel @ self.constraints_op + initial_values = tuple( + torch.zeros( + images.shape[1:], + device=images.device, + dtype=torch.complex64 if is_complex else torch.float32, + ) + for is_complex in self.parameter_is_complex + ) + solution = self.constraints_op(*mrpro.algorithms.optimizers.lbfgs(objective, initial_values)) + return solution + + +class LogLambdasCallback(pl.Callback): + """Log the lambdas.""" + + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: PinqiModule, + _outputs: dict, + _batch: BatchType, + _batch_idx: int, + ) -> None: + if trainer.global_step % 10 == 0: + lambdas = pl_module.pinqi.softplus(pl_module.pinqi.lambdas_raw).detach().cpu().numpy() + for iteration, (lambda_image, lambda_q, lambda_parameter) in enumerate(lambdas): + self.log_dict( + { + f'parameter/lambda_image_{iteration}': lambda_image, + f'parameter/lambda_q_{iteration}': lambda_q, + f'parameter/lambda_parameter_{iteration}': lambda_parameter, + }, + on_step=True, + on_epoch=False, + ) + + +if __name__ == '__main__': + torch.multiprocessing.set_sharing_strategy('file_system') + torch.set_float32_matmul_precision('high') + torch._inductor.config.compile_threads = 4 + torch._inductor.config.worker_start_method = 'fork' + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.cache_size_limit = 256 + torch._functorch.config.activation_memory_budget = 0.5 + + data_folder = Path(' /echo/zimmer08/brainweb') + if not data_folder.exists(): + data_folder.mkdir(parents=True, exist_ok=True) + mrpro.phantoms.brainweb.download_brainweb(output_directory=data_folder, workers=2, progress=True) + + signalmodel = mrpro.operators.models.SaturationRecovery((0.2, 0.8, 4.0)) + constraints_op = mrpro.operators.ConstraintsOp( + bounds=( + (-2, 2), # M0 in [-2, 2] + (0.01, 6.0), # T1 is constrained between 10 ms and 6 s + ) + ) + n_images = len(signalmodel.saturation_time) + parameter_is_complex = [True, False] + + dm = DataModule( + folder=data_folder, + signalmodel=signalmodel, + n_images=n_images, + batch_size=8, + num_workers=8, + size=192, + acceleration=6, + n_coils=1, + max_noise=0.3, + ) + + model = PinqiModule( + signalmodel=signalmodel, + constraints_op=constraints_op, + parameter_is_complex=parameter_is_complex, + n_images=n_images, + ) + + neptune_logger = NeptuneLogger( + log_model_checkpoints=False, + dependencies='infer', + ) + neptune_logger.log_model_summary(model=model, max_depth=-1) + + checkpoint_callback = ModelCheckpoint( + monitor='val/loss', + mode='min', + save_top_k=2, + dirpath=Path('checkpoints') / str(neptune_logger.version), + filename='{epoch:02d}-{val/loss:.4f}', + save_last=True, + ) + + strategy = 'auto' # DDPStrategy(find_unused_parameters=False) + trainer = pl.Trainer( + max_epochs=100, + accelerator='gpu', + devices=1, + strategy=strategy, + logger=neptune_logger, + callbacks=[ + LearningRateMonitor(logging_interval='step'), + checkpoint_callback, + LogLambdasCallback(), + ], + log_every_n_steps=10, + gradient_clip_algorithm='norm', + gradient_clip_val=5.0, + ) + + # trainer.fit(model, datamodule=dm) + +# %% diff --git a/src/mrpro/phantoms/__init__.py b/src/mrpro/phantoms/__init__.py index d8265a504..ff81fccd8 100644 --- a/src/mrpro/phantoms/__init__.py +++ b/src/mrpro/phantoms/__init__.py @@ -17,4 +17,4 @@ "brainweb", "coils", "mdcnn" -] \ No newline at end of file +] diff --git a/src/mrpro/phantoms/brainweb.py b/src/mrpro/phantoms/brainweb.py index 0764bb84c..076728ae5 100644 --- a/src/mrpro/phantoms/brainweb.py +++ b/src/mrpro/phantoms/brainweb.py @@ -227,6 +227,22 @@ def trim_indices(mask: torch.Tensor) -> tuple[slice, slice]: return slice(row_min, row_max), slice(col_min, col_max) +VALUES_ULF_RANDOMIZED: Mapping[TClassNames, BrainwebTissue] = MappingProxyType( + { + 'skl': BrainwebTissue((0.100, 0.400), (0.005, 0.015), (0.00, 0.05), (-0.2, 0.2)), + 'gry': BrainwebTissue((0.350, 0.430), (0.090, 0.115), (0.70, 1.00), (-0.2, 0.2)), + 'wht': BrainwebTissue((0.240, 0.280), (0.075, 0.085), (0.50, 0.90), (-0.2, 0.2)), + 'csf': BrainwebTissue((1.500, 2.500), (1.000, 1.600), (0.95, 1.00), (-0.2, 0.2)), + 'mrw': BrainwebTissue((0.150, 0.250), (0.060, 0.100), (0.70, 1.00), (-0.2, 0.2)), + 'dura': BrainwebTissue((0.300, 0.600), (0.100, 0.200), (0.90, 1.00), (-0.2, 0.2)), + 'fat': BrainwebTissue((0.120, 0.160), (0.080, 0.130), (0.90, 1.00), (-0.2, 0.2)), + 'fat2': BrainwebTissue((0.140, 0.180), (0.080, 0.130), (0.60, 0.90), (-0.2, 0.2)), + 'mus': BrainwebTissue((0.160, 0.200), (0.035, 0.045), (0.90, 1.00), (-0.2, 0.2)), + 'm-s': BrainwebTissue((0.200, 0.400), (0.100, 0.250), (0.90, 1.00), (-0.2, 0.2)), + 'ves': BrainwebTissue((0.300, 0.500), (0.150, 0.300), (0.80, 1.00), (-0.2, 0.2)), + } +) + VALUES_3T_RANDOMIZED: Mapping[TClassNames, BrainwebTissue] = MappingProxyType( { 'skl': BrainwebTissue((0.000, 2.000), (0.000, 0.010), (0.00, 0.05), (-0.2, 0.2)),