-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclimbnet demo.py
More file actions
64 lines (50 loc) · 2.18 KB
/
climbnet demo.py
File metadata and controls
64 lines (50 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# -*- coding: utf-8 -*-
"""
Created on Sat Oct 17 09:39:32 2020
@author: Maggie
"""
import os
import cv2
import torch
from torch.detectron2 import model_zoo
from torch.detectron2.config import get_cfg
from torch.detectron2.data import DatasetCatalog
from torch.detectron2.data import MetadataCatalog
from torch.detectron2.data.datasets import register_coco_instances
from torch.detectron2.engine import DefaultPredictor
from torch.detectron2.utils.visualizer import ColorMode
from torch.detectron2.utils.visualizer import Visualizer
def run_inference(image_path, model_path):
register_coco_instances("climb_dataset", {}, "./mask.json", "")
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATALOADER.NUM_WORKERS = 1
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3 # 3 classes (hold, volume, downclimb)
cfg.MODEL.WEIGHTS = os.path.join(model_path)
cfg.MODEL.DEVICE = 'cpu'
cfg.DATASETS.TEST = ("climb_dataset",)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.75 # set the testing threshold for this model
# setup inference
predictor = DefaultPredictor(cfg)
train_metadata = MetadataCatalog.get("climb_dataset")
# dataset catalog needs to exist so the polygon classes show up correctly
DatasetCatalog.get("climb_dataset")
im = cv2.imread(image_path)
outputs = predictor(im)
v = Visualizer(im[:, :, ::-1],
metadata=train_metadata,
scale=0.75,
instance_mode=ColorMode.IMAGE_BW # remove the colors of unsegmented pixels
)
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
cv2.imshow('climbnet', v.get_image()[:, :, ::-1])
cv2.waitKey(0)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Climbnet demo')
parser.add_argument('image_path', type=str,
help='image file')
parser.add_argument('model_path', type=str,
help='climbnet model weights')
args = parser.parse_args()
run_inference(args.image_path, args.model_path)