Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions benchmarks/filter_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setup_filter(
ball_z_size_um=15,
)
filtered = filtered.astype(settings.filtering_dtype)
filtered = torch.from_numpy(filtered).to(torch_device)
filtered_torch = torch.from_numpy(filtered).to(torch_device)
tiles = tiles.astype(np.bool_)
tiles = torch.from_numpy(tiles).to(torch_device)

Expand All @@ -62,17 +62,20 @@ def setup_filter(
use_mask=True,
)

return ball_filter, filtered, tiles, batch_size
return ball_filter, filtered, filtered_torch, tiles, batch_size


def run_filter(ball_filter, filtered, tiles, batch_size):
def run_filter(ball_filter, filtered, filtered_torch, tiles, batch_size):
for i in range(0, len(filtered), batch_size):
ball_filter.append(
filtered[i : i + batch_size], tiles[i : i + batch_size]
filtered_torch[i : i + batch_size],
tiles[i : i + batch_size],
filtered[i : i + batch_size],
)
if ball_filter.ready:
ball_filter.walk()
ball_filter.get_processed_planes()
ball_filter.get_raw_planes()


if __name__ == "__main__":
Expand Down
41 changes: 29 additions & 12 deletions benchmarks/filter_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
import tqdm
from brainglobe_utils.cells.cells import Cell
from brainglobe_utils.IO.cells import save_cells
from brainglobe_utils.IO.image.load import read_with_dask
from brainglobe_utils.IO.image.load import read_z_stack

from cellfinder.core.detect.filters.plane import TileProcessor
from cellfinder.core.detect.filters.setup_filters import DetectionSettings
Expand All @@ -72,7 +72,7 @@


