Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 64 additions & 24 deletions src/mrpro/operators/PatchOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
stride: Sequence[int] | int | None = None,
dilation: Sequence[int] | int = 1,
domain_size: int | Sequence[int] | None = None,
flatten_patches: bool = True,
) -> None:
"""Initialize the PatchOp.

Expand All @@ -38,6 +39,9 @@ def __init__(
Size of the domain in the dimnsions `dim`.
If None, it is inferred from the input tensor on the first call.
This is only used in the adjoint method.
flatten_patches
If True, flatten the leading grid dimensions to a single patch dimension.
If False, keep shape ``(*grid_size, ...)`` for the forward output.
"""
super().__init__()
self.dim = (dim,) if isinstance(dim, int) else dim
Expand All @@ -60,6 +64,7 @@ def check(param: int | Sequence[int], name: str) -> tuple[int, ...]:
self.stride = check(stride, 'stride') if stride is not None else self.patch_size
self.dilation = check(dilation, 'dilation')
self.domain_size = check(domain_size, 'domain_size') if domain_size is not None else None
self.flatten_patches = flatten_patches

def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Extract N-dimensional patches from an input tensor using a sliding window.
Expand Down Expand Up @@ -100,9 +105,59 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
stride=self.stride,
dilation=self.dilation,
)
patches = patches.flatten(start_dim=0, end_dim=len(self.dim) - 1)
if self.flatten_patches:
patches = patches.flatten(start_dim=0, end_dim=len(self.dim) - 1)
return (patches,)

def _adjoint_fast(self, patches: torch.Tensor) -> torch.Tensor:
"""Adjoint via reshape/permute for non-overlapping patches."""
assert self.domain_size is not None # mypy # noqa: S101
grid = tuple(s // p for s, p in zip(self.domain_size, self.patch_size, strict=True))
n_dim = len(grid)
if self.flatten_patches:
patches = patches.unflatten(0, grid)
permutation: list[int] = []
reshape: list[int] = []
dim = [d % (patches.ndim - n_dim) for d in self.dim]
for i, size in enumerate(patches.shape[n_dim:]):
if i in dim:
j = dim.index(i)
permutation.extend([j, n_dim + i])
reshape.append(grid[j] * self.patch_size[j])
else:
permutation.append(n_dim + i)
reshape.append(size)
return patches.permute(*permutation).reshape(reshape)

def _adjoint_scatter(self, patches: torch.Tensor) -> torch.Tensor:
"""Adjoint via scatter for overlapping patches."""
assert self.domain_size is not None # mypy # noqa: S101
k = len(self.dim)
if not self.flatten_patches:
patches = patches.flatten(start_dim=0, end_dim=k - 1)
output_shape_ = list(patches.shape[1:])
for dim, size in zip(self.dim, self.domain_size, strict=True):
output_shape_[dim] = size
output_shape = torch.Size(output_shape_)
indices = torch.arange(output_shape.numel(), device=patches.device).reshape(output_shape_)
windowed_indices = sliding_window(
x=indices,
window_shape=self.patch_size,
dim=self.dim,
stride=self.stride,
dilation=self.dilation,
).flatten(start_dim=0, end_dim=k - 1)
if windowed_indices.shape[0] != patches.shape[0]:
raise ValueError(
f'Number of patches {patches.shape[0]} does not match the number of '
f'expected patches {windowed_indices.shape[0]}'
)

assembled = patches.new_zeros(output_shape.numel())
assembled.scatter_add_(dim=0, index=windowed_indices.flatten(), src=patches.flatten())
assembled = assembled.reshape(output_shape)
return assembled

def adjoint(
self,
patches: torch.Tensor,
Expand All @@ -127,26 +182,11 @@ def adjoint(
"""
if self.domain_size is None:
raise ValueError('Domain size is not set. Please call forward first or set it at initialization.')

output_shape_ = list(patches.shape[1:])
for dim, size in zip(self.dim, self.domain_size, strict=True):
output_shape_[dim] = size
output_shape = torch.Size(output_shape_)
indices = torch.arange(output_shape.numel(), device=patches.device).reshape(output_shape_)
windowed_indices = sliding_window(
x=indices,
window_shape=self.patch_size,
dim=self.dim,
stride=self.stride,
dilation=self.dilation,
).flatten(start_dim=0, end_dim=len(self.dim) - 1)
if windowed_indices.shape[0] != patches.shape[0]:
raise ValueError(
f'Number of patches {patches.shape[0]} does not match the number of '
f'expected patches {windowed_indices.shape[0]}'
)

assembled = patches.new_zeros(output_shape.numel())
assembled.scatter_add_(dim=0, index=windowed_indices.flatten(), src=patches.flatten())
assembled = assembled.reshape(output_shape)
return (assembled,)
if (
self.stride == self.patch_size # no overlap
and all(d == 1 for d in self.dilation) # no dilation
and all(s % p == 0 for s, p in zip(self.domain_size, self.patch_size, strict=True)) # divisible
):
return (self._adjoint_fast(patches),)
else:
return (self._adjoint_scatter(patches),)
9 changes: 5 additions & 4 deletions tests/operators/test_patch_op.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for Rearrange Operator."""

from collections.abc import Sequence
from typing import Any

import pytest
import torch
Expand All @@ -14,13 +15,15 @@
[
((3, 4, 5), {'dim': (0, 1), 'patch_size': (1, 3), 'stride': (3, 1), 'dilation': (2, 1)}, (2, 1, 3, 5)),
((1, 20), {'dim': -1, 'patch_size': 3, 'stride': 3, 'dilation': 5}, (4, 1, 3)),
((9, 16), {'dim': (-1, 0), 'patch_size': (2, 3), 'stride': (2, 3), 'dilation': 1}, (24, 3, 2)),
((9, 16), {'dim': (-1, 0), 'patch_size': (2, 3), 'stride': None, 'flatten_patches': False}, (8, 3, 3, 2)),
],
)


@TESTCASES
def test_patch_op_adjointness(
input_shape: Sequence[int], arguments: dict[str, int | Sequence[int]], output_shape: Sequence[int]
input_shape: Sequence[int], arguments: dict[str, Any], output_shape: Sequence[int]
) -> None:
"""Test adjointness and shape of Rearrange Op."""
rng = RandomGenerator(seed=0)
Expand All @@ -36,9 +39,7 @@ def test_patch_op_adjointness(


@TESTCASES
def test_patch_op_autodiff(
input_shape: Sequence[int], arguments: dict[str, int | Sequence[int]], output_shape: Sequence[int]
) -> None:
def test_patch_op_autodiff(input_shape: Sequence[int], arguments: dict[str, Any], output_shape: Sequence[int]) -> None:
"""Test autodiff works for PatchOp."""
rng = RandomGenerator(seed=0)
u = rng.complex64_tensor(size=input_shape)
Expand Down
Loading