-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpredict.py
More file actions
42 lines (32 loc) · 1.21 KB
/
predict.py
File metadata and controls
42 lines (32 loc) · 1.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from src.utils.Model import MyModel
import pytorch_lightning as pl
from argparse import ArgumentParser
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch
def get_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--checkpoint_path', type=str, help="adam: learning rate")
return parser
def main(args=None):
pl.seed_everything(52)
parser = ArgumentParser()
script_args, _ = parser.parse_known_args(args)
parser = get_args(parser)
parser = MyModel.add_model_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args(args)
model = MyModel.load_from_checkpoint(**vars(args))
return model
if __name__ == "__main__":
image_path = "C:/Users/Tobias/Downloads/HIDA-ufz_image_challenge/photos_annotated/2019_0626_080354_004.jpg.jpg"
model = main()
image = np.array(Image.open(image_path)) / 255
image = np.moveaxis(image, -1, 0)
image = np.expand_dims(image, 0)
input_tensor = torch.from_numpy(image).float()
output = model(input_tensor).argmax(dim=1).detach().cpu().numpy()
plt.imshow(output[0])
plt.show()
print("finished")