-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathclean_loader.py
More file actions
executable file
·84 lines (70 loc) · 3.51 KB
/
clean_loader.py
File metadata and controls
executable file
·84 lines (70 loc) · 3.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/env python
# torch package
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
def build_cleanset(args):
# Setting Dataset Required Parameters
transforms_list = []
transforms_list_test = []
if args.dataset == "cifar10":
args.num_classes = 10
args.img_size = 32
args.channel = 3
args.mean = [0.4914, 0.4822, 0.4465]
args.std = [0.2023, 0.1994, 0.2010]
elif args.dataset == "gtsrb":
args.num_classes = 43
args.img_size = 32
args.channel = 3
args.mean = None
args.std = None
transforms_list_test.append(transforms.Resize(32))
transforms_list_test.append(transforms.CenterCrop(32))
elif args.dataset == "tiny-imagenet":
args.num_classes = 200
args.img_size = 64
args.channel = 3
args.mean = [0.485, 0.456, 0.406]
args.std = [0.229, 0.224, 0.225]
elif args.dataset == "imagenet200":
args.num_classes = 200
args.img_size = 224
args.channel = 3
args.mean = [0.4802, 0.4481, 0.3975]
args.std = [0.2302, 0.2265, 0.2262]
transforms_list_test.append(transforms.Resize(256))
transforms_list_test.append(transforms.CenterCrop(224))
if args.dataset == "imagenet200":
transforms_list.append(transforms.RandomResizedCrop(args.img_size))
else:
transforms_list.append(transforms.RandomCrop(args.img_size, padding=4))
transforms_list.append(transforms.RandomHorizontalFlip())
transforms_list.append(transforms.ToTensor())
transforms_list_test.append(transforms.ToTensor())
if args.mean is not None and args.std is not None:
transforms_list.append(transforms.Normalize(args.mean, args.std))
transforms_list_test.append(transforms.Normalize(args.mean, args.std))
transform_train = transforms.Compose(transforms_list)
transform_test = transforms.Compose(transforms_list_test)
# Full Trainloader/Testloader
dataset_train = dataset(args, True, transform_train)
dataset_test = dataset(args, False, transform_test)
# trainloader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, num_workers=8, shuffle=True, pin_memory=True)
# testloader = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, num_workers=8, shuffle=False, pin_memory=True)
return dataset_train, dataset_test
def dataset(args, train, transform):
if args.dataset == "cifar10":
return torchvision.datasets.CIFAR10(root=args.data_root, transform=transform, download=True, train=train)
elif args.dataset == "gtsrb":
return torchvision.datasets.ImageFolder(root=args.data_root+'/GTSRB/Train' if train \
else args.data_root+'/GTSRB/val4imagefolder', transform=transform)
# return torchvision.datasets.GTSRB(root=args.data_root+'gtsrb_torch', split='train' if train \
# else 'test', transform=transform, download=True)
elif args.dataset == "tiny-imagenet":
return torchvision.datasets.ImageFolder(root=args.data_root+'/tiny-imagenet-200/train' if train \
else args.data_root + '/tiny-imagenet-200/val', transform=transform)
elif args.dataset == "imagenet200":
return torchvision.datasets.ImageFolder(root=args.data_root+'/imagenet200/train' if train \
else args.data_root + '/imagenet200/val', transform=transform)