forked from Tandoan19/SONNET
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathinfer.py
More file actions
228 lines (183 loc) · 8.18 KB
/
infer.py
File metadata and controls
228 lines (183 loc) · 8.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import argparse
import glob
import math
import os
from collections import deque
import cv2
import numpy as np
from scipy import io as sio
import matplotlib.pyplot as plt
from skimage import measure
from scipy.ndimage import find_objects
from tensorpack.predict import OfflinePredictor, PredictConfig
from tensorpack.tfutils.sessinit import get_model_loader
from config import Config
from misc.utils import rm_n_mkdir
import json
import operator
####
def get_best_chkpts(path, metric_name, comparator='>'):
"""
Return the best checkpoint according to some criteria.
Note that it will only return valid path, so any checkpoint that has been
removed wont be returned (i.e moving to next one that satisfies the criteria
such as second best etc.)
Args:
path: directory contains all checkpoints, including the "stats.json" file
"""
stat_file = path + '/02' + '/stats.json'
ops = {
'>': operator.gt,
'<': operator.lt,
}
op_func = ops[comparator]
with open(stat_file) as f:
info = json.load(f)
if comparator == '>':
best_value = -float("inf")
else:
best_value = +float("inf")
best_chkpt = None
for epoch_stat in info:
epoch_value = epoch_stat[metric_name]
if op_func(epoch_value, best_value):
chkpt_path = "%s/02/model-%d.index" % (path, epoch_stat['global_step'])
if os.path.isfile(chkpt_path):
selected_stat = epoch_stat
best_value = epoch_value
best_chkpt = chkpt_path
return best_chkpt, selected_stat
####
class Inferer(Config):
def __gen_prediction(self, x, predictor):
"""
Using 'predictor' to generate the prediction of image 'x'
Args:
x : input image to be segmented. It will be split into patches
to run the prediction upon before being assembled back
"""
step_size = [40, 40]
msk_size = self.infer_mask_shape
win_size = self.infer_input_shape
def get_last_steps(length, step_size):
nr_step = math.ceil((length - step_size) / step_size)
last_step = (nr_step + 1) * step_size
return int(last_step), int(nr_step + 1)
im_h = x.shape[0]
im_w = x.shape[1]
padt_img, padb_img = 0, 0
padl_img, padr_img = 0, 0
# pad if image size smaller than msk_size (for monusac dataset)
if im_h < msk_size[0]:
diff_h_img = msk_size[0] - im_h
padt_img = diff_h_img // 2
padb_img = diff_h_img - padt_img
if im_w < msk_size[1]:
diff_w_img = msk_size[1] - im_w
padl_img = diff_w_img // 2
padr_img = diff_w_img - padl_img
x_pad = np.lib.pad(x, ((padt_img, padb_img), (padl_img, padr_img), (0, 0)), 'reflect')
im_h_pad = x_pad.shape[0]
im_w_pad = x_pad.shape[1]
last_h, nr_step_h = get_last_steps(im_h_pad, step_size[0])
last_w, nr_step_w = get_last_steps(im_w_pad, step_size[1])
diff_h = win_size[0] - step_size[0]
padt = diff_h // 2
padb = last_h + win_size[0] - im_h_pad
diff_w = win_size[1] - step_size[1]
padl = diff_w // 2
padr = last_w + win_size[1] - im_w_pad
x_pad = np.lib.pad(x_pad, ((padt, padb), (padl, padr), (0, 0)), 'reflect')
#### TODO: optimize this
sub_patches = []
# generating subpatches from orginal
for row in range(0, last_h, step_size[0]):
for col in range (0, last_w, step_size[1]):
win = x_pad[row:row+win_size[0],
col:col+win_size[1]]
sub_patches.append(win)
pred_coded = deque()
pred_ord = deque()
while len(sub_patches) > self.inf_batch_size:
mini_batch = sub_patches[:self.inf_batch_size]
sub_patches = sub_patches[self.inf_batch_size:]
mini_output = predictor(mini_batch)
mini_coded = mini_output[0][:, 18:58, 18:58,:]
mini_ord = mini_output[1][:, 18:58, 18:58, :]
mini_coded = np.split(mini_coded, self.inf_batch_size, axis=0)
pred_coded.extend(mini_coded)
mini_ord = np.split(mini_ord, self.inf_batch_size, axis=0)
pred_ord.extend(mini_ord)
if len(sub_patches) != 0:
mini_output = predictor(sub_patches)
mini_coded = mini_output[0][:, 18:58, 18:58,:]
mini_ord = mini_output[1][:, 18:58, 18:58, :]
mini_coded = np.split(mini_coded, len(sub_patches), axis=0)
pred_coded.extend(mini_coded)
mini_ord = np.split(mini_ord, len(sub_patches), axis=0)
pred_ord.extend(mini_ord)
#### Assemble back into full image
output_patch_shape = np.squeeze(pred_coded[0]).shape
ch = 1 if len(output_patch_shape) == 2 else output_patch_shape[-1]
#### Assemble back into full image
pred_coded = np.squeeze(np.array(pred_coded))
pred_coded = np.reshape(pred_coded, (nr_step_h, nr_step_w) + pred_coded.shape[1:])
pred_coded = np.transpose(pred_coded, [0, 2, 1, 3, 4]) if ch != 1 else \
np.transpose(pred_coded, [0, 2, 1, 3])
pred_coded = np.reshape(pred_coded, (pred_coded.shape[0] * pred_coded.shape[1],
pred_coded.shape[2] * pred_coded.shape[3], ch))
pred_coded = np.squeeze(pred_coded[:im_h_pad, :im_w_pad]) # just crop back to original size
pred_coded = pred_coded[padt_img:padt_img+im_h, padl_img:padl_img+im_w]
pred_ord = np.squeeze(np.array(pred_ord))
pred_ord = np.reshape(pred_ord, (nr_step_h, nr_step_w) + pred_ord.shape[1:])
pred_ord = np.transpose(pred_ord, [0, 2, 1, 3])
pred_ord = np.reshape(pred_ord, (pred_ord.shape[0] * pred_ord.shape[1], pred_ord.shape[2] * pred_ord.shape[3]))
pred_ord = np.squeeze(pred_ord[:im_h_pad, :im_w_pad])
pred_ord = pred_ord[padt_img:padt_img+im_h, padl_img:padl_img+im_w]
return pred_coded, pred_ord
####
def run(self):
if self.inf_auto_find_chkpt:
print('-----Auto Selecting Checkpoint Basing On "%s" Through "%s" Comparison' % \
(self.inf_auto_metric, self.inf_auto_comparator))
model_path, stat = get_best_chkpts(self.save_dir, self.inf_auto_metric, self.inf_auto_comparator)
print('Selecting: %s' % model_path)
print('Having Following Statistics:')
for key, value in stat.items():
print('\t%s: %s' % (key, value))
else:
model_path = self.inf_model_path
model_constructor = self.get_model()
pred_config = PredictConfig(
model = model_constructor(),
session_init = get_model_loader(model_path),
input_names = self.eval_inf_input_tensor_names,
output_names = self.eval_inf_output_tensor_names)
predictor = OfflinePredictor(pred_config)
save_dir = self.inf_output_dir
file_list = glob.glob('%s/*%s' % (self.inf_data_dir, self.inf_imgs_ext))
file_list.sort() # ensure same order
rm_n_mkdir(save_dir)
for filename in file_list:
filename = os.path.basename(filename)
basename = filename.split('.')[0]
print(self.inf_data_dir, basename, end=' ', flush=True)
##
if self.data_type != 'pannuke':
img = cv2.imread(self.inf_data_dir + '/' + filename)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
else:
img = np.load(self.inf_data_dir + '/' + filename)
##
pred_coded, pred_ord = self.__gen_prediction(img, predictor)
sio.savemat('%s/%s.mat' % (save_dir, basename), {'result':[pred_coded], 'result-ord':[pred_ord]})
print('FINISH')
####
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
inferer = Inferer()
inferer.run()