diff --git a/script/extract_features_from_gt.py b/script/extract_features_from_gt.py index 10e3f83..0ba8a81 100644 --- a/script/extract_features_from_gt.py +++ b/script/extract_features_from_gt.py @@ -98,7 +98,7 @@ def get_batch_proposals(self, images, im_scales, im_infos, proposals): ).to("cuda") orig_image_size = (img_info["width"], img_info["height"]) boxes = BoxList(boxes_tensor, orig_image_size) - image_size = (images.image_sizes[idx][1], images.image_sizes[idx][0]) + image_size = (img_info["scale_width"], img_info["scale_height"]) boxes = boxes.resize(image_size) proposals_batch.append(boxes) return proposals_batch @@ -130,7 +130,12 @@ def _image_transform(self, path): ) img = torch.from_numpy(im).permute(2, 0, 1) - im_info = {"width": im_width, "height": im_height} + im_info = { + "width": im_width, + "height": im_height, + "scale_width": img.shape[2], + "scale_height": img.shape[1], + } return img, im_scale, im_info