From 59db9a49282d96533a0fb26b3053044afcef3f99 Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Thu, 12 Mar 2026 16:17:57 -0400 Subject: [PATCH] Add support for dataset normalization. --- cellfinder/core/classify/classify.py | 28 +++++ cellfinder/core/classify/cube_generator.py | 55 +++++++++- cellfinder/core/main.py | 13 +++ cellfinder/core/tools/image_processing.py | 63 +++++++++++ cellfinder/core/tools/tiff.py | 11 +- cellfinder/core/train/train_yaml.py | 57 +++++++++- cellfinder/napari/curation.py | 50 +++++++-- cellfinder/napari/detect/detect.py | 13 +++ cellfinder/napari/detect/detect_containers.py | 10 ++ cellfinder/napari/train/train.py | 6 ++ cellfinder/napari/train/train_containers.py | 2 + pyproject.toml | 1 + tests/core/test_integration/test_detection.py | 4 +- tests/core/test_integration/test_train.py | 59 ++++++++++- .../test_unit/test_classify/test_cube_gen.py | 100 ++++++++++++++++++ .../test_tools/test_image_processing.py | 14 +++ .../training/training_with_stats.yaml | 19 ++++ tests/napari/test_curation.py | 20 ++++ 18 files changed, 508 insertions(+), 17 deletions(-) create mode 100644 tests/data/integration/training/training_with_stats.yaml diff --git a/cellfinder/core/classify/classify.py b/cellfinder/core/classify/classify.py index 896d017f..1a50993b 100644 --- a/cellfinder/core/classify/classify.py +++ b/cellfinder/core/classify/classify.py @@ -16,6 +16,7 @@ CuboidBatchSampler, ) from cellfinder.core.classify.tools import get_model +from cellfinder.core.tools.image_processing import dataset_mean_std from cellfinder.core.train.train_yaml import depth_type, models @@ -37,6 +38,8 @@ def main( pin_memory: bool = False, *, callback: Optional[Callable[[int], None]] = None, + normalize_channels: bool = False, + normalization_down_sampling: int = 32, ) -> List[Cell]: """ Parameters @@ -89,6 +92,14 @@ def main( callback : Callable[int], optional A callback function that is called during classification. Called with the batch number once that batch has been classified. + normalize_channels : bool + If True, the signal and background data will be each normalized + to a mean of zero and standard deviation of 1. Defaults to False. + normalization_down_sampling : int + If `normalize_channels` is True, the data arrays will be down-sampled + in the first axis by this value before calculating their statistics. + E.g. a value of 2 means every second plane will be used. Defaults to + 32. """ if signal_array.ndim != 3: raise IOError("Signal data must be 3D") @@ -102,10 +113,27 @@ def main( start_time = datetime.now() voxel_sizes = list(map(float, voxel_sizes)) + + signal_normalization = background_normalization = None + if normalize_channels: + logger.debug("Calculating channels norms") + signal_normalization = dataset_mean_std( + signal_array, normalization_down_sampling + ) + background_normalization = dataset_mean_std( + background_array, normalization_down_sampling + ) + logger.debug( + f"Signal channel norm is: {signal_normalization}. " + f"Background channel norm is: {background_normalization}" + ) + logger.debug("Initialising cube generator") dataset = CuboidArrayDataset( signal_array=signal_array, background_array=background_array, + signal_normalization=signal_normalization, + background_normalization=background_normalization, points=points, data_voxel_sizes=voxel_sizes, network_voxel_sizes=network_voxel_sizes, diff --git a/cellfinder/core/classify/cube_generator.py b/cellfinder/core/classify/cube_generator.py index 33dde7dc..e18e024f 100644 --- a/cellfinder/core/classify/cube_generator.py +++ b/cellfinder/core/classify/cube_generator.py @@ -1143,6 +1143,12 @@ class CuboidArrayDataset(CuboidThreadedDatasetBase): This determines how many multiples (or fractions) of `n` such planes to buffer, in addition to `n` that is always buffered. So e.g. `1` means `2n` and `0.5` means `1.5n`. `1` is a good default. + :param signal_normalization: None or a 2-tuple of `(mean, std)`. + If not None, the signal channel in the cubes will be normalized to the + provided mean and standard deviation. + :param background_normalization: None or a 2-tuple of `(mean, std)`. + If not None, the background channel in the cubes will be normalized to + the provided mean and standard deviation. """ def __init__( @@ -1150,6 +1156,8 @@ def __init__( signal_array: types.array, background_array: types.array | None, max_axis_0_cuboids_buffered: float = 0, + signal_normalization: None | tuple[float, float] = None, + background_normalization: None | tuple[float, float] = None, **kwargs, ): super().__init__(**kwargs) @@ -1210,6 +1218,9 @@ def __init__( ) self._set_output_data_dim_reordering(self.src_image_data) + self.signal_normalization = signal_normalization + self.background_normalization = background_normalization + def point_has_full_cuboid(self, point: Sequence[float]) -> bool: """ Takes a 3d point and returns whether a cuboid centered on this point @@ -1229,6 +1240,20 @@ def point_has_full_cuboid(self, point: Sequence[float]) -> bool: return True + def get_points_data(self, points_key: Sequence[int]) -> torch.Tensor: + data = super().get_points_data(points_key) + + if self.signal_normalization is not None: + mean, std = self.signal_normalization + data[..., 0] -= mean + data[..., 0] /= std + if self.background_normalization is not None: + mean, std = self.background_normalization + data[..., 1] -= mean + data[..., 1] /= std + + return data + class CuboidTiffDataset(CuboidThreadedDatasetBase): """ @@ -1241,6 +1266,12 @@ class CuboidTiffDataset(CuboidThreadedDatasetBase): The outer list is the number of points/samples. The inner lists is the number of channels (e.g. signal/background) for the given point. + :param points_normalization: None or a sequence of sequences of 2-tuples + of `(mean, std)`. + + If not None, each 2-tuple corresponds to a single filename in + `points_filenames` and that cube will be normalized by the given + mean and standard deviation before returning it. :param max_cuboids_buffered: Integer The number of the most recently accessed cuboids to cache so it isn't read from disk again. @@ -1249,6 +1280,9 @@ class CuboidTiffDataset(CuboidThreadedDatasetBase): def __init__( self, points_filenames: Sequence[Sequence[str]], + points_normalization: ( + Sequence[Sequence[tuple[float, float]]] | None + ) = None, max_cuboids_buffered: int = 0, **kwargs, ): @@ -1259,10 +1293,17 @@ def __init__( raise ValueError( "Points and filenames must have same number of elements" ) + if points_normalization is not None and len(points_filenames) != len( + points_normalization + ): + raise ValueError("Must have normalizations for all elements") self.num_channels = len(points_filenames[0]) filenames_arr = np.array(points_filenames).astype(np.str_) self.filenames_arr = filenames_arr + self.points_norm_arr = None + if points_normalization is not None: + self.points_norm_arr = torch.tensor(points_normalization) self.src_image_data = CachedTiffImage( points_arr=self.points_arr[:, :3], @@ -1273,6 +1314,18 @@ def __init__( ) self._set_output_data_dim_reordering(self.src_image_data) + def get_points_data(self, points_key: Sequence[int]) -> torch.Tensor: + data = super().get_points_data(points_key) + + if self.points_norm_arr is not None: + norms = self.points_norm_arr[tuple(points_key), ...] + mean = norms[:, :, 0].unsqueeze(1).unsqueeze(1).unsqueeze(1) + std = norms[:, :, 1].unsqueeze(1).unsqueeze(1).unsqueeze(1) + data -= mean + data /= std + + return data + class CuboidBatchSampler(Sampler): """ @@ -1292,7 +1345,7 @@ class CuboidBatchSampler(Sampler): Or e.g. just:: - dataset = CuboidStackDataset(...) + dataset = CuboidArrayDataset(...) sampler = CuboidBatchSampler(dataset=dataset, ...) for batch in sampler: data, labels = dataset[batch] diff --git a/cellfinder/core/main.py b/cellfinder/core/main.py index c28281e2..c5531513 100644 --- a/cellfinder/core/main.py +++ b/cellfinder/core/main.py @@ -49,6 +49,8 @@ def main( classify_callback: Optional[Callable[[int], None]] = None, detect_finished_callback: Optional[Callable[[list], None]] = None, classification_max_workers: int = 3, + normalize_channels: bool = False, + normalization_down_sampling: int = 32, ) -> List[Cell]: """ Parameters @@ -183,6 +185,15 @@ def main( classification_max_workers : int The max number of sub-processes to use for data loading / processing during classification. Defaults to 3. + normalize_channels : bool + If True, the signal and background data will be each normalized + to a mean of zero and standard deviation of 1 before classification. + Defaults to False. + normalization_down_sampling : int + If `normalize_channels` is True, the data arrays will be down-sampled + in the first axis by this value before calculating their statistics + before classification. E.g. a value of 2 means every second plane will + be used. Defaults to 32. """ from cellfinder.core.classify import classify from cellfinder.core.detect import detect @@ -246,6 +257,8 @@ def main( network_depth, callback=classify_callback, max_workers=classification_max_workers, + normalize_channels=normalize_channels, + normalization_down_sampling=normalization_down_sampling, ) else: logger.info("No candidates, skipping classification") diff --git a/cellfinder/core/tools/image_processing.py b/cellfinder/core/tools/image_processing.py index 742220d6..bf3811ac 100644 --- a/cellfinder/core/tools/image_processing.py +++ b/cellfinder/core/tools/image_processing.py @@ -1,6 +1,10 @@ import numpy as np +import tqdm from brainglobe_utils.general.numerical import is_even +from cellfinder.core import types +from cellfinder.core.tools.tools import get_data_converter + def crop_center_2d(img, crop_x=None, crop_y=None): """ @@ -85,3 +89,62 @@ def pad_center_2d(img, x_size=None, y_size=None, pad_mode="edge"): y_front = y_back = 0 return np.pad(img, ((y_front, y_back), (x_front, x_back)), pad_mode) + + +def dataset_mean_std( + dataset: types.array, + sampling_factor: int, + show_progress: bool = True, + progress_desc="Estimating channel mean/std", +) -> tuple[float, float]: + """ + Calculates the mean and sample standard deviation of a 3d dataset using + Welford's online algorithm, sampling it along its first dimension. + + :param dataset: A 3d dataset, such as a numpy or dask array. + :param sampling_factor: The sampling factor to sample along the first + dimension. E.g. if the dataset is 10 x 100 x 100 and `sampling_factor` + is 3, then we'll use planes 0, 3, 6, 9 for the calculation (40_000 + data points). + :param show_progress: Whether to show a progress bar during the + calculation. + :param progress_desc: If showing a progress bar, the description to use in + it. + :return: A 2-tuple of `(mean, std)` estimate of the dataset. + """ + # based on https://en.wikipedia.org/wiki/ + # Algorithms_for_calculating_variance#Welford's_online_algorithm + # and https://stackoverflow.com/q/56402955 + plane_n = dataset.shape[1] * dataset.shape[2] + # get data converter from dataset to float64 + converter = get_data_converter(dataset.dtype, np.float64) + + count = 0 + mean = np.array(0, dtype=np.float64) + sq_dist = np.array(0, dtype=np.float64) + + # make it a list so tqdm will know its full size + samples = list(range(0, len(dataset), sampling_factor)) + if show_progress: + it = tqdm.tqdm(samples, desc=progress_desc, unit="planes") + else: + it = samples + + for i in it: + plane = converter(dataset[i, ...]) + # flatten it + new_value = plane.reshape((plane_n,)) + + count += plane_n + delta = new_value - mean + mean += np.sum(delta) / count + delta2 = new_value - mean + sq_dist += np.sum(np.multiply(delta, delta2)) + + if count <= 1: + raise ValueError("Not enough data to compute the variance") + + var_sample = sq_dist / (count - 1) + std = np.sqrt(var_sample) + + return mean.item(), std.item() diff --git a/cellfinder/core/tools/tiff.py b/cellfinder/core/tools/tiff.py index 4f058ec6..b0925e17 100644 --- a/cellfinder/core/tools/tiff.py +++ b/cellfinder/core/tools/tiff.py @@ -26,11 +26,13 @@ def __init__( self, ch1_list: list[str], channels: list[int], + channels_metadata: list[dict], label: str | None = None, ): self.ch1_list = natsort.natsorted(ch1_list) self.label = label self.channels = channels + self.channels_metadata = channels_metadata def make_tifffile_list(self) -> list["TiffFile"]: """ @@ -46,7 +48,10 @@ def make_tifffile_list(self) -> list["TiffFile"]: ] tiff_files = [ - TiffFile(tiffFile, self.channels, self.label) for tiffFile in files + TiffFile( + tiffFile, self.channels, self.channels_metadata, self.label + ) + for tiffFile in files ] return tiff_files @@ -60,6 +65,7 @@ def __init__( self, tiff_dir: str, channels: list[int], + channels_metadata: list[dict], label: str | None = None, ): super(TiffDir, self).__init__( @@ -69,6 +75,7 @@ def __init__( if f.lower().endswith("ch" + str(channels[0]) + ".tif") ], channels, + channels_metadata, label, ) @@ -88,10 +95,12 @@ def __init__( self, path: str, channels: list[int], + channels_metadata: list[dict], label: str | None = None, ): self.path = path self.channels = channels + self.channels_metadata = channels_metadata self.label = label def files_exist(self): diff --git a/cellfinder/core/train/train_yaml.py b/cellfinder/core/train/train_yaml.py index 2f6aa8a2..c46f846b 100644 --- a/cellfinder/core/train/train_yaml.py +++ b/cellfinder/core/train/train_yaml.py @@ -250,6 +250,13 @@ def training_parse(): action="store_true", help="Save training progress to a .csv file", ) + training_parser.add_argument( + "--normalize-channels", + dest="normalize_channels", + action="store_true", + help="Normalize the training data to the mean/std of the datasets " + "from which the cubes came from", + ) training_parser = misc_parse(training_parser) training_parser = download_parser(training_parser) @@ -276,8 +283,25 @@ def get_tiff_files(yaml_contents: list[dict]) -> list[list[TiffFile]]: for d in yaml_contents: if d["bg_channel"] < 0: channels = [d["signal_channel"]] + channels_metadata = [ + {}, + ] else: channels = [d["signal_channel"], d["bg_channel"]] + channels_metadata = [{}, {}] + + if "signal_mean" in d: + channels_metadata[0] = { + "mean": float(d["signal_mean"]), + "std": float(d["signal_std"]), + } + # if we have norm for signal we must have for background + if "signal_mean" in d and d["bg_channel"] >= 0: + channels_metadata[1] = { + "mean": float(d["bg_mean"]), + "std": float(d["bg_std"]), + } + if "cell_def" in d and d["cell_def"]: ch1_tiffs = [ os.path.join(d["cube_dir"], f) @@ -288,11 +312,14 @@ def get_tiff_files(yaml_contents: list[dict]) -> list[list[TiffFile]]: TiffList( find_relevant_tiffs(ch1_tiffs, d["cell_def"]), channels, + channels_metadata, d["type"], ) ) else: - tiff_lists.append(TiffDir(d["cube_dir"], channels, d["type"])) + tiff_lists.append( + TiffDir(d["cube_dir"], channels, channels_metadata, d["type"]) + ) tiff_files = [tiff_dir.make_tifffile_list() for tiff_dir in tiff_lists] return tiff_files @@ -300,14 +327,14 @@ def get_tiff_files(yaml_contents: list[dict]) -> list[list[TiffFile]]: def make_tiff_lists( tiff_files: list[list[TiffFile]], -) -> tuple[list[list[str]], list[Cell]]: +) -> tuple[list[tuple[list[str], list[dict]]], list[Cell]]: cells = [] filenames = [] for group in tiff_files: for image in group: - filenames.append(image.img_files) + filenames.append((image.img_files, image.channels_metadata)) cells.append(image.as_cell()) return filenames, cells @@ -346,22 +373,39 @@ def cli(): no_save_checkpoints=args.no_save_checkpoints, save_progress=args.save_progress, epochs=args.epochs, + normalize_channels=args.normalize_channels, ) def get_dataloader( cells: list[Cell], - filenames: list[list[str]], + filenames: list[tuple[list[str], list[dict]]], batch_size: int, n_processes: int, pin_memory: bool, auto_shuffle: bool, augment: bool, augment_likelihood: float, + normalize_channels: bool, ) -> tuple[DataLoader, CuboidTiffDataset]: + points_filenames = [f[0] for f in filenames] + + points_norm = None + if normalize_channels: + points_norm = [] + for names, channels_norm in filenames: + # check the first channel for metadata. We expect all or none + # of the channels to have metadata + if not channels_norm[0]: + raise ValueError(f"Data mean and std not found for {names}") + + norms = [(ch["mean"], ch["std"]) for ch in channels_norm] + points_norm.append(norms) + dataset = CuboidTiffDataset( points=cells, - points_filenames=filenames, + points_filenames=points_filenames, + points_normalization=points_norm, data_voxel_sizes=(1, 1, 1), network_voxel_sizes=(1, 1, 1), network_cuboid_voxels=(CUBE_DEPTH, CUBE_HEIGHT, CUBE_WIDTH), @@ -408,6 +452,7 @@ def run( max_workers: int = 3, pin_memory: bool = True, augment_likelihood: float = 0.9, + normalize_channels: bool = False, ): start_time = datetime.now() @@ -465,6 +510,7 @@ def run( auto_shuffle=False, augment=False, augment_likelihood=augment_likelihood, + normalize_channels=normalize_channels, ) # for saving checkpoints @@ -485,6 +531,7 @@ def run( auto_shuffle=True, augment=not no_augment, augment_likelihood=augment_likelihood, + normalize_channels=normalize_channels, ) callbacks = [] diff --git a/cellfinder/napari/curation.py b/cellfinder/napari/curation.py index 78b59e1f..e90a3f60 100644 --- a/cellfinder/napari/curation.py +++ b/cellfinder/napari/curation.py @@ -17,6 +17,7 @@ add_button, add_combobox, add_float_box, + add_int_box, ) from qtpy import QtCore from qtpy.QtWidgets import ( @@ -32,6 +33,7 @@ CuboidArrayDataset, CuboidBatchSampler, ) +from cellfinder.core.tools.image_processing import dataset_mean_std # Constants used throughout WINDOW_HEIGHT = 750 @@ -68,6 +70,7 @@ def __init__( self.save_empty_cubes = save_empty_cubes self.max_ram = max_ram self.voxel_sizes = [5, 2, 2] + self.normalization_down_sampling = 32 self.batch_size = 64 self.viewer = viewer @@ -176,6 +179,9 @@ def setup_main_layout(self): def _set_voxel_size(self, value: float, index: int) -> None: self.voxel_sizes[index] = value + def _set_normalization_down_sampling(self, value: int) -> None: + self.normalization_down_sampling = value + def add_loading_panel(self, row: int, column: int = 0): self.load_data_panel = QGroupBox("Load data") self.load_data_layout = QGridLayout() @@ -233,32 +239,46 @@ def add_loading_panel(self, row: int, column: int = 0): ) box_x.valueChanged.connect(partial(self._set_voxel_size, index=2)) self.voxel_sizes_boxes = box_z, box_y, box_x + + box_norm = add_int_box( + self.load_data_layout, + self.normalization_down_sampling, + 1, + 1000, + "Normalization down-sampling", + 6, + tooltip="Down-sampling factor of the z-dimension used to calculate" + " the mean and std of the dataset. Used to normalize the " + "channels during training.", + ) + box_norm.valueChanged.connect(self._set_normalization_down_sampling) + self.norm_sampling_box = box_norm self.training_data_cell_choice, _ = add_combobox( self.load_data_layout, "Training data (cells)", self.point_layer_names, - 6, + 7, callback=self.set_training_data_cell, ) self.training_data_non_cell_choice, _ = add_combobox( self.load_data_layout, "Training_data (non_cells)", self.point_layer_names, - row=7, + row=8, callback=self.set_training_data_non_cell, ) self.mark_as_cell_button = add_button( "Mark as cell(s)", self.load_data_layout, self.mark_as_cell, - row=8, + row=9, tooltip="Mark all selected points as non cell. Shortcut: 'c'", ) self.mark_as_non_cell_button = add_button( "Mark as non cell(s)", self.load_data_layout, self.mark_as_non_cell, - row=8, + row=9, column=1, tooltip="Mark all selected points as non cell. Shortcut: 'x'", ) @@ -266,13 +286,13 @@ def add_loading_panel(self, row: int, column: int = 0): "Add training data layers", self.load_data_layout, self.add_training_data, - row=9, + row=10, ) self.save_training_data_button = add_button( "Save training data", self.load_data_layout, self.save_training_data, - row=9, + row=10, column=1, ) self.load_data_layout.setColumnMinimumWidth(0, COLUMN_WIDTH) @@ -632,7 +652,17 @@ def convert_layers_to_cells(self): self.cells_to_extract = list(set(self.cells_to_extract)) self.non_cells_to_extract = list(set(self.non_cells_to_extract)) + def _calculate_channel_stats(self): + signal_stat = dataset_mean_std( + self.signal_layer.data, self.normalization_down_sampling + ) + bg_stat = dataset_mean_std( + self.background_layer.data, self.normalization_down_sampling + ) + return signal_stat, bg_stat + def __save_yaml_file(self): + signal_stat, bg_stat = self._calculate_channel_stats() yaml_section = [ { "cube_dir": str(self.cell_cube_dir), @@ -640,6 +670,10 @@ def __save_yaml_file(self): "type": "cell", "signal_channel": 0, "bg_channel": 1, + "signal_mean": signal_stat[0], + "signal_std": signal_stat[1], + "bg_mean": bg_stat[0], + "bg_std": bg_stat[1], }, { "cube_dir": str(self.no_cell_cube_dir), @@ -647,6 +681,10 @@ def __save_yaml_file(self): "type": "no_cell", "signal_channel": 0, "bg_channel": 1, + "signal_mean": signal_stat[0], + "signal_std": signal_stat[1], + "bg_mean": bg_stat[0], + "bg_std": bg_stat[1], }, ] diff --git a/cellfinder/napari/detect/detect.py b/cellfinder/napari/detect/detect.py index 51ef67e4..323ed737 100644 --- a/cellfinder/napari/detect/detect.py +++ b/cellfinder/napari/detect/detect.py @@ -263,6 +263,8 @@ def widget( use_pre_trained_weights: bool, trained_model: Optional[Path], classification_batch_size: int, + normalize_channels: bool, + normalization_down_sampling: int, misc_options, start_plane: int, end_plane: int, @@ -347,6 +349,15 @@ def widget( the models. For performance-critical applications, tune to maximize memory usage without running out. Check your GPU/CPU memory to verify it's not full + normalize_channels : bool + For classification only - whether to normalize the cubes to the + mean/std of the image channels before classification. If the model + used for classification was trained on normalized data, this should + be enabled. + normalization_down_sampling : int + If normalizing the cubes is enabled, the input channels will be + down-sampled in z by this value before calculating their mean/std. + E.g. a value of 2 means every second z plane will be used. start_plane : int First plane to process (to process a subset of the data) end_plane : int @@ -430,6 +441,8 @@ def widget( use_pre_trained_weights, trained_model, classification_batch_size, + normalize_channels, + normalization_down_sampling, ) if analyse_local: diff --git a/cellfinder/napari/detect/detect_containers.py b/cellfinder/napari/detect/detect_containers.py index ac6e76e0..f1d36cc0 100644 --- a/cellfinder/napari/detect/detect_containers.py +++ b/cellfinder/napari/detect/detect_containers.py @@ -129,6 +129,8 @@ class ClassificationInputs(InputContainer): use_pre_trained_weights: bool = True trained_model: Optional[Path] = Path.home() classification_batch_size: int = 64 + normalize_channels: bool = False + normalization_down_sampling: int = 32 def as_core_arguments(self) -> dict: args = super().as_core_arguments() @@ -150,6 +152,14 @@ def widget_representation(cls) -> dict: value=cls.defaults()["classification_batch_size"], label="Batch size (classification)", ), + normalize_channels=dict( + value=cls.defaults()["normalize_channels"], + label="Normalize data", + ), + normalization_down_sampling=dict( + value=cls.defaults()["normalization_down_sampling"], + label="Normalization down-sampling", + ), ) diff --git a/cellfinder/napari/train/train.py b/cellfinder/napari/train/train.py index 79d92b6b..830bc9a6 100644 --- a/cellfinder/napari/train/train.py +++ b/cellfinder/napari/train/train.py @@ -59,6 +59,7 @@ def widget( training_options: dict, continue_training: bool, augment: bool, + normalize_channels: bool, tensorboard: bool, save_checkpoints: bool, save_progress: bool, @@ -93,6 +94,10 @@ def widget( this will continue from the pretrained model augment : bool Augment the training data to improve generalisation + normalize_channels : bool + Whether to normalize the cubes by the mean/std of their origin + dataset. If True, the yaml files must include the mean/std of + the origin dataset. tensorboard : bool Log to output_directory/tensorboard save_checkpoints : bool @@ -135,6 +140,7 @@ def widget( learning_rate, batch_size, test_fraction, + normalize_channels, ) misc_training_inputs = MiscTrainingInputs(number_of_free_cpus) diff --git a/cellfinder/napari/train/train_containers.py b/cellfinder/napari/train/train_containers.py index c77ece05..98929a0f 100644 --- a/cellfinder/napari/train/train_containers.py +++ b/cellfinder/napari/train/train_containers.py @@ -81,6 +81,7 @@ class OptionalTrainingInputs(InputContainer): learning_rate: float = 1e-4 batch_size: int = 16 test_fraction: float = 0.1 + normalize_channels: bool = False def as_core_arguments(self) -> dict: arguments = super().as_core_arguments() @@ -105,6 +106,7 @@ def widget_representation(cls) -> dict: test_fraction=cls._custom_widget( "test_fraction", step=0.05, min=0.05, max=0.95 ), + normalize_channels=cls._custom_widget("normalize_channels"), ) diff --git a/pyproject.toml b/pyproject.toml index b6d357d2..2c9eb66f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dev = [ "pytest-qt", "pytest-timeout", "pytest", + "PyYAML", "tox", "pooch >= 1", ] diff --git a/tests/core/test_integration/test_detection.py b/tests/core/test_integration/test_detection.py index 5e145875..544d2b7d 100644 --- a/tests/core/test_integration/test_detection.py +++ b/tests/core/test_integration/test_detection.py @@ -178,13 +178,15 @@ def detect_finished_callback(points): assert npoints == 120, f"Expected 120 points, found {npoints}" -def test_synthetic_data(synthetic_bright_spots, no_free_cpus): +@pytest.mark.parametrize("normalize", [True, False]) +def test_synthetic_data(synthetic_bright_spots, no_free_cpus, normalize): signal_array, background_array = synthetic_bright_spots detected = main( signal_array, background_array, voxel_sizes, n_free_cpus=no_free_cpus, + normalize_channels=normalize, ) assert len(detected) == 8 diff --git a/tests/core/test_integration/test_train.py b/tests/core/test_integration/test_train.py index 64e0d443..b7c710e5 100644 --- a/tests/core/test_integration/test_train.py +++ b/tests/core/test_integration/test_train.py @@ -1,7 +1,7 @@ import os -import sys import pytest +from pytest_mock.plugin import MockerFixture from cellfinder.core.train.train_yaml import cli as train_run @@ -11,6 +11,7 @@ cell_cubes = os.path.join(data_dir, "cells") non_cell_cubes = os.path.join(data_dir, "non_cells") training_yaml_file = os.path.join(data_dir, "training.yaml") +training_yaml_file_stats = os.path.join(data_dir, "training_with_stats.yaml") EPOCHS = "2" @@ -20,7 +21,7 @@ @pytest.mark.slow -def test_train(tmpdir): +def test_train(mocker, tmpdir): tmpdir = str(tmpdir) train_args = [ @@ -32,8 +33,60 @@ def test_train(tmpdir): "--epochs", EPOCHS, ] - sys.argv = train_args + mocker.patch("sys.argv", train_args) + train_run() model_file = os.path.join(tmpdir, "model.keras") assert os.path.exists(model_file) + + +@pytest.mark.parametrize("normalize", [True, False]) +@pytest.mark.parametrize("has_norms", [True, False]) +def test_train_normalization_missing_stats( + mocker: MockerFixture, tmpdir, has_norms, normalize +): + tmpdir = str(tmpdir) + + train_args = [ + "cellfinder_train", + "-y", + training_yaml_file_stats if has_norms else training_yaml_file, + "-o", + tmpdir, + "--epochs", + EPOCHS, + ] + if normalize: + train_args.append("--normalize-channels") + + mocker.patch("sys.argv", train_args) + get_model = mocker.patch( + "cellfinder.core.train.train_yaml.get_model", autospec=True + ) + + if normalize and not has_norms: + # if the yaml doesn't have normalization info an error will be raised + with pytest.raises(ValueError): + train_run() + else: + train_run() + # get the data sets passed to fit() to verify if it has norm data + # there's no clear name property of the mock fit call, so use its repr + (fit_mock,) = [ + m + for m in get_model.mock_calls + if repr(m).startswith("call().fit(") + ] + train_dataset = fit_mock.kwargs["x"].dataset + val_dataset = fit_mock.kwargs["validation_data"].dataset + + if normalize: + # if we normalize, the normalization data should be in dataset + assert train_dataset.points_norm_arr is not None + assert val_dataset.points_norm_arr is not None + else: + # otherwise, no normalization data should have been passed, even if + # the yaml has it + assert train_dataset.points_norm_arr is None + assert val_dataset.points_norm_arr is None diff --git a/tests/core/test_unit/test_classify/test_cube_gen.py b/tests/core/test_unit/test_classify/test_cube_gen.py index 8261666c..4d1ab756 100644 --- a/tests/core/test_unit/test_classify/test_cube_gen.py +++ b/tests/core/test_unit/test_classify/test_cube_gen.py @@ -1246,3 +1246,103 @@ def test_points_unchanged(): assert (x, y, z) == (30, 33, 15) assert tp == Cell.CELL assert i == 1 + + +def _get_volume_with_stats(normalize): + sig_norm = back_norm = None + sig_mean, sig_std = 222, 20 + back_mean, back_std = 555, 5 + if normalize: + sig_norm = sig_mean, sig_std + back_norm = back_mean, back_std + + volume = np.empty((20, 20, 30, 2), dtype=np.float32) + volume[..., 0] = np.random.normal(sig_mean, sig_std, (20, 20, 30)) + volume[..., 1] = np.random.normal(back_mean, back_std, (20, 20, 30)) + + return ( + volume, + sig_norm, + back_norm, + (sig_mean, sig_std), + (back_mean, back_std), + ) + + +def _check_cube_normalization( + stack, sig_mean, sig_std, back_mean, back_std, normalize +): + # get a single and a batch of 2 cubes + cube_1 = stack[0] + cube_2 = stack[[0, 0]] + for cube in [cube_1, cube_2]: + for ch, ex_mean, ex_std in [ + (0, sig_mean, sig_std), + (1, back_mean, back_std), + ]: + # get output stats + std, mean = torch.std_mean(cube[..., ch]) + + lower_mean = ex_mean * 0.8 + upper_mean = ex_mean * 1.2 + # if normalized, it should be standard normal + if normalize: + ex_mean, ex_std = 0, 1 + lower_mean = -0.2 + upper_mean = 0.2 + + assert lower_mean <= mean.item() < upper_mean + assert ex_std * 0.8 <= std.item() < ex_std * 1.2 + + +@pytest.mark.parametrize("normalize", [True, False]) +def test_array_image_data_normalization(normalize): + """ + Checks that the data returned by the CuboidArrayDataset is normalized + if requested, otherwise it shouldn't be normalized. + """ + volume, sig_norm, back_norm, sig_stat, back_stat = _get_volume_with_stats( + normalize + ) + + stack = CuboidArrayDataset( + points=[Cell((10, 10, 10), Cell.UNKNOWN)], + data_voxel_sizes=(1, 1, 1), + network_voxel_sizes=(1, 1, 1), + network_cuboid_voxels=(5, 5, 8), + axis_order=("x", "y", "z"), + output_axis_order=("x", "y", "z", "c"), + signal_array=volume[..., 0], + background_array=volume[..., 1], + signal_normalization=sig_norm, + background_normalization=back_norm, + ) + _check_cube_normalization(stack, *sig_stat, *back_stat, normalize) + + +@pytest.mark.parametrize("normalize", [True, False]) +def test_tiff_image_data_normalization(normalize, tmp_path): + """ + Checks that the data returned by the CuboidTiffDataset is normalized + if requested, otherwise it shouldn't be normalized. + """ + volume, sig_norm, back_norm, sig_stat, back_stat = _get_volume_with_stats( + normalize + ) + + points = [(10, 10, 10)] + cube_size = 5, 5, 8 + filenames, _, _ = to_tiff_cubes(volume, cube_size, points, tmp_path) + + tiffs = CuboidTiffDataset( + points=[Cell(p, Cell.UNKNOWN) for p in points], + data_voxel_sizes=(1, 1, 1), + network_voxel_sizes=(1, 1, 1), + network_cuboid_voxels=(5, 5, 8), + axis_order=("x", "y", "z"), + output_axis_order=("x", "y", "z", "c"), + points_filenames=filenames, + points_normalization=[[sig_norm, back_norm]] if normalize else None, + ) + + _check_cube_normalization(tiffs, *sig_stat, *back_stat, normalize) diff --git a/tests/core/test_unit/test_tools/test_image_processing.py b/tests/core/test_unit/test_tools/test_image_processing.py index 64ad4891..d56d189e 100644 --- a/tests/core/test_unit/test_tools/test_image_processing.py +++ b/tests/core/test_unit/test_tools/test_image_processing.py @@ -1,6 +1,7 @@ import random import numpy as np +import pytest from cellfinder.core.tools import image_processing as img_tools @@ -35,3 +36,16 @@ def test_pad_centre_2d(): img, x_size=new_x_shape, y_size=new_y_shape ) assert (new_y_shape, new_x_shape) == pad_img.shape + + +@pytest.mark.parametrize("progress", [True, False]) +def test_dataset_mean_std(progress): + # checks that dataset_mean_std correctly computes the std/mean + data = np.random.normal(100, 10, (10, 10, 10)) + + mean, std = img_tools.dataset_mean_std( + data, sampling_factor=2, show_progress=progress + ) + # give it enough room for estimation error + assert 90 < mean < 110 + assert 8 < std < 12 diff --git a/tests/data/integration/training/training_with_stats.yaml b/tests/data/integration/training/training_with_stats.yaml new file mode 100644 index 00000000..b2ea7b72 --- /dev/null +++ b/tests/data/integration/training/training_with_stats.yaml @@ -0,0 +1,19 @@ +data: +- bg_channel: 1 + cell_def: '' + cube_dir: tests/data/integration/training/cells + signal_channel: 0 + type: cell + signal_mean: 241.31 + signal_std: 154.92 + bg_mean: 650.94 + bg_std: 217.90 +- bg_channel: 1 + cell_def: '' + cube_dir: tests/data/integration/training/cells + signal_channel: 0 + type: no_cell + signal_mean: 231.28 + signal_std: 79.60 + bg_mean: 836.21 + bg_std: 348.35 diff --git a/tests/napari/test_curation.py b/tests/napari/test_curation.py index 7e12cc0f..a4fc5d15 100644 --- a/tests/napari/test_curation.py +++ b/tests/napari/test_curation.py @@ -4,6 +4,7 @@ import napari import numpy as np import pytest +import yaml from napari.layers import Image, Points from cellfinder.napari import sample_data @@ -46,6 +47,12 @@ def test_update_voxel_size(curation_widget: CurationWidget): assert curation_widget.voxel_sizes == [3, 4, 5] +def test_update_normalization_down_sampling(curation_widget: CurationWidget): + assert curation_widget.normalization_down_sampling == 32 + curation_widget.norm_sampling_box.setValue(8) + assert curation_widget.normalization_down_sampling == 8 + + def test_cell_marking(curation_widget, tmp_path): """ Check that marking cells and non-cells works as expected. @@ -96,6 +103,19 @@ def test_cell_marking(curation_widget, tmp_path): assert len(list((tmp_path / "non_cells").glob("*.tif"))) == 2 assert len(list((tmp_path / "cells").glob("*.tif"))) == 2 + with open(tmp_path / "training.yaml", "r") as fh: + yaml_data = yaml.safe_load(fh) + + for item in yaml_data["data"]: + assert "cube_dir" in item + assert "signal_channel" in item + assert "bg_channel" in item + assert "type" in item + assert "signal_mean" in item + assert "signal_std" in item + assert "bg_mean" in item + assert "bg_std" in item + @pytest.fixture def valid_curation_widget(make_napari_viewer) -> CurationWidget: