diff --git a/cellSAM/model.py b/cellSAM/model.py index 23283be..617c001 100644 --- a/cellSAM/model.py +++ b/cellSAM/model.py @@ -144,9 +144,9 @@ def segment_cellular_image( model, img = model.to(device), img.to(device) preds = model.predict(img, x=None, boxes_per_heatmap=bounding_boxes, device=device, fast=fast) - if preds is None: + if preds[0] is None: warn("No cells detected in the image.") - return np.zeros(img.shape[1:], dtype=np.int32), None, None + return np.zeros(img.shape[-2:], dtype=np.uint8), None, torch.empty((1, 4)) segmentation_predictions, _, x, bounding_boxes = preds @@ -160,12 +160,12 @@ def segment_cellular_image( return mask, x.cpu().numpy(), bounding_boxes -def postprocess_predictions(mask: np.ndarray): - mask_values = np.unique(mask) +def postprocess_predictions(masks: np.ndarray): + mask_values = np.unique(masks) new_masks = [] selem = disk(2) for mask_value in mask_values[1:]: - mask = mask == mask_value + mask = masks == mask_value mask, _ = remove_small_regions(mask, 20, mode="holes") mask, _ = remove_small_regions(mask, 20, mode="islands") opened_mask = binary_opening(mask, selem)