def setup_filter(
signal_path: Path, # expect to load z, y, x
signal_path: Path | np.ndarray, # expect to load z, y, x
batch_size: int = 1,
torch_device="cpu",
dtype=np.uint16,
Expand All @@ -87,14 +87,18 @@ def setup_filter(
n_free_cpus: int = 2,
log_sigma_size: float = 0.2,
n_sds_above_mean_thresh: float = 10,
split_ball_xy_size: int = 3,
split_ball_z_size: int = 3,
detect_centre_of_intensity: bool = False,
split_ball_xy_size: int = 6,
split_ball_z_size: int = 15,
split_ball_overlap_fraction: float = 0.8,
n_splitting_iter: int = 10,
start_plane: int = 0,
end_plane: int = 0,
):
signal_array = read_with_dask(str(signal_path))
signal_array = signal_path
if not isinstance(signal_path, np.ndarray):
signal_array = read_z_stack(str(signal_path))

if end_plane <= 0:
end_plane = len(signal_array)
signal_array = signal_array[start_plane:end_plane, :, :]
Expand Down Expand Up @@ -123,6 +127,7 @@ def setup_filter(
batch_size=batch_size,
torch_device=torch_device,
n_splitting_iter=n_splitting_iter,
detect_centre_of_intensity=detect_centre_of_intensity,
)

kwargs = dataclasses.asdict(settings)
Expand All @@ -134,7 +139,7 @@ def setup_filter(
splitting_settings = DetectionSettings(**kwargs)

signal_array = settings.filter_data_converter_func(signal_array)
signal_array = torch.from_numpy(signal_array).to(torch_device)
signal_array_torch = torch.from_numpy(signal_array).to(torch_device)

tile_processor = TileProcessor(
plane_shape=shape[1:],
Expand Down Expand Up @@ -178,6 +183,7 @@ def setup_filter(
ball_filter,
cell_detector,
signal_array,
signal_array_torch,
batch_size,
)

Expand Down Expand Up @@ -244,9 +250,14 @@ def dump_structures(
writer = csv.writer(fh, delimiter=",")
writer.writerow(["id", "x", "y", "z", "volume", "volume_type"])

s_intensities = cell_detector.get_structures_intensities()
for cell_id, cell_points in cell_detector.get_structures().items():
intensity = None
if settings.detect_centre_of_intensity:
intensity = s_intensities[cell_id]

vol = len(cell_points)
x, y, z = get_structure_centre(cell_points)
x, y, z = get_structure_centre(cell_points, intensity)

if vol < max_vol:
tp = "maybe_cell"
Expand All @@ -270,7 +281,9 @@ def dump_structures(
struct_type_split[p[2], p[1], p[0]] = color

if tp == "needs_split":
centers = split_cells(cell_points, settings=splitting_settings)
centers = split_cells(
cell_points, splitting_settings, intensity
)
for x, y, z in centers:
x, y, z = map(int, [x, y, z])
if any(v < r1 for v in [x, y, z]):
Expand Down Expand Up @@ -347,6 +360,7 @@ def run_filter(
tile_processor: TileProcessor,
ball_filter: BallFilter,
cell_detector: CellDetector,
signal_array_np,
signal_array,
batch_size,
):
Expand All @@ -358,6 +372,7 @@ def run_filter(

for i in tqdm.tqdm(range(0, len(signal_array), batch_size)):
batch = signal_array[i : i + batch_size]
batch_np = signal_array_np[i : i + batch_size]
save_tiffs(output_root, "input", i, batch, n)

batch_clipped = torch.clone(batch)
Expand All @@ -373,10 +388,11 @@ def run_filter(
save_tiffs(output_root, "inside_brain", i, inside_brain_tiles, n)
save_tiffs(output_root, "filtered_2d", i, filtered_2d, n)

ball_filter.append(filtered_2d, inside_brain_tiles)
ball_filter.append(filtered_2d, inside_brain_tiles, batch_np)
if ball_filter.ready:
ball_filter.walk()
middle_planes = ball_filter.get_processed_planes()
raw_planes = ball_filter.get_raw_planes()
buff = middle_planes.copy()
buff[buff != settings.soma_centre_value] = 0
save_tiffs(
Expand All @@ -389,11 +405,11 @@ def run_filter(

detection_middle_planes = detection_converter(middle_planes)

for k, (plane, detection_plane) in enumerate(
zip(middle_planes, detection_middle_planes)
for k, (plane, raw_plane, detection_plane) in enumerate(
zip(middle_planes, raw_planes, detection_middle_planes)
):
previous_plane = cell_detector.process(
detection_plane, previous_plane
detection_plane, previous_plane, raw_plane
)
save_tiffs(
output_root,
Expand Down Expand Up @@ -429,6 +445,7 @@ def run_filter(
ball_overlap_fraction=0.8,
log_sigma_size=0.35,
n_sds_above_mean_thresh=1,
detect_centre_of_intensity=False,
soma_spread_factor=4,
max_cluster_size=10000,
voxel_sizes=(4, 2.03, 2.03),
Expand Down
8 changes: 8 additions & 0 deletions cellfinder/core/detect/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def main(
tiled_thresh_tile_size: float | None = None,
*,
callback: Optional[Callable[[int], None]] = None,
detect_centre_of_intensity: bool = False,
) -> List[Cell]:
"""
Perform cell candidate detection on a 3D signal array.
Expand Down Expand Up @@ -155,6 +156,12 @@ def main(
callback : Callable[int], optional
A callback function that is called every time a plane has finished
being processed. Called with the plane number that has finished.
detect_centre_of_intensity : bool
If False, a candidate cell's center is just the mean of the positions
of all voxels marked as above background, or bright, in that candidate.
The voxel intensity is not taken into account. If True, the center is
calculated similar to the center of mass, but using the intensity. So
the center gets pulled towards the brighter voxels in the volume.

Returns
-------
Expand Down Expand Up @@ -219,6 +226,7 @@ def main(
torch_device=torch_device,
pin_memory=pin_memory,
n_splitting_iter=n_splitting_iter,
detect_centre_of_intensity=detect_centre_of_intensity,
)

# replicate the settings specific to splitting, before we access anything
Expand Down
24 changes: 14 additions & 10 deletions cellfinder/core/detect/filters/plane/plane_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,13 @@ def get_tile_mask(
Parameters
----------
planes : torch.Tensor
Input planes (z-stack). Note, the input data is modified.
Input planes (z-stack). Note, the input data is modified by being
clipped between zero and `clipping_value`. Otherwise, it is
unchanged.

Returns
-------
planes : torch.Tensor
filtered_planes : torch.Tensor
Filtered and thresholded planes (z-stack).
inside_brain_tiles : torch.Tensor
Boolean mask indicating which tiles are inside (1) or
Expand All @@ -144,7 +146,7 @@ def get_tile_mask(
# Threshold the image
enhanced_planes = self.peak_enhancer.enhance_peaks(planes)

_threshold_planes(
filtered_planes = _threshold_planes(
planes,
enhanced_planes,
self.n_sds_above_mean_thresh,
Expand All @@ -154,7 +156,7 @@ def get_tile_mask(
self.torch_device,
)

return planes, inside_brain_tiles
return filtered_planes, inside_brain_tiles

def get_tiled_buffer(self, depth: int, device: str):
return self.tile_walker.get_tiled_buffer(depth, device)
Expand All @@ -169,11 +171,11 @@ def _threshold_planes(
local_threshold_tile_size_px: int,
threshold_value: int,
torch_device: str,
) -> None:
) -> torch.Tensor:
"""
Sets each plane (in-place) to threshold_value, where the corresponding
enhanced_plane > mean + n_sds_above_mean_thresh*std. Each plane will be
set to zero elsewhere.
Sets each pixel in the returned planes to threshold_value, where the
corresponding enhanced_plane > mean + n_sds_above_mean_thresh*std.
Each pixel will be set to zero elsewhere. Original `planes` is unchanged.
"""
z, y, x = enhanced_planes.shape

Expand Down Expand Up @@ -257,10 +259,12 @@ def _threshold_planes(
else:
above = above_global

planes[above] = threshold_value
# subsequent steps only care about the values that are set to threshold or
# above in planes. We set values in *planes* to threshold based on the
# value in *enhanced_planes*. So, there could be values in planes that are
# at threshold already, but in enhanced_planes they are not. So it's best
# to zero all other values, so voxels previously at threshold don't count
planes[torch.logical_not(above)] = 0
filtered_planes = torch.zeros_like(planes)
filtered_planes[above] = threshold_value

return filtered_planes
12 changes: 12 additions & 0 deletions cellfinder/core/detect/filters/setup_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ class DetectionSettings:
done with 50% overlap when striding.
"""

detect_centre_of_intensity: bool = False
"""
How to calculate the centre of a candidate cell, given a collection of
pixels marked as bright or above background.

If False, a candidate cell's center is just the mean of the positions
of all voxels marked as above background, or bright, in that candidate.
The voxel intensity is not taken into account. If True, the center is
calculated similar to the center of mass, but using the intensity. So
the center gets pulled towards the brighter voxels in the volume.
"""

outlier_keep: bool = False
"""Whether to keep outlier structures during detection."""

Expand Down
Loading
Loading