-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtask_generator.py
More file actions
97 lines (86 loc) · 4.37 KB
/
task_generator.py
File metadata and controls
97 lines (86 loc) · 4.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/usr/bin/env python3
"""
General wrapper to help create tasks.
"""
import random
from torchvision import transforms
from utils import *
class TaskGenerator():
def __init__(self,
meta_train_set,
meta_test_set,
meta_train_partitions,
meta_test_partitions,
args):
self.meta_train_set = meta_train_set
self.meta_test_set = meta_test_set
self.meta_train_partitions = meta_train_partitions
self.meta_test_partitions = meta_test_partitions
self.image_resize_transforms = transforms.Resize((args.imgSizeToMetaModel, args.imgSizeToMetaModel))
def _sample_task_idxs_labels(self, partition, n_way, n_train_samples, n_test_samples):
(
train_idxs,
train_labels,
train_labels_orig,
test_idxs,
test_labels,
test_labels_orig
) = [], [], [], [], [], []
clses_in_partition = list(partition.keys())
assert len(clses_in_partition) >= n_way, \
f"Partition has {len(clses_in_partition)} cls, while expecting {n_way}"
sampled_clses = random.sample(clses_in_partition, n_way)
random.shuffle(sampled_clses)
for label, cls in enumerate(sampled_clses):
assert len(partition[cls]) >= n_train_samples+n_test_samples, \
f"Class {cls} has {len(partition[cls])} samples, while expecting {n_train_samples+n_test_samples}"
idxs = random.sample(partition[cls], n_train_samples+n_test_samples)
train_idxs.extend(idxs[:n_train_samples])
train_labels.extend([label for _ in range(n_train_samples)])
train_labels_orig.extend([cls for _ in range(n_train_samples)])
test_idxs.extend(idxs[n_train_samples:])
test_labels.extend([label for _ in range(n_test_samples)])
test_labels_orig.extend([cls for _ in range(n_test_samples)])
return train_idxs, train_labels, train_labels_orig, test_idxs, test_labels, test_labels_orig
def _sample_partition(self, partitions):
# If there is only one partition, always return this partition for constructing all tasks
# which would be the case for supervised task construction
assert len(partitions) > 0
return partitions[np.random.choice(len(partitions))]
def sample_task(self, meta_split, args):
n_way = args.NWay
if meta_split == 'meta_train':
partition_for_task = self._sample_partition(self.meta_train_partitions)
n_train_samples, n_test_samples = args.KShot, args.KQuery
elif meta_split == "meta_test":
partition_for_task = self._sample_partition(self.meta_test_partitions)
n_train_samples, n_test_samples = args.KShotTest, args.KQueryTest
else:
print(f"Invalid argument {meta_split}!")
exit(1)
# sample the labels as true labels, idxs are idxs in filtered metadataset
(
train_idxs,
train_labels,
train_labels_orig,
test_idxs,
test_labels,
test_labels_orig
) = self._sample_task_idxs_labels(partition_for_task,
n_way,
n_train_samples,
n_test_samples)
# use idxs in filtered metadataset to index, which intrinsically would return the intended sample
if meta_split=="meta_train":
meta_set_to_gather = self.meta_train_set
else:
meta_set_to_gather = self.meta_test_set
# could potentially use torch.utils.data.Subset here, however we want to extract images only
train_data = torch.stack([meta_set_to_gather[id][0] for id in train_idxs], dim=0)
test_data = torch.stack([meta_set_to_gather[id][0] for id in test_idxs], dim=0)
# apply resize transformation for input of base learner in meta-learning
assert args.imgSizeToMetaModel > 0
train_data, test_data = self.image_resize_transforms(train_data), \
self.image_resize_transforms(test_data)
train_labels, test_labels = torch.tensor(train_labels), torch.tensor(test_labels)
return train_data, train_labels, train_labels_orig, test_data, test_labels, test_labels_orig