diff --git a/examples/ndimages/affine_2d_test.py b/examples/ndimages/affine_2d_test.py new file mode 100644 index 0000000000..7b335f0efd --- /dev/null +++ b/examples/ndimages/affine_2d_test.py @@ -0,0 +1,116 @@ +""" +Example for 2D images with 3 operations popup view +""" + +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image, ImageFile +import heat as ht +from heat.ndimage.affine import affine_transform +import scipy.ndimage as dnimg + +# ------------------------------------------------------------ +# Load RGB image +# ------------------------------------------------------------ +img: ImageFile = Image.open( + "/home/leonk/projects/heat/examples/ndimages/test_images/jason-leung-Iwlo4RuPefM-unsplash_small.jpg" +).convert("RGB") + +img_np = np.asarray(img, dtype=np.float32) # +print(f"shape of image as numpy array {img_np.shape}") # HWC + +# HWC +x = ht.array(img_np) # HWC +print(f"shape of image converted from numpy to heat array {x.shape}") + +H, W = img_np.shape[:2] +cx, cy = W / 2, H / 2 # NOTE: (x, y) +img_center = np.array([cx, cy, 0], dtype=np.float32) +# img_center = np.array([0, 0, 0], dtype=np.float32) + +fig, axs = plt.subplots(5, 2, figsize=(10, 16)) +axs = axs.ravel() + + +# ------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------ +def centered_linear(A_xy): + """ + Build a 3×4 affine matrix for a linear transform A applied about the image center. + """ + b = img_center - A_xy @ img_center + # b = img_center + A_xy @ (-img_center) + # b = np.array([0, 256, 0], dtype=np.float32) + return np.hstack([A_xy, b[:, None]]).astype(np.float32) + + +def apply(M: np.ndarray, title, idx, mode="nearest", constant_value=0.0): + + print(f"matrix shape: {M.shape}") + print(M) + print(f"image shape: {x.shape}") + + index = idx * 2 + + result = affine_transform( + x, ht.array(M), order=0, mode=mode, cval=constant_value, prefilter=True + ) + + compare = dnimg.affine_transform( + img_np, M, order=0, mode=mode, cval=constant_value, prefilter=True + ) + axs[index].imshow(result.larray.permute(0, 1, 2).cpu().numpy().astype(np.uint8)) + axs[index + 1].imshow(compare.astype(np.uint8)) + axs[index].set_title(title) + axs[index].axis("off") + axs[index + 1].set_title("") + axs[index + 1].axis("off") + axs[index].scatter(img_center[0], img_center[1]) + axs[index + 1].scatter(img_center[0], img_center[1]) + + +# ------------------------------------------------------------ +# Identity +# ------------------------------------------------------------ +apply(np.eye(3, 4, dtype=np.float32), "Identity", 0) + +# ------------------------------------------------------------ +# Translation +100 px RIGHT (b = [tx, ty] = [100, 0]) +# Tip: use mode="constant" to make the shift super obvious +# ------------------------------------------------------------ +M_tr = np.eye(3, 4, dtype=np.float32) +M_tr[:, 3] = [0, 256, 0] # (tx, ty, tz) +apply(M_tr, "Translate +100px (right)", 1, mode="constant", constant_value=0.0) + +# ------------------------------------------------------------ +# Rotate 30° around center (in x,y coords) +# ------------------------------------------------------------ +theta = np.deg2rad(30) +A_rot = np.array( + [[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]], + dtype=np.float32, +) +apply(centered_linear(A_rot), "Rotate 30°", 2) + +# ------------------------------------------------------------ +# Scale ×1.5 around center (in x,y coords) +# ------------------------------------------------------------ +A_scale = np.array([[0.75, 0, 0], [0, 1.5, 0], [0, 0, 1]], dtype=np.float32) +apply(centered_linear(A_scale), "Scale ×1.5", 3) + +# ------------------------------------------------------------ +# Combo: centered (scale→rotate) + then translate (tx,ty) +# ------------------------------------------------------------ +A_combo = A_rot @ A_scale +t = np.array([100, -50, 0], dtype=np.float32) # (tx, ty) +b_combo = img_center - A_combo @ img_center + t +M_combo = np.hstack([A_combo, b_combo[:, None]]).astype(np.float32) +apply(M_combo, "Combo", 4) + +# ------------------------------------------------------------ +# Hide unused subplot +# ------------------------------------------------------------ + +plt.tight_layout() +plt.show() diff --git a/examples/ndimages/affine_3d_test.py b/examples/ndimages/affine_3d_test.py new file mode 100644 index 0000000000..8bc7ab3936 --- /dev/null +++ b/examples/ndimages/affine_3d_test.py @@ -0,0 +1,148 @@ +""" +Non-distributed affine demo on a NIfTI volume (Heat). + +Applies: +- 2D rotation (centered) +- 2D scaling (centered) +- 2D translation +- 2D shear +- 3D rotation (centered) + +Handles Heat channel dimensions correctly. +""" + +import numpy as np +import nibabel as nib +import matplotlib.pyplot as plt +import heat as ht + +from heat.ndimage.affine import affine_transform + + +# ============================================================ +# Helpers +# ============================================================ + +def centered_linear_2d(A, H, W): + """2×3 affine around image center (y, x).""" + c = np.array([H / 2, W / 2], dtype=np.float32) + b = c - A @ c + return np.hstack([A, b[:, None]]).astype(np.float32) + + +def centered_linear_3d(A, D, H, W): + """3×4 affine around volume center (z, y, x).""" + c = np.array([D / 2, H / 2, W / 2], dtype=np.float32) + b = c - A @ c + return np.hstack([A, b[:, None]]).astype(np.float32) + + +def show(title, img): + """Safe grayscale visualization.""" + if img.ndim == 3 and img.shape[0] == 1: + img = img[0] + plt.imshow(img, cmap="gray") + plt.title(title) + plt.axis("off") + + +# ============================================================ +# Load MRI +# ============================================================ + +nii = nib.load( + "PATH" +) +x_np = nii.get_fdata().astype(np.float32) + +print("Loaded MRI:", x_np.shape) + +# Heat array (NO split) +x = ht.array(x_np) + +D, H, W = x_np.shape +mid = D // 2 + +# 2D middle slice +slice2d = x[mid] + + +# ============================================================ +# Plot +# ============================================================ + +plt.figure(figsize=(12, 8)) + +# ------------------------------------------------------------ +# Original +# ------------------------------------------------------------ +plt.subplot(2, 3, 1) +show("Original (middle slice)", slice2d.larray.cpu().numpy()) + +# ------------------------------------------------------------ +# Rotate 20° (2D) +# ------------------------------------------------------------ +theta = np.deg2rad(20) +A_rot = np.array( + [[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]], + dtype=np.float32, +) +M_rot = centered_linear_2d(A_rot, H, W) +y_rot = affine_transform(slice2d, M_rot, order=1) + +plt.subplot(2, 3, 2) +show("Rotate 20°", y_rot.larray.cpu().numpy()) + +# ------------------------------------------------------------ +# Scale ×1.2 +# ------------------------------------------------------------ +A_scale = np.array([[1.2, 0], [0, 1.2]], dtype=np.float32) +M_scale = centered_linear_2d(A_scale, H, W) +y_scale = affine_transform(slice2d, M_scale, order=1) + +plt.subplot(2, 3, 3) +show("Scale ×1.2", y_scale.larray.cpu().numpy()) + +# ------------------------------------------------------------ +# Translate (+20, −20) +# NOTE: backward warping → negate translation +# ------------------------------------------------------------ +M_tr = np.eye(2, 3, dtype=np.float32) +M_tr[:, 2] = [-20, 20] +y_tr = affine_transform(slice2d, M_tr, order=1) + +plt.subplot(2, 3, 4) +show("Translate (+20, −20)", y_tr.larray.cpu().numpy()) + +# ------------------------------------------------------------ +# Shear (0.3) +# ------------------------------------------------------------ +A_shear = np.array([[1, 0.3], [0, 1]], dtype=np.float32) +M_shear = centered_linear_2d(A_shear, H, W) +y_shear = affine_transform(slice2d, M_shear, order=1) + +plt.subplot(2, 3, 5) +show("Shear (0.3)", y_shear.larray.cpu().numpy()) + +# ------------------------------------------------------------ +# 3D rotation around Z-axis (35°) +# ------------------------------------------------------------ +theta3 = np.deg2rad(35) +A3 = np.array( + [[1, 0, 0], + [0, np.cos(theta3), -np.sin(theta3)], + [0, np.sin(theta3), np.cos(theta3)]], + dtype=np.float32, +) +M3 = centered_linear_3d(A3, D, H, W) +y3 = affine_transform(x, M3, order=1) + +# REMOVE channel dimension before slicing +vol3 = y3.larray.squeeze(0) # (D, H, W) + +plt.subplot(2, 3, 6) +show("3D Rotation (Z-axis 35°)", vol3[mid].cpu().numpy()) + +plt.tight_layout() +plt.show() diff --git a/examples/ndimages/run_affine_on_nifti.py b/examples/ndimages/run_affine_on_nifti.py new file mode 100644 index 0000000000..82e031aad0 --- /dev/null +++ b/examples/ndimages/run_affine_on_nifti.py @@ -0,0 +1,136 @@ +import numpy as np +import matplotlib.pyplot as plt +import nibabel as nib +import heat as ht +from mpi4py import MPI + +from heat.ndimage.affine import affine_transform + + +# ============================================================ +# Helpers +# ============================================================ + +def canonicalize_to_ZHW(t): + """ + Force tensor to shape (Z, H, W) + """ + while t.ndim > 3: + t = t.squeeze(0) + + if t.ndim == 2: + t = t.unsqueeze(0) + + if t.ndim != 3: + raise RuntimeError(f"Unexpected tensor shape: {t.shape}") + + return t + + +def strongest_slice(vol): + """ + vol: torch.Tensor (Z, H, W) + """ + scores = vol.abs().amax(dim=(1, 2)) + score, idx = scores.max(dim=0) + return int(idx.item()), float(score.item()) + + +def apply_affine(x, M): + y = affine_transform(x, M, order=0, mode="nearest") + y_local = canonicalize_to_ZHW(y.larray) + return y_local + + +# ============================================================ +# MAIN +# ============================================================ + +def main(): + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + + # -------------------------------------------------------- + # Load MRI on rank 0 + # -------------------------------------------------------- + if rank == 0: + nii = nib.load( + "PATH" + ) + vol = nii.get_fdata().astype(np.float32) + print(f"[rank 0] Loaded MRI: {vol.shape}", flush=True) + else: + vol = None + + vol = comm.bcast(vol, root=0) + + # -------------------------------------------------------- + # Heat array (distributed over Z) + # -------------------------------------------------------- + x = ht.array(vol, split=0) + print(f"[rank {rank}] local input shape = {x.larray.shape}", flush=True) + + comm.Barrier() + + # ======================================================== + # Define affine transforms (SPACE-based) + # ======================================================== + M_identity = np.eye(3, 4, dtype=np.float32) + + M_scale = np.eye(3, 4, dtype=np.float32) + M_scale[0, 0] = 1.2 + M_scale[1, 1] = 1.2 + M_scale[2, 2] = 1.2 + + M_translate = np.eye(3, 4, dtype=np.float32) + M_translate[0, 3] = 10.0 # +10 in Z (SPACE) + + cases = [ + ("Identity", M_identity), + ("Scale ×1.2", M_scale), + ("Translate +10 Z", M_translate), + ] + + # ======================================================== + # Apply all transforms + # ======================================================== + results = [] + + for name, M in cases: + y_local = apply_affine(x, M) + idx, score = strongest_slice(y_local) + + print( + f"[rank {rank}] {name}: strongest slice idx={idx}, score={score:.3e}", + flush=True, + ) + + results.append((name, y_local, idx, score)) + + # ======================================================== + # Visualization — ONE WINDOW PER RANK + # ======================================================== + fig, axs = plt.subplots(1, 3, figsize=(15, 5)) + fig.suptitle(f"Rank {rank} — SPACE affine operations", fontsize=14) + + for ax, (name, vol_local, idx, score) in zip(axs, results): + if score < 1e-3: + ax.set_title(f"{name}\nEMPTY") + ax.axis("off") + continue + + ax.imshow(vol_local[idx].cpu().numpy(), cmap="gray") + ax.set_title(f"{name}\nslice {idx}") + ax.axis("off") + + plt.tight_layout() + plt.show() + + comm.Barrier() + + if rank == 0: + print("\nDONE — multi-operation SPACE-affine demo completed\n", flush=True) + + +if __name__ == "__main__": + main() diff --git a/examples/ndimages/sample.jpg b/examples/ndimages/sample.jpg new file mode 100644 index 0000000000..87d6d9c04f Binary files /dev/null and b/examples/ndimages/sample.jpg differ diff --git a/examples/ndimages/test_images/jason-leung-Iwlo4RuPefM-unsplash.jpg b/examples/ndimages/test_images/jason-leung-Iwlo4RuPefM-unsplash.jpg new file mode 100644 index 0000000000..519c341f41 Binary files /dev/null and b/examples/ndimages/test_images/jason-leung-Iwlo4RuPefM-unsplash.jpg differ diff --git a/examples/ndimages/test_images/jason-leung-Iwlo4RuPefM-unsplash_small.jpg b/examples/ndimages/test_images/jason-leung-Iwlo4RuPefM-unsplash_small.jpg new file mode 100644 index 0000000000..956d81e804 Binary files /dev/null and b/examples/ndimages/test_images/jason-leung-Iwlo4RuPefM-unsplash_small.jpg differ diff --git a/examples/ndimages/test_images/sample.jpg b/examples/ndimages/test_images/sample.jpg new file mode 100644 index 0000000000..87d6d9c04f Binary files /dev/null and b/examples/ndimages/test_images/sample.jpg differ diff --git a/examples/ndimages/view_mri_scroll.py b/examples/ndimages/view_mri_scroll.py new file mode 100644 index 0000000000..03cf8cf01f --- /dev/null +++ b/examples/ndimages/view_mri_scroll.py @@ -0,0 +1,78 @@ +""" +viewing the original data with scroll mech +""" + +import os +import nibabel as nib +import matplotlib.pyplot as plt + +# ============================================================ +# Paths (ONLY files that actually exist) +# ============================================================ + +paths = { + "Original": "PATH", + +} + +# ============================================================ +# Load volumes safely +# ============================================================ + +volumes = {} +for name, path in paths.items(): + if not os.path.exists(path): + print(f"[SKIP] {name}: file not found") + continue + + vol = nib.load(path).get_fdata() + volumes[name] = vol + print(f"[LOAD] {name}: shape={vol.shape}") + +if not volumes: + raise RuntimeError("No volumes loaded") + +# ============================================================ +# Setup figure +# ============================================================ + +titles = list(volumes.keys()) +data = list(volumes.values()) +depths = [v.shape[0] for v in data] + +slice_indices = [d // 2 for d in depths] # one index per volume + +fig, axes = plt.subplots(1, len(data), figsize=(4 * len(data), 5)) +if len(data) == 1: + axes = [axes] + +images = [] + +for ax, title, vol, idx in zip(axes, titles, data, slice_indices): + img = ax.imshow(vol[idx]) + ax.set_title(f"{title}\nslice {idx}") + ax.axis("off") + images.append(img) + +fig.suptitle("Independent slice scrolling per volume") + +# ============================================================ +# Keyboard navigation (ALL volumes together) +# ============================================================ + +def on_key(event): + for i, vol in enumerate(data): + if event.key == "up": + slice_indices[i] = min(slice_indices[i] + 1, vol.shape[0] - 1) + elif event.key == "down": + slice_indices[i] = max(slice_indices[i] - 1, 0) + else: + return + + images[i].set_data(vol[slice_indices[i]]) + axes[i].set_title(f"{titles[i]}\nslice {slice_indices[i]}") + + fig.canvas.draw_idle() + +fig.canvas.mpl_connect("key_press_event", on_key) +plt.show() diff --git a/heat/datasets/flair.nii.gz b/heat/datasets/flair.nii.gz new file mode 100644 index 0000000000..0764a0f2de Binary files /dev/null and b/heat/datasets/flair.nii.gz differ diff --git a/heat/datasets/mri_sample_LICENSE.txt b/heat/datasets/mri_sample_LICENSE.txt new file mode 100644 index 0000000000..22bc3e69c8 --- /dev/null +++ b/heat/datasets/mri_sample_LICENSE.txt @@ -0,0 +1,10 @@ +MIT License + +Copyright (c) 2018 Adam Wolf + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this file to deal in the file without restriction, including without +limitation the rights to use, copy, modify, merge, publish, distribute, +sublicense, and/or sell copies of the file. + +THE FILE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND. diff --git a/heat/ndimage/affine.py b/heat/ndimage/affine.py new file mode 100644 index 0000000000..371e9ad5e6 --- /dev/null +++ b/heat/ndimage/affine.py @@ -0,0 +1,361 @@ +""" +Affine transformations for N-dimensional Heat arrays. + +This module implements backward-warping affine transformations +(translation, rotation, scaling) for 2D and 3D data stored as +Heat DNDarrays, using a PyTorch backend. + +The affine matrix M is interpreted as a *forward* transform +in affine (x, y [, z]) coordinates: + + out = A @ inp + b + +where M = [A | b] has shape (ND, ND+1). + +Internally, backward warping is used for resampling: + + inp = A^{-1} @ (out - b) + +Spatial axis conventions in Heat: +- 2D arrays: (H, W) == (y, x) +- 3D arrays: (D, H, W) == (z, y, x) + +Interpolation and boundary handling: +- order=0: nearest-neighbor +- order=1: bilinear (2D only; 3D falls back to nearest) +- padding modes: 'nearest', 'wrap', 'reflect', 'constant' + +Distributed arrays: +- Non-spatial splits are handled locally without communication. +- Spatial splits support only simple, axis-aligned transforms + (e.g. translation or diagonal scaling). +- More general affine transforms (rotation/shear) would require + halo exchange and are intentionally not supported yet. + +The public entry point is `affine_transform`. +""" + +import numpy as np +import torch +from torch.nn.functional import affine_grid +from torch.nn.functional import grid_sample +import heat as ht +from mpi4py import MPI +import warnings +from ..core.dndarray import DNDarray +from ..core import factories +from ..core import manipulations +from ..core.linalg.basics import transpose + +MODE_TO_PADDING = { + # SciPy mode # torch padding_mode + "constant": "zeros", # fill with ``cval`` (default 0) + "reflect": "reflection", # reflect at the border + "mirror": "reflection", # also mapped to reflection + "nearest": "border", # replicate the edge pixel + # The following SciPy modes have no exact Torch counterpart. + # We keep them as ``None`` and raise an error if they are used. + "wrap": None, + "grid-wrap": None, + "grid-constant": None, +} + +ORDER_TO_MODE = { + 0: "nearest", # order‑0 → nearest‑neighbour + 1: "bilinear", # order‑1 → bilinear (linear) sampling + 3: "bicubic", # order‑3 → bicubic sampling + # SciPy supports orders 2,4,5 as well – they have no direct Torch counterpart. + # We fall back to the closest supported mode. + 2: "bilinear", # closest supported mode + 4: "bicubic", + 5: "bicubic", + # any other order (should never happen) → default to bilinear +} + +filtering_map = {} + + +# ============================================================ +# Helper utilities +# ============================================================ +def _remove_slice(A: torch.Tensor, idx: int, dim: int) -> torch.Tensor: + # Keep rows before and after the removed row + rows_before = torch.arange(0, idx) + print(f"{A.size(dim)=}") + print(f"{idx+1=}") + if (idx + 1) < A.size(dim): + rows_after = torch.arange(idx + 1, A.size(dim)) + print(f"{rows_after=}") + return torch.cat( + [A.index_select(dim, rows_before), A.index_select(dim, rows_after)], dim=dim + ) + else: + return A.index_select(dim, rows_before) + + +def to_full_affine(mat): + """ + Convert reduced affine matrices to full homogeneous form. + + Works with single matrices or batches of any dimensionality: + - (D, D+1) → (D+1, D+1) # single + - (N, D, D+1) → (N, D+1, D+1) # batch + + Args: + mat: Reduced affine tensor + + Returns + ------- + Full homogeneous affine tensor + """ + # Detect if batched by checking number of dimensions + if mat.dim() == 2: + # Single matrix case: (D, D+1) + D = mat.shape[0] # spatial dimension + full = torch.zeros(D + 1, D + 1, dtype=mat.dtype, device=mat.device) + full[:D, :] = mat # copy top D rows + full[D, D] = 1.0 # set homogeneous coordinate + return full + + elif mat.dim() == 3: + # Batched case: (N, D, D+1) + N, D, _ = mat.shape + full = torch.zeros(N, D + 1, D + 1, dtype=mat.dtype, device=mat.device) + full[:, :D, :] = mat # copy top D rows for each batch + full[:, D, D] = 1.0 # set homogeneous coordinate for each batch + return full + + else: + raise ValueError(f"Expected 2D or 3D tensor, got {mat.dim()}D") + + +def _matrix_pixel_to_normalized_coords(M: torch.Tensor, sizes): + """ + Convert scipy affine matrix to PyTorch normalized coordinates. + + Args: + M: Torch Tensor of shape (D+1, D+1) — [a_ij | t_i] + sizes: image sizes [H, W, D, ...] of length D + + Returns + ------- + theta: torch affine matrix of shape (D, D+1) + """ + print() + print("START NORMALIZATION") + print() + D = len(sizes) + print(f"{D=}") + print(f"{sizes=}") + print(f"{M.shape}") + + # construct coord space transform + scales = (torch.as_tensor(sizes) - 1) / 2.0 + M_scales = torch.diag(scales) + D = len(sizes) + T_np = torch.zeros((D + 1, D + 1)) + T_np[:D, :D] = M_scales + T_np[D, D] = 1 + T_np[:D, D] = scales + T_pn = T_np.inverse() + print("construct coord transforms") + print(f"{T_np=}") + + full_transformed = T_pn @ M @ T_np + print(f"{full_transformed=}") + + print() + print("END NORMALIZATION") + print() + + return full_transformed[:, :D, :] + + +def _swap_rows_cols(A, row_pair, col_pair): + print() + print("swapping stuff") + print() + print(A) + i, j = row_pair + k, m = col_pair + + row_idx = torch.arange(A.size(0)) + row_idx[i] = j + row_idx[j] = i + + col_idx = torch.arange(A.size(1)) + col_idx[k] = m + col_idx[m] = k + A = A[row_idx[:, None], col_idx[None, :]] + print(A) + print() + print("end swapping stuff") + print() + return A + + +# ============================================================ +# main methods +# ============================================================ +def affine_transform( + input: DNDarray, + matrix: DNDarray, + offset=0.0, + output_shape=None, + output=None, + order=3, + mode="constant", + cval=0.0, + prefilter=True, +) -> DNDarray: + """ + Input is expected to have shape H x W x C because that is consistent with PIL+Numpy to get Image data + -> to be consistent with scipy the Matrix then has to be of shape 3x4 (3x3, 4x4 also valid) + """ + # TODO: Implement cases 3x3, 2x3, 2x2 + + matrix_torch: torch.Tensor + if input.ndim == 3: # 2d image where third dimension are color vectors + if matrix.shape == (3, 4): + # remove axis that represents transforming the color dimension, because + # torch does not support that + + matrix_torch = _remove_slice(matrix.larray, 2, 0) + matrix_torch = _remove_slice(matrix_torch, 2, 1) + matrix_torch = to_full_affine(matrix_torch) + matrix_torch = _swap_rows_cols(matrix_torch, (0, 1), (0, 1)) + elif matrix.shape == (4, 4): + # remove axis that represents transforming the color dimension, because + # torch does not support that + matrix_torch = _remove_slice(matrix.larray, 2, 0) + matrix_torch = _remove_slice(matrix_torch, 2, 1) + matrix_torch = _swap_rows_cols(matrix_torch, (0, 1), (0, 1)) + else: + raise NotImplementedError() + else: + raise ValueError("transform matrix has no valid shape") + + if matrix_torch.dim() == 2: + matrix_torch = matrix_torch.unsqueeze(0) + + # for now matrix has shape 3x3xB + + t_input = transpose(input, (2, 0, 1)) # to C x H x W + matrix_torch = _matrix_pixel_to_normalized_coords(matrix_torch, t_input.shape[1:]) + input_torch = t_input.larray.unsqueeze(0) + + sample_padding = MODE_TO_PADDING[mode] + sample_mode = ORDER_TO_MODE[order] + size = torch.Size((1, t_input.shape[0], t_input.shape[1], t_input.shape[2])) + print(f"{size=}") + sample_grid: torch.Tensor = affine_grid(matrix_torch, size) + + transformed = grid_sample( + input_torch, sample_grid, padding_mode=sample_padding, mode=sample_mode + ) + return ht.array(transformed.squeeze(0).permute(1, 2, 0)) + + +# ============================================================ +# Helper utilities (old, to be determined if still helpful) +# ============================================================ + + +def _is_identity_affine(M, ND): + """ + Check whether an affine matrix represents the identity transform. + + Parameters + ---------- + M : array-like + Affine matrix of shape (ND, ND+1). + ND : int + Number of spatial dimensions. + + Returns + ------- + bool + True if A is the identity matrix and b is zero. + """ + A = M[:, :ND] + b = M[:, ND:] + return np.allclose(A, np.eye(ND)) and np.allclose(b, 0) + + +def _normalize_input(x, ND): + """ + Normalize a Heat array to a unified internal layout. + + For sampling, inputs are reshaped to include synthetic + batch and channel dimensions: + + - 2D: (N, C, H, W) + - 3D: (N, C, D, H, W) + + These dimensions are internal only and do not imply + semantic batching or channels in the input data. + + Parameters + ---------- + x : ht.DNDarray + Input array. + ND : int + Number of spatial dimensions. + + Returns + ------- + torch.Tensor + Local torch tensor with added batch/channel dimensions. + tuple + Original shape of the input array. + """ + orig_shape = x.shape + t = x.larray + + if ND == 2: + if x.ndim == 2: + t = t.unsqueeze(0).unsqueeze(0) + elif x.ndim == 3: + t = t.unsqueeze(0) + else: + if x.ndim == 3: + t = t.unsqueeze(0).unsqueeze(0) + elif x.ndim == 4: + t = t.unsqueeze(0) + + return t, orig_shape + + +def _make_grid(spatial, device): + """ + Construct a coordinate grid in Heat spatial axis order. + + The grid contains integer coordinates corresponding to + output pixel locations and is later mapped through the + inverse affine transform. + + Parameters + ---------- + spatial : tuple + Spatial shape (H, W) or (D, H, W). + device : torch.device + Target device. + + Returns + ------- + torch.Tensor + Coordinate grid of shape (ND, *spatial) in Heat order. + """ + if len(spatial) == 2: + H, W = spatial + y = torch.arange(H, device=device) + x = torch.arange(W, device=device) + gy, gx = torch.meshgrid(y, x, indexing="ij") + return torch.stack([gy, gx], dim=0) + else: + D, H, W = spatial + z = torch.arange(D, device=device) + y = torch.arange(H, device=device) + x = torch.arange(W, device=device) + gz, gy, gx = torch.meshgrid(z, y, x, indexing="ij") + return torch.stack([gz, gy, gx], dim=0) diff --git a/heat/ndimage/tests/test_affine_transform_distributed.py b/heat/ndimage/tests/test_affine_transform_distributed.py new file mode 100644 index 0000000000..8625879ab3 --- /dev/null +++ b/heat/ndimage/tests/test_affine_transform_distributed.py @@ -0,0 +1,109 @@ +import numpy as np +import pytest +import heat as ht +from mpi4py import MPI +#TODO: need updating, method signature is not compatible anymore +# from heat.ndimage.affine import affine_transform + +# comm = MPI.COMM_WORLD +# rank = comm.Get_rank() +# size = comm.Get_size() + +# @pytest.mark.mpi +# def test_undistributed_affine_translation_backward(): +# """ +# Backward warping with nearest padding. + +# out[z, y, x] = in[z, y, x - 1] +# with x < 0 clamped to 0. +# """ +# data = np.arange(24, dtype=np.float32).reshape(4, 3, 2) +# x = ht.array(data, split=None) + +# M = np.array( +# [ +# [1, 0, 0, 1], +# [0, 1, 0, 0], +# [0, 0, 1, 0], +# ], +# dtype=np.float64, +# ) + +# y = affine_transform(x, M, order=0, mode="nearest").numpy() + +# # correct backward-warp reference +# ref = np.zeros_like(data) +# ref[:, :, 0] = data[:, :, 0] +# ref[:, :, 1] = data[:, :, 0] + +# assert np.allclose(y, ref) + + +# @pytest.mark.mpi +# def test_distributed_non_split_axis_translation_matches_undistributed(): +# data = np.arange(48, dtype=np.float32).reshape(6, 4, 2) + +# M = np.array( +# [ +# [1, 0, 0, 1], +# [0, 1, 0, 0], +# [0, 0, 1, 0], +# ], +# dtype=np.float64, +# ) + +# x_full = ht.array(data, split=None) +# y_ref = affine_transform(x_full, M, order=0).numpy() + +# x_dist = ht.array(data, split=0) +# y_dist = affine_transform(x_dist, M, order=0) + +# assert y_dist.split == 0 +# assert np.allclose(y_dist.resplit(None).numpy(), y_ref) + + +# @pytest.mark.mpi +# def test_split_axis_translation_supported_via_resplit(): +# data = np.zeros((8, 4, 4), dtype=np.float32) +# if rank == 0: +# data[1, 2, 2] = 1.0 +# data = comm.bcast(data, root=0) + +# x = ht.array(data, split=0) + +# # translate +3 along z +# M = np.array( +# [ +# [1, 0, 0, 0], +# [0, 1, 0, 0], +# [0, 0, 1, 3], +# ], +# dtype=np.float64, +# ) + +# y = affine_transform(x, M, order=0) + +# ref = affine_transform(ht.array(data, split=None), M, order=0).numpy() +# got = y.resplit(None).numpy() + +# assert np.allclose(got, ref) + + +# @pytest.mark.mpi +# def test_distributed_vs_undistributed_equivalence(): +# rng = np.random.default_rng(0) +# data = rng.normal(size=(8, 5, 4)).astype(np.float32) + +# M = np.array( +# [ +# [1, 0, 0, 1], +# [0, 1, 0, 0], +# [0, 0, 1, 0], +# ], +# dtype=np.float64, +# ) + +# y_ref = affine_transform(ht.array(data, split=None), M, order=0).numpy() +# y_dist = affine_transform(ht.array(data, split=0), M, order=0) + +# assert np.allclose(y_dist.resplit(None).numpy(), y_ref) diff --git a/pyproject.toml b/pyproject.toml index 3e4f5d0ea2..d3997cbe06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -210,3 +210,8 @@ convention = "numpy" [tool.ruff.format] docstring-code-format = true + +[tool.pytest.ini_options] +markers = [ + "mpi: tests that require mpirun / MPI execution", +]