From 82481910fdf7515816217fa10b64d3df6ef56139 Mon Sep 17 00:00:00 2001 From: tanishaa7 Date: Sat, 28 Feb 2026 20:18:05 +0530 Subject: [PATCH 1/3] Handle empty cell candidates without crashing --- cellfinder/core/main.py | 7 +- cellfinder/core/train/train_yaml.py | 36 +++++++++- cellfinder/napari/detect/detect.py | 15 ++++ cellfinder/napari/detect/thread_worker.py | 19 +++-- cellfinder/napari/train/thread_worker.py | 71 +++++++++++++++++++ cellfinder/napari/train/train.py | 36 ++++------ tests/core/test_integration/test_detection.py | 41 +++++++++++ tests/napari/test_training.py | 6 +- 8 files changed, 202 insertions(+), 29 deletions(-) create mode 100644 cellfinder/napari/train/thread_worker.py diff --git a/cellfinder/core/main.py b/cellfinder/core/main.py index 25c0c793..21a86038 100644 --- a/cellfinder/core/main.py +++ b/cellfinder/core/main.py @@ -243,5 +243,10 @@ def main( callback=classify_callback, ) else: - logger.info("No candidates, skipping classification") + logger.warning( + "No cell candidates were detected. " + "Classification will be skipped. " + "This may occur with very small images or images " + "with no bright structures." + ) return points diff --git a/cellfinder/core/train/train_yaml.py b/cellfinder/core/train/train_yaml.py index f87a4203..c219ff5a 100644 --- a/cellfinder/core/train/train_yaml.py +++ b/cellfinder/core/train/train_yaml.py @@ -13,7 +13,7 @@ ) from datetime import datetime from pathlib import Path -from typing import Dict, Literal +from typing import Callable, Dict, Literal, Optional from brainglobe_utils.general.numerical import ( check_positive_float, @@ -26,6 +26,7 @@ from brainglobe_utils.IO.cells import find_relevant_tiffs from brainglobe_utils.IO.yaml import read_yaml_section from fancylog import fancylog +import keras from keras.callbacks import CSVLogger, ModelCheckpoint, TensorBoard from sklearn.model_selection import train_test_split @@ -297,6 +298,28 @@ def cli(): ) +class EpochEndCallback(keras.callbacks.Callback): + """Keras callback that reports epoch progress via a callback function.""" + + def __init__( + self, + progress_callback: Callable[[str, int, int], None], + epochs: int, + ): + super().__init__() + self._progress_callback = progress_callback + self._epochs = epochs + + def on_epoch_end(self, epoch, logs=None): + # epoch is 0-indexed in Keras + current = epoch + 1 + self._progress_callback( + f"Training epoch {current}/{self._epochs}", + current, + self._epochs, + ) + + def run( output_dir, yaml_file, @@ -316,9 +339,13 @@ def run( no_save_checkpoints=False, save_progress=False, epochs=100, + progress_callback: Optional[Callable[[str, int, int], None]] = None, ): start_time = datetime.now() + if progress_callback is not None: + progress_callback("Preparing training", 0, 1) + ensure_directory_exists(output_dir) model_weights = prep_model_weights( model_weights=model_weights, @@ -430,6 +457,10 @@ def run( csv_logger = CSVLogger(csv_filepath) callbacks.append(csv_logger) + if progress_callback is not None: + progress_callback("Beginning training", 0, epochs) + callbacks.append(EpochEndCallback(progress_callback, epochs)) + logger.info("Beginning training.") # Keras 3.0: `use_multiprocessing` input is set in the # `training_generator` (False by default) @@ -447,6 +478,9 @@ def run( logger.info("Saving model") model.save(output_dir / "model.keras") + if progress_callback is not None: + progress_callback("Training finished", epochs, epochs) + logger.info( "Finished training, " "Total time taken: %s", datetime.now() - start_time, diff --git a/cellfinder/napari/detect/detect.py b/cellfinder/napari/detect/detect.py index f801d462..499fc34e 100644 --- a/cellfinder/napari/detect/detect.py +++ b/cellfinder/napari/detect/detect.py @@ -157,6 +157,13 @@ def get_results_callback( if skip_classification: # after detection w/o classification, everything is unknown def done_func(points): + if not points: + show_info( + "No cell candidates were detected. " + "Try adjusting detection parameters or " + "using a larger image." + ) + return add_single_layer( points, viewer=viewer, @@ -168,6 +175,14 @@ def done_func(points): else: # after classification we have either cell or unknown def done_func(points): + if not points: + show_info( + "No cell candidates were detected. " + "Classification was skipped. " + "Try adjusting detection parameters or " + "using a larger image." + ) + return add_classified_layers( points, viewer=viewer, diff --git a/cellfinder/napari/detect/thread_worker.py b/cellfinder/napari/detect/thread_worker.py index d6a9431f..39ed8017 100644 --- a/cellfinder/napari/detect/thread_worker.py +++ b/cellfinder/napari/detect/thread_worker.py @@ -70,12 +70,23 @@ def detect_callback(plane: int) -> None: def detect_finished_callback(points: list) -> None: self.npoints_detected = len(points) if not self.classification_inputs.skip_classification: - self.update_progress_bar.emit( - "Setting up classification...", 1, 0 - ) + if self.npoints_detected == 0: + self.update_progress_bar.emit( + "No cell candidates detected, " + "skipping classification", + 1, + 1, + ) + else: + self.update_progress_bar.emit( + "Setting up classification...", 1, 0 + ) def classify_callback(batch: int) -> None: - if not self.classification_inputs.skip_classification: + if ( + not self.classification_inputs.skip_classification + and self.npoints_detected > 0 + ): self.update_progress_bar.emit( "Classifying cells", # Default cellfinder-core batch size is 64. diff --git a/cellfinder/napari/train/thread_worker.py b/cellfinder/napari/train/thread_worker.py new file mode 100644 index 00000000..8719c8c1 --- /dev/null +++ b/cellfinder/napari/train/thread_worker.py @@ -0,0 +1,71 @@ +from magicgui.widgets import ProgressBar +from napari.qt.threading import WorkerBase, WorkerBaseSignals +from qtpy.QtCore import Signal + +from cellfinder.core.train.train_yaml import run as train_yaml_run + +from .train_containers import ( + MiscTrainingInputs, + OptionalNetworkInputs, + OptionalTrainingInputs, + TrainingDataInputs, +) + + +class MyTrainingWorkerSignals(WorkerBaseSignals): + """ + Signals used by the TrainingWorker class below. + """ + + # Emits (label, max, value) for the progress bar + update_progress_bar = Signal(str, int, int) + + +class TrainingWorker(WorkerBase): + """ + Runs cellfinder training in a separate thread, to prevent GUI blocking. + + Also handles callbacks between the worker thread and main napari GUI + thread to update a progress bar. + """ + + def __init__( + self, + training_data_inputs: TrainingDataInputs, + optional_network_inputs: OptionalNetworkInputs, + optional_training_inputs: OptionalTrainingInputs, + misc_training_inputs: MiscTrainingInputs, + ): + super().__init__(SignalsClass=MyTrainingWorkerSignals) + self.training_data_inputs = training_data_inputs + self.optional_network_inputs = optional_network_inputs + self.optional_training_inputs = optional_training_inputs + self.misc_training_inputs = misc_training_inputs + + def connect_progress_bar_callback(self, progress_bar: ProgressBar): + """ + Connects the progress bar to the worker so that updates are shown + on the bar. + """ + + def update_progress_bar(label: str, max: int, value: int): + progress_bar.label = label + progress_bar.max = max + progress_bar.value = value + + self.update_progress_bar.connect(update_progress_bar) + + def work(self) -> None: + self.update_progress_bar.emit("Preparing training...", 1, 0) + + def progress_callback(label: str, value: int, max_val: int) -> None: + self.update_progress_bar.emit(label, max_val, value) + + train_yaml_run( + **self.training_data_inputs.as_core_arguments(), + **self.optional_network_inputs.as_core_arguments(), + **self.optional_training_inputs.as_core_arguments(), + **self.misc_training_inputs.as_core_arguments(), + progress_callback=progress_callback, + ) + self.update_progress_bar.emit("Training finished!", 1, 1) diff --git a/cellfinder/napari/train/train.py b/cellfinder/napari/train/train.py index 79d92b6b..fb7afadc 100644 --- a/cellfinder/napari/train/train.py +++ b/cellfinder/napari/train/train.py @@ -2,12 +2,10 @@ from typing import Optional from magicgui import magicgui -from magicgui.widgets import FunctionGui, PushButton -from napari.qt.threading import thread_worker +from magicgui.widgets import FunctionGui, ProgressBar, PushButton from napari.utils.notifications import show_info from qtpy.QtWidgets import QScrollArea -from cellfinder.core.train.train_yaml import run as train_yaml from cellfinder.napari.utils import cellfinder_header, html_label_widget from .train_containers import ( @@ -16,26 +14,12 @@ OptionalTrainingInputs, TrainingDataInputs, ) - - -@thread_worker -def run_training( - training_data_inputs: TrainingDataInputs, - optional_network_inputs: OptionalNetworkInputs, - optional_training_inputs: OptionalTrainingInputs, - misc_training_inputs: MiscTrainingInputs, -): - show_info("Running training...") - train_yaml( - **training_data_inputs.as_core_arguments(), - **optional_network_inputs.as_core_arguments(), - **optional_training_inputs.as_core_arguments(), - **misc_training_inputs.as_core_arguments(), - ) - show_info("Training finished!") +from .thread_worker import TrainingWorker def training_widget() -> FunctionGui: + progress_bar = ProgressBar() + @magicgui( training_label=html_label_widget("Network training", tag="h3"), **TrainingDataInputs.widget_representation(), @@ -143,16 +127,25 @@ def widget( show_info("Please select a YAML file for training") else: show_info("Starting training process...") - worker = run_training( + worker = TrainingWorker( training_data_inputs, optional_network_inputs, optional_training_inputs, misc_training_inputs, ) + + def on_finished(): + show_info("Training finished!") + + worker.returned.connect(on_finished) + worker.connect_progress_bar_callback(progress_bar) worker.start() widget.native.layout().insertWidget(0, cellfinder_header()) + # Insert progress bar before the call and reset buttons + widget.insert(widget.index("number_of_free_cpus") + 1, progress_bar) + @widget.reset_button.changed.connect def restore_defaults(): defaults = { @@ -171,3 +164,4 @@ def restore_defaults(): widget._widget._qwidget = scroll return widget + diff --git a/tests/core/test_integration/test_detection.py b/tests/core/test_integration/test_detection.py index 5e145875..762637a9 100644 --- a/tests/core/test_integration/test_detection.py +++ b/tests/core/test_integration/test_detection.py @@ -336,3 +336,44 @@ def test_detection_plane_too_small(synthetic_spot_clusters, y, x): voxel_sizes=(1, 1, 1), ball_xy_size=50, ) + + +def test_detection_no_candidates(no_free_cpus): + """ + Test that detection on a very small / empty image returns an empty list + without raising an error (issue #344). + """ + # Create a small blank image with no bright structures + signal_array = np.zeros((15, 100, 100), dtype=np.uint16) + background_array = np.zeros((15, 100, 100), dtype=np.uint16) + + # Should return an empty list, not raise an error + detected = main( + signal_array, + background_array, + voxel_sizes, + n_free_cpus=no_free_cpus, + skip_classification=True, + ) + assert isinstance(detected, list) + assert len(detected) == 0 + + +def test_detection_no_candidates_with_classification(no_free_cpus): + """ + Test that detection + classification on an empty image returns an empty + list without raising an error (issue #344). + """ + signal_array = np.zeros((15, 100, 100), dtype=np.uint16) + background_array = np.zeros((15, 100, 100), dtype=np.uint16) + + # Should return an empty list even with classification enabled + detected = main( + signal_array, + background_array, + voxel_sizes, + n_free_cpus=no_free_cpus, + ) + assert isinstance(detected, list) + assert len(detected) == 0 + diff --git a/tests/napari/test_training.py b/tests/napari/test_training.py index 1ff7d509..300c0c57 100644 --- a/tests/napari/test_training.py +++ b/tests/napari/test_training.py @@ -61,7 +61,9 @@ def test_run_with_virtual_yaml_files(get_training_widget): """ Checks that training is run with expected set of parameters. """ - with patch("cellfinder.napari.train.train.run_training") as run_training: + with patch( + "cellfinder.napari.train.train.TrainingWorker" + ) as MockWorker: # make default input valid - need yaml files (they don't technically # have to exist) virtual_yaml_files = ( @@ -83,7 +85,7 @@ def test_run_with_virtual_yaml_files(get_training_widget): expected_network_args.trained_model = None expected_network_args.model_weights = None - run_training.assert_called_once_with( + MockWorker.assert_called_once_with( expected_training_args, expected_network_args, expected_optional_training_args, From 0b5e05e2b5e6807e8a43eb7e209c7f9ebae3508f Mon Sep 17 00:00:00 2001 From: tanishaa7 Date: Thu, 5 Mar 2026 00:08:00 +0530 Subject: [PATCH 2/3] Address review comments in train_yaml --- cellfinder/core/train/train_yaml.py | 33 +++++++---------------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/cellfinder/core/train/train_yaml.py b/cellfinder/core/train/train_yaml.py index c219ff5a..733f3791 100644 --- a/cellfinder/core/train/train_yaml.py +++ b/cellfinder/core/train/train_yaml.py @@ -26,8 +26,7 @@ from brainglobe_utils.IO.cells import find_relevant_tiffs from brainglobe_utils.IO.yaml import read_yaml_section from fancylog import fancylog -import keras -from keras.callbacks import CSVLogger, ModelCheckpoint, TensorBoard +from keras.callbacks import CSVLogger, LambdaCallback, ModelCheckpoint, TensorBoard from sklearn.model_selection import train_test_split import cellfinder.core as package_for_log @@ -298,28 +297,6 @@ def cli(): ) -class EpochEndCallback(keras.callbacks.Callback): - """Keras callback that reports epoch progress via a callback function.""" - - def __init__( - self, - progress_callback: Callable[[str, int, int], None], - epochs: int, - ): - super().__init__() - self._progress_callback = progress_callback - self._epochs = epochs - - def on_epoch_end(self, epoch, logs=None): - # epoch is 0-indexed in Keras - current = epoch + 1 - self._progress_callback( - f"Training epoch {current}/{self._epochs}", - current, - self._epochs, - ) - - def run( output_dir, yaml_file, @@ -459,7 +436,13 @@ def run( if progress_callback is not None: progress_callback("Beginning training", 0, epochs) - callbacks.append(EpochEndCallback(progress_callback, epochs)) + callbacks.append( + LambdaCallback( + on_epoch_end=lambda epoch, logs: progress_callback( + f"Training epoch {epoch+1}/{epochs}", epoch + 1, epochs + ) + ) + ) logger.info("Beginning training.") # Keras 3.0: `use_multiprocessing` input is set in the From 730860ebe981071b48e26947e729e6b88eef5661 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:50:15 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- cellfinder/core/train/train_yaml.py | 7 ++++++- cellfinder/napari/train/train.py | 3 +-- tests/core/test_integration/test_detection.py | 1 - tests/napari/test_training.py | 4 +--- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/cellfinder/core/train/train_yaml.py b/cellfinder/core/train/train_yaml.py index 733f3791..19f299f7 100644 --- a/cellfinder/core/train/train_yaml.py +++ b/cellfinder/core/train/train_yaml.py @@ -26,7 +26,12 @@ from brainglobe_utils.IO.cells import find_relevant_tiffs from brainglobe_utils.IO.yaml import read_yaml_section from fancylog import fancylog -from keras.callbacks import CSVLogger, LambdaCallback, ModelCheckpoint, TensorBoard +from keras.callbacks import ( + CSVLogger, + LambdaCallback, + ModelCheckpoint, + TensorBoard, +) from sklearn.model_selection import train_test_split import cellfinder.core as package_for_log diff --git a/cellfinder/napari/train/train.py b/cellfinder/napari/train/train.py index fb7afadc..e12e4301 100644 --- a/cellfinder/napari/train/train.py +++ b/cellfinder/napari/train/train.py @@ -8,13 +8,13 @@ from cellfinder.napari.utils import cellfinder_header, html_label_widget +from .thread_worker import TrainingWorker from .train_containers import ( MiscTrainingInputs, OptionalNetworkInputs, OptionalTrainingInputs, TrainingDataInputs, ) -from .thread_worker import TrainingWorker def training_widget() -> FunctionGui: @@ -164,4 +164,3 @@ def restore_defaults(): widget._widget._qwidget = scroll return widget - diff --git a/tests/core/test_integration/test_detection.py b/tests/core/test_integration/test_detection.py index 762637a9..22c5eb23 100644 --- a/tests/core/test_integration/test_detection.py +++ b/tests/core/test_integration/test_detection.py @@ -376,4 +376,3 @@ def test_detection_no_candidates_with_classification(no_free_cpus): ) assert isinstance(detected, list) assert len(detected) == 0 - diff --git a/tests/napari/test_training.py b/tests/napari/test_training.py index 300c0c57..0130aecd 100644 --- a/tests/napari/test_training.py +++ b/tests/napari/test_training.py @@ -61,9 +61,7 @@ def test_run_with_virtual_yaml_files(get_training_widget): """ Checks that training is run with expected set of parameters. """ - with patch( - "cellfinder.napari.train.train.TrainingWorker" - ) as MockWorker: + with patch("cellfinder.napari.train.train.TrainingWorker") as MockWorker: # make default input valid - need yaml files (they don't technically # have to exist) virtual_yaml_files = (