-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
70 lines (57 loc) · 2.25 KB
/
inference.py
File metadata and controls
70 lines (57 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python3
import numpy as np
import torch
import utils
from config import PATCH_SIZE
def generate_final_anomaly_heatmap(ind, X_test, Y_test, generator, siamese):
"""
Generate an anomaly heatmap for a given test image index using PyTorch models.
Parameters
----------
ind : int
Index of the test image to process.
X_test, Y_test : np.ndarray
Arrays of shape (N, H, W, C) in [-1, 1].
generator : torch.nn.Module
Generator model (expects NCHW input, returns NCHW output).
siamese : torch.nn.Module
Siamese model used for patch-level anomaly scoring.
"""
device = (
torch.device("mps") if torch.backends.mps.is_available()
else torch.device("cuda") if torch.cuda.is_available()
else torch.device("cpu")
)
# --------------------------------------------------
# Prepare single test sample
# --------------------------------------------------
inp = np.array(X_test[ind:ind+1]) # (1, H, W, C)
real = np.array(Y_test[ind:ind+1]) # (1, H, W, C)
# NHWC -> NCHW for PyTorch
inp_t = torch.from_numpy(np.transpose(inp, (0, 3, 1, 2))).float().to(device)
# --------------------------------------------------
# Generate prediction with the PyTorch generator
# --------------------------------------------------
generator.eval()
with torch.no_grad():
pred_t = generator(inp_t)
# Back to NumPy NHWC for utils functions
predict = pred_t.detach().cpu().numpy()
predict = np.transpose(predict, (0, 2, 3, 1)) # (1, H, W, C)
# --------------------------------------------------
# Compute anomaly heatmap via utils (same as TF)
# --------------------------------------------------
heat_map = np.zeros((1, 256, 256, 1), dtype=np.float32)
reassembled_image = utils.process_images_random(
real,
predict,
10000,
siamese,
patch_size=(PATCH_SIZE, PATCH_SIZE)
)
heat_map += reassembled_image
heat_map = (heat_map - np.min(heat_map)) / (np.max(heat_map) - np.min(heat_map) + 1e-8)
# Combine heat map with absolute pixel difference
abs_diff = np.abs(predict[0, :, :, 0] - real[0, :, :, 0])
final_map = heat_map[0, :, :, 0] * abs_diff
return final_map