Skip to content

Commit 8912fc7

Browse files
committed
Bump version
1 parent 5554684 commit 8912fc7

2 files changed

Lines changed: 36 additions & 1 deletion

File tree

nodes.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,39 @@ def hslerp(a, b, t):
6767

6868
return result
6969

70+
def stable_slerp(a, b, t: float, eps: float = 1e-6):
71+
"""
72+
Numerically stable spherical linear interpolation over the channel dimension.
73+
74+
Treat each BCHW location's C-vector as a point on a hypersphere and SLERP from a->b.
75+
Falls back to LERP when the angle is very small or vectors are near-zero.
76+
"""
77+
if a.shape != b.shape:
78+
raise ValueError("Input tensors a and b must have the same shape.")
79+
80+
# Norms across channel dimension
81+
a_norm = torch.linalg.norm(a, dim=1, keepdim=True).clamp_min(eps)
82+
b_norm = torch.linalg.norm(b, dim=1, keepdim=True).clamp_min(eps)
83+
a_n = a / a_norm
84+
b_n = b / b_norm
85+
86+
# Cosine of angle between vectors
87+
dot = (a_n * b_n).sum(dim=1, keepdim=True).clamp(-1.0 + eps, 1.0 - eps)
88+
theta = torch.acos(dot)
89+
sin_theta = torch.sin(theta).clamp_min(eps)
90+
91+
# Scalar t is expected; keep broadcast-friendly
92+
s0 = torch.sin((1.0 - t) * theta) / sin_theta
93+
s1 = torch.sin(t * theta) / sin_theta
94+
95+
slerp_out = s0 * a + s1 * b
96+
lerp_out = (1.0 - t) * a + t * b
97+
98+
# When angle is too small, prefer LERP to avoid instabilities
99+
use_lerp = (theta < 1e-3).squeeze(1)
100+
out = torch.where(use_lerp.unsqueeze(1), lerp_out, slerp_out)
101+
return out
102+
70103
blending_modes = {
71104
# Args:
72105
# - a (tensor): Latent input 1
@@ -84,6 +117,8 @@ def hslerp(a, b, t):
84117
# Interpolates between tensors a and b using normalized linear interpolation,
85118
# with a twist when t is greater than or equal to 0.5.
86119
'hslerp': hslerp,
120+
# Numerically stable SLERP over channel vectors
121+
'stable_slerp': stable_slerp,
87122
# Adds tensor b to tensor a, scaled by t.
88123
'inject': lambda a, b, t: a + b * t,
89124
# Interpolates between tensors a and b using linear interpolation.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "freeu_advanced"
33
description = "This custom node provides advanced settings for FreeU."
4-
version = "1.0.0"
4+
version = "1.1.0"
55
license = { file = "LICENSE" }
66

77
[project.urls]

0 commit comments

Comments
 (0)