-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathpredict_from_file.py
More file actions
56 lines (47 loc) · 1.95 KB
/
predict_from_file.py
File metadata and controls
56 lines (47 loc) · 1.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import sys
import traceback
import cv2
from sklearn.externals import joblib
from common.config import get_config
from common.image_transformation import apply_image_transformation
import warnings
def main():
warnings.filterwarnings('ignore')
model_name = 'logistic'
if model_name not in ['svm', 'logistic', 'knn']:
print("Invalid model-name '{}'!".format(model_name))
return
#print("Using model '{}'...".format(model_name))
model_serialized_path = get_config(
'model_{}_serialized_path'.format(model_name))
#print("Model deserialized from path '{}'".format(model_serialized_path))
testing_images_labels_path = get_config('testing_images_labels_path')
with open(testing_images_labels_path, 'r') as file:
lines = file.readlines()
total = 0
cnt = 0
for line in lines:
#print("\n\n" + line.strip())
total += 1
image_path, image_label = line.split()
frame = cv2.imread(image_path)
try:
frame = apply_image_transformation(frame)
frame_flattened = frame.flatten()
classifier_model = joblib.load(model_serialized_path)
predicted_labels = classifier_model.predict(frame_flattened)
predicted_label = predicted_labels[0]
#print("Predicted label = {}".format(predicted_label))
if image_label != predicted_label:
cnt += 1
cv2.waitKey(5000)
except Exception:
exception_traceback = traceback.format_exc()
print("Error while applying image transformation on image path '{}' with the following exception trace:\n{}".format(
image_path, exception_traceback))
continue
print(str(cnt)+" "+str(total))
cv2.destroyAllWindows()
#print ("The program completed successfully !!")
if __name__ == '__main__':
main()