forked from centuryglass/IntraPaint
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgenerate.py
More file actions
81 lines (71 loc) · 2.54 KB
/
generate.py
File metadata and controls
81 lines (71 loc) · 2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Simplified script for glid-3-xl for image generation only, no inpainting functionality.
import argparse
import sys
import gc
import os
from PIL import Image
import torch
from torchvision.transforms import functional as TF
import numpy as np
from startup.load_models import loadModels
from startup.create_sample_function import createSampleFunction
from startup.generate_samples import generateSamples
from startup.utils import *
from startup.ml_utils import *
# argument parsing:
parser = buildArgParser(defaultModel='finetune.pt', includeEditParams=False)
args = parser.parse_args()
if args.model_path == 'inpaint.pt':
print("Error: generate.py does not support inpainting. Use one of the following:")
print("\tquickEdit.py: To perform quick inpainting operations with a minimal UI.")
print("\tinpainting_ui.py: To use the inpainting UI, running both UI and generation on the same machine.")
print("\tinpainting_server.py: To run inpainting operations for a remote UI client")
print("\tautoedit.py: To run experimental automated random inpainting operations")
sys.exit()
device = getDevice(args.cpu)
if args.seed >= 0:
torch.manual_seed(args.seed)
model_params, model, diffusion, ldm, bert, clip_model, clip_preprocess, normalize = loadModels(device,
model_path=args.model_path,
bert_path=args.bert_path,
kl_path=args.kl_path,
steps = args.steps,
clip_guidance = args.clip_guidance,
cpu = args.cpu,
ddpm = args.ddpm,
ddim = args.ddim)
sample_fn, clip_score_fn = createSampleFunction(
device,
model,
model_params,
bert,
clip_model,
clip_preprocess,
ldm,
diffusion,
normalize,
prompt=args.text,
negative=args.negative,
image=args.init_image,
guidance_scale=args.guidance_scale,
batch_size=args.batch_size,
width=args.width,
height=args.height,
cutn=args.cutn,
clip_guidance=args.clip_guidance,
clip_guidance_scale=args.clip_guidance_scale,
skip_timesteps=args.skip_timesteps,
ddpm=args.ddpm,
ddim=args.ddim)
gc.collect()
generateSamples(device,
ldm,
diffusion,
sample_fn,
getSaveFn(args.prefix, args.batch_size, ldm, clip_model, clip_preprocess, device),
args.batch_size,
args.num_batches,
width=args.width,
height=args.height,
init_image=args.init_image,
clip_score_fn=clip_score_fn if args.clip_score else None)