-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathsample_and_save.py
More file actions
109 lines (91 loc) · 3.78 KB
/
sample_and_save.py
File metadata and controls
109 lines (91 loc) · 3.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import warnings
from argparse import ArgumentParser
from pathlib import Path
import torch
from accelerate import Accelerator
from rich import print
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm
import utils.inference
warnings.filterwarnings("ignore", category=UserWarning)
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.automatic_dynamic_shapes = False
def sample(args):
torch.set_grad_enabled(False)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
ddpm, lidar_utils, cfg = utils.inference.setup_model(args.ckpt)
accelerator = Accelerator(
mixed_precision=cfg.training.mixed_precision,
dynamo_backend=cfg.training.dynamo_backend,
split_batches=True,
even_batches=False,
step_scheduler_with_optimizer=True,
)
device = accelerator.device
if accelerator.is_main_process:
print(cfg)
dataloader = DataLoader(
TensorDataset(torch.arange(args.num_samples).long()),
batch_size=args.batch_size,
num_workers=cfg.training.num_workers,
drop_last=False,
)
ddpm.to(device)
sample_fn = torch.compile(ddpm.sample)
lidar_utils, dataloader = accelerator.prepare(lidar_utils, dataloader)
save_dir = Path(args.output_dir)
with accelerator.main_process_first():
save_dir.mkdir(parents=True, exist_ok=True)
def postprocess(sample):
if not isinstance(sample, dict): # original unconditional mode
sample = sample.clamp(-1, 1)
sample = lidar_utils.denormalize(sample)
depth, rflct = sample[:, [0]], sample[:, [1]]
depth = lidar_utils.revert_depth(depth)
xyz = lidar_utils.to_xyz(depth)
return torch.cat([depth, xyz, rflct], dim=1)
else: # conditional mode
sample_dict = sample
sample = sample_dict['sample'].clamp(-1, 1)
seg_pred = sample_dict['seg_pred'].unsqueeze(1).to(sample.device)
seg_pred_prob = sample_dict['seg_pred_prob'].unsqueeze(1).to(sample.device)
assert (seg_pred < 20).all().item() and (seg_pred >= 0).all().item()
assert (seg_pred_prob <= 1).all().item() and (seg_pred_prob >= 0).all().item()
sample = lidar_utils.denormalize(sample)
depth, rflct = sample[:, [0]], sample[:, [1]]
depth = lidar_utils.revert_depth(depth)
xyz = lidar_utils.to_xyz(depth)
return torch.cat([depth, xyz, rflct, seg_pred_prob, seg_pred], dim=1) # [B, 7, 64, 1024]
for seeds in tqdm(
dataloader,
desc="saving...",
dynamic_ncols=True,
disable=not accelerator.is_local_main_process,
):
if seeds is None:
break
else:
(seeds,) = seeds
with torch.cuda.amp.autocast(cfg.training.mixed_precision is not None):
samples = sample_fn(
batch_size=len(seeds),
num_steps=args.num_steps,
mode=args.mode,
rng=utils.inference.setup_rng(seeds.cpu().tolist(), device=device),
progress=False,
)
samples = postprocess(samples)
for i in range(len(samples)):
sample = samples[i]
torch.save(sample.clone(), save_dir / f"samples_{seeds[i]:010d}.pth")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--ckpt", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_samples", type=int, default=10_000)
parser.add_argument("--num_steps", type=int, default=256)
parser.add_argument("--mode", choices=["ddpm", "ddim"], default="ddpm")
args = parser.parse_args()
sample(args)