From c8a42d2aae5ff7f4ee3d38bc7f0f0e5e86e1b0bf Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Thu, 26 Mar 2026 14:50:41 -0400 Subject: [PATCH] Fix kernel to be symetric and properly spheroid. --- .../core/detect/filters/volume/ball_filter.py | 11 ++++--- cellfinder/core/tools/geometry.py | 29 ++++++++++--------- .../test_volume_filters/test_ball_filter.py | 19 +++++++++++- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/cellfinder/core/detect/filters/volume/ball_filter.py b/cellfinder/core/detect/filters/volume/ball_filter.py index f36d2aef..426ae9c8 100644 --- a/cellfinder/core/detect/filters/volume/ball_filter.py +++ b/cellfinder/core/detect/filters/volume/ball_filter.py @@ -36,12 +36,11 @@ def get_kernel(ball_xy_size: int, ball_z_size: int) -> np.ndarray: upscale_factor * ball_xy_size, upscale_factor * ball_z_size, ) - upscaled_ball_centre_position = ( - np.floor(upscaled_kernel_shape[0] / 2), - np.floor(upscaled_kernel_shape[1] / 2), - np.floor(upscaled_kernel_shape[2] / 2), - ) - upscaled_ball_radius = upscaled_kernel_shape[0] / 2.0 + # subtract one b/c we need to shift from size/count to index + upscaled_ball_centre_position = [ + (u - 1) / 2 for u in upscaled_kernel_shape + ] + upscaled_ball_radius = [u / 2 for u in upscaled_kernel_shape] sphere_kernel = make_sphere( upscaled_kernel_shape, diff --git a/cellfinder/core/tools/geometry.py b/cellfinder/core/tools/geometry.py index 684c86ea..b103aef2 100644 --- a/cellfinder/core/tools/geometry.py +++ b/cellfinder/core/tools/geometry.py @@ -1,12 +1,12 @@ -from typing import Tuple +from numbers import Number import numpy as np def make_sphere( - ball_shape: Tuple[int, int, int], - radius: float, - position: Tuple[int, int, int], + ball_shape: tuple[int, int, int], + radius: tuple[float, float, float] | float, + position: tuple[float, float, float], ) -> np.ndarray: """ Return a boolean array, with array elements inside a sphere set @@ -17,23 +17,24 @@ def make_sphere( ball_shape : Shape of the output array. radius : - Radius of the sphere. + Radius of the sphere, either single radius for sphere or 3d radius for + spheroid. position : - Centre of the sphere. + Centre of the sphere (can be between voxels). """ - - half_sizes = (radius,) * 3 + if isinstance(radius, Number): + radius = (radius,) * 3 # generate the grid for the support points - # centered at the position indicated by position - grid = [slice(-x0, dim - x0) for x0, dim in zip(position, ball_shape)] + grid = [slice(dim) for dim in ball_shape] meshedgrid = np.ogrid[grid] # calculate the distance of all points from `position` center - # scaled by the radius + # proportional of the radius so 1 would mean at the radius arr = np.zeros(ball_shape, dtype=float) - for x_i, half_size in zip(meshedgrid, half_sizes): - arr += np.abs(x_i / half_size) ** 2 - # the inner part of the sphere will have distance below 1 + for x_i, centre_i, radius_i in zip(meshedgrid, position, radius): + arr += ((x_i - centre_i) / radius_i) ** 2 + + # the inner part of the sphere will have distance below 1 b/c pythagoras return arr <= 1.0 diff --git a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_ball_filter.py b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_ball_filter.py index 34ef4775..30e83692 100644 --- a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_ball_filter.py +++ b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_ball_filter.py @@ -1,7 +1,11 @@ +import numpy as np import pytest import torch -from cellfinder.core.detect.filters.volume.ball_filter import BallFilter +from cellfinder.core.detect.filters.volume.ball_filter import ( + BallFilter, + get_kernel, +) bf_kwargs = { "plane_height": 50, @@ -18,6 +22,19 @@ } +@pytest.mark.parametrize("xy_size", list(range(1, 7))) +@pytest.mark.parametrize("z_size", list(range(1, 7))) +def test_kernel_symetry(xy_size, z_size): + kernel = get_kernel(xy_size, z_size) + for axis in range(3): + flipped = np.flip(kernel, axis=axis) + assert np.allclose(flipped, kernel) + + assert kernel.shape[0] == xy_size + assert kernel.shape[1] == xy_size + assert kernel.shape[2] == z_size + + def test_filter_not_ready(): bf = BallFilter(**bf_kwargs) assert not bf.ready