-
Notifications
You must be signed in to change notification settings - Fork 0
[Droid] Implement SAF Algorithm #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| "# SAF Algorithm | ||
|
|
||
| The SAF (Stochastic Average Flatness) algorithm is a training algorithm designed to find flat minima in the loss landscape of a neural network. The algorithm is implemented in the `SAF` class in the `saf.py` file. | ||
|
|
||
| ## Description | ||
|
|
||
| The SAF algorithm works by adjusting the learning rate and using a temperature parameter to control the sharpness of the minima found. The algorithm also uses an Exponential Moving Average (EMA) to smooth the training process. | ||
|
|
||
| ## Usage | ||
|
|
||
| To use the SAF algorithm, you need to create an instance of the `SAF` class and pass the necessary parameters to the constructor. Here is an example: | ||
|
|
||
| ```python | ||
| saf = SAF( | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be beneficial to include more details about the expected types and shapes of the parameters for the |
||
| training_set=training_set, | ||
| network=network, | ||
| learning_rate=0.01, | ||
| epochs=10, | ||
| iterations_per_epoch=100, | ||
| saf_starting_epoch=5, | ||
| saf_coefficients=[0.1, 0.2, 0.3], | ||
| temperature=0.1, | ||
| saf_hyperparameter=0.5, | ||
| ema_decay_factor=0.9 | ||
| ) | ||
| ``` | ||
|
|
||
| In this example, `training_set` is the dataset to be used for training, `network` is the neural network whose weights are to be optimized, `learning_rate` is the learning rate for the optimization, `epochs` is the number of epochs for the training, `iterations_per_epoch` is the number of iterations per epoch, `saf_starting_epoch` is the epoch at which to start SAF, `saf_coefficients` are the coefficients for the SAF algorithm, `temperature` is the temperature parameter for the SAF algorithm, `saf_hyperparameter` is the hyperparameter for the SAF algorithm, and `ema_decay_factor` is the decay factor for the EMA. | ||
|
|
||
| After creating the `SAF` instance, you can call the `forward` method to start the training process: | ||
|
|
||
| ```python | ||
| saf.forward() | ||
| ``` | ||
|
|
||
| Please note that the `forward` method is currently a placeholder and will be implemented in the future. | ||
|
|
||
| ## Testing | ||
|
|
||
| The `test_saf.py` file contains tests for the SAF algorithm. You can run these tests to verify the correct operation of the algorithm. | ||
|
|
||
| ## Future Work | ||
|
|
||
| The `forward` method of the `SAF` class is currently a placeholder. In the future, this method will be implemented to perform the actual training process according to the SAF algorithm." | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| # Copyright 2021 MosaicML. All Rights Reserved. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Optional, Type, Union | ||
|
|
||
| import torch | ||
|
|
||
| from composer.core import Algorithm, Event, State | ||
| from composer.loggers import Logger | ||
|
|
||
| class SAF(Algorithm): | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| """Implements the SAF algorithm. | ||
|
|
||
| Args: | ||
| training_set: The training set to be used. | ||
| network: The network with weights to be optimized. | ||
| learning_rate: The learning rate for the optimization. | ||
| epochs: The number of epochs for the training. | ||
| iterations_per_epoch: The number of iterations per epoch. | ||
| saf_starting_epoch: The epoch at which to start SAF. | ||
| saf_coefficients: The coefficients for the SAF algorithm. | ||
| temperature: The temperature parameter for the SAF algorithm. | ||
| saf_hyperparameter: The hyperparameter for the SAF algorithm. | ||
| ema_decay_factor: The decay factor for the Exponential Moving Average (EMA). | ||
| """ | ||
|
|
||
| def __init__(self, training_set, network, learning_rate, epochs, iterations_per_epoch, saf_starting_epoch, saf_coefficients, temperature, saf_hyperparameter, ema_decay_factor): | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding type hints for the parameters in the
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| self.training_set = training_set | ||
| self.network = network | ||
| self.learning_rate = learning_rate | ||
| self.epochs = epochs | ||
| self.iterations_per_epoch = iterations_per_epoch | ||
| self.saf_starting_epoch = saf_starting_epoch | ||
| self.saf_coefficients = saf_coefficients | ||
| self.temperature = temperature | ||
| self.saf_hyperparameter = saf_hyperparameter | ||
| self.ema_decay_factor = ema_decay_factor | ||
|
|
||
| def forward(self): | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| # Steps 1 to 9 of the SAF algorithm will be implemented here. | ||
| # This method will output a flat minimum solution. | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| # Copyright 2021 MosaicML. All Rights Reserved. | ||
|
|
||
| import pytest | ||
| from composer.algorithms.saf.saf import SAF | ||
|
|
||
| def test_saf_initialization(): | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tests are currently only checking the initialization of the |
||
| # Test initialization of SAF class | ||
| saf = SAF( | ||
| training_set=None, | ||
| network=None, | ||
| learning_rate=0.01, | ||
| epochs=10, | ||
| iterations_per_epoch=100, | ||
| saf_starting_epoch=5, | ||
| saf_coefficients=[0.1, 0.2, 0.3], | ||
| temperature=0.1, | ||
| saf_hyperparameter=0.5, | ||
| ema_decay_factor=0.9 | ||
| ) | ||
| assert isinstance(saf, SAF) | ||
| assert saf.training_set is None | ||
| assert saf.network is None | ||
| assert saf.learning_rate == 0.01 | ||
| assert saf.epochs == 10 | ||
| assert saf.iterations_per_epoch == 100 | ||
| assert saf.saf_starting_epoch == 5 | ||
| assert saf.saf_coefficients == [0.1, 0.2, 0.3] | ||
| assert saf.temperature == 0.1 | ||
| assert saf.saf_hyperparameter == 0.5 | ||
| assert saf.ema_decay_factor == 0.9 | ||
|
|
||
| def test_saf_forward(): | ||
| # Test forward method of SAF class | ||
| saf = SAF( | ||
| training_set=None, | ||
| network=None, | ||
| learning_rate=0.01, | ||
| epochs=10, | ||
| iterations_per_epoch=100, | ||
| saf_starting_epoch=5, | ||
| saf_coefficients=[0.1, 0.2, 0.3], | ||
| temperature=0.1, | ||
| saf_hyperparameter=0.5, | ||
| ema_decay_factor=0.9 | ||
| ) | ||
| # As the forward method is not implemented yet, it should not raise any exception | ||
| try: | ||
| saf.forward() | ||
| except Exception as e: | ||
| pytest.fail(f"SAF forward method raised an exception: {e}") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure that the
SAFclass is correctly registered and can be retrieved from the registry as expected. This could be verified with a unit test.