Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion cellfinder/core/train/train_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
ArgumentTypeError,
)
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Dict, Literal
from typing import Dict, Literal, Sequence

from brainglobe_utils.cells.cells import Cell
from brainglobe_utils.general.numerical import (
Expand All @@ -29,6 +30,7 @@
from fancylog import fancylog
from keras.callbacks import (
CSVLogger,
LearningRateScheduler,
ModelCheckpoint,
TensorBoard,
)
Expand Down Expand Up @@ -63,6 +65,17 @@
CUBE_DEPTH = 20


def lr_scheduler(
epoch: int,
lr: float,
multiplier: float,
epoch_list: Sequence[int],
) -> float:
if epoch in epoch_list:
return lr * multiplier
return lr


def valid_model_depth(depth):
"""
Ensures a correct existing_model is chosen
Expand Down Expand Up @@ -250,6 +263,26 @@ def training_parse():
action="store_true",
help="Save training progress to a .csv file",
)
training_parser.add_argument(
"--lr-schedule",
dest="lr_schedule",
nargs="*",
type=partial(check_positive_int, none_allowed=False),
default=(),
help="If not empty, the list of epochs when to multiply the current "
"learning rate by the lr_multiplier. E.g. if it's [10, 25], we "
"start with a learning rate of 0.001, and lr_multiplier is "
"0.1, then the LR will be 0.001 for epochs 0-9, 0.0001 for 10-24,"
" and 00001 for epoch 25 and beyond.",
)
training_parser.add_argument(
"--lr-multiplier",
dest="lr_multiplier",
type=partial(check_positive_float, none_allowed=False),
default=0.1,
help="The multiplier by which to multiply the previous learning rate "
"at the epochs listed in lr_schedule.",
)

training_parser = misc_parse(training_parser)
training_parser = download_parser(training_parser)
Expand Down Expand Up @@ -346,6 +379,8 @@ def cli():
no_save_checkpoints=args.no_save_checkpoints,
save_progress=args.save_progress,
epochs=args.epochs,
lr_schedule=args.lr_schedule,
lr_multiplier=args.lr_multiplier,
)


Expand Down Expand Up @@ -407,6 +442,8 @@ def run(
epochs=100,
max_workers: int = 3,
pin_memory: bool = True,
lr_schedule: Sequence[int] = (),
lr_multiplier: float = 0.1,
augment_likelihood: float = 0.9,
):
start_time = datetime.now()
Expand Down Expand Up @@ -521,6 +558,15 @@ def run(
csv_logger = CSVLogger(csv_filepath)
callbacks.append(csv_logger)

if lr_schedule:
# we need to drop the lr by a given schedule. This is called at the
# start of each epoch and is zero based. E.g. if epoch 10 is listed,
# it'll drop at the start of the 11th epoch.
lr_callback = partial(
lr_scheduler, multiplier=lr_multiplier, epoch_list=lr_schedule
)
callbacks.append(LearningRateScheduler(lr_callback))

logger.info("Beginning training.")
if n_processes:
training_dataset.start_dataset_thread(n_processes)
Expand Down
13 changes: 13 additions & 0 deletions cellfinder/napari/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def widget(
save_progress: bool,
epochs: int,
learning_rate: float,
lr_schedule: list[int],
lr_multiplier: float,
batch_size: int,
test_fraction: float,
misc_options: dict,
Expand Down Expand Up @@ -104,6 +106,15 @@ def widget(
(How many times to use each training data point)
learning_rate : float
Learning rate for training the model
lr_schedule : list of ints
If not empty, the list of epochs when to multiply the current
learning rate by the lr_multiplier. E.g. if it's [10, 25], we start
with a learning rate of 0.001, and `lr_multiplier` is 0.1, then the
LR will be 0.001 for epochs 0-9, 0.0001 for 10-24, and 00001
for epoch 25 and beyond.
lr_multiplier : float
The multiplier by which to multiply the previous learning rate
at the epochs listed in `lr_schedule`.
batch_size : int
Training batch size
test_fraction : float
Expand Down Expand Up @@ -135,6 +146,8 @@ def widget(
learning_rate,
batch_size,
test_fraction,
lr_schedule,
lr_multiplier,
)

misc_training_inputs = MiscTrainingInputs(number_of_free_cpus)
Expand Down
8 changes: 8 additions & 0 deletions cellfinder/napari/train/train_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class OptionalTrainingInputs(InputContainer):
learning_rate: float = 1e-4
batch_size: int = 16
test_fraction: float = 0.1
lr_schedule: list[int] | tuple[int, ...] = ()
lr_multiplier: float = 0.1

def as_core_arguments(self) -> dict:
arguments = super().as_core_arguments()
Expand All @@ -105,6 +107,12 @@ def widget_representation(cls) -> dict:
test_fraction=cls._custom_widget(
"test_fraction", step=0.05, min=0.05, max=0.95
),
lr_schedule=cls._custom_widget(
"lr_schedule", custom_label="LR schedule"
),
lr_multiplier=cls._custom_widget(
"lr_multiplier", custom_label="LR multiplier"
),
)


Expand Down
46 changes: 46 additions & 0 deletions tests/core/test_integration/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys

import pytest
from pytest_mock.plugin import MockerFixture

from cellfinder.core.train.train_yaml import cli as train_run

Expand Down Expand Up @@ -37,3 +38,48 @@

model_file = os.path.join(tmpdir, "model.keras")
assert os.path.exists(model_file)


@pytest.mark.parametrize("lr_schedule", [True, False])
def test_train_lr_schedule(mocker: MockerFixture, tmpdir, lr_schedule):
tmpdir = str(tmpdir)

train_args = [
"cellfinder_train",
"-y",
training_yaml_file,
"-o",
tmpdir,
"--epochs",
EPOCHS,
"--lr-multiplier",
"0.3",
]
if lr_schedule:
train_args.extend(["--lr-schedule", "10", "20"])

mocker.patch("sys.argv", train_args)
get_model = mocker.patch(
"cellfinder.core.train.train_yaml.get_model", autospec=True
)

train_run()
# get the data sets passed to fit(). There's no clear name property of
# the mock fit call, so use its repr
(fit_mock,) = [
m for m in get_model.mock_calls if repr(m).startswith("call().fit(")
]
callbacks = fit_mock.kwargs["callbacks"]

# locate the scheduler callback, if any
from keras.callbacks import LearningRateScheduler

callbacks = [c for c in callbacks if isinstance(c, LearningRateScheduler)]
if lr_schedule:
assert len(callbacks) == 1
# the callback is a partial function with these args
partial_callback = callbacks[0].schedule
assert partial_callback.keywords["multiplier"] == 0.3

Check warning on line 82 in tests/core/test_integration/test_train.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Do not perform equality checks with floating point values.

See more on https://sonarcloud.io/project/issues?id=brainglobe_cellfinder&issues=AZxQO-7PyiTGTwyTonYM&open=AZxQO-7PyiTGTwyTonYM&pullRequest=589
assert partial_callback.keywords["epoch_list"] == [10, 20]
else:
assert not callbacks
3 changes: 3 additions & 0 deletions tests/napari/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def test_run_with_virtual_yaml_files(get_training_widget):
expected_network_args = OptionalNetworkInputs()
expected_optional_training_args = OptionalTrainingInputs()
expected_misc_args = MiscTrainingInputs()
# run_training calls lr_schedule with empty list instead of tuple,
# so to do equality comparison, we need to set default to list also
expected_optional_training_args.lr_schedule = []

# we expect the widget to make some changes to the defaults
# displayed before calling the training backend
Expand Down
Loading