-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
95 lines (78 loc) · 3.3 KB
/
predict.py
File metadata and controls
95 lines (78 loc) · 3.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch, torchvision
import detectron2
import time
from detectron2.utils.logger import setup_logger
setup_logger()
import numpy as np
import pandas as pd
import os, json, cv2, random, glob
import matplotlib.pyplot as plt
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultTrainer
from detectron2.evaluation import COCOEvaluator
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.utils.visualizer import ColorMode
if __name__ == '__main__':
# Prepare data
register_coco_instances("train", {}, "/src/train_df.json", "/src/Dataset")
register_coco_instances("val", {}, "/src/val_df.json", "/src/Dataset")
# Setup and train
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/retinanet_R_101_FPN_3x.yaml"))
cfg.OUTPUT_DIR = '/src/models/retinanet_R101.yaml'
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.RETINANET.NUM_CLASSES = 7
#defining non-max suppression threshold and minimun score threshold
cfg.MODEL.RETINANET.NMS_THRESH_TEST = 0.25
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = 0.7
#creating the predictor
predictor = DefaultPredictor(cfg)
#creating path for predictions
if os.path.exists(os.path.join( cfg.OUTPUT_DIR, 'samples' )):
filenames = glob.glob(os.path.join( cfg.OUTPUT_DIR, 'samples', '*' ))
for path in filenames:
os.remove(path)
else:
os.mkdir( os.path.join( cfg.OUTPUT_DIR, 'samples' ) )
#making inferences and drawing images
predictor = DefaultPredictor(cfg)
val_metadata = MetadataCatalog.get('val')
val_data = DatasetCatalog.get('val')
#create dataset
pred_list = []
#making predictions for all validation data
start_time = time.time()
samples = 0
for i, d in enumerate(val_data):
img = cv2.imread(d['file_name'])
outputs = predictor(img)
print(outputs)
print(outputs['instances'].pred_boxes.tensor)
tensor = outputs['instances'].pred_boxes.tensor.tolist()
pred_list.append((d['file_name'], tensor))
samples=samples+1
print(f"--- {(time.time() - start_time)} seconds for predicting {samples} samples ---")
#saving predictions to a json file
with open('/src/predicts.json','w') as f:
json.dump(pred_list, f)
#printing some detection examples
for i, d in enumerate(val_data):
print('Processing image', d['file_name'])
img = cv2.imread(d['file_name'])
outputs = predictor(img)
print(outputs)
print(outputs['instances'].pred_boxes.tensor)
tensor = outputs['instances'].pred_boxes.tensor.tolist()
v = Visualizer( img[:, :, ::-1],
metadata=val_metadata,
#scale=0.5,
#instance_mode=ColorMode.IMAGE_BW
)
out = v.draw_instance_predictions(outputs['instances'].to('cpu') )
cv2.imwrite(os.path.join(cfg.OUTPUT_DIR, 'samples', '%d.jpg'%i),
out.get_image()[:, :, ::-1])