From f95ac70d8ac796ffb92af24073b65c8e5283539b Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Wed, 24 Sep 2025 19:39:26 -0400 Subject: [PATCH 1/5] Add support for using center of intensity to locate cell centers. --- benchmarks/filter_3d.py | 11 +- benchmarks/filter_debug.py | 41 ++- cellfinder/core/detect/detect.py | 8 + .../core/detect/filters/plane/plane_filter.py | 24 +- .../core/detect/filters/setup_filters.py | 12 + .../core/detect/filters/volume/ball_filter.py | 60 +++- .../filters/volume/structure_detection.py | 283 ++++++++++++++---- .../filters/volume/structure_splitting.py | 133 +++++--- .../detect/filters/volume/volume_filter.py | 66 ++-- cellfinder/core/main.py | 9 + cellfinder/napari/detect/detect.py | 9 + cellfinder/napari/detect/detect_containers.py | 5 + tests/core/conftest.py | 121 +++++++- tests/core/test_integration/test_detection.py | 69 +++++ .../test_detection_structure_splitting.py | 4 +- .../test_volume_filters/test_ball_filter.py | 39 ++- 16 files changed, 720 insertions(+), 174 deletions(-) diff --git a/benchmarks/filter_3d.py b/benchmarks/filter_3d.py index 0a8dd0af..9dade7c5 100644 --- a/benchmarks/filter_3d.py +++ b/benchmarks/filter_3d.py @@ -42,7 +42,7 @@ def setup_filter( ball_z_size_um=15, ) filtered = filtered.astype(settings.filtering_dtype) - filtered = torch.from_numpy(filtered).to(torch_device) + filtered_torch = torch.from_numpy(filtered).to(torch_device) tiles = tiles.astype(np.bool_) tiles = torch.from_numpy(tiles).to(torch_device) @@ -62,17 +62,20 @@ def setup_filter( use_mask=True, ) - return ball_filter, filtered, tiles, batch_size + return ball_filter, filtered, filtered_torch, tiles, batch_size -def run_filter(ball_filter, filtered, tiles, batch_size): +def run_filter(ball_filter, filtered, filtered_torch, tiles, batch_size): for i in range(0, len(filtered), batch_size): ball_filter.append( - filtered[i : i + batch_size], tiles[i : i + batch_size] + filtered_torch[i : i + batch_size], + tiles[i : i + batch_size], + filtered[i : i + batch_size], ) if ball_filter.ready: ball_filter.walk() ball_filter.get_processed_planes() + ball_filter.get_raw_planes() if __name__ == "__main__": diff --git a/benchmarks/filter_debug.py b/benchmarks/filter_debug.py index a869b769..77ee4c63 100644 --- a/benchmarks/filter_debug.py +++ b/benchmarks/filter_debug.py @@ -57,7 +57,7 @@ import tqdm from brainglobe_utils.cells.cells import Cell from brainglobe_utils.IO.cells import save_cells -from brainglobe_utils.IO.image.load import read_with_dask +from brainglobe_utils.IO.image.load import read_z_stack from cellfinder.core.detect.filters.plane import TileProcessor from cellfinder.core.detect.filters.setup_filters import DetectionSettings @@ -72,7 +72,7 @@ def setup_filter( - signal_path: Path, # expect to load z, y, x + signal_path: Path | np.ndarray, # expect to load z, y, x batch_size: int = 1, torch_device="cpu", dtype=np.uint16, @@ -87,14 +87,18 @@ def setup_filter( n_free_cpus: int = 2, log_sigma_size: float = 0.2, n_sds_above_mean_thresh: float = 10, - split_ball_xy_size: int = 3, - split_ball_z_size: int = 3, + detect_centre_of_intensity: bool = False, + split_ball_xy_size: int = 6, + split_ball_z_size: int = 15, split_ball_overlap_fraction: float = 0.8, n_splitting_iter: int = 10, start_plane: int = 0, end_plane: int = 0, ): - signal_array = read_with_dask(str(signal_path)) + signal_array = signal_path + if not isinstance(signal_path, np.ndarray): + signal_array = read_z_stack(str(signal_path)) + if end_plane <= 0: end_plane = len(signal_array) signal_array = signal_array[start_plane:end_plane, :, :] @@ -123,6 +127,7 @@ def setup_filter( batch_size=batch_size, torch_device=torch_device, n_splitting_iter=n_splitting_iter, + detect_centre_of_intensity=detect_centre_of_intensity, ) kwargs = dataclasses.asdict(settings) @@ -134,7 +139,7 @@ def setup_filter( splitting_settings = DetectionSettings(**kwargs) signal_array = settings.filter_data_converter_func(signal_array) - signal_array = torch.from_numpy(signal_array).to(torch_device) + signal_array_torch = torch.from_numpy(signal_array).to(torch_device) tile_processor = TileProcessor( plane_shape=shape[1:], @@ -178,6 +183,7 @@ def setup_filter( ball_filter, cell_detector, signal_array, + signal_array_torch, batch_size, ) @@ -244,9 +250,14 @@ def dump_structures( writer = csv.writer(fh, delimiter=",") writer.writerow(["id", "x", "y", "z", "volume", "volume_type"]) + s_intensities = cell_detector.get_structures_intensities() for cell_id, cell_points in cell_detector.get_structures().items(): + intensity = None + if settings.detect_centre_of_intensity: + intensity = s_intensities[cell_id] + vol = len(cell_points) - x, y, z = get_structure_centre(cell_points) + x, y, z = get_structure_centre(cell_points, intensity) if vol < max_vol: tp = "maybe_cell" @@ -270,7 +281,9 @@ def dump_structures( struct_type_split[p[2], p[1], p[0]] = color if tp == "needs_split": - centers = split_cells(cell_points, settings=splitting_settings) + centers = split_cells( + cell_points, splitting_settings, intensity + ) for x, y, z in centers: x, y, z = map(int, [x, y, z]) if any(v < r1 for v in [x, y, z]): @@ -347,6 +360,7 @@ def run_filter( tile_processor: TileProcessor, ball_filter: BallFilter, cell_detector: CellDetector, + signal_array_np, signal_array, batch_size, ): @@ -358,6 +372,7 @@ def run_filter( for i in tqdm.tqdm(range(0, len(signal_array), batch_size)): batch = signal_array[i : i + batch_size] + batch_np = signal_array_np[i : i + batch_size] save_tiffs(output_root, "input", i, batch, n) batch_clipped = torch.clone(batch) @@ -373,10 +388,11 @@ def run_filter( save_tiffs(output_root, "inside_brain", i, inside_brain_tiles, n) save_tiffs(output_root, "filtered_2d", i, filtered_2d, n) - ball_filter.append(filtered_2d, inside_brain_tiles) + ball_filter.append(filtered_2d, inside_brain_tiles, batch_np) if ball_filter.ready: ball_filter.walk() middle_planes = ball_filter.get_processed_planes() + raw_planes = ball_filter.get_raw_planes() buff = middle_planes.copy() buff[buff != settings.soma_centre_value] = 0 save_tiffs( @@ -389,11 +405,11 @@ def run_filter( detection_middle_planes = detection_converter(middle_planes) - for k, (plane, detection_plane) in enumerate( - zip(middle_planes, detection_middle_planes) + for k, (plane, raw_plane, detection_plane) in enumerate( + zip(middle_planes, raw_planes, detection_middle_planes) ): previous_plane = cell_detector.process( - detection_plane, previous_plane + detection_plane, previous_plane, raw_plane ) save_tiffs( output_root, @@ -429,6 +445,7 @@ def run_filter( ball_overlap_fraction=0.8, log_sigma_size=0.35, n_sds_above_mean_thresh=1, + detect_centre_of_intensity=False, soma_spread_factor=4, max_cluster_size=10000, voxel_sizes=(4, 2.03, 2.03), diff --git a/cellfinder/core/detect/detect.py b/cellfinder/core/detect/detect.py index 82bc8512..13bb880a 100644 --- a/cellfinder/core/detect/detect.py +++ b/cellfinder/core/detect/detect.py @@ -58,6 +58,7 @@ def main( tiled_thresh_tile_size: float | None = None, *, callback: Optional[Callable[[int], None]] = None, + detect_centre_of_intensity: bool = False, ) -> List[Cell]: """ Perform cell candidate detection on a 3D signal array. @@ -155,6 +156,12 @@ def main( callback : Callable[int], optional A callback function that is called every time a plane has finished being processed. Called with the plane number that has finished. + detect_centre_of_intensity : bool + If False, a candidate cell's center is just the mean of the positions + of all voxels marked as above background, or bright, in that candidate. + The voxel intensity is not taken into account. If True, the center is + calculated similar to the center of mass, but using the intensity. So + the center gets pulled towards the brighter voxels in the volume. Returns ------- @@ -219,6 +226,7 @@ def main( torch_device=torch_device, pin_memory=pin_memory, n_splitting_iter=n_splitting_iter, + detect_centre_of_intensity=detect_centre_of_intensity, ) # replicate the settings specific to splitting, before we access anything diff --git a/cellfinder/core/detect/filters/plane/plane_filter.py b/cellfinder/core/detect/filters/plane/plane_filter.py index 29606666..fe0ab816 100644 --- a/cellfinder/core/detect/filters/plane/plane_filter.py +++ b/cellfinder/core/detect/filters/plane/plane_filter.py @@ -126,11 +126,13 @@ def get_tile_mask( Parameters ---------- planes : torch.Tensor - Input planes (z-stack). Note, the input data is modified. + Input planes (z-stack). Note, the input data is modified by being + clipped between zero and `clipping_value`. Otherwise, it is + unchanged. Returns ------- - planes : torch.Tensor + filtered_planes : torch.Tensor Filtered and thresholded planes (z-stack). inside_brain_tiles : torch.Tensor Boolean mask indicating which tiles are inside (1) or @@ -144,7 +146,7 @@ def get_tile_mask( # Threshold the image enhanced_planes = self.peak_enhancer.enhance_peaks(planes) - _threshold_planes( + filtered_planes = _threshold_planes( planes, enhanced_planes, self.n_sds_above_mean_thresh, @@ -154,7 +156,7 @@ def get_tile_mask( self.torch_device, ) - return planes, inside_brain_tiles + return filtered_planes, inside_brain_tiles def get_tiled_buffer(self, depth: int, device: str): return self.tile_walker.get_tiled_buffer(depth, device) @@ -169,11 +171,11 @@ def _threshold_planes( local_threshold_tile_size_px: int, threshold_value: int, torch_device: str, -) -> None: +) -> torch.Tensor: """ - Sets each plane (in-place) to threshold_value, where the corresponding - enhanced_plane > mean + n_sds_above_mean_thresh*std. Each plane will be - set to zero elsewhere. + Sets each pixel in the returned planes to threshold_value, where the + corresponding enhanced_plane > mean + n_sds_above_mean_thresh*std. + Each pixel will be set to zero elsewhere. Original `planes` is unchanged. """ z, y, x = enhanced_planes.shape @@ -257,10 +259,12 @@ def _threshold_planes( else: above = above_global - planes[above] = threshold_value # subsequent steps only care about the values that are set to threshold or # above in planes. We set values in *planes* to threshold based on the # value in *enhanced_planes*. So, there could be values in planes that are # at threshold already, but in enhanced_planes they are not. So it's best # to zero all other values, so voxels previously at threshold don't count - planes[torch.logical_not(above)] = 0 + filtered_planes = torch.zeros_like(planes) + filtered_planes[above] = threshold_value + + return filtered_planes diff --git a/cellfinder/core/detect/filters/setup_filters.py b/cellfinder/core/detect/filters/setup_filters.py index 0b75cb5f..cdd8d8c8 100644 --- a/cellfinder/core/detect/filters/setup_filters.py +++ b/cellfinder/core/detect/filters/setup_filters.py @@ -158,6 +158,18 @@ class DetectionSettings: done with 50% overlap when striding. """ + detect_centre_of_intensity: bool = False + """ + How to calculate the centre of a candidate cell, given a collection of + pixels marked as bright or above background. + + If False, a candidate cell's center is just the mean of the positions + of all voxels marked as above background, or bright, in that candidate. + The voxel intensity is not taken into account. If True, the center is + calculated similar to the center of mass, but using the intensity. So + the center gets pulled towards the brighter voxels in the volume. + """ + outlier_keep: bool = False """Whether to keep outlier structures during detection.""" diff --git a/cellfinder/core/detect/filters/volume/ball_filter.py b/cellfinder/core/detect/filters/volume/ball_filter.py index f36d2aef..96bb9fa0 100644 --- a/cellfinder/core/detect/filters/volume/ball_filter.py +++ b/cellfinder/core/detect/filters/volume/ball_filter.py @@ -185,6 +185,8 @@ def __init__( ) # Index of the middle plane in the volume self.middle_z_idx = int(np.floor(self.ball_z_size / 2)) + # the raw input planes, we track along with the filtered planes + self.raw_planes: list[np.ndarray] = [] if not use_mask: return @@ -236,25 +238,30 @@ def ready(self) -> bool: return self.volume.shape[0] >= self.kernel_z_size def append( - self, planes: torch.Tensor, masks: Optional[torch.Tensor] = None + self, + planes: torch.Tensor, + masks: Optional[torch.Tensor] = None, + raw_planes: Optional[np.ndarray] = None, ) -> None: """ Add a new z-stack to the filter. Previous stacks passed to `append` are removed, except enough planes at the top of the previous z-stack to provide padding so we can filter - starting from the first plane in `planes`. The first time we call - `append`, `first_valid_plane` is the first plane to actually be - filtered in the z-stack due to lack of padding. + starting from the first plane in `planes` (or continue filtering planes + at the end of the last stack). The first time we call `append`, + `first_valid_plane` is the first plane to actually be filtered in the + z-stack due to lack of padding before the zeroth plane. - So make sure to call `walk`/`get_processed_planes` before calling - `append` again. + So make sure to call `walk` and + `get_processed_planes`/`get_raw_planes` to get the results before + calling `append` again. Parameters ---------- planes : torch.Tensor - The z-stack. There can be one or more planes in the stack, but it - must have 3 dimensions. Input data is not modified. + The 2d filtered z-stack. There can be one or more planes in the + stack, but it must have 3 dimensions. Input data is not modified. masks : torch.Tensor A z-stack tile mask, indicating for each tile whether it's in or outside the brain. If the latter it's excluded. @@ -263,6 +270,12 @@ def append( parameter will be ignored. Input data is not modified. + raw_planes : np.ndarray or None + The original input data z-stack. There can be one or more planes in + the stack, but it must have 3 dimensions. Input data is not + modified. + + If provided, the planes can be gotten via `get_raw_planes`. """ if self.volume.shape[0]: if self.volume.shape[0] < self.kernel_z_size: @@ -277,6 +290,11 @@ def append( dim=0, ) + if raw_planes is not None: + self.raw_planes = self.raw_planes[remaining_start:] + list( + raw_planes + ) + if self.inside_brain_tiles is not None: self.inside_brain_tiles = torch.cat( [ @@ -287,13 +305,15 @@ def append( ) else: self.volume = planes.clone() + if raw_planes is not None: + self.raw_planes = list(raw_planes) if self.inside_brain_tiles is not None: self.inside_brain_tiles = masks.clone() def get_processed_planes(self) -> np.ndarray: """ After passing enough planes to `append`, and after `walk`, this returns - a copy of the processed planes as a numpy z-stack. + a copy of the processed planes as a single numpy z-stack. It only starts returning planes corresponding to plane `first_valid_plane` relative to the first planes passed to `append`. @@ -315,8 +335,30 @@ def get_processed_planes(self) -> np.ndarray: .numpy() .copy() ) + return planes + def get_raw_planes(self) -> list[np.ndarray]: + """ + Same as `get_processed_planes`, except that it returns the raw input + planes, corresponding to the planes in the z-stack of + `get_processed_planes` for the same batch. + + The planes are returned as a list of 2d numpy arrays, one per plane. + + This will only return valid results if the raw planes were provided + in *all* calls to `append`. + """ + if not self.ready: + raise TypeError("Not enough planes were appended") + + num_processed = self.volume.shape[0] - self.kernel_z_size + 1 + assert num_processed + middle = self.middle_z_idx + + raw_planes = self.raw_planes[middle : middle + num_processed] + return raw_planes + def walk(self) -> None: """ Applies the filter to all the planes passed to `append`. diff --git a/cellfinder/core/detect/filters/volume/structure_detection.py b/cellfinder/core/detect/filters/volume/structure_detection.py index bcd76074..4aef04ea 100644 --- a/cellfinder/core/detect/filters/volume/structure_detection.py +++ b/cellfinder/core/detect/filters/volume/structure_detection.py @@ -7,7 +7,6 @@ from numba import njit, objmode, typed from numba.core import types from numba.experimental import jitclass -from numba.types import DictType from cellfinder.core.tools.tools import get_max_possible_int_value @@ -57,44 +56,76 @@ def traverse_dict(d: Dict[T, T], a: T) -> T: @njit -def get_structure_centre(structure: np.ndarray) -> np.ndarray: +def get_structure_centre( + structure: np.ndarray, intensity: Optional[np.ndarray] = None +) -> np.ndarray: """ - Get the pixel coordinates of the centre of a structure. - - Centre calculated as the mean of each pixel coordinate, - rounded to the nearest integer. + Get the voxel coordinates of the centre of a structure. + + :param structure: A 2D of Nx3 array of the coordinates. + :param intensity: If provided, an 1D N array containing the intensity of + each corresponding voxel in `structure`. + :return: The center positions of the coordinates, rounded to the nearest + integer. If `intensity` is not provided or it's all zeros, centre is + calculated as the mean position for each dimension. Otherwise, we use + the center of mass calculation, but using the intensity. Pulling the + center towards the brighter coordinates. """ # numba support axis for sum, but not mean - return np.round(np.sum(structure, axis=0) / structure.shape[0]) + structure = structure.astype(np.float64) + if intensity is None: + return np.round(np.sum(structure, axis=0) / structure.shape[0]) + + intensity_f = intensity.astype(np.float64) + intensity_sum = np.sum(intensity_f) + # in case they are all zero + if np.isclose(intensity_sum, 0): + return np.round(np.sum(structure, axis=0) / structure.shape[0]) + + intensity_f = intensity_f / intensity_sum + # make it 2d + intensity_f = intensity_f[:, None] + pos = structure * intensity_f + return np.round(np.sum(pos, axis=0)) @njit -def _get_structure_centre(structure: types.ListType) -> np.ndarray: +def _get_structure_centre( + structure: types.ListType, use_centre_of_intensity: bool = False +) -> np.ndarray: + """Same as get_structure_centre, but for internal use.""" # See get_structure_centre. # this is for our own points stored as list optimized by numba - a_sum = 0.0 - b_sum = 0.0 - c_sum = 0.0 - for a, b, c in structure: - a_sum += a - b_sum += b - c_sum += c - - return np.round( - np.array( - [ - a_sum / len(structure), - b_sum / len(structure), - c_sum / len(structure), - ] - ) - ) + intensity_sum = 0.0 + if use_centre_of_intensity: + for i in range(len(structure)): + intensity_sum += structure[i][3] + + a_centre = 0.0 + b_centre = 0.0 + c_centre = 0.0 + # in case they are all zero + if np.isclose(intensity_sum, 0): + n = len(structure) + for a, b, c, _ in structure: + a_centre += a / n + b_centre += b / n + c_centre += c / n + else: + for a, b, c, intensity in structure: + intensity_frac = intensity / intensity_sum + a_centre += a * intensity_frac + b_centre += b * intensity_frac + c_centre += c * intensity_frac + + return np.round(np.array([a_centre, b_centre, c_centre])) # Type declaration has to come outside of the class, # see https://github.com/numba/numba/issues/8808 +# The point is the 3 coordinates, and the intensity tuple_point_type = types.Tuple( - (vol_numba_type, vol_numba_type, vol_numba_type) + (vol_numba_type, vol_numba_type, vol_numba_type, types.float64) ) list_of_points_type = types.ListType(tuple_point_type) @@ -104,8 +135,8 @@ def _get_structure_centre(structure: types.ListType) -> np.ndarray: ("next_structure_id", sid_numba_type), ("soma_centre_value", sid_numba_type), # as large as possible ("shape", types.UniTuple(vol_numba_type, 2)), - ("obsolete_ids", DictType(sid_numba_type, sid_numba_type)), - ("coords_maps", DictType(sid_numba_type, list_of_points_type)), + ("obsolete_ids", types.DictType(sid_numba_type, sid_numba_type)), + ("coords_maps", types.DictType(sid_numba_type, list_of_points_type)), ] @@ -131,8 +162,8 @@ class CellDetector: are scanned. coords_maps : Mapping from structure ID to the coordinates of pixels within that - structure. Coordinates are stored in a list of (x, y, z) tuples of - the coordinates. + structure. Coordinates are stored in a list of (x, y, z, intensity) + tuples of the coordinates. Use `get_structures` to get it as a dict whose values are each a 2D array, where rows are points, and columns x, y, z of the @@ -177,20 +208,46 @@ def _set_soma(self, soma_centre_value: sid_numba_type): self.soma_centre_value = soma_centre_value def process( - self, plane: np.ndarray, previous_plane: Optional[np.ndarray] + self, + plane: np.ndarray, + previous_plane: Optional[np.ndarray], + raw_plane: Optional[np.ndarray] = None, ) -> np.ndarray: """ - Process a new plane (should be in Y, X axis order). + Process the next plane in the z-stack. + + Parameters + ---------- + plane : np.ndarray + The 3D filtered plane (2D array) to be processed for structures. + It's in Y, X axis order. + previous_plane : np.ndarray or None + The plane returned in the last call to `process`, or None for + the first call. + raw_plane : np.ndarray or None + The original raw data before it was processed in the 2D/3D filters. + If provided, the intensity of the detected structures points will + be saved. Otherwise the intensity will be recorded as zero. + + Returns + ------- + plane : np.ndarray + Processed plane as described in `connect_four`. + + It is to be passed in the next iteration of `process`. """ if plane.shape[:2] != self.shape: raise ValueError("plane does not have correct shape") - plane = self.connect_four(plane, previous_plane) + plane = self.connect_four(plane, previous_plane, raw_plane) self.z += 1 return plane def connect_four( - self, plane: np.ndarray, previous_plane: Optional[np.ndarray] + self, + plane: np.ndarray, + previous_plane: Optional[np.ndarray], + raw_plane: Optional[np.ndarray] = None, ) -> np.ndarray: """ Perform structure labelling. @@ -201,6 +258,15 @@ def connect_four( found, they are added to the structure manager and the pixel labeled accordingly. + Parameters + ---------- + plane : np.ndarray + See `process`. + previous_plane : np.ndarray or None + See `process`. + raw_plane : np.ndarray or None + See `process`. + Returns ------- plane : @@ -229,7 +295,11 @@ def connect_four( ) neighbour_ids[0] = self.next_structure_id self.next_structure_id += 1 - struct_id = self.add(x, y, self.z, neighbour_ids) + + intensity = 0 if raw_plane is None else raw_plane[y, x] + struct_id = self.add( + x, y, self.z, neighbour_ids, intensity + ) else: # reset so that grayscale value does not count as # structure in next iterations @@ -239,16 +309,33 @@ def connect_four( return plane - def get_cell_centres(self) -> np.ndarray: + def get_cell_centres( + self, use_centre_of_intensity: bool = False + ) -> np.ndarray: """ - Returns the 2D array of cell centers. It's (N, 3) with X, Y, Z columns. + Returns the 2D Nx3 array of cell centers for all the detected + structures. + + Parameters + ---------- + use_centre_of_intensity: bool + If False, the centres are calculated as the mean position of all + coordinates for each dimension. Otherwise, we use the center of + mass calculation, but using the intensity. Pulling the center + towards the brighter coordinates. + + Returns + ------- + centres : np.ndarray + Nx3 array in X, Y, Z column order, rounded to nearest int. + """ - return self.structures_to_cells() + return self.structures_to_cells(use_centre_of_intensity) def get_structures(self) -> Dict[int, np.ndarray]: """ - Gets the structures as a dict of structure IDs mapped to the 2D array - of structure points (points vs x, y, z columns). + Gets the detected structures as a dict of structure IDs mapped to the + 2D array of structure points (Nx3 in x, y, z column order). """ d = {} for sid, points in self.coords_maps.items(): @@ -262,15 +349,42 @@ def get_structures(self) -> Dict[int, np.ndarray]: d[types.int64(sid)] = item for i, point in enumerate(points): - item[i, :] = point + item[i, :] = point[:3] + + return d + + def get_structures_intensities(self) -> Dict[int, np.ndarray]: + """ + Similar to `get_structures`, but instead of the dict values being the + coordinates of the structure, it's an array of size N with the + intensity of the corresponding coordinates. + + `get_structures_intensities[s][i]` is the corresponding intensity of + the coordinate `get_structures[s][i]` + """ + d = {} + for sid, points in self.coords_maps.items(): + # see get_structures + intensity = np.empty(len(points), dtype=np.float64) + d[types.int64(sid)] = intensity + + for i, point in enumerate(points): + intensity[i] = point[3] return d def add_point( - self, sid: int, point: Union[tuple, list, np.ndarray] + self, + sid: int, + point: Union[tuple, list, np.ndarray], + intensity: float = 0, ) -> None: """ - Add single 3d (x, y, z) *point* to the structure with the given *sid*. + Add single 3d (x, y, z) *point* with intensity `intensity` to the + structure with the given *sid*. + + If all points have the same intensity, computing their centre with and + without the center of intensity weighing leads to the same result. """ # cast in case user passes in int64 (default type for int in python) # and numba complains @@ -278,12 +392,31 @@ def add_point( if key not in self.coords_maps: self.coords_maps[key] = typed.List.empty_list(tuple_point_type) - self._add_point(key, (int(point[0]), int(point[1]), int(point[2]))) + self._add_point( + key, + ( + int(point[0]), + int(point[1]), + int(point[2]), + np.float64(intensity), + ), + ) - def add_points(self, sid: int, points: np.ndarray): + def add_points( + self, + sid: int, + points: np.ndarray, + intensity: Optional[np.ndarray] = None, + ): """ - Adds ndarray of *points* to the structure with the given *sid*. - Each row is a 3-column (x, y, z) point. + Adds Nx3 array of `points` with corresponding `intensity` array of size + N to the structure with the given *sid*. + + Each row in `points` is a 3-column (x, y, z) point. If intensity is + None they all default to intensity of zero. + + If all points have the same intensity, computing their centre with and + without the center of intensity weighing leads to the same result. """ # cast in case user passes in int64 (default type for int in python) # and numba complains @@ -293,17 +426,31 @@ def add_points(self, sid: int, points: np.ndarray): append = self.coords_maps[key].append pts = np.round(points).astype(vol_np_type) - for point in pts: - append((point[0], point[1], point[2])) + if intensity is None: + its = np.zeros(len(pts), dtype=np.float64) + else: + its = np.asarray(intensity, dtype=np.float64) + + for i in range(len(pts)): + x, y, z = pts[i, :] + append((x, y, z, its[i])) def _add_point( - self, sid: sid_numba_type, point: Tuple[int, int, int] + self, + sid: sid_numba_type, + point: Tuple[int, int, int, np.float64], ) -> None: + """Point is `(x, y, z, intensity).`""" # sid must exist self.coords_maps[sid].append(point) def add( - self, x: int, y: int, z: int, neighbour_ids: npt.NDArray[sid_np_type] + self, + x: int, + y: int, + z: int, + neighbour_ids: npt.NDArray[sid_np_type], + intensity: float = 0.0, ) -> sid_numba_type: """ For the current coordinates takes all the neighbours and find the @@ -311,9 +458,9 @@ def add( the neighbours recursively. Once the correct structure id is found, append a point with the - current coordinates to the coordinates map entry for the correct - structure. Hence each entry of the map will be a vector of all the - pertaining points. + current coordinates (and its input intensity) to the coordinates map + entry for the correct structure. Hence each entry of the map will be a + vector of all the pertaining points. """ updated_id = self.sanitise_ids(neighbour_ids) if updated_id not in self.coords_maps: @@ -323,7 +470,9 @@ def add( self.merge_structures(updated_id, neighbour_ids) # Add point for that structure - self._add_point(updated_id, (int(x), int(y), int(z))) + self._add_point( + updated_id, (int(x), int(y), int(z), np.float64(intensity)) + ) return updated_id def sanitise_ids( @@ -372,10 +521,30 @@ def merge_structures( self.coords_maps.pop(neighbour_id) self.obsolete_ids[neighbour_id] = updated_id - def structures_to_cells(self) -> np.ndarray: + def structures_to_cells( + self, use_centre_of_intensity: bool = False + ) -> np.ndarray: + """ + Returns the centre coordinates of all the structures. + + Parameters + ---------- + use_centre_of_intensity : bool + If False or if all coordinates in a structure have the same + intensity, centre is calculated as the mean position for each + dimension. Otherwise, we use the center of mass calculation, but + using the intensity. Pulling the center towards the brighter + coordinates. + + Returns + ------- + center : np.ndarray + A Nx3 array with the centre x, y, z position of the N structures + detected. + """ cell_centres = np.empty((len(self.coords_maps), 3)) for idx, structure in enumerate(self.coords_maps.values()): - p = _get_structure_centre(structure) + p = _get_structure_centre(structure, use_centre_of_intensity) cell_centres[idx] = p return cell_centres diff --git a/cellfinder/core/detect/filters/volume/structure_splitting.py b/cellfinder/core/detect/filters/volume/structure_splitting.py index 79956da6..4dcb3440 100644 --- a/cellfinder/core/detect/filters/volume/structure_splitting.py +++ b/cellfinder/core/detect/filters/volume/structure_splitting.py @@ -1,5 +1,5 @@ from copy import copy -from typing import List, Tuple, Type +from typing import List, Optional, Tuple, Type import numpy as np import torch @@ -36,11 +36,12 @@ def coords_to_volume( xs: np.ndarray, ys: np.ndarray, zs: np.ndarray, + intensity: Optional[np.ndarray], volume_shape: Tuple[int, int, int], ball_radius: int, dtype: Type[np.number], threshold_value: int, -) -> torch.Tensor: +) -> tuple[torch.Tensor, Optional[np.ndarray]]: """ Takes the series of x, y, z points along with the shape of the volume that fully enclose them (also x, y, z order). It than expands the @@ -48,6 +49,9 @@ def coords_to_volume( by the radius internally is set to the threshold value. The volume is then transposed and returned in the Z, Y, X order. + + If `intensity` is not None, it also returns a volume, with the intensity + of every point set to the corresponding intensity value. """ # it's faster doing the work in numpy and then returning as torch array, # than doing the work in torch @@ -56,6 +60,11 @@ def coords_to_volume( expanded_shape = [dim_size + ball_diameter for dim_size in volume_shape] # volume is now x, y, z order volume = np.zeros(expanded_shape, dtype=dtype) + # use largest type. These are small volumes and not processed much except + # to find center, so it's not a large memory/cpu cost + raw_volume = None + if intensity is not None: + raw_volume = np.zeros(volume.shape, dtype=np.float64) x_min, y_min, z_min = xs.min(), ys.min(), zs.min() @@ -67,13 +76,19 @@ def coords_to_volume( # set each point as the center with a value of threshold volume[relative_xs, relative_ys, relative_zs] = threshold_value + if intensity is not None: + raw_volume[relative_xs, relative_ys, relative_zs] = intensity volume = volume.swapaxes(0, 2) - return torch.from_numpy(volume) + if intensity is not None: + raw_volume = raw_volume.swapaxes(0, 2) + return torch.from_numpy(volume), raw_volume def ball_filter_imgs( - volume: torch.Tensor, settings: DetectionSettings + volume: torch.Tensor, + settings: DetectionSettings, + raw_volume: Optional[np.ndarray], ) -> np.ndarray: """ Apply ball filtering to a 3D volume and detect cell centres. @@ -81,12 +96,18 @@ def ball_filter_imgs( Uses the `BallFilter` class to perform ball filtering on the volume and the `CellDetector` class to detect cell centres. - Args: - volume (torch.Tensor): The 3D volume to be filtered (Z, Y, X order). - settings (DetectionSettings): - The settings to use. + Parameters + ---------- + volume : torch.Tensor + The 3D volume to be 3D filtered (Z, Y, X order). Edited in place. + settings : DetectionSettings + The settings to use. + raw_volume : np.ndarray or None + The original input data of the same shape as `volume`, if provided. - Returns: + Returns + ------- + centre : np.ndarray The 2D array of cell centres (N, 3) - X, Y, Z order. """ @@ -123,12 +144,16 @@ def ball_filter_imgs( previous_plane = None for z in range(0, volume.shape[0], batch_size): - bf.append(volume[z : z + batch_size, :, :]) + raw_planes_in = None + if raw_volume is not None: + raw_planes_in = raw_volume[z : z + batch_size, :, :] + bf.append(volume[z : z + batch_size, :, :], raw_planes=raw_planes_in) if bf.ready: bf.walk() middle_planes = bf.get_processed_planes() + raw_planes = None if raw_volume is None else bf.get_raw_planes() n = middle_planes.shape[0] # we edit volume, but only for planes already processed that won't @@ -140,34 +165,44 @@ def ball_filter_imgs( # convert to type needed for detection middle_planes = detection_convert(middle_planes) - for plane in middle_planes: - previous_plane = cell_detector.process(plane, previous_plane) + for i, plane in enumerate(middle_planes): + raw_plane = None if raw_volume is None else raw_planes[i] + previous_plane = cell_detector.process( + plane, previous_plane, raw_plane + ) - return cell_detector.get_cell_centres() + return cell_detector.get_cell_centres(settings.detect_centre_of_intensity) def iterative_ball_filter( - volume: torch.Tensor, settings: DetectionSettings + volume: torch.Tensor, + settings: DetectionSettings, + raw_volume: Optional[np.ndarray], ) -> Tuple[List[int], List[np.ndarray]]: """ Apply iterative ball filtering to the given volume. The volume is eroded at each iteration, by subtracting 1 from the volume. - Parameters: - volume (torch.Tensor): The input volume. It is edited inplace. - Of shape Z, Y, X. - settings (DetectionSettings): The settings to use. + Parameters + ---------- + volume : torch.Tensor + The input volume. It is edited inplace. Of shape Z, Y, X. + settings : DetectionSettings + The settings to use. + raw_volume : np.ndarray or None + The original input data of the same shape as `volume`, if provided. - Returns: - tuple: A tuple containing two lists: - The number of structures found in each iteration. - The cell centres found in each iteration. + Returns + ------- + tuple: A tuple containing two lists: + The number of structures found in each iteration. + The cell centres found in each iteration. """ ns = [] centres = [] for i in range(settings.n_splitting_iter): - cell_centres = ball_filter_imgs(volume, settings) + cell_centres = ball_filter_imgs(volume, settings, raw_volume) volume.sub_(1) n_structures = len(cell_centres) @@ -185,7 +220,6 @@ def check_centre_in_cuboid(centre: np.ndarray, max_coords: np.ndarray) -> bool: Parameters ---------- - centre : np.ndarray x, y, z coordinate. max_coords : np.ndarray @@ -207,26 +241,41 @@ def check_centre_in_cuboid(centre: np.ndarray, max_coords: np.ndarray) -> bool: def split_cells( - cell_points: np.ndarray, settings: DetectionSettings + cell_points: np.ndarray, + settings: DetectionSettings, + intensity: Optional[np.ndarray] = None, ) -> np.ndarray: """ - Split the given cell points into individual cell centres. - - Args: - cell_points (np.ndarray): Array of cell points with shape (N, 3), - where N is the number of cell points and each point is represented - by its x, y, and z coordinates. - settings (DetectionSettings) : The settings to use for splitting. It is - modified inplace. - - Returns: - np.ndarray: Array of absolute cell centres with shape (M, 3), - where M is the number of individual cells and each centre is - represented by its x, y, and z coordinates. + Split the given structure built from the given cell coordinates into + smaller structures with their own cell centres. + + Parameters + ---------- + cell_points : np.ndarray + Array of cell points with shape (N, 3), + where N is the number of cell points and each point is represented + by its x, y, and z coordinates. + settings : DetectionSettings + The settings to use for splitting. + intensity : np.ndarray or None + An array of size N with the intensity of each point. Needed for + computing the cell centre using the center of mass method, if selected. + + Returns + ------- + np.ndarray: Array of absolute cell centres with shape (M, 3), + where M is the number of individual cells and each centre is + represented by its x, y, and z coordinates. """ settings = copy(settings) + if settings.detect_centre_of_intensity and intensity is None: + raise ValueError( + "Using center of intensity, but intensity no provided" + ) + # these points are in x, y, z order columnwise, in absolute pixels - orig_centre = get_structure_centre(cell_points) + # get real unweighed center to start from + orig_centre = get_structure_centre(cell_points, intensity=None) xs = cell_points[:, 0] ys = cell_points[:, 1] @@ -252,10 +301,11 @@ def split_cells( # set both to float32) assert settings.filtering_dtype == settings.plane_original_np_dtype # volume will now be z, y, x order - vol = coords_to_volume( + vol, raw_vol = coords_to_volume( xs, ys, zs, + intensity, volume_shape=original_bounding_cuboid_shape, ball_radius=ball_radius, dtype=settings.filtering_dtype, @@ -279,7 +329,8 @@ def split_cells( # centres is a list of arrays of centres (1 array of centres per ball run) # in x, y, z order - ns, centres = iterative_ball_filter(vol, settings) + ns, centres = iterative_ball_filter(vol, settings, raw_vol) + # add original centre. That's valid, even if using centre of intensity ns.insert(0, 1) centres.insert(0, np.array([relative_orig_centre])) diff --git a/cellfinder/core/detect/filters/volume/volume_filter.py b/cellfinder/core/detect/filters/volume/volume_filter.py index 0ba8e816..19a1d29a 100644 --- a/cellfinder/core/detect/filters/volume/volume_filter.py +++ b/cellfinder/core/detect/filters/volume/volume_filter.py @@ -177,12 +177,12 @@ def _feed_signal_batches( # should only have 2d filter processors on the cpu assert bool(processors) == cpu - # seed the queue with tokens for the buffers + # seed our queue with tokens that give access to their buffers for token in range(len(buffers)): thread.send_msg_to_thread(token) for z in range(start_plane, end_plane, batch_size): - # convert the data to the right type + # convert the input data to the right type np_data = data_converter(data[z : z + batch_size, :, :]) # if we ran out of batches, we are done! n = np_data.shape[0] @@ -190,19 +190,28 @@ def _feed_signal_batches( # thread/underlying queues get first crack at msg. Unless we get # eof, this will block until a buffer is returned from the main - # thread for reuse + # thread for reuse, or give us a buffer if we have unused buffers token = thread.get_msg_from_mainthread() if token is EOFSignal: return - # buffer is free, get it from token + # buffer is free, get it from its token tensor, masks = buffers[token] # for last batch, it can be smaller than normal so only set up to n tensor[:n, :, :] = torch.from_numpy(np_data) tensor = tensor[:n, :, :] + if masks is not None: + masks = masks[:n] + # For CPU, we send the tensor token to the 2d filtering processes + # who have access to the tensors already. They overwrite the data + # in the tensor with the filtered data and let the main thread + # know. For GPU, we don't use these external processes, so we send + # the tensor directly, after moving it to the device. Either way, + # we don't reuse the tensor until the main thread let us know its + # free if not cpu: - # send to device - it won't block here because we pinned memory + # send to device - it won't block here if we pinned memory tensor = tensor.to(device=device, non_blocking=True) # if used, send each plane in batch to processor @@ -213,7 +222,7 @@ def _feed_signal_batches( process.send_msg_to_thread((token, i)) # tell the main thread to wait for processors (if used) - msg = token, tensor, masks, used_processors, n + msg = token, np_data, tensor, masks, used_processors, n if n < batch_size: # on last batch, we are also done after this @@ -326,27 +335,28 @@ def _process( # feeder thread exits at the end, causing a eof to be sent if msg is EOFSignal: break - token, tensor, masks, used_processors, n = msg - # this token is in use until we return it + token, np_data, tensor, masks, used_processors, n = msg + # this token is in use until we return it, meaning tensor is in use processing_tokens.append(token) if cpu: # we did 2d filtering in different process. Make sure all the - # planes are done filtering. Each msg from feeder thread has - # corresponding msg for each used processor (unless exception) + # planes in the tensor/masks are done filtering. Each msg from + # feeder thread has corresponding msg for each used processor + # (unless exception), which corresponds to a plane for process in used_processors: process.get_msg_from_thread() - # batch size can change at the end so resize buffer - planes = tensor[:n, :, :] - masks = masks[:n, :, :] + planes = tensor else: # we're not doing 2d filtering in different process planes, masks = tile_processor.get_tile_mask(tensor) - self.ball_filter.append(planes, masks) + self.ball_filter.append(planes, masks, np_data) if self.ball_filter.ready: self.ball_filter.walk() middle_planes = self.ball_filter.get_processed_planes() + # we always include the raw planes, even if we don't use it + raw_planes = self.ball_filter.get_raw_planes() # at this point we know input tensor can be reused - return # it so feeder thread can load more data into it @@ -364,7 +374,9 @@ def _process( if token is EOFSignal: break # send it more data and return the token - cells_thread.send_msg_to_thread((middle_planes, token)) + cells_thread.send_msg_to_thread( + (middle_planes, raw_planes, token) + ) @inference_wrapper def _run_filter_thread( @@ -404,18 +416,18 @@ def _run_filter_thread( # convert plane to the type needed by detection system # we should not need scaling because throughout # filtering we make sure result fits in this data type - middle_planes, token = msg + middle_planes, raw_planes, token = msg detection_middle_planes = detection_converter(middle_planes) logger.debug(f"🏫 Detecting structures for planes {self.z}+") - for plane, detection_plane in zip( - middle_planes, detection_middle_planes + for raw_plane, plane, detection_plane in zip( + raw_planes, middle_planes, detection_middle_planes ): if save_planes: self.save_plane(plane.astype(original_dtype)) previous_plane = detector.process( - detection_plane, previous_plane + detection_plane, previous_plane, raw_plane ) if callback is not None: @@ -452,26 +464,32 @@ def get_results(self, settings: DetectionSettings) -> List[Cell]: root_settings = self.settings max_cell_volume = settings.max_cell_volume + use_centre_of_intensity = settings.detect_centre_of_intensity # valid cells cells = [] # regions that must be split into cells needs_split = [] structures = self.cell_detector.get_structures().items() + intensities = self.cell_detector.get_structures_intensities() logger.debug(f"Processing {len(structures)} found cells") # first get all the cells that are not clusters for cell_id, cell_points in structures: + intensity = None + # if we don't use COI, don't pass on intensity. That'll disable it + if use_centre_of_intensity: + intensity = intensities[cell_id] cell_volume = len(cell_points) if cell_volume < max_cell_volume: - cell_centre = get_structure_centre(cell_points) + cell_centre = get_structure_centre(cell_points, intensity) cells.append(Cell(cell_centre.tolist(), Cell.UNKNOWN)) else: if cell_volume < settings.max_cluster_size: - needs_split.append((cell_id, cell_points)) + needs_split.append((cell_id, cell_points, intensity)) else: - cell_centre = get_structure_centre(cell_points) + cell_centre = get_structure_centre(cell_points, intensity) cells.append(Cell(cell_centre.tolist(), Cell.ARTIFACT)) if not needs_split: @@ -516,8 +534,8 @@ def _split_cells(arg, settings: DetectionSettings): # likely small and using multiple threads would cost more in overhead than # is worth. num threads can be set only at processes level. torch.set_num_threads(1) - cell_id, cell_points = arg + cell_id, cell_points, intensity = arg try: - return split_cells(cell_points, settings=settings) + return split_cells(cell_points, settings=settings, intensity=intensity) except (ValueError, AssertionError) as err: raise StructureSplitException(f"Cell {cell_id}, error; {err}") diff --git a/cellfinder/core/main.py b/cellfinder/core/main.py index c28281e2..90b424c2 100644 --- a/cellfinder/core/main.py +++ b/cellfinder/core/main.py @@ -49,6 +49,7 @@ def main( classify_callback: Optional[Callable[[int], None]] = None, detect_finished_callback: Optional[Callable[[list], None]] = None, classification_max_workers: int = 3, + detect_centre_of_intensity: bool = False, ) -> List[Cell]: """ Parameters @@ -183,6 +184,13 @@ def main( classification_max_workers : int The max number of sub-processes to use for data loading / processing during classification. Defaults to 3. + detect_centre_of_intensity : bool + If False, a candidate cell's center is just the mean of the positions + of all voxels marked as above background, or bright, in that candidate. + The voxel intensity is not taken into account. If True, the center is + calculated similar to the center of mass, but using the intensity. So + the center gets pulled towards the brighter voxels in the volume. + Defaults to False. """ from cellfinder.core.classify import classify from cellfinder.core.detect import detect @@ -215,6 +223,7 @@ def main( split_ball_xy_size=split_ball_xy_size, split_ball_overlap_fraction=split_ball_overlap_fraction, n_splitting_iter=n_splitting_iter, + detect_centre_of_intensity=detect_centre_of_intensity, ) if detect_finished_callback is not None: diff --git a/cellfinder/napari/detect/detect.py b/cellfinder/napari/detect/detect.py index 51ef67e4..4e6f1e31 100644 --- a/cellfinder/napari/detect/detect.py +++ b/cellfinder/napari/detect/detect.py @@ -256,6 +256,7 @@ def widget( ball_z_size: float, ball_overlap_fraction: float, detection_batch_size: int, + detect_centre_of_intensity: bool, soma_spread_factor: float, max_cluster_size: float, classification_options, @@ -324,6 +325,13 @@ def widget( for all the filters. For performance-critical applications, tune to maximize memory usage without running out. Check your GPU/CPU memory to verify it's not full + detect_centre_of_intensity : bool + If False, a candidate cell's center is just the mean of the + positions of all voxels marked as above background, or bright, in + that candidate. The voxel intensity is not taken into account. If + True, the center is calculated similar to the center of mass, but + using the intensity. So the center gets pulled towards the brighter + voxels in the volume. soma_spread_factor : float Cell spread factor for determining the largest cell volume before splitting up cell clusters. Structures with spherical volume of @@ -421,6 +429,7 @@ def widget( soma_spread_factor, max_cluster_size, detection_batch_size, + detect_centre_of_intensity, ) if use_pre_trained_weights: diff --git a/cellfinder/napari/detect/detect_containers.py b/cellfinder/napari/detect/detect_containers.py index ac6e76e0..6bdc027a 100644 --- a/cellfinder/napari/detect/detect_containers.py +++ b/cellfinder/napari/detect/detect_containers.py @@ -74,6 +74,7 @@ class DetectionInputs(InputContainer): soma_spread_factor: float = 1.4 max_cluster_size: float = 100000 detection_batch_size: int = 1 + detect_centre_of_intensity: bool = False def as_core_arguments(self) -> dict: return super().as_core_arguments() @@ -106,6 +107,10 @@ def widget_representation(cls) -> dict: "tiled_thresh_tile_size", custom_label="Thresholding tile size", ), + detect_centre_of_intensity=cls._custom_widget( + "detect_centre_of_intensity", + custom_label="Use centre of intensity", + ), soma_spread_factor=cls._custom_widget( "soma_spread_factor", custom_label="Split cell spread" ), diff --git a/tests/core/conftest.py b/tests/core/conftest.py index eb4d1e8a..c67f8839 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -68,7 +68,11 @@ def test_data_registry(): def mark_sphere( - data_zyx: np.ndarray, center_xyz, radius: int, fill_value: int + data_zyx: np.ndarray, + center_xyz, + radius: int, + center_fill_value: int, + outside_fill_value: int, ) -> None: shape_zyx = data_zyx.shape @@ -80,8 +84,17 @@ def mark_sphere( + (y - center_xyz[1]) ** 2 + (z - center_xyz[2]) ** 2 ) - # 100 seems to be the right size so std is not too small for filters - data_zyx[dist <= radius] = fill_value + inside = dist <= radius + + if center_fill_value == outside_fill_value: + data_zyx[inside] = center_fill_value + else: + diff = center_fill_value - outside_fill_value + assert diff > 0 + max_val = np.max(dist[inside]) + dist /= max_val + dist = (1 - dist) * diff + outside_fill_value + data_zyx[inside] = dist[inside] @pytest.fixture(scope="session") @@ -161,7 +174,7 @@ def synthetic_single_spot() -> ( signal_array = np.zeros(shape_zyx) background_array = np.zeros_like(signal_array) - mark_sphere(signal_array, center_xyz=c_xyz, radius=2, fill_value=100) + mark_sphere(signal_array, c_xyz, 2, 100, 100) # 1 std should be larger, so it can be considered bright assert np.mean(signal_array) + np.std(signal_array) > 1 @@ -195,13 +208,107 @@ def synthetic_spot_clusters() -> ( background_array = np.zeros_like(signal_array) for center in centers_xyz: - mark_sphere( - signal_array, center_xyz=center, radius=radius, fill_value=100 - ) + mark_sphere(signal_array, center, radius, 100, 100) return signal_array, background_array, centers_xyz +@pytest.fixture(scope="session") +def synthetic_intensity_dropoff_spot() -> tuple[ + np.ndarray, + np.ndarray, + tuple[int, int, int], + tuple[int, int, int], + tuple[int, int, int], +]: + """ + Creates a synthetic signal array with a single spherical spot + that slowly drops off in intensity in the x direction + as a 3d numpy array to be used for cell detection testing. + + The center value is bright, less bright at the edges of the + sphere, and darker further at some distance in the x direction. + The array is a floating type and you must convert it to the right + data type for your tests. Also, `n_sds_above_mean_thresh` must be 0 + to exclude any non-zero areas. + + It returns the signal and background np arrays and 3 x, y, z position + tuples `center`, `mid`, `end`. `center` is the center of the sphere. + `mid` is the mid-point of all non-zero voxels. `end` is the of the non-zero + voxels in the x-direction. I.e. centered in y, z. But in x it's where the + voxels are the least bright at the end of the falloff. + """ + # overall shape and center of sphere + shape_zyx = 20, 50, 50 + c_zyx = 10, 25, 15 + # radius of the sphere + r = 7 + # the dist from the center, along which the brightness will dropoff *after* + # the end of the sphere is reached in x + x_r = 22 + # brightness of sphere center + center_val = 1000 + # brightness of the sphere at its radius + bright_fill = 100 + # brightness of the end of falloff at x_r. + mute_fill = 10 + # center of all the non-zero voxels + c_overall_zyx = 10, 25, 15 - r + (r + x_r - 1) // 2 + # pos of the end of the falloff + end_zyx = 10, 25, 15 + x_r + + signal_array = np.zeros(shape_zyx) + background_array = np.zeros_like(signal_array) + # mark the sphere + mark_sphere(signal_array, c_zyx[::-1], r, center_val, bright_fill) + + # add the brightness dropoff. Start with overall grid + z, y, x = np.mgrid[ + 0 : shape_zyx[0] : 1, 0 : shape_zyx[1] : 1, 0 : shape_zyx[2] : 1 + ] + + # locate voxels only within the z and y radius + within_z_rad = np.abs(z - c_zyx[0]) <= r + within_y_rad = np.abs(y - c_zyx[1]) <= r + # and the voxels that is on the right half of the sphere, but not past the + # end of the dropoff area + within_x_pos_rad = np.logical_and(x - c_zyx[2] >= 0, x - c_zyx[2] <= x_r) + within_yz_rad = np.logical_and(within_y_rad, within_z_rad) + # voxels that are in z, and y radius, are positive x, but below dropoff end + within_xyz_rad = np.logical_and(within_yz_rad, within_x_pos_rad) + # only get the voxels in the above volume, but *outside* the marked sphere. + # These voxels will be updated with gradient dropoff + valid_and_outside_x_r = np.logical_and( + within_xyz_rad, np.logical_not(signal_array) + ) + + # mark the dropoff voxels in the range between bright_fill and mute_fill, + # with lower intensity further from the sphere center (outside the sphere) + dist = np.sqrt( + (x - c_zyx[2]) ** 2 + (y - c_zyx[1]) ** 2 + (z - c_zyx[0]) ** 2 + ) + r_dist = dist[c_zyx[0], c_zyx[1], c_zyx[2] + r] + # get the distance relative to dist at the sphere radius + dist -= r_dist + # the max distance at the dropoff + x_r_dist = dist[c_zyx[0], c_zyx[1], c_zyx[2] + x_r] + + # normalize ratio to 0-1 so we can subtract and go from bright_fill to + # mute_fill at the furthest dist. We use sqrt to make the brightness drop + # off faster + ratio = np.sqrt(np.abs(dist / x_r_dist)) + dist += bright_fill - ratio * (bright_fill - mute_fill) + signal_array[valid_and_outside_x_r] = dist[valid_and_outside_x_r] + + return ( + signal_array, + background_array, + c_zyx[::-1], + c_overall_zyx[::-1], + end_zyx[::-1], + ) + + @pytest.fixture(scope="session") def repo_data_path() -> Path: """ diff --git a/tests/core/test_integration/test_detection.py b/tests/core/test_integration/test_detection.py index 5e145875..6f6b790f 100644 --- a/tests/core/test_integration/test_detection.py +++ b/tests/core/test_integration/test_detection.py @@ -336,3 +336,72 @@ def test_detection_plane_too_small(synthetic_spot_clusters, y, x): voxel_sizes=(1, 1, 1), ball_xy_size=50, ) + + +@pytest.mark.parametrize("split", [True, False]) +@pytest.mark.parametrize("use_coi", [True, False]) +def test_center_of_intensity_gradient( + synthetic_intensity_dropoff_spot, no_free_cpus, use_coi, split +): + """Checks that using the sphere with dropoff, its center is properly found + when using the center of intensity calculation, but not otherwise - where + we just get the overall center of all non-zero values. + + When splitting, the overall bright areas are split in two in either case, + so we verify it. + """ + signal_array, background_array, c_xyz, c_overall_xyz, end_xyz = ( + synthetic_intensity_dropoff_spot + ) + signal_array = signal_array.astype(np.uint16) + background_array = background_array.astype(np.uint16) + + detected = main( + signal_array, + background_array, + soma_diameter=8, + ball_xy_size=8, + ball_z_size=8, + ball_overlap_fraction=0.7, + log_sigma_size=0.8, + n_sds_above_mean_thresh=0, + soma_spread_factor=0.5 if split else 10, + n_splitting_iter=3, + split_ball_z_size=8, + split_ball_overlap_fraction=0.7, + voxel_sizes=voxel_sizes, + n_free_cpus=no_free_cpus, + skip_classification=True, + detect_centre_of_intensity=use_coi, + ) + + if split: + # if splitting, we get two volumes with minimal splitting + _, y, z = c_xyz + assert len(detected) == 2 + for cell in detected: + assert abs(y - cell.y) <= 2 + assert abs(z - cell.z) <= 2 + + # we get points centered in y and z, but on the two end in x + cx0, cx1 = [c.x for c in detected] + assert cx0 != cx1 + if cx0 > cx1: + cx0, cx1 = cx1, cx0 + + assert abs(c_xyz[0] - cx0) <= 6 + assert abs(end_xyz[0] - cx1) <= 6 + + else: + # if not splitting, we get a single cell with center depending on coi + if use_coi: + c = c_xyz + else: + c = c_overall_xyz + x, y, z = c + + assert len(detected) == 1 + cell = detected[0] + assert abs(x - cell.x) <= 2 + assert abs(y - cell.y) <= 1 + assert abs(z - cell.z) <= 1 diff --git a/tests/core/test_integration/test_detection_structure_splitting.py b/tests/core/test_integration/test_detection_structure_splitting.py index 5fe74a77..fb79c4f2 100644 --- a/tests/core/test_integration/test_detection_structure_splitting.py +++ b/tests/core/test_integration/test_detection_structure_splitting.py @@ -47,7 +47,7 @@ def background_array(repo_data_path): ) -def test_structure_splitting(signal_array, background_array): +def test_structure_splitting(signal_array, background_array, no_free_cpus): """ Smoke test to ensure structure splitting code doesn't break. """ @@ -55,7 +55,7 @@ def test_structure_splitting(signal_array, background_array): signal_array, background_array, voxel_sizes, - n_free_cpus=0, + n_free_cpus=no_free_cpus, ) diff --git a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_ball_filter.py b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_ball_filter.py index 34ef4775..47dfefd8 100644 --- a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_ball_filter.py +++ b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_ball_filter.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch @@ -25,6 +26,9 @@ def test_filter_not_ready(): with pytest.raises(TypeError): bf.get_processed_planes() + with pytest.raises(TypeError): + bf.get_raw_planes() + with pytest.raises(TypeError): bf.walk() @@ -54,19 +58,26 @@ def test_filtered_planes(kernel_size, batch_size): kwargs["ball_z_size"] = kernel_size bf = BallFilter(**kwargs, use_mask=False) - data = torch.empty( - (batch_size, kwargs["plane_height"], kwargs["plane_width"]), - dtype=getattr(torch, kwargs["dtype"]), - device=kwargs["torch_device"], - ) - num_planes = 20 + n_batches = num_planes // batch_size + total_planes = n_batches * batch_size sent_planes = 0 gotten_planes = 0 num_padded_planes = kernel_size - 1 - for _ in range(num_planes // batch_size): - bf.append(data) + h, w = kwargs["plane_height"], kwargs["plane_width"] + data = torch.arange(total_planes * h * w).reshape((total_planes, h, w)) + data = data.to( + dtype=getattr(torch, kwargs["dtype"]), device=kwargs["torch_device"] + ) + data_np = data.numpy() + + all_raw_planes = [] + for i in range(n_batches): + bf.append( + data[i * batch_size : (i + 1) * batch_size], + raw_planes=data_np[i * batch_size : (i + 1) * batch_size], + ) sent_planes += batch_size # volume only includes batch and some padding from end of last batch assert bf.volume.shape[0] <= batch_size + kernel_size - 1 @@ -75,13 +86,25 @@ def test_filtered_planes(kernel_size, batch_size): # no need to walk because walking only modifies the contents not # size of volume planes = bf.get_processed_planes() + raw_planes = bf.get_raw_planes() + all_raw_planes.extend(raw_planes) # first batch is 1 or batch minus padding. Remaining is batch size assert planes.shape[0] in ( 1, batch_size, batch_size - num_padded_planes, ) + assert len(raw_planes) == planes.shape[0] + + for raw_plane in raw_planes: + assert raw_plane.shape == planes.shape[1:] gotten_planes += planes.shape[0] assert gotten_planes == sent_planes - num_padded_planes + all_raw_planes_np = np.stack(all_raw_planes, axis=0) + p1 = bf.first_valid_plane + data_np_unpadded = data_np[p1 : total_planes - (num_padded_planes - p1)] + + assert data_np_unpadded.shape == all_raw_planes_np.shape + assert np.array_equal(data_np_unpadded, all_raw_planes_np) From f1de7bb76dd0e2115d0120a88ddcfa6bee7fd442 Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Thu, 12 Feb 2026 21:58:32 -0500 Subject: [PATCH 2/5] Correct splitting coi calc and add improved test. --- .../filters/volume/structure_splitting.py | 4 +- tests/core/conftest.py | 53 ++++++-- tests/core/test_integration/test_detection.py | 115 ++++++++++++------ 3 files changed, 126 insertions(+), 46 deletions(-) diff --git a/cellfinder/core/detect/filters/volume/structure_splitting.py b/cellfinder/core/detect/filters/volume/structure_splitting.py index 4dcb3440..d95a6a27 100644 --- a/cellfinder/core/detect/filters/volume/structure_splitting.py +++ b/cellfinder/core/detect/filters/volume/structure_splitting.py @@ -274,8 +274,8 @@ def split_cells( ) # these points are in x, y, z order columnwise, in absolute pixels - # get real unweighed center to start from - orig_centre = get_structure_centre(cell_points, intensity=None) + # get center to start from in case we find no split points + orig_centre = get_structure_centre(cell_points, intensity=intensity) xs = cell_points[:, 0] ys = cell_points[:, 1] diff --git a/tests/core/conftest.py b/tests/core/conftest.py index c67f8839..00bcf2b1 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -213,8 +213,7 @@ def synthetic_spot_clusters() -> ( return signal_array, background_array, centers_xyz -@pytest.fixture(scope="session") -def synthetic_intensity_dropoff_spot() -> tuple[ +def make_intensity_comet_spot(linear: bool) -> tuple[ np.ndarray, np.ndarray, tuple[int, int, int], @@ -234,9 +233,17 @@ def synthetic_intensity_dropoff_spot() -> tuple[ It returns the signal and background np arrays and 3 x, y, z position tuples `center`, `mid`, `end`. `center` is the center of the sphere. - `mid` is the mid-point of all non-zero voxels. `end` is the of the non-zero - voxels in the x-direction. I.e. centered in y, z. But in x it's where the - voxels are the least bright at the end of the falloff. + `mid` is the mid-point of all non-zero voxels. `end` is the end of the + non-zero voxels in the x-direction. I.e. centered in y, z. But in x it's + where the voxels are the least bright at the end of the falloff. + + If linear, the falloff happens linearly across the whole non-zero voxels, + in particular across the tail relative to the center of the sphere. + Otherwise, it drops in 2 segments, once within the sphere, and once in the + comet's tail, each with a different slope and with the tail being much less + bright. + + Original PR has visual illustrations of the comet. """ # overall shape and center of sphere shape_zyx = 20, 50, 50 @@ -249,9 +256,9 @@ def synthetic_intensity_dropoff_spot() -> tuple[ # brightness of sphere center center_val = 1000 # brightness of the sphere at its radius - bright_fill = 100 + bright_fill = 700 if linear else 100 # brightness of the end of falloff at x_r. - mute_fill = 10 + mute_fill = 50 if linear else 10 # center of all the non-zero voxels c_overall_zyx = 10, 25, 15 - r + (r + x_r - 1) // 2 # pos of the end of the falloff @@ -309,6 +316,38 @@ def synthetic_intensity_dropoff_spot() -> tuple[ ) +@pytest.fixture(scope="session") +def synthetic_intensity_comet_spot() -> tuple[ + np.ndarray, + np.ndarray, + tuple[int, int, int], + tuple[int, int, int], + tuple[int, int, int], +]: + """ + Creates a comet shaped volume where there's a sphere with a tail that is + much less bright than the sphere. See make_intensity_comet_spot for + details. + """ + return make_intensity_comet_spot(linear=False) + + +@pytest.fixture(scope="session") +def synthetic_linear_intensity_comet_spot() -> tuple[ + np.ndarray, + np.ndarray, + tuple[int, int, int], + tuple[int, int, int], + tuple[int, int, int], +]: + """ + Creates a comet shaped volume where there's a sphere with a tail whose + intensity drops off linearly from the center of the sphere. See + make_intensity_comet_spot for details. + """ + return make_intensity_comet_spot(linear=True) + + @pytest.fixture(scope="session") def repo_data_path() -> Path: """ diff --git a/tests/core/test_integration/test_detection.py b/tests/core/test_integration/test_detection.py index 6f6b790f..0b9b5a4d 100644 --- a/tests/core/test_integration/test_detection.py +++ b/tests/core/test_integration/test_detection.py @@ -338,20 +338,16 @@ def test_detection_plane_too_small(synthetic_spot_clusters, y, x): ) -@pytest.mark.parametrize("split", [True, False]) @pytest.mark.parametrize("use_coi", [True, False]) -def test_center_of_intensity_gradient( - synthetic_intensity_dropoff_spot, no_free_cpus, use_coi, split +def test_center_of_intensity_comet_spot( + synthetic_intensity_comet_spot, no_free_cpus, use_coi ): """Checks that using the sphere with dropoff, its center is properly found when using the center of intensity calculation, but not otherwise - where we just get the overall center of all non-zero values. - - When splitting, the overall bright areas are split in two in either case, - so we verify it. """ signal_array, background_array, c_xyz, c_overall_xyz, end_xyz = ( - synthetic_intensity_dropoff_spot + synthetic_intensity_comet_spot ) signal_array = signal_array.astype(np.uint16) background_array = background_array.astype(np.uint16) @@ -365,43 +361,88 @@ def test_center_of_intensity_gradient( ball_overlap_fraction=0.7, log_sigma_size=0.8, n_sds_above_mean_thresh=0, - soma_spread_factor=0.5 if split else 10, - n_splitting_iter=3, - split_ball_z_size=8, - split_ball_overlap_fraction=0.7, + # 10 ensures we don't split because it's under split limit + soma_spread_factor=10, voxel_sizes=voxel_sizes, n_free_cpus=no_free_cpus, skip_classification=True, detect_centre_of_intensity=use_coi, ) - if split: - # if splitting, we get two volumes with minimal splitting - _, y, z = c_xyz - assert len(detected) == 2 - for cell in detected: - assert abs(y - cell.y) <= 2 - assert abs(z - cell.z) <= 2 + if use_coi: + # using coi the cell should be close to sphere center + c = c_xyz + else: + # without coi, the cell is near the overall non-zero volume center + c = c_overall_xyz + x, y, z = c + + assert len(detected) == 1 + cell = detected[0] + assert abs(x - cell.x) <= 2 + assert abs(y - cell.y) <= 1 + assert abs(z - cell.z) <= 1 - # we get points centered in y and z, but on the two end in x - cx0, cx1 = [c.x for c in detected] - assert cx0 != cx1 - if cx0 > cx1: - cx0, cx1 = cx1, cx0 - assert abs(c_xyz[0] - cx0) <= 6 - assert abs(end_xyz[0] - cx1) <= 6 +@pytest.mark.parametrize("use_coi", [True, False]) +def synthetic_center_of_intensity_linear_intensity_comet_spot( + synthetic_intensity_comet_spot, no_free_cpus, use_coi, split +): + """Checks that when splitting cell clusters, using the sphere with dropoff, + its center is properly found when using the center of intensity + calculation, but not otherwise - where we just get the overall center of + all non-zero values. + + We have to use a more linear comet because otherwise the 2d/3d filtering + hollows out the original filtered volume creating multiple spots during + splitting, but we don't actually want to split. Using a linear comet keeps + the "detected" volume intact so splitting leaves it also as a single spot + so we can test where it puts the cell. + """ + signal_array, background_array, c_xyz, c_overall_xyz, end_xyz = ( + synthetic_intensity_comet_spot + ) + signal_array = signal_array.astype(np.uint16) + background_array = background_array.astype(np.uint16) + + detected = main( + signal_array, + background_array, + soma_diameter=8, + ball_xy_size=8, + ball_z_size=8, + ball_overlap_fraction=0.7, + log_sigma_size=0.8, + n_sds_above_mean_thresh=0, + # 0.5 ensures we split because cell volume is over the split limit + soma_spread_factor=0.5, + max_cluster_size=100000, + split_ball_xy_size=6, + split_ball_z_size=8, + split_ball_overlap_fraction=0.2, + n_splitting_iter=1, + voxel_sizes=voxel_sizes, + n_free_cpus=no_free_cpus, + skip_classification=True, + detect_centre_of_intensity=use_coi, + ) + grace = 2 + if use_coi: + # using coi the cell should be close to sphere center + c = c_xyz + # using the linear comet, the center needs more grace because the tail + # weighs more and pulls the cell towards it + grace = 4 + # make sure we don't put the overall center within grace amount + assert c_xyz[0] + grace < c_overall_xyz[0] else: - # if not splitting, we get a single cell with center depending on coi - if use_coi: - c = c_xyz - else: - c = c_overall_xyz - x, y, z = c - - assert len(detected) == 1 - cell = detected[0] - assert abs(x - cell.x) <= 2 - assert abs(y - cell.y) <= 1 - assert abs(z - cell.z) <= 1 + # without coi, the cell is near the overall non-zero volume center + c = c_overall_xyz + x, y, z = c + + assert len(detected) == 1 + cell = detected[0] + assert abs(x - cell.x) <= grace + assert abs(y - cell.y) <= 1 + assert abs(z - cell.z) <= 1 From 325997be28a94509619a35cbfc56cab0bea572ff Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Thu, 12 Feb 2026 22:34:33 -0500 Subject: [PATCH 3/5] Round out tests for coverage. --- .../test_detection_structure_splitting.py | 37 +++++++++++++++++ .../test_structure_detection.py | 41 ++++++++++++++++--- 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/tests/core/test_integration/test_detection_structure_splitting.py b/tests/core/test_integration/test_detection_structure_splitting.py index fb79c4f2..b784bd9b 100644 --- a/tests/core/test_integration/test_detection_structure_splitting.py +++ b/tests/core/test_integration/test_detection_structure_splitting.py @@ -8,10 +8,13 @@ import numpy as np import pytest +import torch from brainglobe_utils.IO.image.load import read_with_dask from cellfinder.core.detect.filters.setup_filters import DetectionSettings from cellfinder.core.detect.filters.volume.structure_splitting import ( + ball_filter_imgs, + check_centre_in_cuboid, split_cells, ) from cellfinder.core.main import main @@ -97,3 +100,37 @@ def test_underflow_issue_435(): expected = {(10, 11, 11), (20, 11, 11)} got = set(map(tuple, centers.tolist())) assert expected == got + + +def test_ball_filter_imgs_invalid_volume(): + """Checks that an invalid volume returns empty array instead.""" + settings = DetectionSettings( + plane_shape=(100, 30), + plane_original_np_dtype=np.float32, + voxel_sizes=(1, 1, 1), + ball_xy_size_um=50, + ) + + vol = ball_filter_imgs(torch.zeros((5, 100, 30)), settings, None) + assert not vol.shape[0] + + +@pytest.mark.parametrize("inside", [True, False]) +def test_check_centre_in_cuboid(inside): + corner = np.array([5, 5, 5]) + if inside: + assert check_centre_in_cuboid(np.array([2, 2, 2]), corner) + else: + assert not check_centre_in_cuboid(np.array([8, 8, 8]), corner) + + +def test_using_coi_without_intensity(): + cell_points = np.zeros((30, 20, 20), dtype=np.bool_) + settings = DetectionSettings( + plane_shape=(100, 100), + plane_original_np_dtype=np.float32, + detect_centre_of_intensity=True, + ) + + with pytest.raises(ValueError): + split_cells(cell_points, settings, None) diff --git a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py index f38656be..cdbb243b 100644 --- a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py +++ b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py @@ -57,8 +57,17 @@ def structure(three_d_cross: np.ndarray) -> np.ndarray: return coords -def test_get_structure_centre(structure: np.ndarray) -> None: - result_point = get_structure_centre(structure) +@pytest.mark.parametrize("use_intensity", [True, False]) +def test_get_structure_centre(structure: np.ndarray, use_intensity) -> None: + """ + Check that get_structure_centre works and that it works the same if we + provide it a zeroed intensity array. + """ + intensity = None + if use_intensity: + intensity = np.zeros(len(structure)) + + result_point = get_structure_centre(structure, intensity) assert (result_point[0], result_point[1], result_point[2]) == ( 1, 1, @@ -224,15 +233,35 @@ def test_add_point(): detector.add_point(1, (7, 5, 5)) -def test_add_points(): +@pytest.mark.parametrize("use_intensity", [True, False]) +def test_add_points(use_intensity): detector = CellDetector(50, 50, 0, 0) points = np.array([(5, 5, 5), (6, 6, 6)], dtype=np.uint32) points2 = np.array([(7, 5, 5), (8, 6, 6)], dtype=np.uint32) points3 = np.array([(8, 5, 5), (8, 6, 6)], dtype=np.uint32) - detector.add_points(0, points) - detector.add_points(0, points2) - detector.add_points(1, points3) + + intensity = intensity2 = intensity3 = None + if use_intensity: + intensity = np.ones(2) + intensity2 = np.ones(2) * 2 + intensity3 = np.ones(2) * 3 + + detector.add_points(0, points, intensity) + detector.add_points(0, points2, intensity2) + detector.add_points(1, points3, intensity3) + + structures = detector.get_structures() + assert np.all(np.concatenate([points, points2], axis=0) == structures[0]) + assert np.all(points3 == structures[1]) + + intensities = detector.get_structures_intensities() + if use_intensity: + assert np.all(np.array([1, 1, 2, 2]) == intensities[0]) + assert np.all(3 == intensities[1]) + else: + assert np.all(0 == intensities[0]) + assert np.all(0 == intensities[1]) def test_change_plane_size(): From 246acfee208031e024a29329d8ceac742cffdfee Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Fri, 20 Mar 2026 14:18:18 -0400 Subject: [PATCH 4/5] split_cells should return the structs of the cells as well. Fix test. --- .../filters/volume/structure_detection.py | 5 + .../filters/volume/structure_splitting.py | 113 +++++++++--------- .../detect/filters/volume/volume_filter.py | 4 +- .../test_detection_structure_splitting.py | 8 +- 4 files changed, 68 insertions(+), 62 deletions(-) diff --git a/cellfinder/core/detect/filters/volume/structure_detection.py b/cellfinder/core/detect/filters/volume/structure_detection.py index 4aef04ea..763810be 100644 --- a/cellfinder/core/detect/filters/volume/structure_detection.py +++ b/cellfinder/core/detect/filters/volume/structure_detection.py @@ -309,6 +309,11 @@ def connect_four( return plane + @property + def n_structures(self) -> int: + """The number of structures detected.""" + return len(self.coords_maps) + def get_cell_centres( self, use_centre_of_intensity: bool = False ) -> np.ndarray: diff --git a/cellfinder/core/detect/filters/volume/structure_splitting.py b/cellfinder/core/detect/filters/volume/structure_splitting.py index d95a6a27..6f945ba9 100644 --- a/cellfinder/core/detect/filters/volume/structure_splitting.py +++ b/cellfinder/core/detect/filters/volume/structure_splitting.py @@ -89,12 +89,13 @@ def ball_filter_imgs( volume: torch.Tensor, settings: DetectionSettings, raw_volume: Optional[np.ndarray], -) -> np.ndarray: +) -> CellDetector: """ - Apply ball filtering to a 3D volume and detect cell centres. + Apply ball filtering to a 3D volume and detect new structures from the + original single structure. Uses the `BallFilter` class to perform ball filtering on the volume - and the `CellDetector` class to detect cell centres. + and the `CellDetector` class to detect the new cells. Parameters ---------- @@ -107,8 +108,8 @@ def ball_filter_imgs( Returns ------- - centre : np.ndarray - The 2D array of cell centres (N, 3) - X, Y, Z order. + cell_detector : CellDetector + The `CellDetector` that tracks the newly detected structures. """ detection_convert = settings.detection_data_converter_func @@ -132,7 +133,12 @@ def ball_filter_imgs( use_mask=False, # we don't need a mask here ) except InvalidVolume: - return np.empty((0, 3)) + return CellDetector( + settings.plane_height, + settings.plane_width, + start_z=0, + soma_centre_value=settings.detection_soma_centre_value, + ) start_z = bf.first_valid_plane cell_detector = CellDetector( @@ -171,14 +177,14 @@ def ball_filter_imgs( plane, previous_plane, raw_plane ) - return cell_detector.get_cell_centres(settings.detect_centre_of_intensity) + return cell_detector def iterative_ball_filter( volume: torch.Tensor, settings: DetectionSettings, raw_volume: Optional[np.ndarray], -) -> Tuple[List[int], List[np.ndarray]]: +) -> List[CellDetector]: """ Apply iterative ball filtering to the given volume. The volume is eroded at each iteration, by subtracting 1 from the volume. @@ -194,24 +200,21 @@ def iterative_ball_filter( Returns ------- - tuple: A tuple containing two lists: - The number of structures found in each iteration. - The cell centres found in each iteration. + cell_detectors: List of CellDetector. + A list of CellDetector instances, each corresponding to the result + of one iteration, in that order. """ - ns = [] - centres = [] + cell_detectors = [] for i in range(settings.n_splitting_iter): - cell_centres = ball_filter_imgs(volume, settings, raw_volume) + cell_detector = ball_filter_imgs(volume, settings, raw_volume) volume.sub_(1) - n_structures = len(cell_centres) - ns.append(n_structures) - centres.append(cell_centres) - if n_structures == 0: + cell_detectors.append(cell_detector) + if not cell_detector.n_structures: break - return ns, centres + return cell_detectors def check_centre_in_cuboid(centre: np.ndarray, max_coords: np.ndarray) -> bool: @@ -244,7 +247,7 @@ def split_cells( cell_points: np.ndarray, settings: DetectionSettings, intensity: Optional[np.ndarray] = None, -) -> np.ndarray: +) -> tuple[np.ndarray, tuple[CellDetector, np.ndarray] | None]: """ Split the given structure built from the given cell coordinates into smaller structures with their own cell centres. @@ -263,9 +266,17 @@ def split_cells( Returns ------- - np.ndarray: Array of absolute cell centres with shape (M, 3), - where M is the number of individual cells and each centre is - represented by its x, y, and z coordinates. + centres, (cell_detector, offset): A 2-tuple of, + np.ndarray: Array of absolute cell centres with shape (M, 3), + where M is the number of individual cells and each centre is + represented by its x, y, and z coordinates. + (CellDetector, np.ndarray) or None: If None, then we didn't find any + better cell candidates during splitting than the original single cell. + Otherwise, it's the `CellDetector` with the structs from the best + iteration and a size 3 `np.ndarray` with the offset of the structs in + the cell detector. I.e. the cell detector uses coordinates relative to + the size of the cube containing the input voxels so the offset must be + added to convert to absolute voxel indices. """ settings = copy(settings) if settings.detect_centre_of_intensity and intensity is None: @@ -281,17 +292,6 @@ def split_cells( ys = cell_points[:, 1] zs = cell_points[:, 2] - # corner coordinates in absolute pixels - orig_corner = np.array([xs.min(), ys.min(), zs.min()]) - # volume center relative to corner - relative_orig_centre = np.array( - [ - orig_centre[0] - orig_corner[0], - orig_centre[1] - orig_corner[1], - orig_centre[2] - orig_corner[2], - ] - ) - # total volume enclosing all points original_bounding_cuboid_shape = get_shape(xs, ys, zs) @@ -329,30 +329,27 @@ def split_cells( # centres is a list of arrays of centres (1 array of centres per ball run) # in x, y, z order - ns, centres = iterative_ball_filter(vol, settings, raw_vol) + cell_detectors = iterative_ball_filter(vol, settings, raw_vol) + struct_counts = [d.n_structures for d in cell_detectors] + + # if best split only resulted in one (or no) struct, stick with original + if not struct_counts or max(struct_counts) <= 1: + return orig_centre, None + + best_iteration = struct_counts.index(max(struct_counts)) + cell_detector = cell_detectors[best_iteration] + relative_centres = cell_detector.get_cell_centres( + settings.detect_centre_of_intensity + ) # add original centre. That's valid, even if using centre of intensity - ns.insert(0, 1) - centres.insert(0, np.array([relative_orig_centre])) - - best_iteration = ns.index(max(ns)) - # TODO: put constraint on minimum centres distance ? - relative_centres = centres[best_iteration] - - if not settings.outlier_keep: - # TODO: change to checking whether in original cluster shape - original_max_coords = np.array(original_bounding_cuboid_shape) - relative_centres = np.array( - [ - x - for x in relative_centres - if check_centre_in_cuboid(x, original_max_coords) - ] - ) - absolute_centres = np.empty((len(relative_centres), 3)) - # convert centers to absolute pixels - absolute_centres[:, 0] = orig_corner[0] + relative_centres[:, 0] - absolute_centres[:, 1] = orig_corner[1] + relative_centres[:, 1] - absolute_centres[:, 2] = orig_corner[2] + relative_centres[:, 2] + original_max_coords = np.array(original_bounding_cuboid_shape) + for x in relative_centres: + assert (x < original_max_coords).all() + + # corner coordinates in absolute pixels + orig_corner = np.array([xs.min(), ys.min(), zs.min()]) + # convert centers to absolute voxels in original vol + absolute_centres = relative_centres + orig_corner[None, :] - return absolute_centres + return absolute_centres, (cell_detector, orig_corner) diff --git a/cellfinder/core/detect/filters/volume/volume_filter.py b/cellfinder/core/detect/filters/volume/volume_filter.py index 19a1d29a..bbcdc1ea 100644 --- a/cellfinder/core/detect/filters/volume/volume_filter.py +++ b/cellfinder/core/detect/filters/volume/volume_filter.py @@ -536,6 +536,8 @@ def _split_cells(arg, settings: DetectionSettings): torch.set_num_threads(1) cell_id, cell_points, intensity = arg try: - return split_cells(cell_points, settings=settings, intensity=intensity) + return split_cells( + cell_points, settings=settings, intensity=intensity + )[0] except (ValueError, AssertionError) as err: raise StructureSplitException(f"Cell {cell_id}, error; {err}") diff --git a/tests/core/test_integration/test_detection_structure_splitting.py b/tests/core/test_integration/test_detection_structure_splitting.py index b784bd9b..3b8f9400 100644 --- a/tests/core/test_integration/test_detection_structure_splitting.py +++ b/tests/core/test_integration/test_detection_structure_splitting.py @@ -94,13 +94,15 @@ def test_underflow_issue_435(): ball_overlap_fraction=0.8, soma_diameter_um=7, ) - centers = split_cells(bright_indices, settings) + centers, (detector, offset) = split_cells(bright_indices, settings) # for some reason, same with pytorch, it's shifted by 1. Probably rounding expected = {(10, 11, 11), (20, 11, 11)} got = set(map(tuple, centers.tolist())) assert expected == got + assert detector.n_structures == 2 + def test_ball_filter_imgs_invalid_volume(): """Checks that an invalid volume returns empty array instead.""" @@ -111,8 +113,8 @@ def test_ball_filter_imgs_invalid_volume(): ball_xy_size_um=50, ) - vol = ball_filter_imgs(torch.zeros((5, 100, 30)), settings, None) - assert not vol.shape[0] + cell_detector = ball_filter_imgs(torch.zeros((5, 100, 30)), settings, None) + assert not cell_detector.n_structures @pytest.mark.parametrize("inside", [True, False]) From 725761d8a47d8a5fa1e83fce2b9a8a36c7250996 Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Sun, 22 Mar 2026 12:35:54 -0400 Subject: [PATCH 5/5] Fix incorrect padding and centering of volume of split cells. --- .../core/detect/filters/volume/ball_filter.py | 49 ++++++- .../filters/volume/structure_splitting.py | 124 ++++++++++-------- .../test_detection_structure_splitting.py | 13 +- .../test_volume_filters/test_ball_filter.py | 12 ++ 4 files changed, 126 insertions(+), 72 deletions(-) diff --git a/cellfinder/core/detect/filters/volume/ball_filter.py b/cellfinder/core/detect/filters/volume/ball_filter.py index 96bb9fa0..9d274ffd 100644 --- a/cellfinder/core/detect/filters/volume/ball_filter.py +++ b/cellfinder/core/detect/filters/volume/ball_filter.py @@ -78,10 +78,10 @@ class BallFilter: ---------- plane_height, plane_width : int Height/width of the planes. - ball_xy_size : float - Diameter of the spherical kernel (in microns) in the x/y dimensions. - ball_z_size : float - Diameter of the spherical kernel in the z dimension in microns. + ball_xy_size : int + Diameter of the spherical kernel (in voxels) in the x/y dimensions. + ball_z_size : int + Diameter of the spherical kernel in the z dimension in voxels. Determines the number of planes stacked to filter the central plane of the stack. overlap_fraction : float @@ -215,6 +215,47 @@ def first_valid_plane(self) -> int: """ return int(math.floor(self.ball_z_size / 2)) + @classmethod + def min_xy_padding(cls, ball_xy_size: int) -> tuple[int, int]: + """ + For a given ball x/y kernel size, it returns the padding needed for the + plane so the full input data is filtered. Otherwise, data on the + edges of the plane will be zeros because the center of the ball will + never be over it. Adding padding will ensure the ball center will be + over all input data. + + Parameters + ---------- + ball_xy_size: int + The x/y kernel (ball) size in voxels. + + return + ---------- + padding: 2-tuple of ints + The padding to add to the start and end of the data to have all + input data considered. + """ + n = ball_xy_size + # e.g. if starting with just 1 voxel: if kernel is even, say 4, then + # left padding will be 1 and right 2. This gives a single voxel output. + # If odd, say 5, it'll be 2 on each side + left = (n - 1) // 2 + right = n - 1 - left + return left, right + + @classmethod + def min_z_padding(cls, ball_z_size: int) -> tuple[int, int]: + """ + Same as for `min_xy_padding`, except in the z-dimension. + """ + n = ball_z_size + # e.g. if starting with just 1 voxel: if kernel is even, say 4, then + # left padding will be 1 and right 2. This gives a single voxel output. + # If odd, say 5, it'll be 2 on each side + bottom = (n - 1) // 2 + top = n - 1 - bottom + return bottom, top + @property def remaining_planes(self) -> int: """ diff --git a/cellfinder/core/detect/filters/volume/structure_splitting.py b/cellfinder/core/detect/filters/volume/structure_splitting.py index 6f945ba9..e4745bcb 100644 --- a/cellfinder/core/detect/filters/volume/structure_splitting.py +++ b/cellfinder/core/detect/filters/volume/structure_splitting.py @@ -4,7 +4,6 @@ import numpy as np import torch -from cellfinder.core import logger from cellfinder.core.detect.filters.setup_filters import DetectionSettings from cellfinder.core.detect.filters.volume.ball_filter import ( BallFilter, @@ -25,7 +24,10 @@ def get_shape( ) -> Tuple[int, int, int]: """ Takes a list of x, y, z coordinates and returns a volume size such that - all the points will fit into it. With axis order = x, y, z. + all the points will fit into it (once the min of each dim is subtracted - + i.e. the smallest point in each dim falls on zero). + + Axis order is x, y, z. """ # +1 because difference. TEST: shape = tuple(int((dim.max() - dim.min()) + 1) for dim in (xs, ys, zs)) @@ -38,15 +40,23 @@ def coords_to_volume( zs: np.ndarray, intensity: Optional[np.ndarray], volume_shape: Tuple[int, int, int], - ball_radius: int, + ball_xy_padding: tuple[int, int], + ball_z_padding: tuple[int, int], dtype: Type[np.number], threshold_value: int, ) -> tuple[torch.Tensor, Optional[np.ndarray]]: """ Takes the series of x, y, z points along with the shape of the volume - that fully enclose them (also x, y, z order). It than expands the - shape by the ball diameter in each axis. Then, each point, shifted - by the radius internally is set to the threshold value. + (also x, y, z order) that fully enclose them, relative to the minimum + point in each dim of the data. It than expands the volume in each dim to + account for the filtering ball diameter by adding padding + before / after the dim. So each point in the expanded volume is shifted + by the start padding internally and is then set to the threshold value. + + The result is that after ball filtering, all the points we get will be + fully contained in the original volume and not in the padding. Of course + the start padding will need to be subtracted from the point indices in the + expanded volume. The volume is then transposed and returned in the Z, Y, X order. @@ -55,9 +65,15 @@ def coords_to_volume( """ # it's faster doing the work in numpy and then returning as torch array, # than doing the work in torch - ball_diameter = ball_radius * 2 - # Expanded to ensure the ball fits even at the border - expanded_shape = [dim_size + ball_diameter for dim_size in volume_shape] + # Expanded to ensure the ball fits at all borders of input cuboid + xy_add = sum(ball_xy_padding) + z_add = sum(ball_z_padding) + expanded_shape = ( + volume_shape[0] + xy_add, + volume_shape[1] + xy_add, + volume_shape[2] + z_add, + ) + # volume is now x, y, z order volume = np.zeros(expanded_shape, dtype=dtype) # use largest type. These are small volumes and not processed much except @@ -69,10 +85,11 @@ def coords_to_volume( x_min, y_min, z_min = xs.min(), ys.min(), zs.min() # shift the points so any sphere centered on it would not have its - # radius expand beyond the volume - relative_xs = np.array((xs - x_min + ball_radius), dtype=np.int64) - relative_ys = np.array((ys - y_min + ball_radius), dtype=np.int64) - relative_zs = np.array((zs - z_min + ball_radius), dtype=np.int64) + # radius expand beyond the volume and so center of sphere would be in + # original volume + relative_xs = np.array((xs - x_min + ball_xy_padding[0]), dtype=np.int64) + relative_ys = np.array((ys - y_min + ball_xy_padding[0]), dtype=np.int64) + relative_zs = np.array((zs - z_min + ball_z_padding[0]), dtype=np.int64) # set each point as the center with a value of threshold volume[relative_xs, relative_ys, relative_zs] = threshold_value @@ -217,32 +234,6 @@ def iterative_ball_filter( return cell_detectors -def check_centre_in_cuboid(centre: np.ndarray, max_coords: np.ndarray) -> bool: - """ - Checks whether a coordinate is in a cuboid. - - Parameters - ---------- - centre : np.ndarray - x, y, z coordinate. - max_coords : np.ndarray - Far corner of cuboid. - - Returns - ------- - True if within cuboid, otherwise False. - """ - relative_coords = centre - if (relative_coords > max_coords).all(): - logger.info( - 'Relative coordinates "{}" exceed maximum volume ' - 'dimension of "{}"'.format(relative_coords, max_coords) - ) - return False - else: - return True - - def split_cells( cell_points: np.ndarray, settings: DetectionSettings, @@ -274,9 +265,11 @@ def split_cells( better cell candidates during splitting than the original single cell. Otherwise, it's the `CellDetector` with the structs from the best iteration and a size 3 `np.ndarray` with the offset of the structs in - the cell detector. I.e. the cell detector uses coordinates relative to - the size of the cube containing the input voxels so the offset must be - added to convert to absolute voxel indices. + the cell detector. I.e. all coordinates in the cell detector is + relative to the size of the cuboid containing only the `cell_points`. + So the offset must be added to convert any voxel indices accessed via + the cell detector (e.g. `cell_detector.get_structures`) to absolute + voxel indices. """ settings = copy(settings) if settings.detect_centre_of_intensity and intensity is None: @@ -292,22 +285,24 @@ def split_cells( ys = cell_points[:, 1] zs = cell_points[:, 2] - # total volume enclosing all points + # total volume enclosing all points from the input original_bounding_cuboid_shape = get_shape(xs, ys, zs) - ball_radius = settings.ball_xy_size // 2 # they should be the same dtype so as to not need a conversion before # passing the input data with marked cells to the filters (we currently # set both to float32) assert settings.filtering_dtype == settings.plane_original_np_dtype - # volume will now be z, y, x order + # Volume will be padded so not lose points on the edges. It's z, y, x order + ball_xy_padding = BallFilter.min_xy_padding(settings.ball_xy_size) + ball_z_padding = BallFilter.min_z_padding(settings.ball_z_size) vol, raw_vol = coords_to_volume( xs, ys, zs, intensity, volume_shape=original_bounding_cuboid_shape, - ball_radius=ball_radius, + ball_xy_padding=ball_xy_padding, + ball_z_padding=ball_z_padding, dtype=settings.filtering_dtype, threshold_value=settings.threshold_value, ) @@ -334,22 +329,39 @@ def split_cells( # if best split only resulted in one (or no) struct, stick with original if not struct_counts or max(struct_counts) <= 1: - return orig_centre, None + return orig_centre[None, :], None best_iteration = struct_counts.index(max(struct_counts)) cell_detector = cell_detectors[best_iteration] - relative_centres = cell_detector.get_cell_centres( + # centers come in where zero is relative to expanded vol corner + expanded_relative_centres = cell_detector.get_cell_centres( settings.detect_centre_of_intensity ) - # add original centre. That's valid, even if using centre of intensity - original_max_coords = np.array(original_bounding_cuboid_shape) - for x in relative_centres: - assert (x < original_max_coords).all() - - # corner coordinates in absolute pixels + # corner coordinates of original unexpended vol in absolute voxels orig_corner = np.array([xs.min(), ys.min(), zs.min()]) - # convert centers to absolute voxels in original vol + # shape of the original unexpanded volume + orig_cuboid_shape = np.array(original_bounding_cuboid_shape) + # start padding added to start of original vol to gain expanded volume + start_padding = np.array( + [ball_xy_padding[0], ball_xy_padding[0], ball_z_padding[0]], + dtype=np.int_, + ) + + # remove padding to get indices relative to original vol corner + relative_centres = expanded_relative_centres - start_padding + for x in relative_centres: + # we allow to be sticking out by one on each side due to rounding, but + # more than one should be impossible + assert (x <= orig_cuboid_shape).all() + assert (x >= -1).all() + # but if they do stick out by one, clip so it's in the valid original vol + relative_centres = np.clip(relative_centres, 0, orig_cuboid_shape - 1) + + # convert relative centers to absolute voxels in original vol absolute_centres = relative_centres + orig_corner[None, :] - return absolute_centres, (cell_detector, orig_corner) + # any indices stored in cell detector is relative to the expanded vol so a + # zero there (i.e. offset) should be relative to where the padding starts + offset = orig_corner - start_padding + return absolute_centres, (cell_detector, offset) diff --git a/tests/core/test_integration/test_detection_structure_splitting.py b/tests/core/test_integration/test_detection_structure_splitting.py index 3b8f9400..a00bc82e 100644 --- a/tests/core/test_integration/test_detection_structure_splitting.py +++ b/tests/core/test_integration/test_detection_structure_splitting.py @@ -14,7 +14,6 @@ from cellfinder.core.detect.filters.setup_filters import DetectionSettings from cellfinder.core.detect.filters.volume.structure_splitting import ( ball_filter_imgs, - check_centre_in_cuboid, split_cells, ) from cellfinder.core.main import main @@ -96,8 +95,7 @@ def test_underflow_issue_435(): ) centers, (detector, offset) = split_cells(bright_indices, settings) - # for some reason, same with pytorch, it's shifted by 1. Probably rounding - expected = {(10, 11, 11), (20, 11, 11)} + expected = set(map(tuple, [p1.tolist(), p2.tolist()])) got = set(map(tuple, centers.tolist())) assert expected == got @@ -117,15 +115,6 @@ def test_ball_filter_imgs_invalid_volume(): assert not cell_detector.n_structures -@pytest.mark.parametrize("inside", [True, False]) -def test_check_centre_in_cuboid(inside): - corner = np.array([5, 5, 5]) - if inside: - assert check_centre_in_cuboid(np.array([2, 2, 2]), corner) - else: - assert not check_centre_in_cuboid(np.array([8, 8, 8]), corner) - - def test_using_coi_without_intensity(): cell_points = np.zeros((30, 20, 20), dtype=np.bool_) settings = DetectionSettings( diff --git a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_ball_filter.py b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_ball_filter.py index 47dfefd8..002724f9 100644 --- a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_ball_filter.py +++ b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_ball_filter.py @@ -51,6 +51,18 @@ def test_filter_plane_params(sizes): assert bf.remaining_planes == remaining +@pytest.mark.parametrize( + "sizes", [(1, 0, 0), (2, 0, 1), (3, 1, 1), (4, 1, 2), (5, 2, 2), (6, 2, 3)] +) +def test_filter_padding(sizes): + # checks that for a given kernel size, the start / end padding matches as + # expected. The start padding is always the lessor (when even) + kernel_size, *padding = sizes + + assert BallFilter.min_xy_padding(kernel_size) == tuple(padding) + assert BallFilter.min_z_padding(kernel_size) == tuple(padding) + + @pytest.mark.parametrize("batch_size", [1, 2, 5, 10]) @pytest.mark.parametrize("kernel_size", [1, 2, 3, 5]) def test_filtered_planes(kernel_size, batch_size):