-
Notifications
You must be signed in to change notification settings - Fork 88
Expand file tree
/
Copy pathbenchmark.py
More file actions
260 lines (221 loc) · 8.7 KB
/
benchmark.py
File metadata and controls
260 lines (221 loc) · 8.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import os
import subprocess
import multiprocessing
import argparse
import pathlib
import csv
from contextlib import nullcontext
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline, StableDiffusionOnnxPipeline
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
prompt = "a photo of an astronaut riding a horse on mars"
def get_inference_pipeline(precision, backend):
"""
returns HuggingFace diffuser pipeline
cf https://github.com/huggingface/diffusers#text-to-image-generation-with-stable-diffusion
note: could not download from CompVis/stable-diffusion-v1-4 (access restricted)
"""
assert precision in ("half", "single"), "precision in ['half', 'single']"
assert backend in ("pytorch", "onnx"), "backend in ['pytorch', 'onnx']"
if backend == "pytorch":
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="main" if precision == "single" else "fp16",
use_auth_token=os.environ["ACCESS_TOKEN"],
torch_dtype=torch.float32 if precision == "single" else torch.float16,
).to(device)
else:
pipe = StableDiffusionOnnxPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=os.environ["ACCESS_TOKEN"],
revision="onnx",
provider="CPUExecutionProvider"
if device.type == "cpu"
else "CUDAExecutionProvider",
torch_dtype=torch.float32 if precision == "single" else torch.float16,
)
# Disable safety
def null_safety(images, **kwargs):
return images, False
pipe.safety_checker = null_safety
return pipe
def do_inference(pipe, n_samples, use_autocast, num_inference_steps):
torch.cuda.empty_cache()
context = autocast if (device.type == "cuda" and use_autocast) else nullcontext
with context("cuda"):
images = pipe(
prompt=[prompt] * n_samples, num_inference_steps=num_inference_steps
).images
return images
def get_inference_time(pipe, n_samples, n_repeats, use_autocast, num_inference_steps):
from torch.utils.benchmark import Timer
timer = Timer(
stmt="do_inference(pipe, n_samples, use_autocast, num_inference_steps)",
setup="from __main__ import do_inference",
globals={
"pipe": pipe,
"n_samples": n_samples,
"use_autocast": use_autocast,
"num_inference_steps": num_inference_steps,
},
num_threads=multiprocessing.cpu_count(),
)
profile_result = timer.timeit(
n_repeats
) # benchmark.Timer performs 2 iterations for warmup
return round(profile_result.mean, 2)
def get_inference_memory(pipe, n_samples, use_autocast, num_inference_steps):
if not torch.cuda.is_available():
return 0
torch.cuda.empty_cache()
context = autocast if (device.type == "cuda" and use_autocast) else nullcontext
with context("cuda"):
images = pipe(
prompt=[prompt] * n_samples, num_inference_steps=num_inference_steps
).images
mem = torch.cuda.memory_reserved()
return round(mem / 1e9, 2)
def run_benchmark(
n_repeats, n_samples, precision, use_autocast, backend, num_inference_steps
):
"""
* n_repeats: nb datapoints for inference latency benchmark
* n_samples: number of samples to generate (~ batch size)
* precision: 'half' or 'single' (use fp16 or fp32 tensors)
returns:
dict like {'memory usage': 17.70, 'latency': 86.71'}
"""
pipe = get_inference_pipeline(precision, backend)
logs = {
"memory": 0.00
if device.type == "cpu"
else get_inference_memory(pipe, n_samples, use_autocast, num_inference_steps),
"latency": get_inference_time(
pipe, n_samples, n_repeats, use_autocast, num_inference_steps
),
}
print(
f"n_samples: {n_samples}\tprecision: {precision}\tautocast: {use_autocast}\tbackend: {backend}"
)
print(logs, "\n")
return logs
def get_device_description():
"""
returns descriptor of cuda device such as
'NVIDIA RTX A6000'
"""
if device.type == "cpu":
name = subprocess.check_output(
"grep -m 1 'model name' /proc/cpuinfo", shell=True
).decode("utf-8")
name = " ".join(name.split(" ")[2:]).strip()
return name
else:
return torch.cuda.get_device_name()
def run_benchmark_grid(grid, n_repeats, num_inference_steps):
"""
* grid : dict like
{
"n_samples": (1, 2),
"precision": ("single", "half"),
"autocast" : ("yes", "no")
}
* n_repeats: nb datapoints for inference latency benchmark
"""
csv_fpath = pathlib.Path(__file__).parent.parent / "benchmark_tmp.csv"
# create benchmark.csv if not exists
if not os.path.isfile(csv_fpath):
header = [
"device",
"precision",
"autocast",
"runtime",
"n_samples",
"latency",
"memory",
]
with open(csv_fpath, "w") as f:
writer = csv.writer(f)
writer.writerow(header)
# append new benchmark results to it if benchmark_tmp.csv already exists
with open(csv_fpath, "a") as f:
writer = csv.writer(f)
device_desc = get_device_description()
for n_samples in grid["n_samples"]:
for precision in grid["precision"]:
# restrict enabling autocast to half precision
if precision == "single":
use_autocast_vals = ("no",)
else:
use_autocast_vals = grid["autocast"]
for use_autocast_val in use_autocast_vals:
use_autocast = use_autocast_val == "yes"
for backend in grid["backend"]:
try:
new_log = run_benchmark(
n_repeats=n_repeats,
n_samples=n_samples,
precision=precision,
use_autocast=use_autocast,
backend=backend,
num_inference_steps=num_inference_steps,
)
except Exception as e:
if "CUDA out of memory" in str(
e
) or "Failed to allocate memory" in str(e):
print(str(e))
torch.cuda.empty_cache()
new_log = {"latency": -1.00, "memory": -1.00}
else:
raise e
latency = new_log["latency"]
memory = new_log["memory"]
new_row = [
device_desc,
precision,
use_autocast,
backend,
n_samples,
latency,
memory,
]
writer.writerow(new_row)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--samples",
default="1",
type=str,
help="Comma sepearated list of batch sizes (number of samples)",
)
parser.add_argument(
"--steps", default=50, type=int, help="Number of diffusion steps."
)
parser.add_argument(
"--repeats",
default=3,
type=int,
help="Number of repeats.",
)
parser.add_argument(
"--autocast",
default="no",
type=str,
help="If 'yes', will perform additional runs with autocast activated for half precision inferences",
)
args = parser.parse_args()
grid = {
"n_samples": tuple(map(int, args.samples.split(","))),
# Only use single-precision for cpu because "LayerNormKernelImpl" not implemented for 'Half' on cpu,
# Remove autocast won't help. Ref:
# https://github.com/CompVis/stable-diffusion/issues/307
"precision": ("single",) if device.type == "cpu" else ("single", "half"),
"autocast": ("no",) if args.autocast == "no" else ("yes", "no"),
# Only use onnx for cpu, until issues are fixed by upstreams. Ref:
# https://github.com/huggingface/diffusers/issues/489#issuecomment-1261577250
# https://github.com/huggingface/diffusers/pull/440
"backend": ("pytorch", "onnx") if device.type == "cpu" else ("pytorch",),
}
run_benchmark_grid(grid, n_repeats=args.repeats, num_inference_steps=args.steps)