Prune stable diffusion checkpoints.
usage: prune.py [-h] [-p] [-e] [-c] [-a] [-d] [-u] input output
Prune a stable diffusion checkpoint
positional arguments:
input input checkpoint
output output checkpoint
optional arguments:
-h, --help show this help message and exit
-p, --fp16 convert to float16
-e, --ema use EMA for weights
-c, --no-clip strip CLIP weights
-a, --no-vae strip VAE weights
-d, --no-depth strip depth model weights
-u, --no-unet strip UNet weights
Convert to torch.float16, use ema weights and remove CLIP model weights.
python3 prune.py -pec sd-v1-4-full-ema.ckpt pruned.ckptKeep precision the same and use ema weights.
python3 prune.py -e sd-v1-4-full-ema.ckpt pruned.ckptConvert to torch.float16, remove VAE and CLIP model weights.
python3 prune.py -pca sd-v1-4-full-ema.ckpt pruned.ckptnumpy
torch!=1.13.0
Note that torch==1.13.0 has a bug in the torch.load function that forces you to install pytorch_lightning if you want to load stable diffusion checkpoints that include pytorch_lightning callbacks. (pytorch/pytorch#88438)
It should be fixed in the next release of torch 1.13.1, 1.14.0-dev or 2.0.0.
torch