-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
104 lines (85 loc) · 4.09 KB
/
inference.py
File metadata and controls
104 lines (85 loc) · 4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from __future__ import annotations
import logging
from dataclasses import dataclass
import numpy as np
from .config import Settings
from .model import ModelLoadError, Sam2Model
from .parser import SamRequest
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class InferenceMetadata:
backend: str
used_fallback: bool
class InferenceService:
def __init__(self, settings: Settings):
self._settings = settings
self._model = Sam2Model(settings)
@property
def backend(self) -> str:
return self._model.backend if self._model.backend != "uninitialized" else self._settings.inference_mode
def warmup(self) -> None:
if self._settings.inference_mode != "sam2":
return
self._model.ensure_loaded()
def predict(self, request: SamRequest) -> tuple[np.ndarray, InferenceMetadata]:
try:
if self._settings.inference_mode == "heuristic":
return self._heuristic_predict(request), InferenceMetadata(backend="heuristic", used_fallback=False)
return self._sam2_predict(request), InferenceMetadata(backend="sam2", used_fallback=False)
except ModelLoadError:
raise
except Exception:
if not self._settings.allow_fallback:
raise
logger.exception("SAM2 inference failed, falling back to heuristic inference")
return self._heuristic_predict(request), InferenceMetadata(backend="heuristic", used_fallback=True)
def _sam2_predict(self, request: SamRequest) -> np.ndarray:
mask = np.zeros((request.shape_x, request.shape_y, request.shape_z), dtype=np.uint8)
for z in range(request.shape_z):
slice_data = request.image[:, :, z]
image_rgb = _prepare_sam_input(slice_data, request)
raw_mask = self._model.predict(image_rgb=image_rgb, point=request.point, bbox=request.bbox)
binary_mask = (raw_mask > self._settings.threshold).astype(np.uint8)
if binary_mask.shape != (request.shape_y, request.shape_x):
raise RuntimeError(
f"Unexpected mask shape {binary_mask.shape}; expected {(request.shape_y, request.shape_x)}."
)
mask[:, :, z] = binary_mask.T
return mask
def _heuristic_predict(self, request: SamRequest) -> np.ndarray:
mask = np.zeros((request.shape_x, request.shape_y, request.shape_z), dtype=np.uint8)
for z in range(request.shape_z):
if request.interaction_type == "bbox":
assert request.bbox is not None
x0, y0, x1, y1 = request.bbox
mask[x0 : x1 + 1, y0 : y1 + 1, z] = 1
else:
assert request.point is not None
x, y = request.point
radius = max(3, min(request.shape_x, request.shape_y) // 32)
xx, yy = np.ogrid[: request.shape_x, : request.shape_y]
circle = (xx - x) ** 2 + (yy - y) ** 2 <= radius ** 2
mask[:, :, z][circle] = 1
return mask
def _prepare_sam_input(slice_data: np.ndarray, request: SamRequest) -> np.ndarray:
normalized = _normalize_slice(slice_data, request)
image_u8 = np.clip(normalized * 255.0, 0.0, 255.0).astype(np.uint8)
image_hw3 = np.repeat(image_u8.T[:, :, None], 3, axis=2)
return image_hw3
def _normalize_slice(slice_data: np.ndarray, request: SamRequest) -> np.ndarray:
data = slice_data.astype(np.float32, copy=False)
if request.element_class_name == "uint8":
return data / 255.0
if request.element_class_name == "uint16":
return data / 65535.0
if request.element_class_name == "float32" or request.element_class_name == "float64":
min_value = request.intensity_min
max_value = request.intensity_max
else:
info = np.iinfo(request.dtype)
min_value = float(info.min)
max_value = float(info.max)
if max_value <= min_value:
raise ValueError(f"Invalid intensity range {min_value}..{max_value}.")
data = np.clip(data, min_value, max_value)
return (data - min_value) / (max_value - min_value)