From 572e77d7162263d2b998a6cac8590ced8a04c211 Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Fri, 13 Mar 2026 00:20:43 -0400 Subject: [PATCH] Add support for setting learning rate during training. --- cellfinder/core/train/train_yaml.py | 48 ++++++++++++++++++++- cellfinder/napari/train/train.py | 13 ++++++ cellfinder/napari/train/train_containers.py | 8 ++++ tests/core/test_integration/test_train.py | 46 ++++++++++++++++++++ tests/napari/test_training.py | 3 ++ 5 files changed, 117 insertions(+), 1 deletion(-) diff --git a/cellfinder/core/train/train_yaml.py b/cellfinder/core/train/train_yaml.py index 2f6aa8a2..7047394f 100644 --- a/cellfinder/core/train/train_yaml.py +++ b/cellfinder/core/train/train_yaml.py @@ -12,8 +12,9 @@ ArgumentTypeError, ) from datetime import datetime +from functools import partial from pathlib import Path -from typing import Dict, Literal +from typing import Dict, Literal, Sequence from brainglobe_utils.cells.cells import Cell from brainglobe_utils.general.numerical import ( @@ -29,6 +30,7 @@ from fancylog import fancylog from keras.callbacks import ( CSVLogger, + LearningRateScheduler, ModelCheckpoint, TensorBoard, ) @@ -63,6 +65,17 @@ CUBE_DEPTH = 20 +def lr_scheduler( + epoch: int, + lr: float, + multiplier: float, + epoch_list: Sequence[int], +) -> float: + if epoch in epoch_list: + return lr * multiplier + return lr + + def valid_model_depth(depth): """ Ensures a correct existing_model is chosen @@ -250,6 +263,26 @@ def training_parse(): action="store_true", help="Save training progress to a .csv file", ) + training_parser.add_argument( + "--lr-schedule", + dest="lr_schedule", + nargs="*", + type=partial(check_positive_int, none_allowed=False), + default=(), + help="If not empty, the list of epochs when to multiply the current " + "learning rate by the lr_multiplier. E.g. if it's [10, 25], we " + "start with a learning rate of 0.001, and lr_multiplier is " + "0.1, then the LR will be 0.001 for epochs 0-9, 0.0001 for 10-24," + " and 00001 for epoch 25 and beyond.", + ) + training_parser.add_argument( + "--lr-multiplier", + dest="lr_multiplier", + type=partial(check_positive_float, none_allowed=False), + default=0.1, + help="The multiplier by which to multiply the previous learning rate " + "at the epochs listed in lr_schedule.", + ) training_parser = misc_parse(training_parser) training_parser = download_parser(training_parser) @@ -346,6 +379,8 @@ def cli(): no_save_checkpoints=args.no_save_checkpoints, save_progress=args.save_progress, epochs=args.epochs, + lr_schedule=args.lr_schedule, + lr_multiplier=args.lr_multiplier, ) @@ -407,6 +442,8 @@ def run( epochs=100, max_workers: int = 3, pin_memory: bool = True, + lr_schedule: Sequence[int] = (), + lr_multiplier: float = 0.1, augment_likelihood: float = 0.9, ): start_time = datetime.now() @@ -521,6 +558,15 @@ def run( csv_logger = CSVLogger(csv_filepath) callbacks.append(csv_logger) + if lr_schedule: + # we need to drop the lr by a given schedule. This is called at the + # start of each epoch and is zero based. E.g. if epoch 10 is listed, + # it'll drop at the start of the 11th epoch. + lr_callback = partial( + lr_scheduler, multiplier=lr_multiplier, epoch_list=lr_schedule + ) + callbacks.append(LearningRateScheduler(lr_callback)) + logger.info("Beginning training.") if n_processes: training_dataset.start_dataset_thread(n_processes) diff --git a/cellfinder/napari/train/train.py b/cellfinder/napari/train/train.py index 79d92b6b..7057bee8 100644 --- a/cellfinder/napari/train/train.py +++ b/cellfinder/napari/train/train.py @@ -64,6 +64,8 @@ def widget( save_progress: bool, epochs: int, learning_rate: float, + lr_schedule: list[int], + lr_multiplier: float, batch_size: int, test_fraction: float, misc_options: dict, @@ -104,6 +106,15 @@ def widget( (How many times to use each training data point) learning_rate : float Learning rate for training the model + lr_schedule : list of ints + If not empty, the list of epochs when to multiply the current + learning rate by the lr_multiplier. E.g. if it's [10, 25], we start + with a learning rate of 0.001, and `lr_multiplier` is 0.1, then the + LR will be 0.001 for epochs 0-9, 0.0001 for 10-24, and 00001 + for epoch 25 and beyond. + lr_multiplier : float + The multiplier by which to multiply the previous learning rate + at the epochs listed in `lr_schedule`. batch_size : int Training batch size test_fraction : float @@ -135,6 +146,8 @@ def widget( learning_rate, batch_size, test_fraction, + lr_schedule, + lr_multiplier, ) misc_training_inputs = MiscTrainingInputs(number_of_free_cpus) diff --git a/cellfinder/napari/train/train_containers.py b/cellfinder/napari/train/train_containers.py index c77ece05..c4b7f151 100644 --- a/cellfinder/napari/train/train_containers.py +++ b/cellfinder/napari/train/train_containers.py @@ -81,6 +81,8 @@ class OptionalTrainingInputs(InputContainer): learning_rate: float = 1e-4 batch_size: int = 16 test_fraction: float = 0.1 + lr_schedule: list[int] | tuple[int, ...] = () + lr_multiplier: float = 0.1 def as_core_arguments(self) -> dict: arguments = super().as_core_arguments() @@ -105,6 +107,12 @@ def widget_representation(cls) -> dict: test_fraction=cls._custom_widget( "test_fraction", step=0.05, min=0.05, max=0.95 ), + lr_schedule=cls._custom_widget( + "lr_schedule", custom_label="LR schedule" + ), + lr_multiplier=cls._custom_widget( + "lr_multiplier", custom_label="LR multiplier" + ), ) diff --git a/tests/core/test_integration/test_train.py b/tests/core/test_integration/test_train.py index 64e0d443..519f64f6 100644 --- a/tests/core/test_integration/test_train.py +++ b/tests/core/test_integration/test_train.py @@ -2,6 +2,7 @@ import sys import pytest +from pytest_mock.plugin import MockerFixture from cellfinder.core.train.train_yaml import cli as train_run @@ -37,3 +38,48 @@ def test_train(tmpdir): model_file = os.path.join(tmpdir, "model.keras") assert os.path.exists(model_file) + + +@pytest.mark.parametrize("lr_schedule", [True, False]) +def test_train_lr_schedule(mocker: MockerFixture, tmpdir, lr_schedule): + tmpdir = str(tmpdir) + + train_args = [ + "cellfinder_train", + "-y", + training_yaml_file, + "-o", + tmpdir, + "--epochs", + EPOCHS, + "--lr-multiplier", + "0.3", + ] + if lr_schedule: + train_args.extend(["--lr-schedule", "10", "20"]) + + mocker.patch("sys.argv", train_args) + get_model = mocker.patch( + "cellfinder.core.train.train_yaml.get_model", autospec=True + ) + + train_run() + # get the data sets passed to fit(). There's no clear name property of + # the mock fit call, so use its repr + (fit_mock,) = [ + m for m in get_model.mock_calls if repr(m).startswith("call().fit(") + ] + callbacks = fit_mock.kwargs["callbacks"] + + # locate the scheduler callback, if any + from keras.callbacks import LearningRateScheduler + + callbacks = [c for c in callbacks if isinstance(c, LearningRateScheduler)] + if lr_schedule: + assert len(callbacks) == 1 + # the callback is a partial function with these args + partial_callback = callbacks[0].schedule + assert partial_callback.keywords["multiplier"] == 0.3 + assert partial_callback.keywords["epoch_list"] == [10, 20] + else: + assert not callbacks diff --git a/tests/napari/test_training.py b/tests/napari/test_training.py index 1ff7d509..5ea50380 100644 --- a/tests/napari/test_training.py +++ b/tests/napari/test_training.py @@ -76,6 +76,9 @@ def test_run_with_virtual_yaml_files(get_training_widget): expected_network_args = OptionalNetworkInputs() expected_optional_training_args = OptionalTrainingInputs() expected_misc_args = MiscTrainingInputs() + # run_training calls lr_schedule with empty list instead of tuple, + # so to do equality comparison, we need to set default to list also + expected_optional_training_args.lr_schedule = [] # we expect the widget to make some changes to the defaults # displayed before calling the training backend