diff --git a/composer/algorithms/algorithm_registry.py b/composer/algorithms/algorithm_registry.py index 76c68c563a..1d159a0a15 100644 --- a/composer/algorithms/algorithm_registry.py +++ b/composer/algorithms/algorithm_registry.py @@ -10,6 +10,7 @@ SAMHparams, ScaleScheduleHparams, SelectiveBackpropHparams, SeqLengthWarmupHparams, SqueezeExciteHparams, StochasticDepthHparams, SWAHparams) +from composer.algorithms.saf.saf import SAF from composer.core.algorithm import Algorithm registry: Dict[str, Type[AlgorithmHparams]] = { @@ -35,6 +36,7 @@ 'sam': SAMHparams, 'alibi': AlibiHparams, 'selective_backprop': SelectiveBackpropHparams, + 'saf': SAF, } diff --git a/composer/algorithms/saf/README.md b/composer/algorithms/saf/README.md new file mode 100644 index 0000000000..66150831c7 --- /dev/null +++ b/composer/algorithms/saf/README.md @@ -0,0 +1,44 @@ +"# SAF Algorithm + +The SAF (Stochastic Average Flatness) algorithm is a training algorithm designed to find flat minima in the loss landscape of a neural network. The algorithm is implemented in the `SAF` class in the `saf.py` file. + +## Description + +The SAF algorithm works by adjusting the learning rate and using a temperature parameter to control the sharpness of the minima found. The algorithm also uses an Exponential Moving Average (EMA) to smooth the training process. + +## Usage + +To use the SAF algorithm, you need to create an instance of the `SAF` class and pass the necessary parameters to the constructor. Here is an example: + +```python +saf = SAF( + training_set=training_set, + network=network, + learning_rate=0.01, + epochs=10, + iterations_per_epoch=100, + saf_starting_epoch=5, + saf_coefficients=[0.1, 0.2, 0.3], + temperature=0.1, + saf_hyperparameter=0.5, + ema_decay_factor=0.9 +) +``` + +In this example, `training_set` is the dataset to be used for training, `network` is the neural network whose weights are to be optimized, `learning_rate` is the learning rate for the optimization, `epochs` is the number of epochs for the training, `iterations_per_epoch` is the number of iterations per epoch, `saf_starting_epoch` is the epoch at which to start SAF, `saf_coefficients` are the coefficients for the SAF algorithm, `temperature` is the temperature parameter for the SAF algorithm, `saf_hyperparameter` is the hyperparameter for the SAF algorithm, and `ema_decay_factor` is the decay factor for the EMA. + +After creating the `SAF` instance, you can call the `forward` method to start the training process: + +```python +saf.forward() +``` + +Please note that the `forward` method is currently a placeholder and will be implemented in the future. + +## Testing + +The `test_saf.py` file contains tests for the SAF algorithm. You can run these tests to verify the correct operation of the algorithm. + +## Future Work + +The `forward` method of the `SAF` class is currently a placeholder. In the future, this method will be implemented to perform the actual training process according to the SAF algorithm." diff --git a/composer/algorithms/saf/saf.py b/composer/algorithms/saf/saf.py new file mode 100644 index 0000000000..f48dae41fb --- /dev/null +++ b/composer/algorithms/saf/saf.py @@ -0,0 +1,43 @@ +# Copyright 2021 MosaicML. All Rights Reserved. + +from __future__ import annotations + +from typing import Optional, Type, Union + +import torch + +from composer.core import Algorithm, Event, State +from composer.loggers import Logger + +class SAF(Algorithm): + """Implements the SAF algorithm. + + Args: + training_set: The training set to be used. + network: The network with weights to be optimized. + learning_rate: The learning rate for the optimization. + epochs: The number of epochs for the training. + iterations_per_epoch: The number of iterations per epoch. + saf_starting_epoch: The epoch at which to start SAF. + saf_coefficients: The coefficients for the SAF algorithm. + temperature: The temperature parameter for the SAF algorithm. + saf_hyperparameter: The hyperparameter for the SAF algorithm. + ema_decay_factor: The decay factor for the Exponential Moving Average (EMA). + """ + + def __init__(self, training_set, network, learning_rate, epochs, iterations_per_epoch, saf_starting_epoch, saf_coefficients, temperature, saf_hyperparameter, ema_decay_factor): + self.training_set = training_set + self.network = network + self.learning_rate = learning_rate + self.epochs = epochs + self.iterations_per_epoch = iterations_per_epoch + self.saf_starting_epoch = saf_starting_epoch + self.saf_coefficients = saf_coefficients + self.temperature = temperature + self.saf_hyperparameter = saf_hyperparameter + self.ema_decay_factor = ema_decay_factor + + def forward(self): + # Steps 1 to 9 of the SAF algorithm will be implemented here. + # This method will output a flat minimum solution. + pass diff --git a/tests/algorithms/test_saf.py b/tests/algorithms/test_saf.py new file mode 100644 index 0000000000..f577e415bd --- /dev/null +++ b/tests/algorithms/test_saf.py @@ -0,0 +1,50 @@ +# Copyright 2021 MosaicML. All Rights Reserved. + +import pytest +from composer.algorithms.saf.saf import SAF + +def test_saf_initialization(): + # Test initialization of SAF class + saf = SAF( + training_set=None, + network=None, + learning_rate=0.01, + epochs=10, + iterations_per_epoch=100, + saf_starting_epoch=5, + saf_coefficients=[0.1, 0.2, 0.3], + temperature=0.1, + saf_hyperparameter=0.5, + ema_decay_factor=0.9 + ) + assert isinstance(saf, SAF) + assert saf.training_set is None + assert saf.network is None + assert saf.learning_rate == 0.01 + assert saf.epochs == 10 + assert saf.iterations_per_epoch == 100 + assert saf.saf_starting_epoch == 5 + assert saf.saf_coefficients == [0.1, 0.2, 0.3] + assert saf.temperature == 0.1 + assert saf.saf_hyperparameter == 0.5 + assert saf.ema_decay_factor == 0.9 + +def test_saf_forward(): + # Test forward method of SAF class + saf = SAF( + training_set=None, + network=None, + learning_rate=0.01, + epochs=10, + iterations_per_epoch=100, + saf_starting_epoch=5, + saf_coefficients=[0.1, 0.2, 0.3], + temperature=0.1, + saf_hyperparameter=0.5, + ema_decay_factor=0.9 + ) + # As the forward method is not implemented yet, it should not raise any exception + try: + saf.forward() + except Exception as e: + pytest.fail(f"SAF forward method raised an exception: {e}")