-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathutils.py
More file actions
72 lines (56 loc) · 1.92 KB
/
utils.py
File metadata and controls
72 lines (56 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import copy
import os
import random
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
# TODO: given bounding boxes and corresponding scores, perform non max suppression
def nms(bounding_boxes, confidence_score, threshold=0.05):
"""
bounding boxes of shape Nx4
confidence scores of shape N
threshold: confidence threshold for boxes to be considered
return: list of bounding boxes and scores
"""
boxes, scores = None, None
return boxes, scores
# TODO: calculate the intersection over union of two boxes
def iou(box1, box2):
"""
Calculates Intersection over Union for two bounding boxes (xmin, ymin, xmax, ymax)
returns IoU vallue
"""
return iou
def tensor_to_PIL(image):
"""
converts a tensor normalized image (imagenet mean & std) into a PIL RGB image
will not work with batches (if batch size is 1, squeeze before using this)
"""
inv_normalize = transforms.Normalize(
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
std=[1/0.229, 1/0.224, 1/0.255],
)
inv_tensor = inv_normalize(image)
inv_tensor = torch.clamp(inv_tensor, 0, 1)
original_image = transforms.ToPILImage()(inv_tensor).convert("RGB")
return original_image
def get_box_data(classes, bbox_coordinates):
"""
classes : tensor containing class predictions/gt
bbox_coordinates: tensor containing [[xmin0, ymin0, xmax0, ymax0], [xmin1, ymin1, ...]] (Nx4)
return list of boxes as expected by the wandb bbox plotter
"""
box_list = [{
"position": {
"minX": bbox_coordinates[i][0],
"minY": bbox_coordinates[i][1],
"maxX": bbox_coordinates[i][2],
"maxY": bbox_coordinates[i][3],
},
"class_id": classes[i],
} for i in range(len(classes))
]
return box_list