Adds support for varying learning rate during training#589
Conversation
a2cb647 to
fe795cd
Compare
alessandrofelder
left a comment
There was a problem hiding this comment.
I ran a few replicates to see how replicable modifying the learning rate is and here are the plots. Raw/Norm refer to whether the training data was normalized, f=n refers to the percent data used for testing so 50% simulates low training data, and N=n refers to the number of lines in the plot indicating the replicates.
Interesting that in one run there seems to be a dip/peak in the accuracy/loss after 45 iterations (I presume this is when you changed the learning rate for those plots?). I don't know enough about the topic to judge whether this should be expected sometimes? What do you think @matham ? Maybe @adamltyson @IgorTatarnikov or @aymuos15 have some insight - see the second set of plots with the red + blue lines in the PR description.
Apart from this curiosity, the code changes are sensible as far as I can tell, and I have confirmed results and performance are backwards-compatible for our usual "large" test dataset.
Successfully installed cellfinder-1.10.1.dev1+ge127361a0
Running cellfinder from the main branch...
[...]
Finished. Total time taken: 0:29:51.219290
Successfully installed cellfinder-1.9.1.dev5+g572e77d71
Running cellfinder from the PR branch...
[..]
Finished. Total time taken: 0:29:57.329079 main.py:95
Comparing the results...
Number of cells only found with cellfinder from main: 0
Number of cells only found with cellfinder from 589: 0
Number of matched cells: 70411
Matching proportion: 1.0
|
Correct, that is when the learning rate changes. This is a very typical thing, that changing learning rate changes the accuracy/loss. And it's because taking smaller steps in the learning space lets us find exact peaks, which is harder to with a a larger step that typically will over/undershoot the exact peaks.
This question e.g. talks about this: https://datascience.stackexchange.com/questions/23213/why-does-decreasing-the-sgd-learning-rate-cause-a-massive-increase-in-accuracy
On Thu, May 21, 2026, at 12:49 PM, Alessandro Felder wrote:
***@***.**** commented on this pull request.
> I ran a few replicates to see how replicable modifying the learning rate is and here are the plots. Raw/Norm refer to whether the training data was normalized, f=n refers to the percent data used for testing so 50% simulates low training data, and N=n refers to the number of lines in the plot indicating the replicates.
>
Interesting that in one run there seems to be a dip/peak in the accuracy/loss after 45 iterations (I presume this is when you changed the learning rate for those plots?). I don't know enough about the topic to judge whether this should be expected sometimes? What do you think @matham <https://github.com/matham> ? Maybe @adamltyson <https://github.com/adamltyson> @IgorTatarnikov <https://github.com/IgorTatarnikov> or @aymuos15 <https://github.com/aymuos15> have some insight - see the second set of plots with the red + blue lines in the PR description.
Apart from this curiosity, the code changes are sensible as far as I can tell, and I have confirmed results and performance are backwards-compatible for our usual "large" test dataset.
`Successfully installed cellfinder-1.10.1.dev1+ge127361a0
Running cellfinder from the main branch...
[...]
Finished. Total time taken: 0:29:51.219290
Successfully installed cellfinder-1.9.1.dev5+g572e77d71
Running cellfinder from the PR branch...
[..]
Finished. Total time taken: 0:29:57.329079 main.py:95
Comparing the results...
Number of cells only found with cellfinder from main: 0
Number of cells only found with cellfinder from 589: 0
Number of matched cells: 70411
Matching proportion: 1.0
`
… —
Reply to this email directly, view it on GitHub <#589 (review)>, or unsubscribe <https://github.com/notifications/unsubscribe-auth/AAMRN7URKXGL4PHXEXH6ARL434XRHAVCNFSM6AAAAACU3BBX7CVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHM2DGMZYHA4DANZXGI>.
Triage notifications on the go with GitHub Mobile for iOS <https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675> or Android <https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
You are receiving this because you were mentioned.Message ID: ***@***.***>
|
|
Keras ships with some common lr-schedulers by default. https://keras.io/api/optimizers/learning_rate_schedules/learning_rate_schedule/ Would be nice to have some baselines with that too? Repeating the experiments as above? On that, maybe the arg should accept something like this instead? Or is that overkill? |
|
The network used and the task to do (binary classification of cuboids) are simple enough I think that more complicated learning rate schedulers would be overkill. Scheduling based on epochs is a first approach/dead simple approach and something typical users can understand and know how to adjust and troubleshoot - do some initial training, see approx which epochs it stabilizes and then reurn training and drop the learning rate around those epochs. But more complex approaches would potentially be harder to fine tune and leads to more complexity. Of less importance, I think dropping keras and going native pytorch should be a high priority (I have seen recent keras versions have memory leaks with pytorch). But as you said this could also have been done in pytorch. |
Out of interest, are there any other reasons to motivate this? The reason we continued to use keras is to try and mitigate the effects of needing to migrate frameworks every few years. cellfinder has already gone from caffe -> NiftyNet -> Keras (TensorFlow) -> Keras (pytorch) in the last 10 years. |
Description
This PR depends on #493. Only the last commit in this PR is unique to this PR.
What is this PR
Why is this PR needed?
It's pretty common to adjust the learning rate as training progresses. A common pattern is to start with a higher rate and then as it starts converging, you drop the learning rate so it can find the best minima around the global minima until it's fully stable. There are many approaches. A very basic one is to drop the learning rate by a set amount at specific epochs. You would then run a few training sessions to find which epochs are good choices to reduce the learning rate.
E.g. when I originally trained my _final_model I used to classify our data I got the following curves for train/test accuracy/loss. The labels indicate if the training data was raw or normalized (as in #588). The number following that indicates the epoch when the learning rate was dropped. E.g.
Norm 45means the training data was normalized and that the learning rate was dropped from 1e-3 to 1e-4 at epoch 45. No decay means the learning rate stayed the same (although I typically stopped those early)The following is the plotted data from that training (data augmentation was
25%usingresnet50_tvusing a batch size of32and with--continue-training). Based on that I chose the model that used normalized data and with learning rate decay at epoch 45 - that had the best results while minimizing over-fitting:What does this PR do?
This adds the following two parameters to the training code (copied from the docs):
lr_schedule : list of ints: If not empty, the list of epochs when to multiply the current learning rate by thelr_multiplier. E.g. if it's[10, 25], we start with a learning rate of0.001, andlr_multiplieris0.1, then the LR will be0.001for epochs 0-9,0.0001for 10-24, and0.00001for epoch 25 and beyond.lr_multiplier : float: The multiplier by which to multiply the previous learning rate at the epochs listed inlr_schedule.I ran a few replicates to see how replicable modifying the learning rate is and here are the plots. Raw/Norm refer to whether the training data was normalized,
f=nrefers to the percent data used for testing so 50% simulates low training data, andN=nrefers to the number of lines in the plot indicating the replicates.References
#493.
How has this PR been tested?
I added tests and tested it on my own models.
Is this a breaking change?
No. By default the options are off.
Does this PR require an update to the documentation?
There are new parameters.
Checklist: