diff --git a/src/mrpro/operators/PatchOp.py b/src/mrpro/operators/PatchOp.py index a1b7c7abb..0e8939ff2 100644 --- a/src/mrpro/operators/PatchOp.py +++ b/src/mrpro/operators/PatchOp.py @@ -21,6 +21,7 @@ def __init__( stride: Sequence[int] | int | None = None, dilation: Sequence[int] | int = 1, domain_size: int | Sequence[int] | None = None, + flatten_patches: bool = True, ) -> None: """Initialize the PatchOp. @@ -38,6 +39,9 @@ def __init__( Size of the domain in the dimnsions `dim`. If None, it is inferred from the input tensor on the first call. This is only used in the adjoint method. + flatten_patches + If True, flatten the leading grid dimensions to a single patch dimension. + If False, keep shape ``(*grid_size, ...)`` for the forward output. """ super().__init__() self.dim = (dim,) if isinstance(dim, int) else dim @@ -60,6 +64,7 @@ def check(param: int | Sequence[int], name: str) -> tuple[int, ...]: self.stride = check(stride, 'stride') if stride is not None else self.patch_size self.dilation = check(dilation, 'dilation') self.domain_size = check(domain_size, 'domain_size') if domain_size is not None else None + self.flatten_patches = flatten_patches def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor,]: """Extract N-dimensional patches from an input tensor using a sliding window. @@ -100,9 +105,59 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: stride=self.stride, dilation=self.dilation, ) - patches = patches.flatten(start_dim=0, end_dim=len(self.dim) - 1) + if self.flatten_patches: + patches = patches.flatten(start_dim=0, end_dim=len(self.dim) - 1) return (patches,) + def _adjoint_fast(self, patches: torch.Tensor) -> torch.Tensor: + """Adjoint via reshape/permute for non-overlapping patches.""" + assert self.domain_size is not None # mypy # noqa: S101 + grid = tuple(s // p for s, p in zip(self.domain_size, self.patch_size, strict=True)) + n_dim = len(grid) + if self.flatten_patches: + patches = patches.unflatten(0, grid) + permutation: list[int] = [] + reshape: list[int] = [] + dim = [d % (patches.ndim - n_dim) for d in self.dim] + for i, size in enumerate(patches.shape[n_dim:]): + if i in dim: + j = dim.index(i) + permutation.extend([j, n_dim + i]) + reshape.append(grid[j] * self.patch_size[j]) + else: + permutation.append(n_dim + i) + reshape.append(size) + return patches.permute(*permutation).reshape(reshape) + + def _adjoint_scatter(self, patches: torch.Tensor) -> torch.Tensor: + """Adjoint via scatter for overlapping patches.""" + assert self.domain_size is not None # mypy # noqa: S101 + k = len(self.dim) + if not self.flatten_patches: + patches = patches.flatten(start_dim=0, end_dim=k - 1) + output_shape_ = list(patches.shape[1:]) + for dim, size in zip(self.dim, self.domain_size, strict=True): + output_shape_[dim] = size + output_shape = torch.Size(output_shape_) + indices = torch.arange(output_shape.numel(), device=patches.device).reshape(output_shape_) + windowed_indices = sliding_window( + x=indices, + window_shape=self.patch_size, + dim=self.dim, + stride=self.stride, + dilation=self.dilation, + ).flatten(start_dim=0, end_dim=k - 1) + if windowed_indices.shape[0] != patches.shape[0]: + raise ValueError( + f'Number of patches {patches.shape[0]} does not match the number of ' + f'expected patches {windowed_indices.shape[0]}' + ) + + assembled = patches.new_zeros(output_shape.numel()) + assembled.scatter_add_(dim=0, index=windowed_indices.flatten(), src=patches.flatten()) + assembled = assembled.reshape(output_shape) + return assembled + def adjoint( self, patches: torch.Tensor, @@ -127,26 +182,11 @@ def adjoint( """ if self.domain_size is None: raise ValueError('Domain size is not set. Please call forward first or set it at initialization.') - - output_shape_ = list(patches.shape[1:]) - for dim, size in zip(self.dim, self.domain_size, strict=True): - output_shape_[dim] = size - output_shape = torch.Size(output_shape_) - indices = torch.arange(output_shape.numel(), device=patches.device).reshape(output_shape_) - windowed_indices = sliding_window( - x=indices, - window_shape=self.patch_size, - dim=self.dim, - stride=self.stride, - dilation=self.dilation, - ).flatten(start_dim=0, end_dim=len(self.dim) - 1) - if windowed_indices.shape[0] != patches.shape[0]: - raise ValueError( - f'Number of patches {patches.shape[0]} does not match the number of ' - f'expected patches {windowed_indices.shape[0]}' - ) - - assembled = patches.new_zeros(output_shape.numel()) - assembled.scatter_add_(dim=0, index=windowed_indices.flatten(), src=patches.flatten()) - assembled = assembled.reshape(output_shape) - return (assembled,) + if ( + self.stride == self.patch_size # no overlap + and all(d == 1 for d in self.dilation) # no dilation + and all(s % p == 0 for s, p in zip(self.domain_size, self.patch_size, strict=True)) # divisible + ): + return (self._adjoint_fast(patches),) + else: + return (self._adjoint_scatter(patches),) diff --git a/tests/operators/test_patch_op.py b/tests/operators/test_patch_op.py index 4c435b04f..e518b4d32 100644 --- a/tests/operators/test_patch_op.py +++ b/tests/operators/test_patch_op.py @@ -1,6 +1,7 @@ """Tests for Rearrange Operator.""" from collections.abc import Sequence +from typing import Any import pytest import torch @@ -14,13 +15,15 @@ [ ((3, 4, 5), {'dim': (0, 1), 'patch_size': (1, 3), 'stride': (3, 1), 'dilation': (2, 1)}, (2, 1, 3, 5)), ((1, 20), {'dim': -1, 'patch_size': 3, 'stride': 3, 'dilation': 5}, (4, 1, 3)), + ((9, 16), {'dim': (-1, 0), 'patch_size': (2, 3), 'stride': (2, 3), 'dilation': 1}, (24, 3, 2)), + ((9, 16), {'dim': (-1, 0), 'patch_size': (2, 3), 'stride': None, 'flatten_patches': False}, (8, 3, 3, 2)), ], ) @TESTCASES def test_patch_op_adjointness( - input_shape: Sequence[int], arguments: dict[str, int | Sequence[int]], output_shape: Sequence[int] + input_shape: Sequence[int], arguments: dict[str, Any], output_shape: Sequence[int] ) -> None: """Test adjointness and shape of Rearrange Op.""" rng = RandomGenerator(seed=0) @@ -36,9 +39,7 @@ def test_patch_op_adjointness( @TESTCASES -def test_patch_op_autodiff( - input_shape: Sequence[int], arguments: dict[str, int | Sequence[int]], output_shape: Sequence[int] -) -> None: +def test_patch_op_autodiff(input_shape: Sequence[int], arguments: dict[str, Any], output_shape: Sequence[int]) -> None: """Test autodiff works for PatchOp.""" rng = RandomGenerator(seed=0) u = rng.complex64_tensor(size=input_shape)