Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions composer/algorithms/algorithm_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SAMHparams, ScaleScheduleHparams, SelectiveBackpropHparams,
SeqLengthWarmupHparams, SqueezeExciteHparams, StochasticDepthHparams,
SWAHparams)
from composer.algorithms.saf.saf import SAF
from composer.core.algorithm import Algorithm

registry: Dict[str, Type[AlgorithmHparams]] = {
Expand All @@ -35,6 +36,7 @@
'sam': SAMHparams,
'alibi': AlibiHparams,
'selective_backprop': SelectiveBackpropHparams,
'saf': SAF,

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure that the SAF class is correctly registered and can be retrieved from the registry as expected. This could be verified with a unit test.

}


Expand Down
44 changes: 44 additions & 0 deletions composer/algorithms/saf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"# SAF Algorithm

The SAF (Stochastic Average Flatness) algorithm is a training algorithm designed to find flat minima in the loss landscape of a neural network. The algorithm is implemented in the `SAF` class in the `saf.py` file.

## Description

The SAF algorithm works by adjusting the learning rate and using a temperature parameter to control the sharpness of the minima found. The algorithm also uses an Exponential Moving Average (EMA) to smooth the training process.

## Usage

To use the SAF algorithm, you need to create an instance of the `SAF` class and pass the necessary parameters to the constructor. Here is an example:

```python
saf = SAF(

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be beneficial to include more details about the expected types and shapes of the parameters for the SAF class in the README.md file. This would make it easier for users to correctly use the class.

training_set=training_set,
network=network,
learning_rate=0.01,
epochs=10,
iterations_per_epoch=100,
saf_starting_epoch=5,
saf_coefficients=[0.1, 0.2, 0.3],
temperature=0.1,
saf_hyperparameter=0.5,
ema_decay_factor=0.9
)
```

In this example, `training_set` is the dataset to be used for training, `network` is the neural network whose weights are to be optimized, `learning_rate` is the learning rate for the optimization, `epochs` is the number of epochs for the training, `iterations_per_epoch` is the number of iterations per epoch, `saf_starting_epoch` is the epoch at which to start SAF, `saf_coefficients` are the coefficients for the SAF algorithm, `temperature` is the temperature parameter for the SAF algorithm, `saf_hyperparameter` is the hyperparameter for the SAF algorithm, and `ema_decay_factor` is the decay factor for the EMA.

After creating the `SAF` instance, you can call the `forward` method to start the training process:

```python
saf.forward()
```

Please note that the `forward` method is currently a placeholder and will be implemented in the future.

## Testing

The `test_saf.py` file contains tests for the SAF algorithm. You can run these tests to verify the correct operation of the algorithm.

## Future Work

The `forward` method of the `SAF` class is currently a placeholder. In the future, this method will be implemented to perform the actual training process according to the SAF algorithm."
43 changes: 43 additions & 0 deletions composer/algorithms/saf/saf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2021 MosaicML. All Rights Reserved.

from __future__ import annotations

from typing import Optional, Type, Union

import torch

from composer.core import Algorithm, Event, State
from composer.loggers import Logger

class SAF(Algorithm):

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The saf.py file does not include any docstrings for the SAF class or its methods. Adding docstrings would provide valuable context about what the class is and how its methods should be used.

"""Implements the SAF algorithm.

Args:
training_set: The training set to be used.
network: The network with weights to be optimized.
learning_rate: The learning rate for the optimization.
epochs: The number of epochs for the training.
iterations_per_epoch: The number of iterations per epoch.
saf_starting_epoch: The epoch at which to start SAF.
saf_coefficients: The coefficients for the SAF algorithm.
temperature: The temperature parameter for the SAF algorithm.
saf_hyperparameter: The hyperparameter for the SAF algorithm.
ema_decay_factor: The decay factor for the Exponential Moving Average (EMA).
"""

def __init__(self, training_set, network, learning_rate, epochs, iterations_per_epoch, saf_starting_epoch, saf_coefficients, temperature, saf_hyperparameter, ema_decay_factor):

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding type hints for the parameters in the __init__ method. This would improve code readability and maintainability.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SAF class currently accepts a large number of parameters in its constructor. Consider refactoring the class to use a configuration object or similar pattern to manage these parameters.

self.training_set = training_set
self.network = network
self.learning_rate = learning_rate
self.epochs = epochs
self.iterations_per_epoch = iterations_per_epoch
self.saf_starting_epoch = saf_starting_epoch
self.saf_coefficients = saf_coefficients
self.temperature = temperature
self.saf_hyperparameter = saf_hyperparameter
self.ema_decay_factor = ema_decay_factor

def forward(self):

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward method is currently a placeholder. It's crucial to implement this method for the functionality of the SAF algorithm.

# Steps 1 to 9 of the SAF algorithm will be implemented here.
# This method will output a flat minimum solution.
pass
50 changes: 50 additions & 0 deletions tests/algorithms/test_saf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2021 MosaicML. All Rights Reserved.

import pytest
from composer.algorithms.saf.saf import SAF

def test_saf_initialization():

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are currently only checking the initialization of the SAF class and the absence of exceptions when calling the forward method. More comprehensive tests should be added to ensure the correctness of the SAF algorithm.

# Test initialization of SAF class
saf = SAF(
training_set=None,
network=None,
learning_rate=0.01,
epochs=10,
iterations_per_epoch=100,
saf_starting_epoch=5,
saf_coefficients=[0.1, 0.2, 0.3],
temperature=0.1,
saf_hyperparameter=0.5,
ema_decay_factor=0.9
)
assert isinstance(saf, SAF)
assert saf.training_set is None
assert saf.network is None
assert saf.learning_rate == 0.01
assert saf.epochs == 10
assert saf.iterations_per_epoch == 100
assert saf.saf_starting_epoch == 5
assert saf.saf_coefficients == [0.1, 0.2, 0.3]
assert saf.temperature == 0.1
assert saf.saf_hyperparameter == 0.5
assert saf.ema_decay_factor == 0.9

def test_saf_forward():
# Test forward method of SAF class
saf = SAF(
training_set=None,
network=None,
learning_rate=0.01,
epochs=10,
iterations_per_epoch=100,
saf_starting_epoch=5,
saf_coefficients=[0.1, 0.2, 0.3],
temperature=0.1,
saf_hyperparameter=0.5,
ema_decay_factor=0.9
)
# As the forward method is not implemented yet, it should not raise any exception
try:
saf.forward()
except Exception as e:
pytest.fail(f"SAF forward method raised an exception: {e}")