Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
18bff99
oft
ljleb Nov 10, 2023
47bdac9
device
ljleb Nov 10, 2023
0681458
rename
ljleb Nov 10, 2023
1fe0882
fix black
ljleb Nov 10, 2023
2810d89
cayley interpolation for alpha
ljleb Nov 10, 2023
09bed88
refact
ljleb Nov 10, 2023
f11c054
add method to __all__
ljleb Nov 10, 2023
1f497e9
include 1D 'rotation'
ljleb Nov 10, 2023
36fccaa
ignore alpha for now
ljleb Nov 10, 2023
f18208d
refact
ljleb Nov 10, 2023
1dafe83
implement fractional rotations
ljleb Nov 11, 2023
149ab16
fix transform direction
ljleb Nov 11, 2023
1f71391
fix eye
ljleb Nov 11, 2023
b464fd3
rewrite with out=
ljleb Nov 11, 2023
e1dc59c
it works; opt now
ljleb Nov 12, 2023
cbb6a06
optimize: 45m -> 7m
ljleb Nov 12, 2023
ce62946
rm print
ljleb Nov 12, 2023
8172927
fix precision issues
ljleb Nov 12, 2023
19fcc0a
fix precision issues
ljleb Nov 12, 2023
f954270
black
ljleb Nov 12, 2023
e94e252
dont change
ljleb Nov 12, 2023
1f380c8
imps
ljleb Nov 12, 2023
ea95b66
beta is deformation
ljleb Nov 12, 2023
1751f59
simplify
ljleb Nov 12, 2023
c69bb95
@
ljleb Nov 12, 2023
0d5160b
backup
ljleb Nov 13, 2023
1920496
deal with conv attention shape, rotate centroids
ljleb Nov 13, 2023
5a1c776
black
ljleb Nov 17, 2023
f61d6aa
wip
ljleb Nov 18, 2023
6ac82d6
refact
ljleb Nov 18, 2023
7fff089
backup
ljleb Nov 20, 2023
6ddc503
remove approx
ljleb Nov 21, 2023
a6742b3
dont edit
ljleb Nov 21, 2023
d84b776
fix fp16 and fp32 merges
ljleb Dec 7, 2023
d812ea8
reduced svd
ljleb Dec 8, 2023
38d4db6
black
ljleb Dec 8, 2023
3c90395
dont ellipsis
ljleb Dec 17, 2023
f506831
print more info for debug
ljleb Dec 18, 2023
1b46056
dont merge sdxl kek
ljleb Dec 18, 2023
aeb8c99
black
ljleb Dec 18, 2023
de71102
revert utils.py
ljleb Dec 18, 2023
a01e016
cache impl
ljleb Jan 27, 2024
003017e
cache eigen inv
ljleb Jan 27, 2024
81515bd
Update merge.py
ljleb Jan 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion sd_meh/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def merge_models(
work_device: Optional[str] = None,
prune: bool = False,
threads: int = 1,
cache: Optional[Dict] = None,
) -> Dict:
thetas = load_thetas(models, prune, device, precision)

Expand Down Expand Up @@ -169,6 +170,7 @@ def merge_models(
device=device,
work_device=work_device,
threads=threads,
cache=cache,
)

return un_prune_model(merged, thetas, models, device, prune, precision)
Expand Down Expand Up @@ -221,6 +223,7 @@ def simple_merge(
device: str = "cpu",
work_device: Optional[str] = None,
threads: int = 1,
cache: Optional[Dict] = None,
) -> Dict:
futures = []
with tqdm(thetas["model_a"].keys(), desc="stage 1") as progress:
Expand All @@ -238,6 +241,7 @@ def simple_merge(
weights_clip,
device,
work_device,
cache,
)
futures.append(future)

Expand Down Expand Up @@ -367,6 +371,7 @@ def merge_key(
weights_clip: bool = False,
device: str = "cpu",
work_device: Optional[str] = None,
cache: Optional[Dict] = None,
) -> Optional[Tuple[str, Dict]]:
if work_device is None:
work_device = device
Expand Down Expand Up @@ -410,7 +415,7 @@ def merge_key(
except AttributeError as e:
raise ValueError(f"{merge_mode} not implemented, aborting merge!") from e

merge_args = get_merge_method_args(current_bases, thetas, key, work_device)
merge_args = get_merge_method_args(current_bases, thetas, key, work_device, cache)

# dealing wiht pix2pix and inpainting models
if (a_size := merge_args["a"].size()) != (b_size := merge_args["b"].size()):
Expand Down Expand Up @@ -460,11 +465,16 @@ def get_merge_method_args(
thetas: Dict,
key: str,
work_device: str,
cache: Optional[Dict],
) -> Dict:
if cache is not None and key not in cache:
cache[key] = {}

merge_method_args = {
"a": thetas["model_a"][key].to(work_device),
"b": thetas["model_b"][key].to(work_device),
**current_bases,
"cache": cache[key] if cache is not None else None,
}

if "model_c" in thetas:
Expand Down
96 changes: 95 additions & 1 deletion sd_meh/merge_methods.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import functools
import math
from typing import Tuple
import operator
import textwrap

import torch
from torch import Tensor
from typing import Tuple

__all__ = [
"weighted_sum",
Expand All @@ -17,6 +20,7 @@
"similarity_add_difference",
"distribution_crossover",
"ties_add_difference",
"rotate",
]


Expand Down Expand Up @@ -209,3 +213,93 @@ def filter_top_k(a: Tensor, k: float):
k_value, _ = torch.kthvalue(torch.abs(a.flatten()).float(), k)
top_k_filter = (torch.abs(a) >= k_value).float()
return a * top_k_filter


def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs):
if alpha == 0 and beta == 0:
return a

