diff --git a/cellfinder/core/main.py b/cellfinder/core/main.py index 25c0c793..21a86038 100644 --- a/cellfinder/core/main.py +++ b/cellfinder/core/main.py @@ -243,5 +243,10 @@ def main( callback=classify_callback, ) else: - logger.info("No candidates, skipping classification") + logger.warning( + "No cell candidates were detected. " + "Classification will be skipped. " + "This may occur with very small images or images " + "with no bright structures." + ) return points diff --git a/cellfinder/core/train/train_yaml.py b/cellfinder/core/train/train_yaml.py index f87a4203..19f299f7 100644 --- a/cellfinder/core/train/train_yaml.py +++ b/cellfinder/core/train/train_yaml.py @@ -13,7 +13,7 @@ ) from datetime import datetime from pathlib import Path -from typing import Dict, Literal +from typing import Callable, Dict, Literal, Optional from brainglobe_utils.general.numerical import ( check_positive_float, @@ -26,7 +26,12 @@ from brainglobe_utils.IO.cells import find_relevant_tiffs from brainglobe_utils.IO.yaml import read_yaml_section from fancylog import fancylog -from keras.callbacks import CSVLogger, ModelCheckpoint, TensorBoard +from keras.callbacks import ( + CSVLogger, + LambdaCallback, + ModelCheckpoint, + TensorBoard, +) from sklearn.model_selection import train_test_split import cellfinder.core as package_for_log @@ -316,9 +321,13 @@ def run( no_save_checkpoints=False, save_progress=False, epochs=100, + progress_callback: Optional[Callable[[str, int, int], None]] = None, ): start_time = datetime.now() + if progress_callback is not None: + progress_callback("Preparing training", 0, 1) + ensure_directory_exists(output_dir) model_weights = prep_model_weights( model_weights=model_weights, @@ -430,6 +439,16 @@ def run( csv_logger = CSVLogger(csv_filepath) callbacks.append(csv_logger) + if progress_callback is not None: + progress_callback("Beginning training", 0, epochs) + callbacks.append( + LambdaCallback( + on_epoch_end=lambda epoch, logs: progress_callback( + f"Training epoch {epoch+1}/{epochs}", epoch + 1, epochs + ) + ) + ) + logger.info("Beginning training.") # Keras 3.0: `use_multiprocessing` input is set in the # `training_generator` (False by default) @@ -447,6 +466,9 @@ def run( logger.info("Saving model") model.save(output_dir / "model.keras") + if progress_callback is not None: + progress_callback("Training finished", epochs, epochs) + logger.info( "Finished training, " "Total time taken: %s", datetime.now() - start_time, diff --git a/cellfinder/napari/detect/detect.py b/cellfinder/napari/detect/detect.py index f801d462..499fc34e 100644 --- a/cellfinder/napari/detect/detect.py +++ b/cellfinder/napari/detect/detect.py @@ -157,6 +157,13 @@ def get_results_callback( if skip_classification: # after detection w/o classification, everything is unknown def done_func(points): + if not points: + show_info( + "No cell candidates were detected. " + "Try adjusting detection parameters or " + "using a larger image." + ) + return add_single_layer( points, viewer=viewer, @@ -168,6 +175,14 @@ def done_func(points): else: # after classification we have either cell or unknown def done_func(points): + if not points: + show_info( + "No cell candidates were detected. " + "Classification was skipped. " + "Try adjusting detection parameters or " + "using a larger image." + ) + return add_classified_layers( points, viewer=viewer, diff --git a/cellfinder/napari/detect/thread_worker.py b/cellfinder/napari/detect/thread_worker.py index d6a9431f..39ed8017 100644 --- a/cellfinder/napari/detect/thread_worker.py +++ b/cellfinder/napari/detect/thread_worker.py @@ -70,12 +70,23 @@ def detect_callback(plane: int) -> None: def detect_finished_callback(points: list) -> None: self.npoints_detected = len(points) if not self.classification_inputs.skip_classification: - self.update_progress_bar.emit( - "Setting up classification...", 1, 0 - ) + if self.npoints_detected == 0: + self.update_progress_bar.emit( + "No cell candidates detected, " + "skipping classification", + 1, + 1, + ) + else: + self.update_progress_bar.emit( + "Setting up classification...", 1, 0 + ) def classify_callback(batch: int) -> None: - if not self.classification_inputs.skip_classification: + if ( + not self.classification_inputs.skip_classification + and self.npoints_detected > 0 + ): self.update_progress_bar.emit( "Classifying cells", # Default cellfinder-core batch size is 64. diff --git a/cellfinder/napari/train/thread_worker.py b/cellfinder/napari/train/thread_worker.py new file mode 100644 index 00000000..8719c8c1 --- /dev/null +++ b/cellfinder/napari/train/thread_worker.py @@ -0,0 +1,71 @@ +from magicgui.widgets import ProgressBar +from napari.qt.threading import WorkerBase, WorkerBaseSignals +from qtpy.QtCore import Signal + +from cellfinder.core.train.train_yaml import run as train_yaml_run + +from .train_containers import ( + MiscTrainingInputs, + OptionalNetworkInputs, + OptionalTrainingInputs, + TrainingDataInputs, +) + + +class MyTrainingWorkerSignals(WorkerBaseSignals): + """ + Signals used by the TrainingWorker class below. + """ + + # Emits (label, max, value) for the progress bar + update_progress_bar = Signal(str, int, int) + + +class TrainingWorker(WorkerBase): + """ + Runs cellfinder training in a separate thread, to prevent GUI blocking. + + Also handles callbacks between the worker thread and main napari GUI + thread to update a progress bar. + """ + + def __init__( + self, + training_data_inputs: TrainingDataInputs, + optional_network_inputs: OptionalNetworkInputs, + optional_training_inputs: OptionalTrainingInputs, + misc_training_inputs: MiscTrainingInputs, + ): + super().__init__(SignalsClass=MyTrainingWorkerSignals) + self.training_data_inputs = training_data_inputs + self.optional_network_inputs = optional_network_inputs + self.optional_training_inputs = optional_training_inputs + self.misc_training_inputs = misc_training_inputs + + def connect_progress_bar_callback(self, progress_bar: ProgressBar): + """ + Connects the progress bar to the worker so that updates are shown + on the bar. + """ + + def update_progress_bar(label: str, max: int, value: int): + progress_bar.label = label + progress_bar.max = max + progress_bar.value = value + + self.update_progress_bar.connect(update_progress_bar) + + def work(self) -> None: + self.update_progress_bar.emit("Preparing training...", 1, 0) + + def progress_callback(label: str, value: int, max_val: int) -> None: + self.update_progress_bar.emit(label, max_val, value) + + train_yaml_run( + **self.training_data_inputs.as_core_arguments(), + **self.optional_network_inputs.as_core_arguments(), + **self.optional_training_inputs.as_core_arguments(), + **self.misc_training_inputs.as_core_arguments(), + progress_callback=progress_callback, + ) + self.update_progress_bar.emit("Training finished!", 1, 1) diff --git a/cellfinder/napari/train/train.py b/cellfinder/napari/train/train.py index 79d92b6b..e12e4301 100644 --- a/cellfinder/napari/train/train.py +++ b/cellfinder/napari/train/train.py @@ -2,14 +2,13 @@ from typing import Optional from magicgui import magicgui -from magicgui.widgets import FunctionGui, PushButton -from napari.qt.threading import thread_worker +from magicgui.widgets import FunctionGui, ProgressBar, PushButton from napari.utils.notifications import show_info from qtpy.QtWidgets import QScrollArea -from cellfinder.core.train.train_yaml import run as train_yaml from cellfinder.napari.utils import cellfinder_header, html_label_widget +from .thread_worker import TrainingWorker from .train_containers import ( MiscTrainingInputs, OptionalNetworkInputs, @@ -18,24 +17,9 @@ ) -@thread_worker -def run_training( - training_data_inputs: TrainingDataInputs, - optional_network_inputs: OptionalNetworkInputs, - optional_training_inputs: OptionalTrainingInputs, - misc_training_inputs: MiscTrainingInputs, -): - show_info("Running training...") - train_yaml( - **training_data_inputs.as_core_arguments(), - **optional_network_inputs.as_core_arguments(), - **optional_training_inputs.as_core_arguments(), - **misc_training_inputs.as_core_arguments(), - ) - show_info("Training finished!") - - def training_widget() -> FunctionGui: + progress_bar = ProgressBar() + @magicgui( training_label=html_label_widget("Network training", tag="h3"), **TrainingDataInputs.widget_representation(), @@ -143,16 +127,25 @@ def widget( show_info("Please select a YAML file for training") else: show_info("Starting training process...") - worker = run_training( + worker = TrainingWorker( training_data_inputs, optional_network_inputs, optional_training_inputs, misc_training_inputs, ) + + def on_finished(): + show_info("Training finished!") + + worker.returned.connect(on_finished) + worker.connect_progress_bar_callback(progress_bar) worker.start() widget.native.layout().insertWidget(0, cellfinder_header()) + # Insert progress bar before the call and reset buttons + widget.insert(widget.index("number_of_free_cpus") + 1, progress_bar) + @widget.reset_button.changed.connect def restore_defaults(): defaults = { diff --git a/tests/core/test_integration/test_detection.py b/tests/core/test_integration/test_detection.py index 5e145875..22c5eb23 100644 --- a/tests/core/test_integration/test_detection.py +++ b/tests/core/test_integration/test_detection.py @@ -336,3 +336,43 @@ def test_detection_plane_too_small(synthetic_spot_clusters, y, x): voxel_sizes=(1, 1, 1), ball_xy_size=50, ) + + +def test_detection_no_candidates(no_free_cpus): + """ + Test that detection on a very small / empty image returns an empty list + without raising an error (issue #344). + """ + # Create a small blank image with no bright structures + signal_array = np.zeros((15, 100, 100), dtype=np.uint16) + background_array = np.zeros((15, 100, 100), dtype=np.uint16) + + # Should return an empty list, not raise an error + detected = main( + signal_array, + background_array, + voxel_sizes, + n_free_cpus=no_free_cpus, + skip_classification=True, + ) + assert isinstance(detected, list) + assert len(detected) == 0 + + +def test_detection_no_candidates_with_classification(no_free_cpus): + """ + Test that detection + classification on an empty image returns an empty + list without raising an error (issue #344). + """ + signal_array = np.zeros((15, 100, 100), dtype=np.uint16) + background_array = np.zeros((15, 100, 100), dtype=np.uint16) + + # Should return an empty list even with classification enabled + detected = main( + signal_array, + background_array, + voxel_sizes, + n_free_cpus=no_free_cpus, + ) + assert isinstance(detected, list) + assert len(detected) == 0 diff --git a/tests/napari/test_training.py b/tests/napari/test_training.py index 1ff7d509..0130aecd 100644 --- a/tests/napari/test_training.py +++ b/tests/napari/test_training.py @@ -61,7 +61,7 @@ def test_run_with_virtual_yaml_files(get_training_widget): """ Checks that training is run with expected set of parameters. """ - with patch("cellfinder.napari.train.train.run_training") as run_training: + with patch("cellfinder.napari.train.train.TrainingWorker") as MockWorker: # make default input valid - need yaml files (they don't technically # have to exist) virtual_yaml_files = ( @@ -83,7 +83,7 @@ def test_run_with_virtual_yaml_files(get_training_widget): expected_network_args.trained_model = None expected_network_args.model_weights = None - run_training.assert_called_once_with( + MockWorker.assert_called_once_with( expected_training_args, expected_network_args, expected_optional_training_args,