From 0703f160feb6091d6a48c604f8426bd036ca2ff8 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 02:06:36 +0100 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- src/mrpro/operators/PatchOp.py | 74 ++++++++++++++++++++++---------- tests/operators/test_patch_op.py | 1 + 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/src/mrpro/operators/PatchOp.py b/src/mrpro/operators/PatchOp.py index 190bf28fb..9745ed331 100644 --- a/src/mrpro/operators/PatchOp.py +++ b/src/mrpro/operators/PatchOp.py @@ -103,6 +103,49 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: patches = patches.flatten(start_dim=0, end_dim=len(self.dim) - 1) return (patches,) + def _adjoint_fast(self, patches: torch.Tensor) -> torch.Tensor: + """Adjoint via reshape/permute for non-overlapping patches.""" + assert self.domain_size is not None # mypy # noqa: S101 + grid = tuple(s // p for s, p in zip(self.domain_size, self.patch_size, strict=True)) + permutation: list[int] = [] + reshape: list[int] = [] + dim = [d % (patches.ndim - 1) for d in self.dim] + for i, size in enumerate(patches.shape[1:]): + if i in dim: + j = dim.index(i) + permutation.extend([j, len(self.dim) + i]) + reshape.append(grid[j] * self.patch_size[j]) + else: + permutation.append(len(self.dim) + i) + reshape.append(size) + return patches.unflatten(0, grid).permute(*permutation).reshape(reshape) + + def _adjoint_scatter(self, patches: torch.Tensor) -> torch.Tensor: + """Adjoint via scatter for overlapping patches.""" + assert self.domain_size is not None # mypy # noqa: S101 + output_shape_ = list(patches.shape[1:]) + for dim, size in zip(self.dim, self.domain_size, strict=True): + output_shape_[dim] = size + output_shape = torch.Size(output_shape_) + indices = torch.arange(output_shape.numel(), device=patches.device).reshape(output_shape_) + windowed_indices = sliding_window( + x=indices, + window_shape=self.patch_size, + dim=self.dim, + stride=self.stride, + dilation=self.dilation, + ).flatten(start_dim=0, end_dim=len(self.dim) - 1) + if windowed_indices.shape[0] != patches.shape[0]: + raise ValueError( + f'Number of patches {patches.shape[0]} does not match the number of ' + f'expected patches {windowed_indices.shape[0]}' + ) + + assembled = patches.new_zeros(output_shape.numel()) + assembled.scatter_add_(dim=0, index=windowed_indices.flatten(), src=patches.flatten()) + assembled = assembled.reshape(output_shape) + return assembled + def adjoint( self, patches: torch.Tensor, @@ -127,26 +170,11 @@ def adjoint( """ if self.domain_size is None: raise ValueError('Domain size is not set. Please call forward first or set it at initialization.') - - output_shape_ = list(patches.shape[1:]) - for dim, size in zip(self.dim, self.domain_size, strict=True): - output_shape_[dim] = size - output_shape = torch.Size(output_shape_) - indices = torch.arange(output_shape.numel(), device=patches.device).reshape(output_shape_) - windowed_indices = sliding_window( - x=indices, - window_shape=self.patch_size, - dim=self.dim, - stride=self.stride, - dilation=self.dilation, - ).flatten(start_dim=0, end_dim=len(self.dim) - 1) - if windowed_indices.shape[0] != patches.shape[0]: - raise ValueError( - f'Number of patches {patches.shape[0]} does not match the number of ' - f'expected patches {windowed_indices.shape[0]}' - ) - - assembled = patches.new_zeros(output_shape.numel()) - assembled.scatter_add_(dim=0, index=windowed_indices.flatten(), src=patches.flatten()) - assembled = assembled.reshape(output_shape) - return (assembled,) + if ( + self.stride == self.patch_size # no overlap + and all(d == 1 for d in self.dilation) # no dilation + and all(s % p == 0 for s, p in zip(self.domain_size, self.patch_size, strict=True)) # divisible + ): + return (self._adjoint_fast(patches),) + else: + return (self._adjoint_scatter(patches),) diff --git a/tests/operators/test_patch_op.py b/tests/operators/test_patch_op.py index 46b856178..6aeaf8868 100644 --- a/tests/operators/test_patch_op.py +++ b/tests/operators/test_patch_op.py @@ -14,6 +14,7 @@ [ ((3, 4, 5), {'dim': (0, 1), 'patch_size': (1, 3), 'stride': (3, 1), 'dilation': (2, 1)}, (2, 1, 3, 5)), ((1, 20), {'dim': -1, 'patch_size': 3, 'stride': 3, 'dilation': 5}, (4, 1, 3)), + ((9, 16), {'dim': (-1, 0), 'patch_size': (2, 3), 'stride': (2, 3), 'dilation': 1}, (24, 3, 2)), ], ) From 80448ed8a180ae41f13d60c6bafedc892d2fe0f6 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 10 Feb 2026 14:12:45 +0100 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- src/mrpro/operators/PatchOp.py | 26 +++++++++++++++++++------- tests/operators/test_patch_op.py | 8 ++++---- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/mrpro/operators/PatchOp.py b/src/mrpro/operators/PatchOp.py index 9745ed331..9339f74d0 100644 --- a/src/mrpro/operators/PatchOp.py +++ b/src/mrpro/operators/PatchOp.py @@ -21,6 +21,7 @@ def __init__( stride: Sequence[int] | int | None = None, dilation: Sequence[int] | int = 1, domain_size: int | Sequence[int] | None = None, + flatten_patches: bool = True, ) -> None: """Initialize the PatchOp. @@ -38,6 +39,9 @@ def __init__( Size of the domain in the dimnsions `dim`. If None, it is inferred from the input tensor on the first call. This is only used in the adjoint method. + flatten_patches + If True, flatten the leading grid dimensions to a single patch dimension. + If False, keep shape ``(*grid_size, ...)`` for the forward output. """ super().__init__() self.dim = (dim,) if isinstance(dim, int) else dim @@ -60,6 +64,7 @@ def check(param: int | Sequence[int], name: str) -> tuple[int, ...]: self.stride = check(stride, 'stride') if stride is not None else self.patch_size self.dilation = check(dilation, 'dilation') self.domain_size = check(domain_size, 'domain_size') if domain_size is not None else None + self.flatten_patches = flatten_patches def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor,]: """Extract N-dimensional patches from an input tensor using a sliding window. @@ -100,29 +105,36 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: stride=self.stride, dilation=self.dilation, ) - patches = patches.flatten(start_dim=0, end_dim=len(self.dim) - 1) + if self.flatten_patches: + patches = patches.flatten(start_dim=0, end_dim=len(self.dim) - 1) return (patches,) def _adjoint_fast(self, patches: torch.Tensor) -> torch.Tensor: """Adjoint via reshape/permute for non-overlapping patches.""" assert self.domain_size is not None # mypy # noqa: S101 grid = tuple(s // p for s, p in zip(self.domain_size, self.patch_size, strict=True)) + n_dim = len(grid) + if self.flatten_patches: + patches = patches.unflatten(0, grid) permutation: list[int] = [] reshape: list[int] = [] - dim = [d % (patches.ndim - 1) for d in self.dim] - for i, size in enumerate(patches.shape[1:]): + dim = [d % (patches.ndim - n_dim) for d in self.dim] + for i, size in enumerate(patches.shape[n_dim:]): if i in dim: j = dim.index(i) - permutation.extend([j, len(self.dim) + i]) + permutation.extend([j, n_dim + i]) reshape.append(grid[j] * self.patch_size[j]) else: - permutation.append(len(self.dim) + i) + permutation.append(n_dim + i) reshape.append(size) - return patches.unflatten(0, grid).permute(*permutation).reshape(reshape) + return patches.permute(*permutation).reshape(reshape) def _adjoint_scatter(self, patches: torch.Tensor) -> torch.Tensor: """Adjoint via scatter for overlapping patches.""" assert self.domain_size is not None # mypy # noqa: S101 + k = len(self.dim) + if not self.flatten_patches: + patches = patches.flatten(start_dim=0, end_dim=k - 1) output_shape_ = list(patches.shape[1:]) for dim, size in zip(self.dim, self.domain_size, strict=True): output_shape_[dim] = size @@ -134,7 +146,7 @@ def _adjoint_scatter(self, patches: torch.Tensor) -> torch.Tensor: dim=self.dim, stride=self.stride, dilation=self.dilation, - ).flatten(start_dim=0, end_dim=len(self.dim) - 1) + ).flatten(start_dim=0, end_dim=k - 1) if windowed_indices.shape[0] != patches.shape[0]: raise ValueError( f'Number of patches {patches.shape[0]} does not match the number of ' diff --git a/tests/operators/test_patch_op.py b/tests/operators/test_patch_op.py index 6aeaf8868..49395f10c 100644 --- a/tests/operators/test_patch_op.py +++ b/tests/operators/test_patch_op.py @@ -1,6 +1,7 @@ """Tests for Rearrange Operator.""" from collections.abc import Sequence +from typing import Any import pytest import torch @@ -15,13 +16,14 @@ ((3, 4, 5), {'dim': (0, 1), 'patch_size': (1, 3), 'stride': (3, 1), 'dilation': (2, 1)}, (2, 1, 3, 5)), ((1, 20), {'dim': -1, 'patch_size': 3, 'stride': 3, 'dilation': 5}, (4, 1, 3)), ((9, 16), {'dim': (-1, 0), 'patch_size': (2, 3), 'stride': (2, 3), 'dilation': 1}, (24, 3, 2)), + ((9, 16), {'dim': (-1, 0), 'patch_size': (2, 3), 'stride': None, 'flatten_patches': False}, (8, 3, 3, 2)), ], ) @TESTCASES def test_patch_op_adjointness( - input_shape: Sequence[int], arguments: dict[str, int | Sequence[int]], output_shape: Sequence[int] + input_shape: Sequence[int], arguments: dict[str, Any], output_shape: Sequence[int] ) -> None: """Test adjointness and shape of Rearrange Op.""" rng = RandomGenerator(seed=0) @@ -37,9 +39,7 @@ def test_patch_op_adjointness( @TESTCASES -def test_patch_op_autodiff( - input_shape: Sequence[int], arguments: dict[str, int | Sequence[int]], output_shape: Sequence[int] -) -> None: +def test_patch_op_autodiff(input_shape: Sequence[int], arguments: dict[str, Any], output_shape: Sequence[int]) -> None: """Test autodiff works for PatchOp.""" rng = RandomGenerator(seed=0) u = rng.complex64_tensor(size=input_shape)