Skip to content

Adds support for varying learning rate during training#589

Merged
alessandrofelder merged 1 commit into
brainglobe:mainfrom
matham:caching_lr
May 22, 2026
Merged

Adds support for varying learning rate during training#589
alessandrofelder merged 1 commit into
brainglobe:mainfrom
matham:caching_lr

Conversation

@matham
Copy link
Copy Markdown
Contributor

@matham matham commented Feb 12, 2026

Description

This PR depends on #493. Only the last commit in this PR is unique to this PR.

What is this PR

  • Bug fix
  • Addition of a new feature
  • Other

Why is this PR needed?

It's pretty common to adjust the learning rate as training progresses. A common pattern is to start with a higher rate and then as it starts converging, you drop the learning rate so it can find the best minima around the global minima until it's fully stable. There are many approaches. A very basic one is to drop the learning rate by a set amount at specific epochs. You would then run a few training sessions to find which epochs are good choices to reduce the learning rate.

E.g. when I originally trained my _final_model I used to classify our data I got the following curves for train/test accuracy/loss. The labels indicate if the training data was raw or normalized (as in #588). The number following that indicates the epoch when the learning rate was dropped. E.g. Norm 45 means the training data was normalized and that the learning rate was dropped from 1e-3 to 1e-4 at epoch 45. No decay means the learning rate stayed the same (although I typically stopped those early)

The following is the plotted data from that training (data augmentation was 25% using resnet50_tv using a batch size of 32 and with --continue-training). Based on that I chose the model that used normalized data and with learning rate decay at epoch 45 - that had the best results while minimizing over-fitting:

ours_train_accuracy
ours_test_accuracy
ours_train_loss
ours_test_loss

What does this PR do?

This adds the following two parameters to the training code (copied from the docs):

  • lr_schedule : list of ints: If not empty, the list of epochs when to multiply the current learning rate by the lr_multiplier. E.g. if it's [10, 25], we start with a learning rate of 0.001, and lr_multiplier is 0.1, then the LR will be 0.001 for epochs 0-9, 0.0001 for 10-24, and 0.00001 for epoch 25 and beyond.
  • lr_multiplier : float: The multiplier by which to multiply the previous learning rate at the epochs listed in lr_schedule.

I ran a few replicates to see how replicable modifying the learning rate is and here are the plots. Raw/Norm refer to whether the training data was normalized, f=n refers to the percent data used for testing so 50% simulates low training data, and N=n refers to the number of lines in the plot indicating the replicates.

train_accuracy
test_accuracy
train_loss
test_loss

References

#493.

How has this PR been tested?

I added tests and tested it on my own models.

Is this a breaking change?

No. By default the options are off.

Does this PR require an update to the documentation?

There are new parameters.

Checklist:

  • The code has been tested locally
  • Tests have been added to cover all new functionality (unit & integration)
  • The documentation has been updated to reflect any changes
  • The code has been formatted with pre-commit

@github-project-automation github-project-automation Bot moved this to Priorities in Core development Mar 12, 2026
@matham matham force-pushed the caching_lr branch 2 times, most recently from a2cb647 to fe795cd Compare March 13, 2026 03:15
Copy link
Copy Markdown
Member

@alessandrofelder alessandrofelder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran a few replicates to see how replicable modifying the learning rate is and here are the plots. Raw/Norm refer to whether the training data was normalized, f=n refers to the percent data used for testing so 50% simulates low training data, and N=n refers to the number of lines in the plot indicating the replicates.

Interesting that in one run there seems to be a dip/peak in the accuracy/loss after 45 iterations (I presume this is when you changed the learning rate for those plots?). I don't know enough about the topic to judge whether this should be expected sometimes? What do you think @matham ? Maybe @adamltyson @IgorTatarnikov or @aymuos15 have some insight - see the second set of plots with the red + blue lines in the PR description.

Apart from this curiosity, the code changes are sensible as far as I can tell, and I have confirmed results and performance are backwards-compatible for our usual "large" test dataset.

Successfully installed cellfinder-1.10.1.dev1+ge127361a0
Running cellfinder from the main branch...
[...]
Finished. Total time taken: 0:29:51.219290  

Successfully installed cellfinder-1.9.1.dev5+g572e77d71
Running cellfinder from the PR branch...
[..]
Finished. Total time taken: 0:29:57.329079                            main.py:95

Comparing the results...
Number of cells only found with cellfinder from main: 0
Number of cells only found with cellfinder from 589: 0
Number of matched cells: 70411
Matching proportion: 1.0        

@matham
Copy link
Copy Markdown
Contributor Author

matham commented May 21, 2026 via email

@aymuos15
Copy link
Copy Markdown

aymuos15 commented May 21, 2026

Keras ships with some common lr-schedulers by default. https://keras.io/api/optimizers/learning_rate_schedules/learning_rate_schedule/

Would be nice to have some baselines with that too? Repeating the experiments as above?

On that, maybe the arg should accept something like this instead? Or is that overkill?

@matham
Copy link
Copy Markdown
Contributor Author

matham commented May 21, 2026

The network used and the task to do (binary classification of cuboids) are simple enough I think that more complicated learning rate schedulers would be overkill.

Scheduling based on epochs is a first approach/dead simple approach and something typical users can understand and know how to adjust and troubleshoot - do some initial training, see approx which epochs it stabilizes and then reurn training and drop the learning rate around those epochs. But more complex approaches would potentially be harder to fine tune and leads to more complexity.

Of less importance, I think dropping keras and going native pytorch should be a high priority (I have seen recent keras versions have memory leaks with pytorch). But as you said this could also have been done in pytorch.

@alessandrofelder alessandrofelder self-requested a review May 22, 2026 09:45
@adamltyson
Copy link
Copy Markdown
Member

Of less importance, I think dropping keras and going native pytorch should be a high priority (I have seen recent keras versions have memory leaks with pytorch).

Out of interest, are there any other reasons to motivate this? The reason we continued to use keras is to try and mitigate the effects of needing to migrate frameworks every few years. cellfinder has already gone from caffe -> NiftyNet -> Keras (TensorFlow) -> Keras (pytorch) in the last 10 years.

@alessandrofelder
Copy link
Copy Markdown
Member

Cool, thanks for your insights @matham and @aymuos15 - I understand now and I am happy with this specific PR so will merge.

Have moved related big-picture discussion to #618 🙂

@alessandrofelder alessandrofelder merged commit 98c0daa into brainglobe:main May 22, 2026
15 of 18 checks passed
@github-project-automation github-project-automation Bot moved this from Backlog to Archive in Core development May 22, 2026
@matham matham deleted the caching_lr branch May 22, 2026 19:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: Archive

Development

Successfully merging this pull request may close these issues.

4 participants