Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion cellfinder/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,5 +243,10 @@ def main(
callback=classify_callback,
)
else:
logger.info("No candidates, skipping classification")
logger.warning(
"No cell candidates were detected. "
"Classification will be skipped. "
"This may occur with very small images or images "
"with no bright structures."
)
return points
26 changes: 24 additions & 2 deletions cellfinder/core/train/train_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from datetime import datetime
from pathlib import Path
from typing import Dict, Literal
from typing import Callable, Dict, Literal, Optional

from brainglobe_utils.general.numerical import (
check_positive_float,
Expand All @@ -26,7 +26,12 @@
from brainglobe_utils.IO.cells import find_relevant_tiffs
from brainglobe_utils.IO.yaml import read_yaml_section
from fancylog import fancylog
from keras.callbacks import CSVLogger, ModelCheckpoint, TensorBoard
from keras.callbacks import (
CSVLogger,
LambdaCallback,
ModelCheckpoint,
TensorBoard,
)
from sklearn.model_selection import train_test_split

import cellfinder.core as package_for_log
Expand Down Expand Up @@ -316,9 +321,13 @@ def run(
no_save_checkpoints=False,
save_progress=False,
epochs=100,
progress_callback: Optional[Callable[[str, int, int], None]] = None,
):
start_time = datetime.now()

if progress_callback is not None:
progress_callback("Preparing training", 0, 1)

ensure_directory_exists(output_dir)
model_weights = prep_model_weights(
model_weights=model_weights,
Expand Down Expand Up @@ -430,6 +439,16 @@ def run(
csv_logger = CSVLogger(csv_filepath)
callbacks.append(csv_logger)

if progress_callback is not None:
progress_callback("Beginning training", 0, epochs)
callbacks.append(
LambdaCallback(
on_epoch_end=lambda epoch, logs: progress_callback(
f"Training epoch {epoch+1}/{epochs}", epoch + 1, epochs
)
)
)

logger.info("Beginning training.")
# Keras 3.0: `use_multiprocessing` input is set in the
# `training_generator` (False by default)
Expand All @@ -447,6 +466,9 @@ def run(
logger.info("Saving model")
model.save(output_dir / "model.keras")

if progress_callback is not None:
progress_callback("Training finished", epochs, epochs)

logger.info(
"Finished training, " "Total time taken: %s",
datetime.now() - start_time,
Expand Down
15 changes: 15 additions & 0 deletions cellfinder/napari/detect/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ def get_results_callback(
if skip_classification:
# after detection w/o classification, everything is unknown
def done_func(points):
if not points:
show_info(
"No cell candidates were detected. "
"Try adjusting detection parameters or "
"using a larger image."
)
return
add_single_layer(
points,
viewer=viewer,
Expand All @@ -168,6 +175,14 @@ def done_func(points):
else:
# after classification we have either cell or unknown
def done_func(points):
if not points:
show_info(
"No cell candidates were detected. "
"Classification was skipped. "
"Try adjusting detection parameters or "
"using a larger image."
)
return
add_classified_layers(
points,
viewer=viewer,
Expand Down
19 changes: 15 additions & 4 deletions cellfinder/napari/detect/thread_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,23 @@ def detect_callback(plane: int) -> None:
def detect_finished_callback(points: list) -> None:
self.npoints_detected = len(points)
if not self.classification_inputs.skip_classification:
self.update_progress_bar.emit(
"Setting up classification...", 1, 0
)
if self.npoints_detected == 0:
self.update_progress_bar.emit(
"No cell candidates detected, "
"skipping classification",
1,
1,
)
else:
self.update_progress_bar.emit(
"Setting up classification...", 1, 0
)

def classify_callback(batch: int) -> None:
if not self.classification_inputs.skip_classification:
if (
not self.classification_inputs.skip_classification
and self.npoints_detected > 0
):
self.update_progress_bar.emit(
"Classifying cells",
# Default cellfinder-core batch size is 64.
Expand Down
71 changes: 71 additions & 0 deletions cellfinder/napari/train/thread_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from magicgui.widgets import ProgressBar
from napari.qt.threading import WorkerBase, WorkerBaseSignals
from qtpy.QtCore import Signal

from cellfinder.core.train.train_yaml import run as train_yaml_run

from .train_containers import (
MiscTrainingInputs,
OptionalNetworkInputs,
OptionalTrainingInputs,
TrainingDataInputs,
)


class MyTrainingWorkerSignals(WorkerBaseSignals):
"""
Signals used by the TrainingWorker class below.
"""

# Emits (label, max, value) for the progress bar
update_progress_bar = Signal(str, int, int)


class TrainingWorker(WorkerBase):
"""
Runs cellfinder training in a separate thread, to prevent GUI blocking.

Also handles callbacks between the worker thread and main napari GUI
thread to update a progress bar.
"""

def __init__(
self,
training_data_inputs: TrainingDataInputs,
optional_network_inputs: OptionalNetworkInputs,
optional_training_inputs: OptionalTrainingInputs,
misc_training_inputs: MiscTrainingInputs,
):
super().__init__(SignalsClass=MyTrainingWorkerSignals)
self.training_data_inputs = training_data_inputs
self.optional_network_inputs = optional_network_inputs
self.optional_training_inputs = optional_training_inputs
self.misc_training_inputs = misc_training_inputs

def connect_progress_bar_callback(self, progress_bar: ProgressBar):
"""
Connects the progress bar to the worker so that updates are shown
on the bar.
"""

def update_progress_bar(label: str, max: int, value: int):
progress_bar.label = label
progress_bar.max = max
progress_bar.value = value

self.update_progress_bar.connect(update_progress_bar)

def work(self) -> None:
self.update_progress_bar.emit("Preparing training...", 1, 0)

def progress_callback(label: str, value: int, max_val: int) -> None:
self.update_progress_bar.emit(label, max_val, value)

train_yaml_run(
**self.training_data_inputs.as_core_arguments(),
**self.optional_network_inputs.as_core_arguments(),
**self.optional_training_inputs.as_core_arguments(),
**self.misc_training_inputs.as_core_arguments(),
progress_callback=progress_callback,
)
self.update_progress_bar.emit("Training finished!", 1, 1)
35 changes: 14 additions & 21 deletions cellfinder/napari/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from typing import Optional

from magicgui import magicgui
from magicgui.widgets import FunctionGui, PushButton
from napari.qt.threading import thread_worker
from magicgui.widgets import FunctionGui, ProgressBar, PushButton
from napari.utils.notifications import show_info
from qtpy.QtWidgets import QScrollArea

from cellfinder.core.train.train_yaml import run as train_yaml
from cellfinder.napari.utils import cellfinder_header, html_label_widget

from .thread_worker import TrainingWorker
from .train_containers import (
MiscTrainingInputs,
OptionalNetworkInputs,
Expand All @@ -18,24 +17,9 @@
)


@thread_worker
def run_training(
training_data_inputs: TrainingDataInputs,
optional_network_inputs: OptionalNetworkInputs,
optional_training_inputs: OptionalTrainingInputs,
misc_training_inputs: MiscTrainingInputs,
):
show_info("Running training...")
train_yaml(
**training_data_inputs.as_core_arguments(),
**optional_network_inputs.as_core_arguments(),
**optional_training_inputs.as_core_arguments(),
**misc_training_inputs.as_core_arguments(),
)
show_info("Training finished!")


def training_widget() -> FunctionGui:
progress_bar = ProgressBar()

@magicgui(
training_label=html_label_widget("Network training", tag="h3"),
**TrainingDataInputs.widget_representation(),
Expand Down Expand Up @@ -143,16 +127,25 @@ def widget(
show_info("Please select a YAML file for training")
else:
show_info("Starting training process...")
worker = run_training(
worker = TrainingWorker(
training_data_inputs,
optional_network_inputs,
optional_training_inputs,
misc_training_inputs,
)

def on_finished():
show_info("Training finished!")

worker.returned.connect(on_finished)
worker.connect_progress_bar_callback(progress_bar)
worker.start()

widget.native.layout().insertWidget(0, cellfinder_header())

# Insert progress bar before the call and reset buttons
widget.insert(widget.index("number_of_free_cpus") + 1, progress_bar)

@widget.reset_button.changed.connect
def restore_defaults():
defaults = {
Expand Down
40 changes: 40 additions & 0 deletions tests/core/test_integration/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,43 @@ def test_detection_plane_too_small(synthetic_spot_clusters, y, x):
voxel_sizes=(1, 1, 1),
ball_xy_size=50,
)


def test_detection_no_candidates(no_free_cpus):
"""
Test that detection on a very small / empty image returns an empty list
without raising an error (issue #344).
"""
# Create a small blank image with no bright structures
signal_array = np.zeros((15, 100, 100), dtype=np.uint16)
background_array = np.zeros((15, 100, 100), dtype=np.uint16)

# Should return an empty list, not raise an error
detected = main(
signal_array,
background_array,
voxel_sizes,
n_free_cpus=no_free_cpus,
skip_classification=True,
)
assert isinstance(detected, list)
assert len(detected) == 0


def test_detection_no_candidates_with_classification(no_free_cpus):
"""
Test that detection + classification on an empty image returns an empty
list without raising an error (issue #344).
"""
signal_array = np.zeros((15, 100, 100), dtype=np.uint16)
background_array = np.zeros((15, 100, 100), dtype=np.uint16)

# Should return an empty list even with classification enabled
detected = main(
signal_array,
background_array,
voxel_sizes,
n_free_cpus=no_free_cpus,
)
assert isinstance(detected, list)
assert len(detected) == 0
4 changes: 2 additions & 2 deletions tests/napari/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_run_with_virtual_yaml_files(get_training_widget):
"""
Checks that training is run with expected set of parameters.
"""
with patch("cellfinder.napari.train.train.run_training") as run_training:
with patch("cellfinder.napari.train.train.TrainingWorker") as MockWorker:
# make default input valid - need yaml files (they don't technically
# have to exist)
virtual_yaml_files = (
Expand All @@ -83,7 +83,7 @@ def test_run_with_virtual_yaml_files(get_training_widget):
expected_network_args.trained_model = None
expected_network_args.model_weights = None

run_training.assert_called_once_with(
MockWorker.assert_called_once_with(
expected_training_args,
expected_network_args,
expected_optional_training_args,
Expand Down