-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patheval.py
More file actions
154 lines (125 loc) · 6.84 KB
/
eval.py
File metadata and controls
154 lines (125 loc) · 6.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# ---------------------------------------------------------------------------------------------------------------------
# Core training code for Astro-DSB for astrophysical observational inversion
# ---------------------------------------------------------------------------------------------------------------------
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import sys
import random
import argparse
import copy
from pathlib import Path
from datetime import datetime
import numpy as np
import torch
from torch.multiprocessing import Process
from logger import Logger
from distributed_util import init_processes
from corruption import build_corruption
import datasets
from astrodsb import Runner, download_ckpt
import colored_traceback.always
from ipdb import set_trace as debug
RESULT_DIR = Path("results")
def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
def create_training_options():
# --------------- basic ---------------
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--name", type=str, default=None, help="experiment ID")
parser.add_argument("--ckpt", type=str, default=None, help="resumed checkpoint name")
parser.add_argument("--gpu", type=int, default=None, help="set only if you wish to run on a particular device")
parser.add_argument("--n-gpu-per-node", type=int, default=1, help="number of gpu on each node")
parser.add_argument("--master-address", type=str, default='localhost', help="address for master")
parser.add_argument("--node-rank", type=int, default=0, help="the index of node")
parser.add_argument("--num-proc-node", type=int, default=1, help="The number of nodes in multi node env")
# parser.add_argument("--amp", action="store_true")
# --------------- DSB model ---------------
parser.add_argument("--image-size", type=int, default=128)
parser.add_argument("--t0", type=float, default=1e-4, help="sigma start time in network parametrization")
parser.add_argument("--T", type=float, default=1., help="sigma end time in network parametrization")
parser.add_argument("--interval", type=int, default=1000, help="number of interval")
parser.add_argument("--beta-max", type=float, default=0.3, help="max diffusion for the diffusion model")
# parser.add_argument("--beta-min", type=float, default=0.1)
parser.add_argument("--ot-ode", action="store_true", help="use OT-ODE model")
# configs for M2 conditional enhancement
parser.add_argument("--cond-x1", action="store_true", help="conditional the network with enhanced observation")
# --------------- optimizer and loss ---------------
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--microbatch", type=int, default=2, help="accumulate gradient over microbatch until full batch-size")
parser.add_argument("--num-itr", type=int, default=35000, help="training iteration")
parser.add_argument("--lr", type=float, default=5e-5, help="learning rate")
parser.add_argument("--lr-gamma", type=float, default=0.99, help="learning rate decay ratio")
parser.add_argument("--lr-step", type=int, default=1000, help="learning rate decay step size")
parser.add_argument("--l2-norm", type=float, default=0.0)
parser.add_argument("--ema", type=float, default=0.99)
# --------------- path and logging ---------------
parser.add_argument("--dataset-dir", type=Path, default="/dataset", help="path to preprocessed dataset")
parser.add_argument("--log-dir", type=Path, default=".log", help="path to log std outputs and writer data")
parser.add_argument("--log-writer", type=str, default=None, help="log writer: can be tensorbard, wandb, or None")
opt = parser.parse_args()
# ========= auto setup =========
opt.device='cuda' if opt.gpu is None else f'cuda:{opt.gpu}'
if opt.name is None:
opt.name = opt.corrupt
opt.distributed = opt.n_gpu_per_node > 1
opt.use_fp16 = False # disable fp16 for training
# log ngc meta data
if "NGC_JOB_ID" in os.environ.keys():
opt.ngc_job_id = os.environ["NGC_JOB_ID"]
# ========= path handle =========
os.makedirs(opt.log_dir, exist_ok=True)
opt.ckpt_path = RESULT_DIR / opt.name
os.makedirs(opt.ckpt_path, exist_ok=True)
if opt.ckpt is not None:
ckpt_file = RESULT_DIR / opt.ckpt / "latest.pt"
assert ckpt_file.exists()
opt.load = ckpt_file
else:
opt.load = None
# ========= auto assert =========
assert opt.batch_size % opt.microbatch == 0, f"{opt.batch_size=} is not dividable by {opt.microbatch}!"
return opt
def main(opt):
log = Logger(opt.global_rank, opt.log_dir)
log.info("===================================================================================")
log.info(" Diffusion Schrodinger Bridge Solver for Astrophysical Observational Inversion ")
log.info("===================================================================================")
log.info("Command used:\n{}".format(" ".join(sys.argv)))
log.info(f"Experiment ID: {opt.name}")
# set seed: make sure each gpu has differnet seed!
if opt.seed is not None:
set_seed(opt.seed + opt.global_rank)
DATASET = datasets.AllData(opt)
train_dataset, val_dataset = DATASET.get_loaders()
run.eval(opt, val_dataset)
log.info("Finish!")
if __name__ == '__main__':
opt = create_training_options()
if opt.distributed:
size = opt.n_gpu_per_node
processes = []
for rank in range(size):
opt = copy.deepcopy(opt)
opt.local_rank = rank
global_rank = rank + opt.node_rank * opt.n_gpu_per_node
global_size = opt.num_proc_node * opt.n_gpu_per_node
opt.global_rank = global_rank
opt.global_size = global_size
print('Node rank %d, local proc %d, global proc %d, global_size %d' % (opt.node_rank, rank, global_rank, global_size))
p = Process(target=init_processes, args=(global_rank, global_size, main, opt))
p.start()
processes.append(p)
for p in processes:
p.join()
else:
torch.cuda.set_device(0)
opt.global_rank = 0
opt.local_rank = 0
opt.global_size = 1
init_processes(0, opt.n_gpu_per_node, main, opt)