Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions cellfinder/core/classify/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CuboidBatchSampler,
)
from cellfinder.core.classify.tools import get_model
from cellfinder.core.tools.image_processing import dataset_mean_std
from cellfinder.core.train.train_yaml import depth_type, models


Expand All @@ -37,6 +38,8 @@ def main(
pin_memory: bool = False,
*,
callback: Optional[Callable[[int], None]] = None,
normalize_channels: bool = False,
normalization_down_sampling: int = 32,
) -> List[Cell]:
"""
Parameters
Expand Down Expand Up @@ -89,6 +92,14 @@ def main(
callback : Callable[int], optional
A callback function that is called during classification. Called with
the batch number once that batch has been classified.
normalize_channels : bool
If True, the signal and background data will be each normalized
to a mean of zero and standard deviation of 1. Defaults to False.
normalization_down_sampling : int
If `normalize_channels` is True, the data arrays will be down-sampled
in the first axis by this value before calculating their statistics.
E.g. a value of 2 means every second plane will be used. Defaults to
32.
"""
if signal_array.ndim != 3:
raise IOError("Signal data must be 3D")
Expand All @@ -102,10 +113,27 @@ def main(
start_time = datetime.now()

voxel_sizes = list(map(float, voxel_sizes))

signal_normalization = background_normalization = None
if normalize_channels:
logger.debug("Calculating channels norms")
signal_normalization = dataset_mean_std(
signal_array, normalization_down_sampling
)
background_normalization = dataset_mean_std(
background_array, normalization_down_sampling
)
logger.debug(
f"Signal channel norm is: {signal_normalization}. "
f"Background channel norm is: {background_normalization}"
)

logger.debug("Initialising cube generator")
dataset = CuboidArrayDataset(
signal_array=signal_array,
background_array=background_array,
signal_normalization=signal_normalization,
background_normalization=background_normalization,
points=points,
data_voxel_sizes=voxel_sizes,
network_voxel_sizes=network_voxel_sizes,
Expand Down
55 changes: 54 additions & 1 deletion cellfinder/core/classify/cube_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,13 +1143,21 @@ class CuboidArrayDataset(CuboidThreadedDatasetBase):
This determines how many multiples (or fractions) of `n` such planes
to buffer, in addition to `n` that is always buffered. So e.g. `1`
means `2n` and `0.5` means `1.5n`. `1` is a good default.
:param signal_normalization: None or a 2-tuple of `(mean, std)`.
If not None, the signal channel in the cubes will be normalized to the
provided mean and standard deviation.
:param background_normalization: None or a 2-tuple of `(mean, std)`.
If not None, the background channel in the cubes will be normalized to
the provided mean and standard deviation.
"""

def __init__(
self,
signal_array: types.array,
background_array: types.array | None,
max_axis_0_cuboids_buffered: float = 0,
signal_normalization: None | tuple[float, float] = None,
background_normalization: None | tuple[float, float] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -1210,6 +1218,9 @@ def __init__(
)
self._set_output_data_dim_reordering(self.src_image_data)

self.signal_normalization = signal_normalization
self.background_normalization = background_normalization

def point_has_full_cuboid(self, point: Sequence[float]) -> bool:
"""
Takes a 3d point and returns whether a cuboid centered on this point
Expand All @@ -1229,6 +1240,20 @@ def point_has_full_cuboid(self, point: Sequence[float]) -> bool:

return True

def get_points_data(self, points_key: Sequence[int]) -> torch.Tensor:
data = super().get_points_data(points_key)

if self.signal_normalization is not None:
mean, std = self.signal_normalization
data[..., 0] -= mean
data[..., 0] /= std
if self.background_normalization is not None:
mean, std = self.background_normalization
data[..., 1] -= mean
data[..., 1] /= std

return data


class CuboidTiffDataset(CuboidThreadedDatasetBase):
"""
Expand All @@ -1241,6 +1266,12 @@ class CuboidTiffDataset(CuboidThreadedDatasetBase):

The outer list is the number of points/samples. The inner lists is the
number of channels (e.g. signal/background) for the given point.
:param points_normalization: None or a sequence of sequences of 2-tuples
of `(mean, std)`.

If not None, each 2-tuple corresponds to a single filename in
`points_filenames` and that cube will be normalized by the given
mean and standard deviation before returning it.
:param max_cuboids_buffered: Integer
The number of the most recently accessed cuboids to cache so it isn't
read from disk again.
Expand All @@ -1249,6 +1280,9 @@ class CuboidTiffDataset(CuboidThreadedDatasetBase):
def __init__(
self,
points_filenames: Sequence[Sequence[str]],
points_normalization: (
Sequence[Sequence[tuple[float, float]]] | None
) = None,
max_cuboids_buffered: int = 0,
**kwargs,
):
Expand All @@ -1259,10 +1293,17 @@ def __init__(
raise ValueError(
"Points and filenames must have same number of elements"
)
if points_normalization is not None and len(points_filenames) != len(
points_normalization
):
raise ValueError("Must have normalizations for all elements")

self.num_channels = len(points_filenames[0])
filenames_arr = np.array(points_filenames).astype(np.str_)
self.filenames_arr = filenames_arr
self.points_norm_arr = None
if points_normalization is not None:
self.points_norm_arr = torch.tensor(points_normalization)

self.src_image_data = CachedTiffImage(
points_arr=self.points_arr[:, :3],
Expand All @@ -1273,6 +1314,18 @@ def __init__(
)
self._set_output_data_dim_reordering(self.src_image_data)

def get_points_data(self, points_key: Sequence[int]) -> torch.Tensor:
data = super().get_points_data(points_key)

if self.points_norm_arr is not None:
norms = self.points_norm_arr[tuple(points_key), ...]
mean = norms[:, :, 0].unsqueeze(1).unsqueeze(1).unsqueeze(1)
std = norms[:, :, 1].unsqueeze(1).unsqueeze(1).unsqueeze(1)
data -= mean
data /= std

return data


class CuboidBatchSampler(Sampler):
"""
Expand All @@ -1292,7 +1345,7 @@ class CuboidBatchSampler(Sampler):

Or e.g. just::

dataset = CuboidStackDataset(...)
dataset = CuboidArrayDataset(...)
sampler = CuboidBatchSampler(dataset=dataset, ...)
for batch in sampler:
data, labels = dataset[batch]
Expand Down
13 changes: 13 additions & 0 deletions cellfinder/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def main(
classify_callback: Optional[Callable[[int], None]] = None,
detect_finished_callback: Optional[Callable[[list], None]] = None,
classification_max_workers: int = 3,
normalize_channels: bool = False,
normalization_down_sampling: int = 32,
) -> List[Cell]:
"""
Parameters
Expand Down Expand Up @@ -183,6 +185,15 @@ def main(
classification_max_workers : int
The max number of sub-processes to use for data loading / processing
during classification. Defaults to 3.
normalize_channels : bool
If True, the signal and background data will be each normalized
to a mean of zero and standard deviation of 1 before classification.
Defaults to False.
normalization_down_sampling : int
If `normalize_channels` is True, the data arrays will be down-sampled
in the first axis by this value before calculating their statistics
before classification. E.g. a value of 2 means every second plane will
be used. Defaults to 32.
"""
from cellfinder.core.classify import classify
from cellfinder.core.detect import detect
Expand Down Expand Up @@ -246,6 +257,8 @@ def main(
network_depth,
callback=classify_callback,
max_workers=classification_max_workers,
normalize_channels=normalize_channels,
normalization_down_sampling=normalization_down_sampling,
)
else:
logger.info("No candidates, skipping classification")
Expand Down
63 changes: 63 additions & 0 deletions cellfinder/core/tools/image_processing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import numpy as np
import tqdm
from brainglobe_utils.general.numerical import is_even

from cellfinder.core import types
from cellfinder.core.tools.tools import get_data_converter


def crop_center_2d(img, crop_x=None, crop_y=None):
"""
Expand Down Expand Up @@ -85,3 +89,62 @@ def pad_center_2d(img, x_size=None, y_size=None, pad_mode="edge"):
y_front = y_back = 0

return np.pad(img, ((y_front, y_back), (x_front, x_back)), pad_mode)


def dataset_mean_std(
dataset: types.array,
sampling_factor: int,
show_progress: bool = True,
progress_desc="Estimating channel mean/std",
) -> tuple[float, float]:
"""
Calculates the mean and sample standard deviation of a 3d dataset using
Welford's online algorithm, sampling it along its first dimension.

:param dataset: A 3d dataset, such as a numpy or dask array.
:param sampling_factor: The sampling factor to sample along the first
dimension. E.g. if the dataset is 10 x 100 x 100 and `sampling_factor`
is 3, then we'll use planes 0, 3, 6, 9 for the calculation (40_000
data points).
:param show_progress: Whether to show a progress bar during the
calculation.
:param progress_desc: If showing a progress bar, the description to use in
it.
:return: A 2-tuple of `(mean, std)` estimate of the dataset.
"""
# based on https://en.wikipedia.org/wiki/
# Algorithms_for_calculating_variance#Welford's_online_algorithm
# and https://stackoverflow.com/q/56402955
plane_n = dataset.shape[1] * dataset.shape[2]
# get data converter from dataset to float64
converter = get_data_converter(dataset.dtype, np.float64)

count = 0
mean = np.array(0, dtype=np.float64)
sq_dist = np.array(0, dtype=np.float64)

# make it a list so tqdm will know its full size
samples = list(range(0, len(dataset), sampling_factor))
if show_progress:
it = tqdm.tqdm(samples, desc=progress_desc, unit="planes")
else:
it = samples

for i in it:
plane = converter(dataset[i, ...])
# flatten it
new_value = plane.reshape((plane_n,))

count += plane_n
delta = new_value - mean
mean += np.sum(delta) / count
delta2 = new_value - mean
sq_dist += np.sum(np.multiply(delta, delta2))

if count <= 1:
raise ValueError("Not enough data to compute the variance")

var_sample = sq_dist / (count - 1)
std = np.sqrt(var_sample)

return mean.item(), std.item()
11 changes: 10 additions & 1 deletion cellfinder/core/tools/tiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ def __init__(
self,
ch1_list: list[str],
channels: list[int],
channels_metadata: list[dict],
label: str | None = None,
):
self.ch1_list = natsort.natsorted(ch1_list)
self.label = label
self.channels = channels
self.channels_metadata = channels_metadata

def make_tifffile_list(self) -> list["TiffFile"]:
"""
Expand All @@ -46,7 +48,10 @@ def make_tifffile_list(self) -> list["TiffFile"]:
]

tiff_files = [
TiffFile(tiffFile, self.channels, self.label) for tiffFile in files
TiffFile(
tiffFile, self.channels, self.channels_metadata, self.label
)
for tiffFile in files
]
return tiff_files

Expand All @@ -60,6 +65,7 @@ def __init__(
self,
tiff_dir: str,
channels: list[int],
channels_metadata: list[dict],
label: str | None = None,
):
super(TiffDir, self).__init__(
Expand All @@ -69,6 +75,7 @@ def __init__(
if f.lower().endswith("ch" + str(channels[0]) + ".tif")
],
channels,
channels_metadata,
label,
)

Expand All @@ -88,10 +95,12 @@ def __init__(
self,
path: str,
channels: list[int],
channels_metadata: list[dict],
label: str | None = None,
):
self.path = path
self.channels = channels
self.channels_metadata = channels_metadata
self.label = label

def files_exist(self):
Expand Down
Loading
Loading