-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAE_training.py
More file actions
237 lines (165 loc) · 7.22 KB
/
AE_training.py
File metadata and controls
237 lines (165 loc) · 7.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# Some aspects of this program were taken from this repository:
# https://github.com/L1aoXingyu/pytorch-beginner/blob/master/08-AutoEncoder/conv_autoencoder.py
#
# Structure of the autoencoder and saving of the output while testing as an image were picked up.
import time
import torch
import numpy as np
from torch import cuda, nn, optim, cosine_similarity
from torch.utils import data
from torchvision.utils import save_image
from model_def import FeatureExtractorDataSet, AutoEncoder, plot_roc
def main(train_im_folder: str, test_im_folder: str, pair_file: str, tag: str, t_stmp: str, verbose: bool = False):
"""
Main function for part 1.
:param train_im_folder:
:param test_im_folder:
:param pair_file:
:param tag:
:param t_stmp:
:param verbose:
:return:
"""
# -------------------------------- Hyper-parameters --------------------------------
learning_rate = 0.005
lamb = 0.04
epochs = 25
batch_size = 32
batch_print = 20
threshold = 0.85
sgd_momentum = 0.3
op_dir = "models/"
# ----------------------------------------------------------------------------------
# --------------------------------- Fetching model ---------------------------------
device = torch.device("cuda:0" if cuda.is_available() else "cpu")
training_set = FeatureExtractorDataSet(img_folder=train_im_folder)
train_loader = data.DataLoader(dataset=training_set,
batch_size=batch_size,
shuffle=True)
test_set = FeatureExtractorDataSet(img_folder=test_im_folder,
pair_file=pair_file)
test_loader = data.DataLoader(dataset=test_set,
batch_size=batch_size,
shuffle=True)
if verbose:
print("Dataset and data loaders acquired.")
model = AutoEncoder(bias=True).to(device)
model.train(True)
criterion = nn.MSELoss()
# optimizer = optim.Adam(params=model.parameters(),
# lr=learning_rate,
# weight_decay=lamb,
# amsgrad=False
# )
optimizer = optim.SGD(params=model.parameters(),
lr=learning_rate,
momentum=sgd_momentum,
weight_decay=lamb,
nesterov=True
)
running_loss = 0.0
past_loss = np.inf
total_len = len(train_loader)
# ----------------------------------------------------------------------------------
# ------------------------------- Start of training --------------------------------
print("\nStart of Part 1 training.")
print("Total batches in an epoch: {0}".format(total_len))
start_time = time.time()
for epoch in range(epochs):
if verbose:
print("")
for i, image in enumerate(train_loader):
# Change variable type to match GPU requirements
inp = image.to(device)
# lab = labels.to(device)
# Reset gradients before processing
optimizer.zero_grad()
# Get model output
out = model(inp)
# Calculate loss
loss = criterion(out, inp)
# Update weights
loss.backward()
optimizer.step()
running_loss += loss.item()
if (i + 1) % batch_print == 0:
if verbose:
print(
"\rEpoch: {0}, step: {1}/{2} Running Loss (avg): {3:.06f}, Past: {4:.06f} ".format(
epoch + 1, i + 1, total_len, (running_loss / batch_print), (past_loss / batch_print)),
end="")
if running_loss < past_loss:
past_loss = running_loss
running_loss = 0.0
if i + 1 == total_len and epoch == 24:
# print(out[0])
save_image(tensor=out[:4],
filename="visualization/{0}_T{1}.png".format(tag, t_stmp),
nrow=2,
normalize=True)
train_time = time.time()
print("\nTraining completed in {0} sec\n".format(train_time - start_time))
# ----------------------------------------------------------------------------------
# if train_only:
# return model
# ------------------------------ Start of evaluation -------------------------------
print("Starting Part 1 evaluation\n")
model.eval()
count = 0
tot = 0
total_score = np.array([])
total_label = np.array([])
for i, (left_im, right_im, label) in enumerate(test_loader): # l_name, r_name
try:
left = left_im.to(device)
right = right_im.to(device)
lab = np.asarray(label)
total_label = np.append(total_label, lab)
out1 = model(left)
out2 = model(right)
score = cosine_similarity(out1, out2).detach().cpu().numpy()
total_score = np.append(total_score, score)
prediction = score >= threshold
# print(score)
# print(prediction)
count += (prediction == lab).sum()
tot += left_im.shape[0]
except FileNotFoundError:
print("File not found: {0}. Skipping iteration {1}".format(FileNotFoundError.filename, i))
plot_roc(y_true=total_label, y_score=total_score, tag=tag, tstmp=t_stmp)
# print(count)
acc = (count * 100) / tot
# print("\nAccuracy = {0:.06f} %\n\n".format(acc))
# ----------------------------------------------------------------------------------
# if train_only:
# return model
# --------------------------------- Saving model -----------------------------------
filename = op_dir + "P2_A-{2:.03f}_{1}_T_{0}.pt".format(t_stmp, tag, acc)
# Idea for named saved file was picked up from here:
# https://github.com/quiltdata/pytorch-examples/blob/master/imagenet/main.py
save_file = {"model": model.state_dict(),
"criterion": criterion.state_dict(),
"optimizer": optimizer.state_dict()
}
torch.save(obj=save_file, f=filename)
# ----------------------------------------------------------------------------------
return model
if __name__ == "__main__":
# This is sample imput for a test run. Actual input is given from file main.py
print("Program run started at", time.asctime())
tstmp = time.strftime("%Y%m%d_%H%M%S", time.gmtime())
main(train_im_folder="seen-dataset/TrainingSet/",
test_im_folder="seen-dataset/ValidationSet/",
pair_file="seen-dataset/dataset_seen_validation_siamese.csv",
tag="SN",
t_stmp=tstmp)
main(train_im_folder="shuffled-dataset/TrainingSet/",
test_im_folder="shuffled-dataset/ValidationSet/",
pair_file="shuffled-dataset/dataset_seen_validation_siamese.csv",
tag="SH",
t_stmp=tstmp)
main(train_im_folder="unseen-dataset/TrainingSet/",
test_im_folder="unseen-dataset/ValidationSet/",
pair_file="unseen-dataset/dataset_seen_validation_siamese.csv",
tag="UN",
t_stmp=tstmp)