is_conv = len(a.shape) == 4 and a.shape[-1] != 1
if len(a.shape) == 0 or is_conv or torch.allclose(a.half(), b.half()):
return weighted_sum(a, b, beta)

if len(a.shape) == 4:
shape_2d = (-1, functools.reduce(operator.mul, a.shape[1:]))
else:
shape_2d = (-1, a.shape[-1])

a_neurons = a.reshape(*shape_2d).double()
b_neurons = b.reshape(*shape_2d).double()

a_centroid = a_neurons.mean(0)
b_centroid = b_neurons.mean(0)
new_centroid = weighted_sum(a_centroid, b_centroid, alpha)
if len(a.shape) == 1 or len(a.shape) == 2 and a.shape[0] == 1:
return new_centroid.reshape_as(a)

a_neurons -= a_centroid
b_neurons -= b_centroid

alpha_is_float = alpha != round(alpha)

if kwargs["cache"] is not None and "rotation" in kwargs["cache"]:
rotation = transform = kwargs["cache"]["rotation"].to(a.device)
else:
svd_driver = "gesvd" if a.is_cuda else None
u, _, v_t = torch.linalg.svd(a_neurons.T @ b_neurons, driver=svd_driver)

if alpha_is_float:
# cancel reflection. without this, eigenvalues often have a complex component
# and then we can't obtain a valid dtype for the merge
u[:, -1] /= torch.det(u) * torch.det(v_t)

rotation = transform = u @ v_t
if not torch.isfinite(u).all():
raise ValueError(
textwrap.dedent(
f"""determinant error: {torch.det(rotation)}.
This can happen when merging on the CPU with the "rotate" method.
Consider merging on a cuda device, or try setting alpha to 1 for the problematic blocks.
See this related discussion for more info: https://github.com/s1dlx/meh/pull/50#discussion_r1429469484"""
)
)

if kwargs["cache"] is not None:
kwargs["cache"]["rotation"] = rotation.cpu()

if alpha_is_float:
transform = fractional_matrix_power(transform, alpha, kwargs["cache"])
elif alpha == 0:
transform = torch.eye(
len(transform),
dtype=transform.dtype,
device=transform.device,
)
elif alpha != 1:
transform = torch.linalg.matrix_power(transform, round(alpha))

if beta != 0:
# interpolate the relationship between the neurons
a_neurons = weighted_sum(a_neurons, b_neurons @ rotation.T, beta)

a_neurons @= transform
a_neurons += new_centroid
return a_neurons.reshape_as(a).to(a.dtype)


def fractional_matrix_power(matrix: Tensor, power: float, cache: dict):
if cache is not None and "eigenvalues" in cache:
eigenvalues = cache["eigenvalues"].to(matrix.device)
eigenvectors = cache["eigenvectors"].to(matrix.device)
eigenvectors_inv = cache["eigenvectors_inv"].to(matrix.device)
else:
eigenvalues, eigenvectors = torch.linalg.eig(matrix)
eigenvectors_inv = torch.linalg.inv(eigenvectors)
if cache is not None:
cache["eigenvalues"] = eigenvalues.cpu()
cache["eigenvectors"] = eigenvectors.cpu()
cache["eigenvectors_inv"] = eigenvectors_inv.cpu()

eigenvalues.pow_(power)
result = eigenvectors @ torch.diag(eigenvalues) @ eigenvectors_inv
return result.real.to(dtype=matrix.dtype)
6 changes: 2 additions & 4 deletions sd_meh/rebasin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2200,11 +2200,9 @@ def apply_permutation(ps: PermutationSpec, perm, params):
def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha):
for k in model_a:
try:
perm_params = get_permuted_param(
ps, perm, k, model_a
)
perm_params = get_permuted_param(ps, perm, k, model_a)
model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params
except RuntimeError: # dealing with pix2pix and inpainting models
except RuntimeError: # dealing with pix2pix and inpainting models
continue
return model_a

Expand Down