-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathpreprocess.py
More file actions
202 lines (178 loc) · 8.6 KB
/
preprocess.py
File metadata and controls
202 lines (178 loc) · 8.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import tensorflow as tf
import numpy as np
from skimage.transform import resize
import os
from functools import partial
from zipfile import ZipFile, BadZipFile
from utils import ENV_NAME_TO_GYM_NAME
AUTOTUNE = tf.data.experimental.AUTOTUNE
def add_noise(img, noise_type):
"""
helper function to add desired noise type to img frame
:param img: img either (h, w, 1) or (h, w, 3) numpy array
:param noise_type: [vertical, horizontal, both] type of additive noise
:return: noisy img
"""
width = 5
line_colour = np.array([128, 0, 255]) / 255.0
if noise_type == "vertical":
# pick random x-coordinate
img_shape = img.shape
x_loc = np.random.randint(0, img_shape[1]-width)
line_pixel = np.tile(line_colour, (img_shape[0], width, 1))
img[:, x_loc:x_loc+width, :] = line_pixel
elif noise_type == "horizontal":
# pick random y-coordinate
img_shape = img.shape
y_loc = np.random.randint(0, img_shape[0]-width)
line_pixel = np.tile(line_colour, (width, img_shape[1], 1))
img[y_loc:y_loc+width, :] = line_pixel
elif noise_type == "both":
# pick random x-y coordinate
img_shape = img.shape
x_loc = np.random.randint(0, img_shape[1] - width)
col_pixel = np.tile(line_colour, (img_shape[0], width, 1))
img[:, x_loc:x_loc + width, :] = col_pixel
y_loc = np.random.randint(0, img_shape[0] - width)
row_pixel = np.tile(line_colour, (width, img_shape[1], 1))
img[y_loc:y_loc + width, :] = row_pixel
return img
def prepare_dataset(ds, batch_size, loss_to_use, split, shuffle_buffer_size=1000):
"""
This is a small dataset, only load it once, and keep it in memory.
Use `.cache(filename)` to cache preprocessing work for datasets that don't
fit in memory.
:param ds: Tensorflow Dataset object
:param batch_size: batch size
:param loss_to_use: "transporter" or "pkey"
:param split: data set split
:param shuffle_buffer_size: Size of the buffer to use for shuffling
:return:
"""
ds = ds.shuffle(buffer_size=shuffle_buffer_size)
ds = ds.batch(batch_size, drop_remainder=True)
# `prefetch` lets the dataset fetch batches in the background while the model is training
ds = ds.prefetch(buffer_size=AUTOTUNE)
if loss_to_use == "transporter" and split == "train":
ds = ds.take(10 ** 5) # train data size 100k
elif loss_to_use == "transporter" and (split == "valid" or split == "test"):
ds = ds.take(3*10**2) # valid/test data size 1k
return ds
def atari_generator_func(filenames_list, noise_type="vertical", rgb=False):
"""
A python generator for Atari frames.
:param filenames_list: list of data filenames containing observations.
:param noise_type: type of noise to be added for "noisy" Atari exps
:param rgb: (bool) If RGB (True) Grayscale (False)
:return:
"""
for file in filenames_list:
try:
with ZipFile(file, 'r') as zf:
data = np.load(file)
if not rgb: # gray-scale
obs = data['observations']
for i in range(obs.shape[0]):
yield obs[i, :, :, 3][:, :, None]
elif rgb: # atari full-sized colored frames (160, 210, 3)
obs = data['frames'] / 255.0
for i in range(obs.shape[0]):
if noise_type == "none":
yield resize(obs[i], (84, 84), order=0)
elif noise_type != "none":
yield add_noise(resize(obs[i], (84, 84), order=0), noise_type)
except BadZipFile:
print("Corrupted zip file ignored..")
def transporter_atari_gen(filenames_list, noise_type, rgb=False):
"""
:param filenames_list: list of data filenames containing atari frames.
:param noise_type: type of noise to be added for "noisy" Atari exps
:param rgb: (bool) (True) if rgb frame (False) for greyscale
:return:
"""
for file in filenames_list:
try:
with ZipFile(file, 'r') as zf:
data = np.load(file)
if not rgb:
obs = data['observations']
obs_shape = obs.shape
for count in range(int(10**5 / len(filenames_list))):
window_start_idx = np.random.randint(0, obs_shape[0]-100, 1)[0]
window_obs = obs[window_start_idx:window_start_idx+100, :, :, :]
t1 = np.random.randint(0, 49, 1)[0]
t2 = np.random.randint(50, 99, 1)[0]
image_a = window_obs[t1, :, :, 3]
image_b = window_obs[t2, :, :, 3]
image_a, image_b = image_a[:, :, None], image_b[:, :, None]
if noise_type == "none":
yield np.stack([image_a, image_b], axis=3)
elif noise_type != "none":
yield np.stack([add_noise(image_a, noise_type),
add_noise(image_b, noise_type)], axis=3)
elif rgb:
obs = data['frames'] / 255.0
obs_shape = obs.shape
for count in range(int(10**5 / len(filenames_list))):
window_start_idx = np.random.randint(0, obs_shape[0]-100)
window_obs = obs[window_start_idx:window_start_idx+100, :, :, :]
t1 = np.random.randint(0, 49, 1)[0]
t2 = np.random.randint(50, 99, 1)[0]
image_a = resize(window_obs[t1, :, :, :], (84, 84), order=0)
image_b = resize(window_obs[t2, :, :, :], (84, 84), order=0)
if noise_type == "none":
yield np.stack([image_a, image_b], axis=3)
elif noise_type != "none":
yield np.stack([add_noise(image_a, noise_type),
add_noise(image_b, noise_type)], axis=3)
except BadZipFile:
print("Corrupted zip file ignored...")
def deepmind_atari(data_path, env_name, split, loss_to_use, batch_size,
noise_type, rgb_frames):
"""
Create TF Dataset object from DM-style atari frames (84, 84, 1)
:param data_path: location of root data dir.
:param env_name: environment name.
:param split: {'train', 'valid', 'test'}
:param loss_to_use: "pkey" or "transporter" loss
:param batch_size: batch size.
:param noise_type: noise_type [vertical, horizontal, both, back_flicker, none] to be added
:param rgb_frames: boolean for indicating rgb frames (True) or grayscale frames (False)
:return: tf.data.Dataset object with DM-style atari frames
"""
env_name = ENV_NAME_TO_GYM_NAME[env_name]
if env_name is None:
raise ValueError("Unsupported environment: %s" % env_name)
if split == "train":
data_dir = os.path.join(data_path, "train", env_name)
elif split == "valid" or split == "test":
data_dir = os.path.join(data_path, "test", env_name)
else:
raise ValueError("Unknown dataset split: %s" % split)
assert os.path.isdir(data_dir), "%s does not exist" % data_dir
# Load files
filenames_list = []
for subdir, dirs, files in os.walk(data_dir):
for file in files: # opens each .npz file
filenames_list.append(os.path.join(subdir, file))
# splitting test folder files into 2 halves i.e. valid + test
if split == "valid":
filenames_list = filenames_list[0:round(len(filenames_list)/2)]
elif split == "test":
filenames_list = filenames_list[round(len(filenames_list)/2)+1:]
# Different processing depending on loss
if loss_to_use == "pkey":
# creating training dataset class
gen_func = partial(atari_generator_func, filenames_list=filenames_list,
noise_type=noise_type, rgb=rgb_frames)
ds = tf.data.Dataset.from_generator(gen_func, output_types=tf.float32)
ds = prepare_dataset(ds, batch_size, loss_to_use, split)
elif loss_to_use == "transporter":
# creating training dataset class
train_gen_func = partial(transporter_atari_gen, filenames_list=filenames_list,
noise_type=noise_type, rgb=rgb_frames)
ds = tf.data.Dataset.from_generator(train_gen_func, output_types=tf.float32)
ds = prepare_dataset(ds, batch_size, loss_to_use, split)
else:
raise ValueError("Unknown loss to use: %s" % loss_to_use)
return ds