Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 64 additions & 116 deletions cellfinder/core/detect/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,53 @@
from cellfinder.core.tools.tools import inference_wrapper


# -------------------------------------------------------------------------
# Validation helper for the detection pipeline
# -------------------------------------------------------------------------
def _validate_detection_inputs(
signal_array,
voxel_sizes,
soma_diameter,
ball_xy_size,
ball_z_size,
ball_overlap_fraction,
batch_size,
):
"""
Validate detection parameters before running the pipeline.
"""

# Ensure numeric data
if not np.issubdtype(signal_array.dtype, np.number):
raise TypeError(
f"signal_array must contain numeric values, got dtype {signal_array.dtype}"
)

# Ensure 3D volume
if signal_array.ndim != 3:
raise ValueError("Input data must be 3D")

# Voxel sizes must be positive
if any(v <= 0 for v in voxel_sizes):
raise ValueError("voxel_sizes must contain positive values")

# Soma diameter must be positive
if soma_diameter <= 0:
raise ValueError("soma_diameter must be positive")

# Ball filter dimensions must be positive
if ball_xy_size <= 0 or ball_z_size <= 0:
raise ValueError("ball filter sizes must be positive")

# Overlap fraction must be valid
if not (0 < ball_overlap_fraction <= 1):
raise ValueError("ball_overlap_fraction must be between 0 and 1")

# Batch size must be valid if provided
if batch_size is not None and batch_size < 1:
raise ValueError("batch_size must be >= 1")


