From 18bff992cc62a53997f194fa72ae8eeb5f6347e8 Mon Sep 17 00:00:00 2001 From: ljleb Date: Thu, 9 Nov 2023 23:17:27 -0500 Subject: [PATCH 01/44] oft --- sd_meh/merge_methods.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index c10c459..5e18df2 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -209,3 +209,14 @@ def filter_top_k(a: Tensor, k: float): k_value, _ = torch.kthvalue(torch.abs(a.flatten()).float(), k) top_k_filter = (torch.abs(a) >= k_value).float() return a * top_k_filter + + +def orthogonal_rotation(a: Tensor, b: Tensor, **kwargs): + if a.shape == (): + return a + + a_reshape = a.reshape(-1, a.shape[-1]).float() + b_reshape = b.reshape(-1, b.shape[-1]).float() + U, _, V = torch.svd(torch.matmul(a_reshape.T, b_reshape)) + Q = torch.matmul(U, V.T) + return torch.matmul(a_reshape, Q).reshape_as(a) From 47bdac9ed5d495a9b4dab108c7ca3d5b257abfad Mon Sep 17 00:00:00 2001 From: ljleb Date: Thu, 9 Nov 2023 23:20:35 -0500 Subject: [PATCH 02/44] device --- sd_meh/merge_methods.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 5e18df2..4e7c6ac 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -214,9 +214,8 @@ def filter_top_k(a: Tensor, k: float): def orthogonal_rotation(a: Tensor, b: Tensor, **kwargs): if a.shape == (): return a - a_reshape = a.reshape(-1, a.shape[-1]).float() b_reshape = b.reshape(-1, b.shape[-1]).float() U, _, V = torch.svd(torch.matmul(a_reshape.T, b_reshape)) Q = torch.matmul(U, V.T) - return torch.matmul(a_reshape, Q).reshape_as(a) + return torch.matmul(a_reshape, Q).reshape_as(a).to(dtype=a.dtype) From 06814587c9833ee8932eb6960ab295c222fcabd9 Mon Sep 17 00:00:00 2001 From: ljleb Date: Thu, 9 Nov 2023 23:36:09 -0500 Subject: [PATCH 03/44] rename --- sd_meh/merge_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 4e7c6ac..38a3b6c 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -211,7 +211,7 @@ def filter_top_k(a: Tensor, k: float): return a * top_k_filter -def orthogonal_rotation(a: Tensor, b: Tensor, **kwargs): +def rotate(a: Tensor, b: Tensor, **kwargs): if a.shape == (): return a a_reshape = a.reshape(-1, a.shape[-1]).float() From 1fe0882ae8fd9a4aa0ce40e9ec6f0ea727832978 Mon Sep 17 00:00:00 2001 From: ljleb Date: Thu, 9 Nov 2023 23:40:09 -0500 Subject: [PATCH 04/44] fix black --- sd_meh/rebasin.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sd_meh/rebasin.py b/sd_meh/rebasin.py index 2fbb418..010d67f 100644 --- a/sd_meh/rebasin.py +++ b/sd_meh/rebasin.py @@ -2200,11 +2200,9 @@ def apply_permutation(ps: PermutationSpec, perm, params): def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha): for k in model_a: try: - perm_params = get_permuted_param( - ps, perm, k, model_a - ) + perm_params = get_permuted_param(ps, perm, k, model_a) model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params - except RuntimeError: # dealing with pix2pix and inpainting models + except RuntimeError: # dealing with pix2pix and inpainting models continue return model_a From 2810d89bc71cb2ebfe69395d1896279d711d42d0 Mon Sep 17 00:00:00 2001 From: ljleb Date: Fri, 10 Nov 2023 00:30:34 -0500 Subject: [PATCH 05/44] cayley interpolation for alpha --- sd_meh/merge_methods.py | 15 ++++++++++----- sd_meh/utils.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 38a3b6c..292d013 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -211,11 +211,16 @@ def filter_top_k(a: Tensor, k: float): return a * top_k_filter -def rotate(a: Tensor, b: Tensor, **kwargs): - if a.shape == (): +def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): + if len(a.shape) <= 1: return a + + # make sure matrices are at most 2D a_reshape = a.reshape(-1, a.shape[-1]).float() b_reshape = b.reshape(-1, b.shape[-1]).float() - U, _, V = torch.svd(torch.matmul(a_reshape.T, b_reshape)) - Q = torch.matmul(U, V.T) - return torch.matmul(a_reshape, Q).reshape_as(a).to(dtype=a.dtype) + + u, _, v = torch.svd(torch.matmul(a_reshape.T, b_reshape)) + from .utils import interpolate_cayley + transform = interpolate_cayley(torch.matmul(u, v.T), alpha) + + return torch.matmul(a_reshape, transform).reshape_as(a).to(dtype=a.dtype) diff --git a/sd_meh/utils.py b/sd_meh/utils.py index f507ae8..c361923 100644 --- a/sd_meh/utils.py +++ b/sd_meh/utils.py @@ -1,5 +1,6 @@ import inspect import logging +import torch from sd_meh import merge_methods from sd_meh.merge import NUM_TOTAL_BLOCKS @@ -124,3 +125,34 @@ def weights_and_bases( bases |= bases_beta return weights, bases + + +def is_orthogonal(matrix, tol=1e-6): + identity = torch.eye(matrix.size(0), dtype=matrix.dtype, device=matrix.device) + matrix_transpose = matrix.t() + return torch.allclose(matrix @ matrix_transpose, identity, atol=tol) and \ + torch.allclose(matrix_transpose @ matrix, identity, atol=tol) + + +# Cayley transform of matrix A +def cayley_transform(A): + I = torch.eye(A.size(0), dtype=A.dtype, device=A.device) + return (I - A) @ torch.inverse(I + A) + + +# Inverse Cayley transform +def inverse_cayley_transform(X): + I = torch.eye(X.size(0), dtype=X.dtype, device=X.device) + return torch.inverse(I - X) @ (I + X) + + +# Interpolate between identity and orthogonal matrix using Cayley transform +def interpolate_cayley(m, t): + # Cayley transform of A + X = cayley_transform(m) + + # Scale X + Y = X * t + + # Inverse Cayley transform + return inverse_cayley_transform(Y) From 09bed888cde6e3e77e9e72e44972201c21660da4 Mon Sep 17 00:00:00 2001 From: ljleb Date: Fri, 10 Nov 2023 00:34:54 -0500 Subject: [PATCH 06/44] refact --- sd_meh/merge_methods.py | 8 +++++++- sd_meh/utils.py | 31 ------------------------------- 2 files changed, 7 insertions(+), 32 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 292d013..8d908fe 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -220,7 +220,13 @@ def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): b_reshape = b.reshape(-1, b.shape[-1]).float() u, _, v = torch.svd(torch.matmul(a_reshape.T, b_reshape)) - from .utils import interpolate_cayley transform = interpolate_cayley(torch.matmul(u, v.T), alpha) return torch.matmul(a_reshape, transform).reshape_as(a).to(dtype=a.dtype) + + +def interpolate_cayley(matrix: Tensor, power: float): + identity = torch.eye(matrix.size(0), dtype=matrix.dtype, device=matrix.device) + cayley = (identity - matrix) @ torch.inverse(identity + matrix) + scaled_cayley = cayley * power + return torch.inverse(identity - scaled_cayley) @ (identity + scaled_cayley) diff --git a/sd_meh/utils.py b/sd_meh/utils.py index c361923..39efe98 100644 --- a/sd_meh/utils.py +++ b/sd_meh/utils.py @@ -125,34 +125,3 @@ def weights_and_bases( bases |= bases_beta return weights, bases - - -def is_orthogonal(matrix, tol=1e-6): - identity = torch.eye(matrix.size(0), dtype=matrix.dtype, device=matrix.device) - matrix_transpose = matrix.t() - return torch.allclose(matrix @ matrix_transpose, identity, atol=tol) and \ - torch.allclose(matrix_transpose @ matrix, identity, atol=tol) - - -# Cayley transform of matrix A -def cayley_transform(A): - I = torch.eye(A.size(0), dtype=A.dtype, device=A.device) - return (I - A) @ torch.inverse(I + A) - - -# Inverse Cayley transform -def inverse_cayley_transform(X): - I = torch.eye(X.size(0), dtype=X.dtype, device=X.device) - return torch.inverse(I - X) @ (I + X) - - -# Interpolate between identity and orthogonal matrix using Cayley transform -def interpolate_cayley(m, t): - # Cayley transform of A - X = cayley_transform(m) - - # Scale X - Y = X * t - - # Inverse Cayley transform - return inverse_cayley_transform(Y) From f11c0548dffa1cc5ba9bb10280645e3562c736c1 Mon Sep 17 00:00:00 2001 From: ljleb Date: Fri, 10 Nov 2023 02:27:08 -0500 Subject: [PATCH 07/44] add method to __all__ --- sd_meh/merge_methods.py | 1 + sd_meh/utils.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 8d908fe..debc0c1 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -17,6 +17,7 @@ "similarity_add_difference", "distribution_crossover", "ties_add_difference", + "rotate", ] diff --git a/sd_meh/utils.py b/sd_meh/utils.py index 39efe98..f507ae8 100644 --- a/sd_meh/utils.py +++ b/sd_meh/utils.py @@ -1,6 +1,5 @@ import inspect import logging -import torch from sd_meh import merge_methods from sd_meh.merge import NUM_TOTAL_BLOCKS From 1f497e918dbbe02fb4c666d3d9b7b3acd147fe38 Mon Sep 17 00:00:00 2001 From: ljleb Date: Fri, 10 Nov 2023 02:55:01 -0500 Subject: [PATCH 08/44] include 1D 'rotation' --- sd_meh/merge_methods.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index debc0c1..ff4f34d 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -213,10 +213,10 @@ def filter_top_k(a: Tensor, k: float): def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): - if len(a.shape) <= 1: + if len(a.shape) == 0: return a - # make sure matrices are at most 2D + # make sure matrices are 2D a_reshape = a.reshape(-1, a.shape[-1]).float() b_reshape = b.reshape(-1, b.shape[-1]).float() From 36fccaad2bd103fc9f471fc9f0c6024e01af1828 Mon Sep 17 00:00:00 2001 From: ljleb Date: Fri, 10 Nov 2023 05:22:00 -0500 Subject: [PATCH 09/44] ignore alpha for now --- sd_meh/merge_methods.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index ff4f34d..6ba2685 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -1,6 +1,7 @@ import math from typing import Tuple - +import scipy +import numpy as np import torch from torch import Tensor @@ -216,18 +217,13 @@ def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): if len(a.shape) == 0: return a + shape = (a.shape[-1], -1) + # make sure matrices are 2D - a_reshape = a.reshape(-1, a.shape[-1]).float() - b_reshape = b.reshape(-1, b.shape[-1]).float() + a_reshape = a.reshape(*shape).float() + b_reshape = b.reshape(*shape).float() u, _, v = torch.svd(torch.matmul(a_reshape.T, b_reshape)) - transform = interpolate_cayley(torch.matmul(u, v.T), alpha) + transform = torch.matmul(u, v.T) return torch.matmul(a_reshape, transform).reshape_as(a).to(dtype=a.dtype) - - -def interpolate_cayley(matrix: Tensor, power: float): - identity = torch.eye(matrix.size(0), dtype=matrix.dtype, device=matrix.device) - cayley = (identity - matrix) @ torch.inverse(identity + matrix) - scaled_cayley = cayley * power - return torch.inverse(identity - scaled_cayley) @ (identity + scaled_cayley) From f18208d98aa0c3803729c8c9b9d5cd6f4a83d16c Mon Sep 17 00:00:00 2001 From: ljleb Date: Fri, 10 Nov 2023 05:39:14 -0500 Subject: [PATCH 10/44] refact --- sd_meh/merge_methods.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 6ba2685..665dccc 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -217,13 +217,11 @@ def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): if len(a.shape) == 0: return a - shape = (a.shape[-1], -1) + a_reshape = a.reshape(-1, a.shape[-1]).float() + b_reshape = b.reshape(-1, a.shape[-1]).float() - # make sure matrices are 2D - a_reshape = a.reshape(*shape).float() - b_reshape = b.reshape(*shape).float() - - u, _, v = torch.svd(torch.matmul(a_reshape.T, b_reshape)) + cross_covariance = torch.matmul(a_reshape.T, b_reshape) + u, _, v = torch.svd(cross_covariance) transform = torch.matmul(u, v.T) - - return torch.matmul(a_reshape, transform).reshape_as(a).to(dtype=a.dtype) + rotated_a = torch.matmul(a_reshape, transform) + return rotated_a.reshape_as(a).to(dtype=a.dtype) From 1dafe833850f317a5469e06d625805e1bc73d09f Mon Sep 17 00:00:00 2001 From: ljleb Date: Sat, 11 Nov 2023 02:34:58 -0500 Subject: [PATCH 11/44] implement fractional rotations --- sd_meh/merge.py | 2 ++ sd_meh/merge_methods.py | 36 ++++++++++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/sd_meh/merge.py b/sd_meh/merge.py index 760391e..7c462e7 100644 --- a/sd_meh/merge.py +++ b/sd_meh/merge.py @@ -420,6 +420,8 @@ def merge_key( merged_key = merge_args["b"] else: merged_key = merge_method(**merge_args).to(device) + gc.collect() + torch.cuda.empty_cache() if weights_clip: merged_key = clip_weights_key(thetas, merged_key, key) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 665dccc..04a543e 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -1,7 +1,6 @@ import math from typing import Tuple import scipy -import numpy as np import torch from torch import Tensor @@ -217,11 +216,32 @@ def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): if len(a.shape) == 0: return a - a_reshape = a.reshape(-1, a.shape[-1]).float() - b_reshape = b.reshape(-1, a.shape[-1]).float() + a_2d = a.reshape(-1, a.shape[-1]).float() + b_2d = b.reshape(-1, a.shape[-1]).float() + u, _, v = torch.svd(b_2d.T @ a_2d) + del _, b_2d - cross_covariance = torch.matmul(a_reshape.T, b_reshape) - u, _, v = torch.svd(cross_covariance) - transform = torch.matmul(u, v.T) - rotated_a = torch.matmul(a_reshape, transform) - return rotated_a.reshape_as(a).to(dtype=a.dtype) + if alpha == round(alpha): + transform = u @ v.T + if alpha != 1: + transform.copy_(torch.linalg.matrix_power(transform, round(alpha))) + else: + # remove flips: make det(transform) > 0 + # otherwise orthogonal_power(transform, alpha) will have a complex component + d = torch.ones(a_2d.shape[1], device=a_2d.device) + d[-1] = torch.linalg.det(u) * torch.linalg.det(v.T) + transform = u @ d @ v.T + del d + transform.copy_(fractional_matrix_power(transform, alpha)) + del u, v + + a_2d.copy_(a_2d @ transform) + return a_2d.reshape_as(a).to(dtype=a.dtype) + + +def fractional_matrix_power(matrix: Tensor, power: float): + eigenvalues, eigenvectors = torch.linalg.eig(matrix) + eigenvalues.copy_(eigenvalues ** power) + return ( + eigenvectors @ torch.diag(eigenvalues) @ eigenvectors.T + ).real.to(dtype=matrix.dtype) From 149ab1607316e049d4f284a886db48831efef6ba Mon Sep 17 00:00:00 2001 From: ljleb Date: Sat, 11 Nov 2023 03:03:46 -0500 Subject: [PATCH 12/44] fix transform direction --- sd_meh/merge_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 04a543e..1e5fb3f 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -218,7 +218,7 @@ def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): a_2d = a.reshape(-1, a.shape[-1]).float() b_2d = b.reshape(-1, a.shape[-1]).float() - u, _, v = torch.svd(b_2d.T @ a_2d) + u, _, v = torch.svd(a_2d.T @ b_2d) del _, b_2d if alpha == round(alpha): From 1f71391fa6e8cf51c7f40abaeb2fce410501c8ea Mon Sep 17 00:00:00 2001 From: ljleb Date: Sat, 11 Nov 2023 03:18:15 -0500 Subject: [PATCH 13/44] fix eye --- sd_meh/merge_methods.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 1e5fb3f..9f5de08 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -228,8 +228,8 @@ def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): else: # remove flips: make det(transform) > 0 # otherwise orthogonal_power(transform, alpha) will have a complex component - d = torch.ones(a_2d.shape[1], device=a_2d.device) - d[-1] = torch.linalg.det(u) * torch.linalg.det(v.T) + d = torch.eye(a_2d.shape[1], device=a_2d.device) + d[-1, -1] = torch.linalg.det(u) * torch.linalg.det(v.T) transform = u @ d @ v.T del d transform.copy_(fractional_matrix_power(transform, alpha)) From b464fd350ddb5da2f4b780e00f36ba413b55bf40 Mon Sep 17 00:00:00 2001 From: ljleb Date: Sat, 11 Nov 2023 17:02:17 -0500 Subject: [PATCH 14/44] rewrite with out= --- sd_meh/merge_methods.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 9f5de08..6a7cc27 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -1,6 +1,7 @@ import math +import sys from typing import Tuple -import scipy +from kmeans_pytorch import kmeans import torch from torch import Tensor @@ -218,30 +219,27 @@ def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): a_2d = a.reshape(-1, a.shape[-1]).float() b_2d = b.reshape(-1, a.shape[-1]).float() - u, _, v = torch.svd(a_2d.T @ b_2d) - del _, b_2d + u, sigma, v = torch.svd(a_2d.T @ b_2d) + transform = u @ v.T if alpha == round(alpha): - transform = u @ v.T if alpha != 1: - transform.copy_(torch.linalg.matrix_power(transform, round(alpha))) + torch.linalg.matrix_power(transform, round(alpha), out=transform) else: - # remove flips: make det(transform) > 0 - # otherwise orthogonal_power(transform, alpha) will have a complex component - d = torch.eye(a_2d.shape[1], device=a_2d.device) - d[-1, -1] = torch.linalg.det(u) * torch.linalg.det(v.T) - transform = u @ d @ v.T - del d + if torch.linalg.det(transform) < 0: + # remove reflection, otherwise we get a complex component + u[:, -1] *= -1 + torch.matmul(u, v.T, out=transform) + transform.copy_(fractional_matrix_power(transform, alpha)) - del u, v - a_2d.copy_(a_2d @ transform) + torch.matmul(a_2d, transform, out=a_2d) return a_2d.reshape_as(a).to(dtype=a.dtype) def fractional_matrix_power(matrix: Tensor, power: float): eigenvalues, eigenvectors = torch.linalg.eig(matrix) - eigenvalues.copy_(eigenvalues ** power) + eigenvalues.pow_(power) return ( eigenvectors @ torch.diag(eigenvalues) @ eigenvectors.T ).real.to(dtype=matrix.dtype) From e1dc59cb0e441d03ac3fcf5a2cb1503a80f0bf24 Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 12 Nov 2023 01:33:48 -0500 Subject: [PATCH 15/44] it works; opt now --- sd_meh/merge_methods.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 6a7cc27..de52e0a 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -214,23 +214,20 @@ def filter_top_k(a: Tensor, k: float): def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): - if len(a.shape) == 0: + if len(a.shape) == 0 or (a == b).all(): return a - a_2d = a.reshape(-1, a.shape[-1]).float() - b_2d = b.reshape(-1, a.shape[-1]).float() - u, sigma, v = torch.svd(a_2d.T @ b_2d) - transform = u @ v.T + a_2d = a.reshape(-1, a.shape[-1]).double() + b_2d = b.reshape(-1, b.shape[-1]).double() + u, _, v_t = torch.linalg.svd(torch.matmul(a_2d.T, b_2d)) + transform = torch.matmul(u, v_t) if alpha == round(alpha): if alpha != 1: torch.linalg.matrix_power(transform, round(alpha), out=transform) else: - if torch.linalg.det(transform) < 0: - # remove reflection, otherwise we get a complex component - u[:, -1] *= -1 - torch.matmul(u, v.T, out=transform) - + u[:, -1] *= torch.det(transform) + torch.matmul(u, v_t, out=transform) transform.copy_(fractional_matrix_power(transform, alpha)) torch.matmul(a_2d, transform, out=a_2d) @@ -240,6 +237,8 @@ def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): def fractional_matrix_power(matrix: Tensor, power: float): eigenvalues, eigenvectors = torch.linalg.eig(matrix) eigenvalues.pow_(power) - return ( - eigenvectors @ torch.diag(eigenvalues) @ eigenvectors.T - ).real.to(dtype=matrix.dtype) + result = ( + eigenvectors @ torch.diag(eigenvalues) @ eigenvectors.H + ) + print(f"complex error: {torch.linalg.norm(result.imag)}") + return result.real.to(dtype=matrix.dtype) From cbb6a06e5fbc2b1585e60122e2148de019a73743 Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 12 Nov 2023 02:45:48 -0500 Subject: [PATCH 16/44] optimize: 45m -> 7m --- sd_meh/merge_methods.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index de52e0a..5e602c8 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -217,9 +217,9 @@ def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): if len(a.shape) == 0 or (a == b).all(): return a - a_2d = a.reshape(-1, a.shape[-1]).double() - b_2d = b.reshape(-1, b.shape[-1]).double() - u, _, v_t = torch.linalg.svd(torch.matmul(a_2d.T, b_2d)) + a_2d = a.reshape(-1, a.shape[-1]).float() + b_2d = b.reshape(-1, b.shape[-1]).float() + u, _, v_t = torch.linalg.svd(torch.matmul(a_2d.T, b_2d), driver="gesvd" if a.is_cuda else None) transform = torch.matmul(u, v_t) if alpha == round(alpha): @@ -227,11 +227,12 @@ def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): torch.linalg.matrix_power(transform, round(alpha), out=transform) else: u[:, -1] *= torch.det(transform) - torch.matmul(u, v_t, out=transform) + transform.copy_(torch.matmul(u, v_t)) transform.copy_(fractional_matrix_power(transform, alpha)) - torch.matmul(a_2d, transform, out=a_2d) - return a_2d.reshape_as(a).to(dtype=a.dtype) + torch.matmul(a_2d.float(), transform, out=a_2d) + res = a_2d.reshape_as(a).to(dtype=a.dtype) + return res def fractional_matrix_power(matrix: Tensor, power: float): @@ -241,4 +242,4 @@ def fractional_matrix_power(matrix: Tensor, power: float): eigenvectors @ torch.diag(eigenvalues) @ eigenvectors.H ) print(f"complex error: {torch.linalg.norm(result.imag)}") - return result.real.to(dtype=matrix.dtype) + return result.real From ce62946547472ce448ce0637f6fb9342e7ce36ed Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 12 Nov 2023 02:49:29 -0500 Subject: [PATCH 17/44] rm print --- sd_meh/merge_methods.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 5e602c8..d28b7f2 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -231,15 +231,11 @@ def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): transform.copy_(fractional_matrix_power(transform, alpha)) torch.matmul(a_2d.float(), transform, out=a_2d) - res = a_2d.reshape_as(a).to(dtype=a.dtype) - return res + return a_2d.reshape_as(a).to(dtype=a.dtype) def fractional_matrix_power(matrix: Tensor, power: float): eigenvalues, eigenvectors = torch.linalg.eig(matrix) eigenvalues.pow_(power) - result = ( - eigenvectors @ torch.diag(eigenvalues) @ eigenvectors.H - ) - print(f"complex error: {torch.linalg.norm(result.imag)}") + result = eigenvectors @ torch.diag(eigenvalues) @ eigenvectors.H return result.real From 8172927287c609184f83579d3e11a87294b164b4 Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 12 Nov 2023 13:50:05 -0500 Subject: [PATCH 18/44] fix precision issues --- sd_meh/merge_methods.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index d28b7f2..2766283 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -220,22 +220,23 @@ def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): a_2d = a.reshape(-1, a.shape[-1]).float() b_2d = b.reshape(-1, b.shape[-1]).float() u, _, v_t = torch.linalg.svd(torch.matmul(a_2d.T, b_2d), driver="gesvd" if a.is_cuda else None) - transform = torch.matmul(u, v_t) if alpha == round(alpha): + transform = torch.matmul(u, v_t) if alpha != 1: - torch.linalg.matrix_power(transform, round(alpha), out=transform) + transform = torch.linalg.matrix_power(transform, round(alpha)) else: - u[:, -1] *= torch.det(transform) - transform.copy_(torch.matmul(u, v_t)) - transform.copy_(fractional_matrix_power(transform, alpha)) + u[:, -1] /= torch.det(u) * torch.det(v_t) + transform = torch.matmul(u, v_t) + transform = fractional_matrix_power(transform, alpha) - torch.matmul(a_2d.float(), transform, out=a_2d) + a_2d = torch.matmul(a_2d, transform) return a_2d.reshape_as(a).to(dtype=a.dtype) def fractional_matrix_power(matrix: Tensor, power: float): - eigenvalues, eigenvectors = torch.linalg.eig(matrix) + eigenvalues, eigenvectors = torch.linalg.eig(matrix.double()) eigenvalues.pow_(power) - result = eigenvectors @ torch.diag(eigenvalues) @ eigenvectors.H - return result.real + result = eigenvectors @ torch.diag(eigenvalues) @ torch.linalg.inv(eigenvectors) + error = torch.linalg.vector_norm(result.imag) + return result.real.to(dtype=matrix.dtype) From 19fcc0a1646413699dfc5430a768325c6dccb72e Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 12 Nov 2023 13:50:28 -0500 Subject: [PATCH 19/44] fix precision issues --- sd_meh/merge_methods.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 2766283..94e8481 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -238,5 +238,4 @@ def fractional_matrix_power(matrix: Tensor, power: float): eigenvalues, eigenvectors = torch.linalg.eig(matrix.double()) eigenvalues.pow_(power) result = eigenvectors @ torch.diag(eigenvalues) @ torch.linalg.inv(eigenvectors) - error = torch.linalg.vector_norm(result.imag) return result.real.to(dtype=matrix.dtype) From f9542703d4a011e2df4c9f0318ef0f0d317d2fa1 Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 12 Nov 2023 13:52:05 -0500 Subject: [PATCH 20/44] black --- sd_meh/merge_methods.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 94e8481..7e1ffad 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -219,7 +219,8 @@ def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): a_2d = a.reshape(-1, a.shape[-1]).float() b_2d = b.reshape(-1, b.shape[-1]).float() - u, _, v_t = torch.linalg.svd(torch.matmul(a_2d.T, b_2d), driver="gesvd" if a.is_cuda else None) + svd_driver = "gesvd" if a.is_cuda else None + u, _, v_t = torch.linalg.svd(torch.matmul(a_2d.T, b_2d), driver=svd_driver) if alpha == round(alpha): transform = torch.matmul(u, v_t) From e94e25275d9c9d5e5a232bbdeb806d9c4444f3ea Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 12 Nov 2023 13:53:17 -0500 Subject: [PATCH 21/44] dont change --- sd_meh/merge.py | 2 -- sd_meh/merge_methods.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/sd_meh/merge.py b/sd_meh/merge.py index 7c462e7..760391e 100644 --- a/sd_meh/merge.py +++ b/sd_meh/merge.py @@ -420,8 +420,6 @@ def merge_key( merged_key = merge_args["b"] else: merged_key = merge_method(**merge_args).to(device) - gc.collect() - torch.cuda.empty_cache() if weights_clip: merged_key = clip_weights_key(thetas, merged_key, key) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 7e1ffad..84f4197 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -1,7 +1,5 @@ import math -import sys from typing import Tuple -from kmeans_pytorch import kmeans import torch from torch import Tensor From 1f380c8997bf7cc0780546056dc264a822f7d492 Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 12 Nov 2023 13:53:42 -0500 Subject: [PATCH 22/44] imps --- sd_meh/merge_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 84f4197..da040c7 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -1,7 +1,7 @@ import math -from typing import Tuple import torch from torch import Tensor +from typing import Tuple __all__ = [ "weighted_sum", From ea95b6665a1ec58fafe183734f566a35f26a10bf Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 12 Nov 2023 17:14:40 -0500 Subject: [PATCH 23/44] beta is deformation --- sd_meh/merge_methods.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index da040c7..9ed6703 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -211,7 +211,7 @@ def filter_top_k(a: Tensor, k: float): return a * top_k_filter -def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): +def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): if len(a.shape) == 0 or (a == b).all(): return a @@ -220,16 +220,25 @@ def rotate(a: Tensor, b: Tensor, alpha: float, **kwargs): svd_driver = "gesvd" if a.is_cuda else None u, _, v_t = torch.linalg.svd(torch.matmul(a_2d.T, b_2d), driver=svd_driver) - if alpha == round(alpha): - transform = torch.matmul(u, v_t) - if alpha != 1: - transform = torch.linalg.matrix_power(transform, round(alpha)) - else: + alpha_is_float = alpha != round(alpha) + if alpha_is_float: u[:, -1] /= torch.det(u) * torch.det(v_t) - transform = torch.matmul(u, v_t) + + transform = rotation = u @ v_t + if beta != 0: + # remove the full rotation used to apply beta + alpha -= 1 + + if alpha_is_float: transform = fractional_matrix_power(transform, alpha) + elif alpha != 1: + transform = torch.linalg.matrix_power(transform, round(alpha)) + + if beta != 0: + a_2d = weighted_sum(a_2d @ rotation, b_2d, beta) + # alpha was decremented, no need to apply @ rotation.T - a_2d = torch.matmul(a_2d, transform) + a_2d @= transform return a_2d.reshape_as(a).to(dtype=a.dtype) From 1751f5967c8aa9104c165cd84a02ef7eab1a36ce Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 12 Nov 2023 17:55:47 -0500 Subject: [PATCH 24/44] simplify --- sd_meh/merge_methods.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 9ed6703..0123220 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -225,18 +225,13 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): u[:, -1] /= torch.det(u) * torch.det(v_t) transform = rotation = u @ v_t - if beta != 0: - # remove the full rotation used to apply beta - alpha -= 1 - if alpha_is_float: transform = fractional_matrix_power(transform, alpha) elif alpha != 1: transform = torch.linalg.matrix_power(transform, round(alpha)) if beta != 0: - a_2d = weighted_sum(a_2d @ rotation, b_2d, beta) - # alpha was decremented, no need to apply @ rotation.T + a_2d = weighted_sum(a_2d, b_2d @ rotation.T, beta) a_2d @= transform return a_2d.reshape_as(a).to(dtype=a.dtype) From c69bb952977eed0128fc7e2a324d1e16932606b2 Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 12 Nov 2023 18:23:46 -0500 Subject: [PATCH 25/44] @ --- sd_meh/merge_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 0123220..7e4f61a 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -218,7 +218,7 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): a_2d = a.reshape(-1, a.shape[-1]).float() b_2d = b.reshape(-1, b.shape[-1]).float() svd_driver = "gesvd" if a.is_cuda else None - u, _, v_t = torch.linalg.svd(torch.matmul(a_2d.T, b_2d), driver=svd_driver) + u, _, v_t = torch.linalg.svd(a_2d.T @ b_2d, driver=svd_driver) alpha_is_float = alpha != round(alpha) if alpha_is_float: From 0d5160b351a881cd3520900ef3c0d2edd95ff1a5 Mon Sep 17 00:00:00 2001 From: ljleb Date: Mon, 13 Nov 2023 04:13:06 -0500 Subject: [PATCH 26/44] backup --- sd_meh/merge_methods.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 7e4f61a..16000f2 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -215,8 +215,27 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): if len(a.shape) == 0 or (a == b).all(): return a - a_2d = a.reshape(-1, a.shape[-1]).float() - b_2d = b.reshape(-1, b.shape[-1]).float() + if len(a.shape) == 4: # conv + # ideally we should stack each n x n kernel into 1D for more freedom + # however, this brings the number of dimensions of the covariance matrix + # to a very high number for some layers (> 10k x 10k) + # SVD is not practical in these cases + # so instead, we break down the conv kernel into individual input features + # to lock their angles and distances along with the + a_2d = a.permute(0, 2, 3, 1).reshape(-1, a.shape[1]).float() + b_2d = b.permute(0, 2, 3, 1).reshape(-1, a.shape[1]).float() + + def reshape_fn(m): + m = m.reshape(a.shape[0], a.shape[2], a.shape[3], a.shape[1]) + m = m.permute(0, 3, 1, 2) + return m.contiguous() # apparently needed for saving + else: + a_2d = a.reshape(-1, a.shape[-1]).float() + b_2d = b.reshape(-1, b.shape[-1]).float() + + def reshape_fn(m): + return m.reshape_as(a) + svd_driver = "gesvd" if a.is_cuda else None u, _, v_t = torch.linalg.svd(a_2d.T @ b_2d, driver=svd_driver) @@ -234,7 +253,7 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): a_2d = weighted_sum(a_2d, b_2d @ rotation.T, beta) a_2d @= transform - return a_2d.reshape_as(a).to(dtype=a.dtype) + return reshape_fn(a_2d).to(dtype=a.dtype) def fractional_matrix_power(matrix: Tensor, power: float): From 1920496b3641b28f64024825cb414540f5550caf Mon Sep 17 00:00:00 2001 From: ljleb Date: Mon, 13 Nov 2023 16:08:44 -0500 Subject: [PATCH 27/44] deal with conv attention shape, rotate centroids --- sd_meh/merge_methods.py | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 16000f2..14d4a8c 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -212,29 +212,22 @@ def filter_top_k(a: Tensor, k: float): def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): - if len(a.shape) == 0 or (a == b).all(): + if len(a.shape) == 0 or (len(a.shape) == 4 and a.shape[-1] != 1) or (a == b).all(): return a - if len(a.shape) == 4: # conv - # ideally we should stack each n x n kernel into 1D for more freedom - # however, this brings the number of dimensions of the covariance matrix - # to a very high number for some layers (> 10k x 10k) - # SVD is not practical in these cases - # so instead, we break down the conv kernel into individual input features - # to lock their angles and distances along with the - a_2d = a.permute(0, 2, 3, 1).reshape(-1, a.shape[1]).float() - b_2d = b.permute(0, 2, 3, 1).reshape(-1, a.shape[1]).float() - - def reshape_fn(m): - m = m.reshape(a.shape[0], a.shape[2], a.shape[3], a.shape[1]) - m = m.permute(0, 3, 1, 2) - return m.contiguous() # apparently needed for saving + if len(a.shape) == 4: + shape_2d = (-1, a.shape[1]) else: - a_2d = a.reshape(-1, a.shape[-1]).float() - b_2d = b.reshape(-1, b.shape[-1]).float() + shape_2d = (-1, a.shape[-1]) - def reshape_fn(m): - return m.reshape_as(a) + a_2d = a.reshape(*shape_2d).double() + b_2d = b.reshape(*shape_2d).double() + + a_centroid = a_2d.mean(0) + b_centroid = b_2d.mean(0) + new_centroid = sample_ellipsis(a_centroid, b_centroid, 2 * torch.pi * alpha) + a_2d -= a_centroid + b_2d -= b_centroid svd_driver = "gesvd" if a.is_cuda else None u, _, v_t = torch.linalg.svd(a_2d.T @ b_2d, driver=svd_driver) @@ -253,7 +246,8 @@ def reshape_fn(m): a_2d = weighted_sum(a_2d, b_2d @ rotation.T, beta) a_2d @= transform - return reshape_fn(a_2d).to(dtype=a.dtype) + a_2d += new_centroid + return a_2d.reshape_as(a).to(dtype=a.dtype) def fractional_matrix_power(matrix: Tensor, power: float): @@ -261,3 +255,10 @@ def fractional_matrix_power(matrix: Tensor, power: float): eigenvalues.pow_(power) result = eigenvectors @ torch.diag(eigenvalues) @ torch.linalg.inv(eigenvectors) return result.real.to(dtype=matrix.dtype) + + +def sample_ellipsis(a, b, t): + return torch.column_stack((a, b)) @ torch.tensor([ + math.sin(t), + math.cos(t), + ], dtype=a.dtype, device=a.device) From 5a1c776ab53d6a3639d684e443976ea610b38216 Mon Sep 17 00:00:00 2001 From: ljleb Date: Thu, 16 Nov 2023 20:31:12 -0500 Subject: [PATCH 28/44] black --- sd_meh/merge_methods.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 14d4a8c..b90065b 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -258,7 +258,11 @@ def fractional_matrix_power(matrix: Tensor, power: float): def sample_ellipsis(a, b, t): - return torch.column_stack((a, b)) @ torch.tensor([ - math.sin(t), - math.cos(t), - ], dtype=a.dtype, device=a.device) + return torch.column_stack((a, b)) @ torch.tensor( + [ + math.sin(t), + math.cos(t), + ], + dtype=a.dtype, + device=a.device, + ) From f61d6aa15741f18ed48ad49627975009a61c5738 Mon Sep 17 00:00:00 2001 From: ljleb Date: Sat, 18 Nov 2023 00:31:23 -0500 Subject: [PATCH 29/44] wip --- sd_meh/merge.py | 6 ++++- sd_meh/merge_methods.py | 58 ++++++++++++++++++++++++++++++++--------- 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/sd_meh/merge.py b/sd_meh/merge.py index 760391e..b1ac4af 100644 --- a/sd_meh/merge.py +++ b/sd_meh/merge.py @@ -419,7 +419,11 @@ def merge_key( else: merged_key = merge_args["b"] else: - merged_key = merge_method(**merge_args).to(device) + try: + merged_key = merge_method(**merge_args).to(device) + except Exception as e: + print(merge_args["a"].shape, e) + raise if weights_clip: merged_key = clip_weights_key(thetas, merged_key, key) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index b90065b..04babca 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -1,4 +1,7 @@ +import functools import math +import operator + import torch from torch import Tensor from typing import Tuple @@ -212,25 +215,40 @@ def filter_top_k(a: Tensor, k: float): def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): - if len(a.shape) == 0 or (len(a.shape) == 4 and a.shape[-1] != 1) or (a == b).all(): + if alpha == 0 and beta == 0: return a + is_large_conv = len(a.shape) == 4 #and a.shape[-1] != 1 + if len(a.shape) == 0 or is_large_conv or torch.allclose(a, b): + return weighted_sum(a, b, beta) + if len(a.shape) == 4: shape_2d = (-1, a.shape[1]) else: shape_2d = (-1, a.shape[-1]) - a_2d = a.reshape(*shape_2d).double() - b_2d = b.reshape(*shape_2d).double() + a_neurons = a.reshape(*shape_2d) + b_neurons = b.reshape(*shape_2d) + + # reciprocal function used to reduce the dimensionality of neurons + # this allows us to solve the procrustes problem on: + # - all dimensions of small neurons (< 2560) + # - a fraction of dimensions for larger neurons (>= 2560) + # for a tradeoff between more quality and longer merge time + similarity_threshold = 1#min(1.0, 2 / a_neurons.shape[1] + 0.999) - a_centroid = a_2d.mean(0) - b_centroid = b_2d.mean(0) + neuron_dims = reduce_dimensions(a_neurons, b_neurons, similarity_threshold) + a_neurons = a_neurons[:, neuron_dims].double() + b_neurons = b_neurons[:, neuron_dims].double() + + a_centroid = a_neurons.mean(0) + b_centroid = b_neurons.mean(0) new_centroid = sample_ellipsis(a_centroid, b_centroid, 2 * torch.pi * alpha) - a_2d -= a_centroid - b_2d -= b_centroid + a_neurons -= a_centroid + b_neurons -= b_centroid svd_driver = "gesvd" if a.is_cuda else None - u, _, v_t = torch.linalg.svd(a_2d.T @ b_2d, driver=svd_driver) + u, _, v_t = torch.linalg.svd(a_neurons.T @ b_neurons, driver=svd_driver) alpha_is_float = alpha != round(alpha) if alpha_is_float: @@ -243,11 +261,14 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): transform = torch.linalg.matrix_power(transform, round(alpha)) if beta != 0: - a_2d = weighted_sum(a_2d, b_2d @ rotation.T, beta) + a_neurons = weighted_sum(a_neurons, b_neurons @ rotation.T, beta) + + a_neurons @= transform + a_neurons += new_centroid - a_2d @= transform - a_2d += new_centroid - return a_2d.reshape_as(a).to(dtype=a.dtype) + a.view(*shape_2d)[:, neuron_dims] = a_neurons.to(a.dtype) + # a_2d_large[:, ~mask] = weighted_sum(a_2d_large[:, ~mask], b_2d_large[:, ~mask], beta) + return a def fractional_matrix_power(matrix: Tensor, power: float): @@ -266,3 +287,16 @@ def sample_ellipsis(a, b, t): dtype=a.dtype, device=a.device, ) + + +def reduce_dimensions(a_neurons, b_neurons, threshold): + a_unit = a_neurons / torch.linalg.vector_norm(a_neurons, dim=0, keepdim=True) + b_unit = b_neurons / torch.linalg.vector_norm(b_neurons, dim=0, keepdim=True) + neuron_similarities = torch.sum(a_unit * b_unit, dim=0) + + neuron_dims = torch.ones(a_neurons.shape[1], dtype=torch.bool) + indices_to_disable = torch.nonzero(neuron_similarities >= threshold) + neuron_dims[indices_to_disable] = False + + print(f"\n{a_neurons.shape[1]} -> {len(torch.nonzero(neuron_dims))} by {threshold}") + return neuron_dims From 6ac82d66e1277c4e141e3ececcc6abb8e780ca4e Mon Sep 17 00:00:00 2001 From: ljleb Date: Sat, 18 Nov 2023 03:44:03 -0500 Subject: [PATCH 30/44] refact --- sd_meh/merge_methods.py | 44 +++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 04babca..002a964 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -1,7 +1,4 @@ -import functools import math -import operator - import torch from torch import Tensor from typing import Tuple @@ -218,8 +215,8 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): if alpha == 0 and beta == 0: return a - is_large_conv = len(a.shape) == 4 #and a.shape[-1] != 1 - if len(a.shape) == 0 or is_large_conv or torch.allclose(a, b): + is_conv = len(a.shape) == 4 #and a.shape[-1] != 1 + if len(a.shape) == 0 or is_conv or torch.allclose(a, b): return weighted_sum(a, b, beta) if len(a.shape) == 4: @@ -227,25 +224,25 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): else: shape_2d = (-1, a.shape[-1]) - a_neurons = a.reshape(*shape_2d) - b_neurons = b.reshape(*shape_2d) + a_view = a.view(*shape_2d) + b_view = b.view(*shape_2d) + + a_centroid = a_view.mean(0) + b_centroid = b_view.mean(0) + new_centroid = sample_ellipsis(a_centroid, b_centroid, 2 * torch.pi * alpha) + a_neurons_large = a_view - a_centroid + b_neurons_large = b_view - b_centroid # reciprocal function used to reduce the dimensionality of neurons # this allows us to solve the procrustes problem on: # - all dimensions of small neurons (< 2560) # - a fraction of dimensions for larger neurons (>= 2560) # for a tradeoff between more quality and longer merge time - similarity_threshold = 1#min(1.0, 2 / a_neurons.shape[1] + 0.999) + similarity_threshold = min(1.0, 2 / a_neurons_large.shape[1] + 0.999) - neuron_dims = reduce_dimensions(a_neurons, b_neurons, similarity_threshold) - a_neurons = a_neurons[:, neuron_dims].double() - b_neurons = b_neurons[:, neuron_dims].double() - - a_centroid = a_neurons.mean(0) - b_centroid = b_neurons.mean(0) - new_centroid = sample_ellipsis(a_centroid, b_centroid, 2 * torch.pi * alpha) - a_neurons -= a_centroid - b_neurons -= b_centroid + rotation_dims = reduce_dimensions(a_neurons_large, b_neurons_large, similarity_threshold) + a_neurons = a_neurons_large[:, rotation_dims].double() + b_neurons = b_neurons_large[:, rotation_dims].double() svd_driver = "gesvd" if a.is_cuda else None u, _, v_t = torch.linalg.svd(a_neurons.T @ b_neurons, driver=svd_driver) @@ -264,11 +261,16 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): a_neurons = weighted_sum(a_neurons, b_neurons @ rotation.T, beta) a_neurons @= transform - a_neurons += new_centroid - a.view(*shape_2d)[:, neuron_dims] = a_neurons.to(a.dtype) - # a_2d_large[:, ~mask] = weighted_sum(a_2d_large[:, ~mask], b_2d_large[:, ~mask], beta) - return a + res = torch.empty_like(a_neurons_large) + res[:, rotation_dims] = a_neurons.to(res.dtype) + res[:, ~rotation_dims] = weighted_sum( + a_neurons_large[:, ~rotation_dims], + b_neurons_large[:, ~rotation_dims], + beta, + ) + res += new_centroid + return res.reshape_as(a) def fractional_matrix_power(matrix: Tensor, power: float): From 7fff089cd55651cea923e6f8f0026c9d29f0637e Mon Sep 17 00:00:00 2001 From: ljleb Date: Mon, 20 Nov 2023 18:30:04 -0500 Subject: [PATCH 31/44] backup --- sd_meh/merge.py | 3 ++ sd_meh/merge_methods.py | 76 ++++++++++++++++++++++++++++------------- 2 files changed, 55 insertions(+), 24 deletions(-) diff --git a/sd_meh/merge.py b/sd_meh/merge.py index b1ac4af..8fdcd96 100644 --- a/sd_meh/merge.py +++ b/sd_meh/merge.py @@ -424,6 +424,9 @@ def merge_key( except Exception as e: print(merge_args["a"].shape, e) raise + finally: + gc.collect() + torch.cuda.empty_cache() if weights_clip: merged_key = clip_weights_key(thetas, merged_key, key) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 002a964..b9aa00f 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -1,4 +1,6 @@ +import functools import math +import operator import torch from torch import Tensor from typing import Tuple @@ -215,58 +217,78 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): if alpha == 0 and beta == 0: return a - is_conv = len(a.shape) == 4 #and a.shape[-1] != 1 + is_conv = len(a.shape) == 4 and a.shape[-1] != 1 if len(a.shape) == 0 or is_conv or torch.allclose(a, b): return weighted_sum(a, b, beta) if len(a.shape) == 4: - shape_2d = (-1, a.shape[1]) + shape_2d = (-1, functools.reduce(operator.mul, a.shape[1:])) else: shape_2d = (-1, a.shape[-1]) - a_view = a.view(*shape_2d) - b_view = b.view(*shape_2d) + a_neurons_large = a.reshape(*shape_2d).double() + b_neurons_large = b.reshape(*shape_2d).double() - a_centroid = a_view.mean(0) - b_centroid = b_view.mean(0) + a_centroid = a_neurons_large.mean(0) + b_centroid = b_neurons_large.mean(0) new_centroid = sample_ellipsis(a_centroid, b_centroid, 2 * torch.pi * alpha) - a_neurons_large = a_view - a_centroid - b_neurons_large = b_view - b_centroid + a_neurons_large -= a_centroid + b_neurons_large -= b_centroid + + if len(a.shape) == 1 or len(a.shape) == 2 and a.shape[0] == 1: + print(a.shape, "is 1D") + return new_centroid.reshape_as(a) # reciprocal function used to reduce the dimensionality of neurons + # it determines the cosine similarity threshold to drop dimensions # this allows us to solve the procrustes problem on: # - all dimensions of small neurons (< 2560) # - a fraction of dimensions for larger neurons (>= 2560) # for a tradeoff between more quality and longer merge time - similarity_threshold = min(1.0, 2 / a_neurons_large.shape[1] + 0.999) + threshold = 1#min(1.0, 0.125 / shape_2d[1] + 0.9999) - rotation_dims = reduce_dimensions(a_neurons_large, b_neurons_large, similarity_threshold) - a_neurons = a_neurons_large[:, rotation_dims].double() - b_neurons = b_neurons_large[:, rotation_dims].double() + dims_to_rotate = get_irreducible_dimensions( + a_neurons_large, + b_neurons_large, + threshold, + max_dims=5760, # dims of one of the large layers + ) + # rotation_dims[:] = True + a_neurons = a_neurons_large[:, dims_to_rotate] + b_neurons = b_neurons_large[:, dims_to_rotate] svd_driver = "gesvd" if a.is_cuda else None u, _, v_t = torch.linalg.svd(a_neurons.T @ b_neurons, driver=svd_driver) alpha_is_float = alpha != round(alpha) if alpha_is_float: + # cancel reflection. without this, eigenvalues often have a complex component + # and then we can't obtain a valid dtype for the merge u[:, -1] /= torch.det(u) * torch.det(v_t) transform = rotation = u @ v_t if alpha_is_float: transform = fractional_matrix_power(transform, alpha) + elif alpha == 0: + transform = torch.eye( + len(transform), + dtype=transform.dtype, + device=transform.device, + ) elif alpha != 1: transform = torch.linalg.matrix_power(transform, round(alpha)) if beta != 0: + # interpolate the relationship between the neurons a_neurons = weighted_sum(a_neurons, b_neurons @ rotation.T, beta) a_neurons @= transform res = torch.empty_like(a_neurons_large) - res[:, rotation_dims] = a_neurons.to(res.dtype) - res[:, ~rotation_dims] = weighted_sum( - a_neurons_large[:, ~rotation_dims], - b_neurons_large[:, ~rotation_dims], + res[:, dims_to_rotate] = a_neurons.to(res.dtype) + res[:, ~dims_to_rotate] = weighted_sum( + a_neurons_large[:, ~dims_to_rotate], + b_neurons_large[:, ~dims_to_rotate], beta, ) res += new_centroid @@ -291,14 +313,20 @@ def sample_ellipsis(a, b, t): ) -def reduce_dimensions(a_neurons, b_neurons, threshold): +def get_irreducible_dimensions(a_neurons, b_neurons, threshold, max_dims): a_unit = a_neurons / torch.linalg.vector_norm(a_neurons, dim=0, keepdim=True) b_unit = b_neurons / torch.linalg.vector_norm(b_neurons, dim=0, keepdim=True) - neuron_similarities = torch.sum(a_unit * b_unit, dim=0) - - neuron_dims = torch.ones(a_neurons.shape[1], dtype=torch.bool) - indices_to_disable = torch.nonzero(neuron_similarities >= threshold) - neuron_dims[indices_to_disable] = False - - print(f"\n{a_neurons.shape[1]} -> {len(torch.nonzero(neuron_dims))} by {threshold}") + similarities = torch.sum(a_unit * b_unit, dim=0) + + neuron_dims = torch.zeros(a_neurons.shape[1], dtype=torch.bool) + indices_to_keep = torch.nonzero(similarities < threshold) + if len(indices_to_keep) > max_dims: + sorted_indices = sorted( + indices_to_keep, + key=similarities.__getitem__, + ) + indices_to_keep = sorted_indices[:max_dims] + neuron_dims[indices_to_keep] = True + + print(f"\n{a_neurons.shape[1]} -> {len(torch.nonzero(neuron_dims))}, {100 * len(torch.nonzero(neuron_dims)) / a_neurons.shape[1]:.2f}% under {threshold}, max {float(max(similarities[i] for i in indices_to_keep))}") return neuron_dims From 6ddc503c3fe683d685f5b50b721396281577338e Mon Sep 17 00:00:00 2001 From: ljleb Date: Mon, 20 Nov 2023 19:01:47 -0500 Subject: [PATCH 32/44] remove approx --- sd_meh/merge_methods.py | 62 ++++++----------------------------------- 1 file changed, 8 insertions(+), 54 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index b9aa00f..8fea6c6 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -226,37 +226,18 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): else: shape_2d = (-1, a.shape[-1]) - a_neurons_large = a.reshape(*shape_2d).double() - b_neurons_large = b.reshape(*shape_2d).double() + a_neurons = a.reshape(*shape_2d).double() + b_neurons = b.reshape(*shape_2d).double() - a_centroid = a_neurons_large.mean(0) - b_centroid = b_neurons_large.mean(0) + a_centroid = a_neurons.mean(0) + b_centroid = b_neurons.mean(0) new_centroid = sample_ellipsis(a_centroid, b_centroid, 2 * torch.pi * alpha) - a_neurons_large -= a_centroid - b_neurons_large -= b_centroid + a_neurons -= a_centroid + b_neurons -= b_centroid if len(a.shape) == 1 or len(a.shape) == 2 and a.shape[0] == 1: - print(a.shape, "is 1D") return new_centroid.reshape_as(a) - # reciprocal function used to reduce the dimensionality of neurons - # it determines the cosine similarity threshold to drop dimensions - # this allows us to solve the procrustes problem on: - # - all dimensions of small neurons (< 2560) - # - a fraction of dimensions for larger neurons (>= 2560) - # for a tradeoff between more quality and longer merge time - threshold = 1#min(1.0, 0.125 / shape_2d[1] + 0.9999) - - dims_to_rotate = get_irreducible_dimensions( - a_neurons_large, - b_neurons_large, - threshold, - max_dims=5760, # dims of one of the large layers - ) - # rotation_dims[:] = True - a_neurons = a_neurons_large[:, dims_to_rotate] - b_neurons = b_neurons_large[:, dims_to_rotate] - svd_driver = "gesvd" if a.is_cuda else None u, _, v_t = torch.linalg.svd(a_neurons.T @ b_neurons, driver=svd_driver) @@ -283,16 +264,8 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): a_neurons = weighted_sum(a_neurons, b_neurons @ rotation.T, beta) a_neurons @= transform - - res = torch.empty_like(a_neurons_large) - res[:, dims_to_rotate] = a_neurons.to(res.dtype) - res[:, ~dims_to_rotate] = weighted_sum( - a_neurons_large[:, ~dims_to_rotate], - b_neurons_large[:, ~dims_to_rotate], - beta, - ) - res += new_centroid - return res.reshape_as(a) + a_neurons += new_centroid + return a_neurons.reshape_as(a).to(a.dtype) def fractional_matrix_power(matrix: Tensor, power: float): @@ -311,22 +284,3 @@ def sample_ellipsis(a, b, t): dtype=a.dtype, device=a.device, ) - - -def get_irreducible_dimensions(a_neurons, b_neurons, threshold, max_dims): - a_unit = a_neurons / torch.linalg.vector_norm(a_neurons, dim=0, keepdim=True) - b_unit = b_neurons / torch.linalg.vector_norm(b_neurons, dim=0, keepdim=True) - similarities = torch.sum(a_unit * b_unit, dim=0) - - neuron_dims = torch.zeros(a_neurons.shape[1], dtype=torch.bool) - indices_to_keep = torch.nonzero(similarities < threshold) - if len(indices_to_keep) > max_dims: - sorted_indices = sorted( - indices_to_keep, - key=similarities.__getitem__, - ) - indices_to_keep = sorted_indices[:max_dims] - neuron_dims[indices_to_keep] = True - - print(f"\n{a_neurons.shape[1]} -> {len(torch.nonzero(neuron_dims))}, {100 * len(torch.nonzero(neuron_dims)) / a_neurons.shape[1]:.2f}% under {threshold}, max {float(max(similarities[i] for i in indices_to_keep))}") - return neuron_dims From a6742b35608a33f774fd879ac8a00ff0cf015ed4 Mon Sep 17 00:00:00 2001 From: ljleb Date: Mon, 20 Nov 2023 19:11:01 -0500 Subject: [PATCH 33/44] dont edit --- sd_meh/merge.py | 9 +-------- sd_meh/merge_methods.py | 2 +- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/sd_meh/merge.py b/sd_meh/merge.py index 8fdcd96..760391e 100644 --- a/sd_meh/merge.py +++ b/sd_meh/merge.py @@ -419,14 +419,7 @@ def merge_key( else: merged_key = merge_args["b"] else: - try: - merged_key = merge_method(**merge_args).to(device) - except Exception as e: - print(merge_args["a"].shape, e) - raise - finally: - gc.collect() - torch.cuda.empty_cache() + merged_key = merge_method(**merge_args).to(device) if weights_clip: merged_key = clip_weights_key(thetas, merged_key, key) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 8fea6c6..89d4b75 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -269,7 +269,7 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): def fractional_matrix_power(matrix: Tensor, power: float): - eigenvalues, eigenvectors = torch.linalg.eig(matrix.double()) + eigenvalues, eigenvectors = torch.linalg.eig(matrix) eigenvalues.pow_(power) result = eigenvectors @ torch.diag(eigenvalues) @ torch.linalg.inv(eigenvectors) return result.real.to(dtype=matrix.dtype) From d84b7768905d0ed66fb2e5aaf670d810faabe935 Mon Sep 17 00:00:00 2001 From: ljleb Date: Thu, 7 Dec 2023 18:49:37 -0500 Subject: [PATCH 34/44] fix fp16 and fp32 merges --- sd_meh/merge_methods.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 89d4b75..e33617e 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -218,7 +218,7 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): return a is_conv = len(a.shape) == 4 and a.shape[-1] != 1 - if len(a.shape) == 0 or is_conv or torch.allclose(a, b): + if len(a.shape) == 0 or is_conv or torch.allclose(a.half(), b.half()): return weighted_sum(a, b, beta) if len(a.shape) == 4: @@ -232,12 +232,12 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): a_centroid = a_neurons.mean(0) b_centroid = b_neurons.mean(0) new_centroid = sample_ellipsis(a_centroid, b_centroid, 2 * torch.pi * alpha) - a_neurons -= a_centroid - b_neurons -= b_centroid - if len(a.shape) == 1 or len(a.shape) == 2 and a.shape[0] == 1: return new_centroid.reshape_as(a) + a_neurons -= a_centroid + b_neurons -= b_centroid + svd_driver = "gesvd" if a.is_cuda else None u, _, v_t = torch.linalg.svd(a_neurons.T @ b_neurons, driver=svd_driver) From d812ea8fe83eb7c3fe23e3619960b21b046347d7 Mon Sep 17 00:00:00 2001 From: ljleb Date: Fri, 8 Dec 2023 01:13:23 -0500 Subject: [PATCH 35/44] reduced svd --- sd_meh/merge_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index e33617e..e83348d 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -239,7 +239,7 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): b_neurons -= b_centroid svd_driver = "gesvd" if a.is_cuda else None - u, _, v_t = torch.linalg.svd(a_neurons.T @ b_neurons, driver=svd_driver) + u, _, v_t = torch.linalg.svd(a_neurons.T @ b_neurons, full_matrices=False, driver=svd_driver) alpha_is_float = alpha != round(alpha) if alpha_is_float: From 38d4db663e403dc7a4943c3b8dbd2043f37d1707 Mon Sep 17 00:00:00 2001 From: ljleb Date: Fri, 8 Dec 2023 01:14:47 -0500 Subject: [PATCH 36/44] black --- sd_meh/merge_methods.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index e83348d..7b7d962 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -239,7 +239,9 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): b_neurons -= b_centroid svd_driver = "gesvd" if a.is_cuda else None - u, _, v_t = torch.linalg.svd(a_neurons.T @ b_neurons, full_matrices=False, driver=svd_driver) + u, _, v_t = torch.linalg.svd( + a_neurons.T @ b_neurons, full_matrices=False, driver=svd_driver + ) alpha_is_float = alpha != round(alpha) if alpha_is_float: @@ -248,6 +250,11 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): u[:, -1] /= torch.det(u) * torch.det(v_t) transform = rotation = u @ v_t + print("shape:", transform.shape) + det = torch.det(transform) + if torch.abs(det.abs() - 1) > 1e-6: + print("determinant error:", det) + if alpha_is_float: transform = fractional_matrix_power(transform, alpha) elif alpha == 0: @@ -272,6 +279,8 @@ def fractional_matrix_power(matrix: Tensor, power: float): eigenvalues, eigenvectors = torch.linalg.eig(matrix) eigenvalues.pow_(power) result = eigenvectors @ torch.diag(eigenvalues) @ torch.linalg.inv(eigenvectors) + if ((error := result.imag) > 1e-4).any(): + print("image error:", error) return result.real.to(dtype=matrix.dtype) From 3c90395dc5c2302f7c23be442fe584fd5c4e93ea Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 17 Dec 2023 17:30:48 -0500 Subject: [PATCH 37/44] dont ellipsis --- sd_meh/merge_methods.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 7b7d962..981be83 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -231,7 +231,7 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): a_centroid = a_neurons.mean(0) b_centroid = b_neurons.mean(0) - new_centroid = sample_ellipsis(a_centroid, b_centroid, 2 * torch.pi * alpha) + new_centroid = weighted_sum(a_centroid, b_centroid, alpha) if len(a.shape) == 1 or len(a.shape) == 2 and a.shape[0] == 1: return new_centroid.reshape_as(a) @@ -239,15 +239,13 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): b_neurons -= b_centroid svd_driver = "gesvd" if a.is_cuda else None - u, _, v_t = torch.linalg.svd( - a_neurons.T @ b_neurons, full_matrices=False, driver=svd_driver - ) + u, _, v_t = torch.linalg.svd(a_neurons.T @ b_neurons, driver=svd_driver) alpha_is_float = alpha != round(alpha) if alpha_is_float: # cancel reflection. without this, eigenvalues often have a complex component # and then we can't obtain a valid dtype for the merge - u[:, -1] /= torch.det(u) * torch.det(v_t) + u[:, -1] *= torch.nan_to_num(1 / (torch.det(u) * torch.det(v_t))) transform = rotation = u @ v_t print("shape:", transform.shape) @@ -282,14 +280,3 @@ def fractional_matrix_power(matrix: Tensor, power: float): if ((error := result.imag) > 1e-4).any(): print("image error:", error) return result.real.to(dtype=matrix.dtype) - - -def sample_ellipsis(a, b, t): - return torch.column_stack((a, b)) @ torch.tensor( - [ - math.sin(t), - math.cos(t), - ], - dtype=a.dtype, - device=a.device, - ) From f506831822e780d039c921e85975229db9f4a19a Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 17 Dec 2023 20:58:12 -0500 Subject: [PATCH 38/44] print more info for debug --- sd_meh/merge_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 981be83..6961f55 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -248,7 +248,7 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): u[:, -1] *= torch.nan_to_num(1 / (torch.det(u) * torch.det(v_t))) transform = rotation = u @ v_t - print("shape:", transform.shape) + print(f"shape: {a.shape} -> {a_neurons.shape} -> {transform.shape}") det = torch.det(transform) if torch.abs(det.abs() - 1) > 1e-6: print("determinant error:", det) From 1b460568b97bf53f0dc2ef2e5cb3fcbb2b67cedc Mon Sep 17 00:00:00 2001 From: ljleb Date: Mon, 18 Dec 2023 02:54:22 -0500 Subject: [PATCH 39/44] dont merge sdxl kek --- sd_meh/merge_methods.py | 15 ++++++++------- sd_meh/utils.py | 8 +++----- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 6961f55..96f2895 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -1,6 +1,8 @@ import functools import math import operator +import textwrap + import torch from torch import Tensor from typing import Tuple @@ -245,13 +247,14 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): if alpha_is_float: # cancel reflection. without this, eigenvalues often have a complex component # and then we can't obtain a valid dtype for the merge - u[:, -1] *= torch.nan_to_num(1 / (torch.det(u) * torch.det(v_t))) + u[:, -1] /= (torch.det(u) * torch.det(v_t)) transform = rotation = u @ v_t - print(f"shape: {a.shape} -> {a_neurons.shape} -> {transform.shape}") - det = torch.det(transform) - if torch.abs(det.abs() - 1) > 1e-6: - print("determinant error:", det) + if not torch.isfinite(u).all(): + raise ValueError(textwrap.dedent(f"""determinant error: {torch.det(rotation)}. + This can happen when merging on the CPU with the "rotate" method. + Consider merging on a cuda device, or try setting alpha to 1 for the problematic blocks. + See this related discussion for more info: https://github.com/s1dlx/meh/pull/50#discussion_r1429469484""")) if alpha_is_float: transform = fractional_matrix_power(transform, alpha) @@ -277,6 +280,4 @@ def fractional_matrix_power(matrix: Tensor, power: float): eigenvalues, eigenvectors = torch.linalg.eig(matrix) eigenvalues.pow_(power) result = eigenvectors @ torch.diag(eigenvalues) @ torch.linalg.inv(eigenvectors) - if ((error := result.imag) > 1e-4).any(): - print("image error:", error) return result.real.to(dtype=matrix.dtype) diff --git a/sd_meh/utils.py b/sd_meh/utils.py index f507ae8..27e135c 100644 --- a/sd_meh/utils.py +++ b/sd_meh/utils.py @@ -17,12 +17,10 @@ def compute_weights(weights, base): if not weights: return [base] * NUM_TOTAL_BLOCKS - if "," not in weights: - return weights - w_alpha = list(map(float, weights.split(","))) - if len(w_alpha) == NUM_TOTAL_BLOCKS: - return w_alpha + w_alpha[len(w_alpha):NUM_TOTAL_BLOCKS] = [w_alpha[-1]] * max(0, NUM_TOTAL_BLOCKS - len(w_alpha)) + w_alpha[NUM_TOTAL_BLOCKS:] = () + return w_alpha def assemble_weights_and_bases(preset, weights, base, greek_letter): From aeb8c99688ba906824fa26c9bce59a0264f7a7c1 Mon Sep 17 00:00:00 2001 From: ljleb Date: Mon, 18 Dec 2023 14:41:25 -0500 Subject: [PATCH 40/44] black --- sd_meh/merge_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 96f2895..4d20e3a 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -247,7 +247,7 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): if alpha_is_float: # cancel reflection. without this, eigenvalues often have a complex component # and then we can't obtain a valid dtype for the merge - u[:, -1] /= (torch.det(u) * torch.det(v_t)) + u[:, -1] /= torch.det(u) * torch.det(v_t) transform = rotation = u @ v_t if not torch.isfinite(u).all(): From de7110290f17fd93d4410fdb996024e614ada5b6 Mon Sep 17 00:00:00 2001 From: ljleb Date: Mon, 18 Dec 2023 14:48:11 -0500 Subject: [PATCH 41/44] revert utils.py --- sd_meh/merge_methods.py | 8 ++++++-- sd_meh/utils.py | 8 +++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 4d20e3a..48d4821 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -251,10 +251,14 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): transform = rotation = u @ v_t if not torch.isfinite(u).all(): - raise ValueError(textwrap.dedent(f"""determinant error: {torch.det(rotation)}. + raise ValueError( + textwrap.dedent( + f"""determinant error: {torch.det(rotation)}. This can happen when merging on the CPU with the "rotate" method. Consider merging on a cuda device, or try setting alpha to 1 for the problematic blocks. - See this related discussion for more info: https://github.com/s1dlx/meh/pull/50#discussion_r1429469484""")) + See this related discussion for more info: https://github.com/s1dlx/meh/pull/50#discussion_r1429469484""" + ) + ) if alpha_is_float: transform = fractional_matrix_power(transform, alpha) diff --git a/sd_meh/utils.py b/sd_meh/utils.py index 27e135c..f507ae8 100644 --- a/sd_meh/utils.py +++ b/sd_meh/utils.py @@ -17,10 +17,12 @@ def compute_weights(weights, base): if not weights: return [base] * NUM_TOTAL_BLOCKS + if "," not in weights: + return weights + w_alpha = list(map(float, weights.split(","))) - w_alpha[len(w_alpha):NUM_TOTAL_BLOCKS] = [w_alpha[-1]] * max(0, NUM_TOTAL_BLOCKS - len(w_alpha)) - w_alpha[NUM_TOTAL_BLOCKS:] = () - return w_alpha + if len(w_alpha) == NUM_TOTAL_BLOCKS: + return w_alpha def assemble_weights_and_bases(preset, weights, base, greek_letter): From a01e016c65132ca2ef73b620c030975e9d154e07 Mon Sep 17 00:00:00 2001 From: ljleb Date: Sat, 27 Jan 2024 10:39:10 -0500 Subject: [PATCH 42/44] cache impl --- sd_meh/merge.py | 12 ++++++++- sd_meh/merge_methods.py | 55 ++++++++++++++++++++++++++--------------- 2 files changed, 46 insertions(+), 21 deletions(-) diff --git a/sd_meh/merge.py b/sd_meh/merge.py index 760391e..28cc40d 100644 --- a/sd_meh/merge.py +++ b/sd_meh/merge.py @@ -141,6 +141,7 @@ def merge_models( work_device: Optional[str] = None, prune: bool = False, threads: int = 1, + cache: Optional[Dict] = None, ) -> Dict: thetas = load_thetas(models, prune, device, precision) @@ -169,6 +170,7 @@ def merge_models( device=device, work_device=work_device, threads=threads, + cache=cache, ) return un_prune_model(merged, thetas, models, device, prune, precision) @@ -221,6 +223,7 @@ def simple_merge( device: str = "cpu", work_device: Optional[str] = None, threads: int = 1, + cache: Optional[Dict] = None, ) -> Dict: futures = [] with tqdm(thetas["model_a"].keys(), desc="stage 1") as progress: @@ -238,6 +241,7 @@ def simple_merge( weights_clip, device, work_device, + cache, ) futures.append(future) @@ -367,6 +371,7 @@ def merge_key( weights_clip: bool = False, device: str = "cpu", work_device: Optional[str] = None, + cache: Optional[Dict] = None, ) -> Optional[Tuple[str, Dict]]: if work_device is None: work_device = device @@ -410,7 +415,7 @@ def merge_key( except AttributeError as e: raise ValueError(f"{merge_mode} not implemented, aborting merge!") from e - merge_args = get_merge_method_args(current_bases, thetas, key, work_device) + merge_args = get_merge_method_args(current_bases, thetas, key, work_device, cache) # dealing wiht pix2pix and inpainting models if (a_size := merge_args["a"].size()) != (b_size := merge_args["b"].size()): @@ -460,11 +465,16 @@ def get_merge_method_args( thetas: Dict, key: str, work_device: str, + cache: Optional[Dict], ) -> Dict: + if cache is not None and key not in cache: + cache[key] = {} + merge_method_args = { "a": thetas["model_a"][key].to(work_device), "b": thetas["model_b"][key].to(work_device), **current_bases, + "cache": cache[key] } if "model_c" in thetas: diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 48d4821..4a5f201 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -240,28 +240,35 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): a_neurons -= a_centroid b_neurons -= b_centroid - svd_driver = "gesvd" if a.is_cuda else None - u, _, v_t = torch.linalg.svd(a_neurons.T @ b_neurons, driver=svd_driver) - alpha_is_float = alpha != round(alpha) - if alpha_is_float: - # cancel reflection. without this, eigenvalues often have a complex component - # and then we can't obtain a valid dtype for the merge - u[:, -1] /= torch.det(u) * torch.det(v_t) - - transform = rotation = u @ v_t - if not torch.isfinite(u).all(): - raise ValueError( - textwrap.dedent( - f"""determinant error: {torch.det(rotation)}. - This can happen when merging on the CPU with the "rotate" method. - Consider merging on a cuda device, or try setting alpha to 1 for the problematic blocks. - See this related discussion for more info: https://github.com/s1dlx/meh/pull/50#discussion_r1429469484""" + + if kwargs["cache"] is not None and "rotation" in kwargs["cache"]: + rotation = transform = kwargs["cache"]["rotation"].to(a.device) + else: + svd_driver = "gesvd" if a.is_cuda else None + u, _, v_t = torch.linalg.svd(a_neurons.T @ b_neurons, driver=svd_driver) + + if alpha_is_float: + # cancel reflection. without this, eigenvalues often have a complex component + # and then we can't obtain a valid dtype for the merge + u[:, -1] /= torch.det(u) * torch.det(v_t) + + rotation = transform = u @ v_t + if not torch.isfinite(u).all(): + raise ValueError( + textwrap.dedent( + f"""determinant error: {torch.det(rotation)}. + This can happen when merging on the CPU with the "rotate" method. + Consider merging on a cuda device, or try setting alpha to 1 for the problematic blocks. + See this related discussion for more info: https://github.com/s1dlx/meh/pull/50#discussion_r1429469484""" + ) ) - ) + + if kwargs["cache"] is not None: + kwargs["cache"]["rotation"] = rotation.cpu() if alpha_is_float: - transform = fractional_matrix_power(transform, alpha) + transform = fractional_matrix_power(transform, alpha, kwargs["cache"]) elif alpha == 0: transform = torch.eye( len(transform), @@ -280,8 +287,16 @@ def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs): return a_neurons.reshape_as(a).to(a.dtype) -def fractional_matrix_power(matrix: Tensor, power: float): - eigenvalues, eigenvectors = torch.linalg.eig(matrix) +def fractional_matrix_power(matrix: Tensor, power: float, cache: dict): + if cache is not None and "eigenvalues" in cache: + eigenvalues = cache["eigenvalues"].to(matrix.device) + eigenvectors = cache["eigenvectors"].to(matrix.device) + else: + eigenvalues, eigenvectors = torch.linalg.eig(matrix) + if cache is not None: + cache["eigenvalues"] = eigenvalues.cpu() + cache["eigenvectors"] = eigenvectors.cpu() + eigenvalues.pow_(power) result = eigenvectors @ torch.diag(eigenvalues) @ torch.linalg.inv(eigenvectors) return result.real.to(dtype=matrix.dtype) From 003017e8b0a04623e9de0e73e5aa83adead99917 Mon Sep 17 00:00:00 2001 From: ljleb Date: Sat, 27 Jan 2024 11:11:52 -0500 Subject: [PATCH 43/44] cache eigen inv --- sd_meh/merge_methods.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sd_meh/merge_methods.py b/sd_meh/merge_methods.py index 4a5f201..32cab22 100644 --- a/sd_meh/merge_methods.py +++ b/sd_meh/merge_methods.py @@ -291,12 +291,15 @@ def fractional_matrix_power(matrix: Tensor, power: float, cache: dict): if cache is not None and "eigenvalues" in cache: eigenvalues = cache["eigenvalues"].to(matrix.device) eigenvectors = cache["eigenvectors"].to(matrix.device) + eigenvectors_inv = cache["eigenvectors_inv"].to(matrix.device) else: eigenvalues, eigenvectors = torch.linalg.eig(matrix) + eigenvectors_inv = torch.linalg.inv(eigenvectors) if cache is not None: cache["eigenvalues"] = eigenvalues.cpu() cache["eigenvectors"] = eigenvectors.cpu() + cache["eigenvectors_inv"] = eigenvectors_inv.cpu() eigenvalues.pow_(power) - result = eigenvectors @ torch.diag(eigenvalues) @ torch.linalg.inv(eigenvectors) + result = eigenvectors @ torch.diag(eigenvalues) @ eigenvectors_inv return result.real.to(dtype=matrix.dtype) From 81515bda3cba52edd263964c5517f4713faad86e Mon Sep 17 00:00:00 2001 From: ljleb Date: Sun, 28 Jan 2024 20:12:52 -0500 Subject: [PATCH 44/44] Update merge.py --- sd_meh/merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sd_meh/merge.py b/sd_meh/merge.py index 28cc40d..38d3989 100644 --- a/sd_meh/merge.py +++ b/sd_meh/merge.py @@ -474,7 +474,7 @@ def get_merge_method_args( "a": thetas["model_a"][key].to(work_device), "b": thetas["model_b"][key].to(work_device), **current_bases, - "cache": cache[key] + "cache": cache[key] if cache is not None else None, } if "model_c" in thetas: