diff --git a/cellfinder/core/detect/detect.py b/cellfinder/core/detect/detect.py index 82bc8512..e10d2b56 100644 --- a/cellfinder/core/detect/detect.py +++ b/cellfinder/core/detect/detect.py @@ -28,6 +28,53 @@ from cellfinder.core.tools.tools import inference_wrapper +# ------------------------------------------------------------------------- +# Validation helper for the detection pipeline +# ------------------------------------------------------------------------- +def _validate_detection_inputs( + signal_array, + voxel_sizes, + soma_diameter, + ball_xy_size, + ball_z_size, + ball_overlap_fraction, + batch_size, +): + """ + Validate detection parameters before running the pipeline. + """ + + # Ensure numeric data + if not np.issubdtype(signal_array.dtype, np.number): + raise TypeError( + f"signal_array must contain numeric values, got dtype {signal_array.dtype}" + ) + + # Ensure 3D volume + if signal_array.ndim != 3: + raise ValueError("Input data must be 3D") + + # Voxel sizes must be positive + if any(v <= 0 for v in voxel_sizes): + raise ValueError("voxel_sizes must contain positive values") + + # Soma diameter must be positive + if soma_diameter <= 0: + raise ValueError("soma_diameter must be positive") + + # Ball filter dimensions must be positive + if ball_xy_size <= 0 or ball_z_size <= 0: + raise ValueError("ball filter sizes must be positive") + + # Overlap fraction must be valid + if not (0 < ball_overlap_fraction <= 1): + raise ValueError("ball_overlap_fraction must be between 0 and 1") + + # Batch size must be valid if provided + if batch_size is not None and batch_size < 1: + raise ValueError("batch_size must be >= 1") + + @inference_wrapper def main( signal_array: types.array, @@ -61,129 +108,36 @@ def main( ) -> List[Cell]: """ Perform cell candidate detection on a 3D signal array. - - Parameters - ---------- - signal_array : numpy.ndarray or dask array - 3D array representing the signal data in z, y, x order. - start_plane : int - First plane index to process (inclusive, to process a subset of the - data). - end_plane : int - Last plane index to process (exclusive, to process a subset of the - data). - voxel_sizes : 3-tuple of floats - Size of your voxels in the z, y, and x dimensions (microns). - soma_diameter : float - The expected in-plane (xy) soma diameter (microns). - max_cluster_size : float - Largest detected cell cluster (in cubic um) where splitting - should be attempted. Clusters above this size will be labeled - as artifacts. - ball_xy_size : float - 3d filter's in-plane (xy) filter ball size (microns). - ball_z_size : float - 3d filter's axial (z) filter ball size (microns). - ball_overlap_fraction : float - 3d filter's fraction of the ball filter needed to be filled by - foreground voxels, centered on a voxel, to retain the voxel. - soma_spread_factor : float - Cell spread factor for determining the largest cell volume before - splitting up cell clusters. Structures with spherical volume of - diameter `soma_spread_factor * soma_diameter` or less will not be - split. - n_free_cpus : int - How many CPU cores to leave free. - log_sigma_size : float - Gaussian filter width (as a fraction of soma diameter) used during - 2d in-plane Laplacian of Gaussian filtering. - n_sds_above_mean_thresh : float - Per-plane intensity threshold (the number of standard deviations - above the mean) of the filtered 2d planes used to mark pixels as - foreground or background. - outlier_keep : bool, optional - Whether to keep outliers during detection. Defaults to False. - artifact_keep : bool, optional - Whether to keep artifacts during detection. Defaults to False. - save_planes : bool, optional - Whether to save the planes during detection. Defaults to False. - plane_directory : str, optional - Directory path to save the planes. Defaults to None. - batch_size: int - The number of planes of the original data volume to process at - once. The GPU/CPU memory must be able to contain this many planes - for all the filters. For performance-critical applications, tune to - maximize memory usage without running out. Check your GPU/CPU memory - to verify it's not full. - torch_device : str, optional - The device on which to run the computation. If not specified (None), - "cuda" will be used if a GPU is available, otherwise "cpu". - You can also manually specify "cuda" or "cpu". - pin_memory: bool - Pins data to be sent to the GPU to the CPU memory. This allows faster - GPU data speeds, but can only be used if the data used by the GPU can - stay in the CPU RAM while the GPU uses it. I.e. there's enough RAM. - Otherwise, if there's a risk of the RAM being paged, it shouldn't be - used. Defaults to False. - split_ball_xy_size: float - Similar to `ball_xy_size`, except the value to use for the 3d - filter during cluster splitting. - split_ball_z_size: float - Similar to `ball_z_size`, except the value to use for the 3d filter - during cluster splitting. - split_ball_overlap_fraction: float - Similar to `ball_overlap_fraction`, except the value to use for the - 3d filter during cluster splitting. - n_splitting_iter: int - The number of iterations to run the 3d filtering on a cluster. Each - iteration reduces the cluster size by the voxels not retained in - the previous iteration. - n_sds_above_mean_tiled_thresh : float - Per-plane, per-tile intensity threshold (the number of standard - deviations above the mean) for the filtered 2d planes used to mark - pixels as foreground or background. When used, (tile size is not zero) - a pixel is marked as foreground if its intensity is above both the - per-plane and per-tile threshold. I.e. it's above the set number of - standard deviations of the per-plane average and of the per-plane - per-tile average for the tile that contains it. - tiled_thresh_tile_size : float - The tile size used to tile the x, y plane to calculate the local - average intensity for the tiled threshold. The value is multiplied - by soma diameter (i.e. 1 means one soma diameter). If zero or None, the - tiled threshold is disabled and only the per-plane threshold is used. - Tiling is done with 50% overlap when striding. - callback : Callable[int], optional - A callback function that is called every time a plane has finished - being processed. Called with the plane number that has finished. - - Returns - ------- - List[Cell] - List of detected cell candidates. """ + start_time = datetime.now() + if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" + if batch_size is None: if torch_device == "cpu": batch_size = 4 else: batch_size = 1 - if not np.issubdtype(signal_array.dtype, np.number): - raise TypeError( - "signal_array must be a numpy datatype, but has datatype " - f"{signal_array.dtype}" - ) - - if signal_array.ndim != 3: - raise ValueError("Input data must be 3D") + # Validate detection parameters early + _validate_detection_inputs( + signal_array, + voxel_sizes, + soma_diameter, + ball_xy_size, + ball_z_size, + ball_overlap_fraction, + batch_size, + ) if end_plane < 0: end_plane = len(signal_array) end_plane = min(len(signal_array), end_plane) torch_device = torch_device.lower() + # Use SciPy filtering on CPU (better performance); use PyTorch on GPU if torch_device != "cuda": use_scipy = True @@ -191,6 +145,7 @@ def main( use_scipy = False batch_size = max(batch_size, 1) + # brainmapper can pass them in as str voxel_sizes = list(map(float, voxel_sizes)) @@ -221,20 +176,12 @@ def main( n_splitting_iter=n_splitting_iter, ) - # replicate the settings specific to splitting, before we access anything - # of the original settings, causing cached properties + # replicate the settings specific to splitting kwargs = dataclasses.asdict(settings) kwargs["ball_z_size_um"] = split_ball_z_size kwargs["ball_xy_size_um"] = split_ball_xy_size kwargs["ball_overlap_fraction"] = split_ball_overlap_fraction - # always run on cpu because copying to gpu overhead is likely slower than - # any benefit for detection on smallish volumes kwargs["torch_device"] = "cpu" - # for splitting, we only do 3d filtering. Its input is a zero volume - # with cell voxels marked with threshold_value. So just use float32 - # for input because the filters will also use float(32). So there will - # not be need to convert the input a different dtype before passing to - # the filters. kwargs["plane_original_np_dtype"] = np.float32 splitting_settings = DetectionSettings(**kwargs) @@ -268,4 +215,5 @@ def main( time_elapsed = datetime.now() - start_time s = f"Detection complete. Found {len(cells)} cells in {time_elapsed}" logger.info(s) + return cells