Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
af28e1e
affine for ND still to fix 4D
markak47 Nov 16, 2025
6798efb
Fix Normalize shape (npx expo start -c
markak47 Nov 16, 2025
cb9b523
changed nii.gz into heat datasets with liscence file removed unnecess…
markak47 Dec 5, 2025
4bbd12e
changed nii.gz into heat datasets with liscence file ,removed unneces…
markak47 Dec 5, 2025
188bc5c
changed nii.gz into heat datasets with liscence file ,removed unneces…
markak47 Dec 5, 2025
d986169
added documentation
markak47 Dec 14, 2025
2cda26b
added documentation to mri scan
markak47 Dec 14, 2025
970c25c
removed duplicate file
markak47 Dec 14, 2025
681ba56
chore: trigger copilot review
markak47 Dec 14, 2025
2a29fda
chore: trigger copilot review
markak47 Dec 14, 2025
20ec2d9
updated test cases to also test for all valid split
markak47 Dec 14, 2025
74f37fa
examples
markak47 Dec 14, 2025
144a884
fixed mpirun + new test_cube.py example file
markak47 Jan 4, 2026
7a6fd1c
MPI-safe + raise error, TODO Halo exchange for big data sets
markak47 Jan 16, 2026
16e0314
halo exchange testing
markak47 Jan 22, 2026
d930852
identity + translation scaling yet to do
markak47 Jan 24, 2026
99be835
distrubted in space
markak47 Jan 27, 2026
eb81dd3
doc string added
markak47 Jan 27, 2026
cf10b33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2026
bebdcd1
Merge branch 'main' into features/1900-Image_transformations_beyond_FFT
ClaudiaComito Feb 3, 2026
2e245b7
test
markak47 Feb 3, 2026
0bc02e9
test
markak47 Feb 3, 2026
33e8c79
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2026
d93343c
intitial implementation of affine_transform using affine_grid and gri…
Jun 8, 2026
39c19b6
Merge remote-tracking branch 'upstream' into features/1900-Image_tran…
brownbaerchen Jun 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions examples/ndimages/affine_2d_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
Example for 2D images with 3 operations popup view
"""

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageFile
import heat as ht
from heat.ndimage.affine import affine_transform
import scipy.ndimage as dnimg

# ------------------------------------------------------------
# Load RGB image
# ------------------------------------------------------------
img: ImageFile = Image.open(
"/home/leonk/projects/heat/examples/ndimages/test_images/jason-leung-Iwlo4RuPefM-unsplash_small.jpg"
).convert("RGB")

img_np = np.asarray(img, dtype=np.float32) #
print(f"shape of image as numpy array {img_np.shape}") # HWC

# HWC
x = ht.array(img_np) # HWC
print(f"shape of image converted from numpy to heat array {x.shape}")

H, W = img_np.shape[:2]
cx, cy = W / 2, H / 2 # NOTE: (x, y)
img_center = np.array([cx, cy, 0], dtype=np.float32)
# img_center = np.array([0, 0, 0], dtype=np.float32)

fig, axs = plt.subplots(5, 2, figsize=(10, 16))
axs = axs.ravel()


# ------------------------------------------------------------
# Helpers
# ------------------------------------------------------------
def centered_linear(A_xy):
"""
Build a 3×4 affine matrix for a linear transform A applied about the image center.
"""
b = img_center - A_xy @ img_center
# b = img_center + A_xy @ (-img_center)
# b = np.array([0, 256, 0], dtype=np.float32)
return np.hstack([A_xy, b[:, None]]).astype(np.float32)


def apply(M: np.ndarray, title, idx, mode="nearest", constant_value=0.0):

print(f"matrix shape: {M.shape}")
print(M)
print(f"image shape: {x.shape}")

index = idx * 2

result = affine_transform(
x, ht.array(M), order=0, mode=mode, cval=constant_value, prefilter=True
)

compare = dnimg.affine_transform(
img_np, M, order=0, mode=mode, cval=constant_value, prefilter=True
)
axs[index].imshow(result.larray.permute(0, 1, 2).cpu().numpy().astype(np.uint8))
axs[index + 1].imshow(compare.astype(np.uint8))
axs[index].set_title(title)
axs[index].axis("off")
axs[index + 1].set_title("")
axs[index + 1].axis("off")
axs[index].scatter(img_center[0], img_center[1])
axs[index + 1].scatter(img_center[0], img_center[1])


# ------------------------------------------------------------
# Identity
# ------------------------------------------------------------
apply(np.eye(3, 4, dtype=np.float32), "Identity", 0)

# ------------------------------------------------------------
# Translation +100 px RIGHT (b = [tx, ty] = [100, 0])
# Tip: use mode="constant" to make the shift super obvious
# ------------------------------------------------------------
M_tr = np.eye(3, 4, dtype=np.float32)
M_tr[:, 3] = [0, 256, 0] # (tx, ty, tz)
apply(M_tr, "Translate +100px (right)", 1, mode="constant", constant_value=0.0)

# ------------------------------------------------------------
# Rotate 30° around center (in x,y coords)
# ------------------------------------------------------------
theta = np.deg2rad(30)
A_rot = np.array(
[[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]],
dtype=np.float32,
)
apply(centered_linear(A_rot), "Rotate 30°", 2)

# ------------------------------------------------------------
# Scale ×1.5 around center (in x,y coords)
# ------------------------------------------------------------
A_scale = np.array([[0.75, 0, 0], [0, 1.5, 0], [0, 0, 1]], dtype=np.float32)
apply(centered_linear(A_scale), "Scale ×1.5", 3)

# ------------------------------------------------------------
# Combo: centered (scale→rotate) + then translate (tx,ty)
# ------------------------------------------------------------
A_combo = A_rot @ A_scale
t = np.array([100, -50, 0], dtype=np.float32) # (tx, ty)
b_combo = img_center - A_combo @ img_center + t
M_combo = np.hstack([A_combo, b_combo[:, None]]).astype(np.float32)
apply(M_combo, "Combo", 4)

# ------------------------------------------------------------
# Hide unused subplot
# ------------------------------------------------------------

plt.tight_layout()
plt.show()
148 changes: 148 additions & 0 deletions examples/ndimages/affine_3d_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
Non-distributed affine demo on a NIfTI volume (Heat).

Applies:
- 2D rotation (centered)
- 2D scaling (centered)
- 2D translation
- 2D shear
- 3D rotation (centered)

Handles Heat channel dimensions correctly.
"""

import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import heat as ht

from heat.ndimage.affine import affine_transform


# ============================================================
# Helpers
# ============================================================

def centered_linear_2d(A, H, W):
"""2×3 affine around image center (y, x)."""
c = np.array([H / 2, W / 2], dtype=np.float32)
b = c - A @ c
return np.hstack([A, b[:, None]]).astype(np.float32)


def centered_linear_3d(A, D, H, W):
"""3×4 affine around volume center (z, y, x)."""
c = np.array([D / 2, H / 2, W / 2], dtype=np.float32)
b = c - A @ c
return np.hstack([A, b[:, None]]).astype(np.float32)


def show(title, img):
"""Safe grayscale visualization."""
if img.ndim == 3 and img.shape[0] == 1:
img = img[0]
plt.imshow(img, cmap="gray")
plt.title(title)
plt.axis("off")


# ============================================================
# Load MRI
# ============================================================

nii = nib.load(
"PATH"
)
x_np = nii.get_fdata().astype(np.float32)

print("Loaded MRI:", x_np.shape)

# Heat array (NO split)
x = ht.array(x_np)

D, H, W = x_np.shape
mid = D // 2

# 2D middle slice
slice2d = x[mid]


# ============================================================
# Plot
# ============================================================

plt.figure(figsize=(12, 8))

# ------------------------------------------------------------
# Original
# ------------------------------------------------------------
plt.subplot(2, 3, 1)
show("Original (middle slice)", slice2d.larray.cpu().numpy())

# ------------------------------------------------------------
# Rotate 20° (2D)
# ------------------------------------------------------------
theta = np.deg2rad(20)
A_rot = np.array(
[[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]],
dtype=np.float32,
)
M_rot = centered_linear_2d(A_rot, H, W)
y_rot = affine_transform(slice2d, M_rot, order=1)

plt.subplot(2, 3, 2)
show("Rotate 20°", y_rot.larray.cpu().numpy())

# ------------------------------------------------------------
# Scale ×1.2
# ------------------------------------------------------------
A_scale = np.array([[1.2, 0], [0, 1.2]], dtype=np.float32)
M_scale = centered_linear_2d(A_scale, H, W)
y_scale = affine_transform(slice2d, M_scale, order=1)

plt.subplot(2, 3, 3)
show("Scale ×1.2", y_scale.larray.cpu().numpy())

# ------------------------------------------------------------
# Translate (+20, −20)
# NOTE: backward warping → negate translation
# ------------------------------------------------------------
M_tr = np.eye(2, 3, dtype=np.float32)
M_tr[:, 2] = [-20, 20]
y_tr = affine_transform(slice2d, M_tr, order=1)

plt.subplot(2, 3, 4)
show("Translate (+20, −20)", y_tr.larray.cpu().numpy())

# ------------------------------------------------------------
# Shear (0.3)
# ------------------------------------------------------------
A_shear = np.array([[1, 0.3], [0, 1]], dtype=np.float32)
M_shear = centered_linear_2d(A_shear, H, W)
y_shear = affine_transform(slice2d, M_shear, order=1)

plt.subplot(2, 3, 5)
show("Shear (0.3)", y_shear.larray.cpu().numpy())

# ------------------------------------------------------------
# 3D rotation around Z-axis (35°)
# ------------------------------------------------------------
theta3 = np.deg2rad(35)
A3 = np.array(
[[1, 0, 0],
[0, np.cos(theta3), -np.sin(theta3)],
[0, np.sin(theta3), np.cos(theta3)]],
dtype=np.float32,
)
M3 = centered_linear_3d(A3, D, H, W)
y3 = affine_transform(x, M3, order=1)

# REMOVE channel dimension before slicing
vol3 = y3.larray.squeeze(0) # (D, H, W)

plt.subplot(2, 3, 6)
show("3D Rotation (Z-axis 35°)", vol3[mid].cpu().numpy())

plt.tight_layout()
plt.show()
Loading
Loading