@@ -67,6 +67,39 @@ def hslerp(a, b, t):
6767
6868 return result
6969
70+ def stable_slerp (a , b , t : float , eps : float = 1e-6 ):
71+ """
72+ Numerically stable spherical linear interpolation over the channel dimension.
73+
74+ Treat each BCHW location's C-vector as a point on a hypersphere and SLERP from a->b.
75+ Falls back to LERP when the angle is very small or vectors are near-zero.
76+ """
77+ if a .shape != b .shape :
78+ raise ValueError ("Input tensors a and b must have the same shape." )
79+
80+ # Norms across channel dimension
81+ a_norm = torch .linalg .norm (a , dim = 1 , keepdim = True ).clamp_min (eps )
82+ b_norm = torch .linalg .norm (b , dim = 1 , keepdim = True ).clamp_min (eps )
83+ a_n = a / a_norm
84+ b_n = b / b_norm
85+
86+ # Cosine of angle between vectors
87+ dot = (a_n * b_n ).sum (dim = 1 , keepdim = True ).clamp (- 1.0 + eps , 1.0 - eps )
88+ theta = torch .acos (dot )
89+ sin_theta = torch .sin (theta ).clamp_min (eps )
90+
91+ # Scalar t is expected; keep broadcast-friendly
92+ s0 = torch .sin ((1.0 - t ) * theta ) / sin_theta
93+ s1 = torch .sin (t * theta ) / sin_theta
94+
95+ slerp_out = s0 * a + s1 * b
96+ lerp_out = (1.0 - t ) * a + t * b
97+
98+ # When angle is too small, prefer LERP to avoid instabilities
99+ use_lerp = (theta < 1e-3 ).squeeze (1 )
100+ out = torch .where (use_lerp .unsqueeze (1 ), lerp_out , slerp_out )
101+ return out
102+
70103blending_modes = {
71104 # Args:
72105 # - a (tensor): Latent input 1
@@ -84,6 +117,8 @@ def hslerp(a, b, t):
84117 # Interpolates between tensors a and b using normalized linear interpolation,
85118 # with a twist when t is greater than or equal to 0.5.
86119 'hslerp' : hslerp ,
120+ # Numerically stable SLERP over channel vectors
121+ 'stable_slerp' : stable_slerp ,
87122 # Adds tensor b to tensor a, scaled by t.
88123 'inject' : lambda a , b , t : a + b * t ,
89124 # Interpolates between tensors a and b using linear interpolation.
0 commit comments