-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmake_stats.py
More file actions
41 lines (38 loc) · 1.6 KB
/
make_stats.py
File metadata and controls
41 lines (38 loc) · 1.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import argparse
import os
import torch
import torch.backends.cudnn as cudnn
import datasets
from torchvision import transforms
from config import cfg, process_args
from data import fetch_dataset, make_data_loader
from utils import save, process_control, process_dataset, collate, Stats, makedir_exist_ok
cudnn.benchmark = True
parser = argparse.ArgumentParser(description='cfg')
for k in cfg:
exec('parser.add_argument(\'--{0}\', default=cfg[\'{0}\'], type=type(cfg[\'{0}\']))'.format(k))
parser.add_argument('--control_name', default=None, type=str)
args = vars(parser.parse_args())
process_args(args)
stats_path = os.path.join('res', 'stats')
dim = 1
if __name__ == "__main__":
makedir_exist_ok(stats_path)
process_control()
cfg['seed'] = 0
data_names = ['MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'SVHN']
with torch.no_grad():
for data_name in data_names:
cfg['data_name'] = data_name
root = os.path.join('data', cfg['data_name'])
dataset = fetch_dataset(cfg['data_name'])
process_dataset(dataset)
dataset['train'].transform = datasets.Compose([transforms.ToTensor()])
data_loader = make_data_loader(dataset, 'global')
stats = Stats(dim=dim)
for i, input in enumerate(data_loader['train']):
input = collate(input)
stats.update(input['data'])
stats = (stats.mean.tolist(), stats.std.tolist())
print(cfg['data_name'], stats)
save(stats, os.path.join(stats_path, '{}.pt'.format(cfg['data_name'])))