diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index c10c459..1546833 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -17,6 +17,7 @@ "similarity_add_difference", "distribution_crossover", "ties_add_difference", + "add_perpendicular", ] @@ -209,3 +210,16 @@ def filter_top_k(a: Tensor, k: float): k_value, _ = torch.kthvalue(torch.abs(a.flatten()).float(), k) top_k_filter = (torch.abs(a) >= k_value).float() return a * top_k_filter + + +def add_perpendicular( + a: Tensor, b: Tensor, alpha: float, c: Tensor = None, **kwargs +) -> Tensor: + a_diff = a.float() - c.float() + b_diff = b.float() - c.float() + a_ortho = a_diff * (a_diff / torch.linalg.norm(a_diff) * (b_diff / torch.linalg.norm(a_diff))).sum() + b_perp = b_diff - a_ortho + res = a + alpha * b_perp + if torch.isnan(res).any(): + return a + return res.to(a.dtype)