-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathinference.py
More file actions
44 lines (33 loc) · 1.32 KB
/
inference.py
File metadata and controls
44 lines (33 loc) · 1.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from common import N_TARGETS
from utils.torch import to_device, to_numpy
def infer_batch(inputs, model, device):
inputs = to_device(inputs, device)
predicted = model(*inputs)
inputs = [x.cpu() for x in inputs]
preds = torch.sigmoid(predicted)
preds = to_numpy(preds).astype(np.float32)
return preds
def infer(model, loader, checkpoint_file=None, device=torch.device('cuda')):
n_obs = len(loader.dataset)
batch_sz = loader.batch_size
predictions = np.zeros((n_obs, N_TARGETS))
currently_deterministic = torch.backends.cudnn.deterministic
torch.backends.cudnn.deterministic = True
if checkpoint_file is not None:
print(f'Starting inference for model: {checkpoint_file}')
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
with torch.no_grad():
for i, (inputs, _) in enumerate(tqdm(loader)):
start_index = i * batch_sz
end_index = min(start_index + batch_sz, n_obs)
batch_preds = infer_batch(inputs, model, device)
predictions[start_index:end_index, :] += batch_preds
torch.backends.cudnn.deterministic = currently_deterministic
return predictions