@inference_wrapper
def main(
signal_array: types.array,
Expand Down Expand Up @@ -61,136 +108,44 @@ def main(
) -> List[Cell]:
"""
Perform cell candidate detection on a 3D signal array.

Parameters
----------
signal_array : numpy.ndarray or dask array
3D array representing the signal data in z, y, x order.
start_plane : int
First plane index to process (inclusive, to process a subset of the
data).
end_plane : int
Last plane index to process (exclusive, to process a subset of the
data).
voxel_sizes : 3-tuple of floats
Size of your voxels in the z, y, and x dimensions (microns).
soma_diameter : float
The expected in-plane (xy) soma diameter (microns).
max_cluster_size : float
Largest detected cell cluster (in cubic um) where splitting
should be attempted. Clusters above this size will be labeled
as artifacts.
ball_xy_size : float
3d filter's in-plane (xy) filter ball size (microns).
ball_z_size : float
3d filter's axial (z) filter ball size (microns).
ball_overlap_fraction : float
3d filter's fraction of the ball filter needed to be filled by
foreground voxels, centered on a voxel, to retain the voxel.
soma_spread_factor : float
Cell spread factor for determining the largest cell volume before
splitting up cell clusters. Structures with spherical volume of
diameter `soma_spread_factor * soma_diameter` or less will not be
split.
n_free_cpus : int
How many CPU cores to leave free.
log_sigma_size : float
Gaussian filter width (as a fraction of soma diameter) used during
2d in-plane Laplacian of Gaussian filtering.
n_sds_above_mean_thresh : float
Per-plane intensity threshold (the number of standard deviations
above the mean) of the filtered 2d planes used to mark pixels as
foreground or background.
outlier_keep : bool, optional
Whether to keep outliers during detection. Defaults to False.
artifact_keep : bool, optional
Whether to keep artifacts during detection. Defaults to False.
save_planes : bool, optional
Whether to save the planes during detection. Defaults to False.
plane_directory : str, optional
Directory path to save the planes. Defaults to None.
batch_size: int
The number of planes of the original data volume to process at
once. The GPU/CPU memory must be able to contain this many planes
for all the filters. For performance-critical applications, tune to
maximize memory usage without running out. Check your GPU/CPU memory
to verify it's not full.
torch_device : str, optional
The device on which to run the computation. If not specified (None),
"cuda" will be used if a GPU is available, otherwise "cpu".
You can also manually specify "cuda" or "cpu".
pin_memory: bool
Pins data to be sent to the GPU to the CPU memory. This allows faster
GPU data speeds, but can only be used if the data used by the GPU can
stay in the CPU RAM while the GPU uses it. I.e. there's enough RAM.
Otherwise, if there's a risk of the RAM being paged, it shouldn't be
used. Defaults to False.
split_ball_xy_size: float
Similar to `ball_xy_size`, except the value to use for the 3d
filter during cluster splitting.
split_ball_z_size: float
Similar to `ball_z_size`, except the value to use for the 3d filter
during cluster splitting.
split_ball_overlap_fraction: float
Similar to `ball_overlap_fraction`, except the value to use for the
3d filter during cluster splitting.
n_splitting_iter: int
The number of iterations to run the 3d filtering on a cluster. Each
iteration reduces the cluster size by the voxels not retained in
the previous iteration.
n_sds_above_mean_tiled_thresh : float
Per-plane, per-tile intensity threshold (the number of standard
deviations above the mean) for the filtered 2d planes used to mark
pixels as foreground or background. When used, (tile size is not zero)
a pixel is marked as foreground if its intensity is above both the
per-plane and per-tile threshold. I.e. it's above the set number of
standard deviations of the per-plane average and of the per-plane
per-tile average for the tile that contains it.
tiled_thresh_tile_size : float
The tile size used to tile the x, y plane to calculate the local
average intensity for the tiled threshold. The value is multiplied
by soma diameter (i.e. 1 means one soma diameter). If zero or None, the
tiled threshold is disabled and only the per-plane threshold is used.
Tiling is done with 50% overlap when striding.
callback : Callable[int], optional
A callback function that is called every time a plane has finished
being processed. Called with the plane number that has finished.

Returns
-------
List[Cell]
List of detected cell candidates.
"""

start_time = datetime.now()

if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

if batch_size is None:
if torch_device == "cpu":
batch_size = 4
else:
batch_size = 1

if not np.issubdtype(signal_array.dtype, np.number):
raise TypeError(
"signal_array must be a numpy datatype, but has datatype "
f"{signal_array.dtype}"
)

if signal_array.ndim != 3:
raise ValueError("Input data must be 3D")
# Validate detection parameters early
_validate_detection_inputs(
signal_array,
voxel_sizes,
soma_diameter,
ball_xy_size,
ball_z_size,
ball_overlap_fraction,
batch_size,
)

if end_plane < 0:
end_plane = len(signal_array)
end_plane = min(len(signal_array), end_plane)

torch_device = torch_device.lower()

# Use SciPy filtering on CPU (better performance); use PyTorch on GPU
if torch_device != "cuda":
use_scipy = True
else:
use_scipy = False

batch_size = max(batch_size, 1)

# brainmapper can pass them in as str
voxel_sizes = list(map(float, voxel_sizes))

Expand Down Expand Up @@ -221,20 +176,12 @@ def main(
n_splitting_iter=n_splitting_iter,
)

# replicate the settings specific to splitting, before we access anything
# of the original settings, causing cached properties
# replicate the settings specific to splitting
kwargs = dataclasses.asdict(settings)
kwargs["ball_z_size_um"] = split_ball_z_size
kwargs["ball_xy_size_um"] = split_ball_xy_size
kwargs["ball_overlap_fraction"] = split_ball_overlap_fraction
# always run on cpu because copying to gpu overhead is likely slower than
# any benefit for detection on smallish volumes
kwargs["torch_device"] = "cpu"
# for splitting, we only do 3d filtering. Its input is a zero volume
# with cell voxels marked with threshold_value. So just use float32
# for input because the filters will also use float(32). So there will
# not be need to convert the input a different dtype before passing to
# the filters.
kwargs["plane_original_np_dtype"] = np.float32
splitting_settings = DetectionSettings(**kwargs)

Expand Down Expand Up @@ -268,4 +215,5 @@ def main(
time_elapsed = datetime.now() - start_time
s = f"Detection complete. Found {len(cells)} cells in {time_elapsed}"
logger.info(s)

return cells