Skip to content

Siniara/cifar10

Repository files navigation

Python PyTorch TorchVision TorchMetrics CUDA

Simple CNN for CIFAR10

TinyVGG-inspired CNN architecture for solving CIFAR10.

Design

Baseline Model: TinyVGG with data normalisation and early stopping, trained in mini-batches with Adam optimizer (run 1). Model code. The dataset is balanced, so accuracy is used as primary training metric.

Additional experiments:

  • + Data augmentation (run 2)
  • + Batch normalisation (run 3)
  • + Dropout (run 4)

Validation results are compared across runs, and the best model is selected for evaluation on the test set. See Running Custom Experiments for instructions to reproduce these results.

Results

Validation

Screenshot from the simple dashboard for run comparison in compare_runs.ipynb: compare_runs.png The experiment/run results are provided in the metrics directory as .json files , and can be used to run the dashboard.

Validation results of the best model for each run:

Run Train Loss Val Loss Train Accuracy Val Accuracy Train Time (s) Epochs
Run 3 0.45 0.57 0.84 0.81 671.11 20
Run 4 0.71 0.60 0.75 0.79 670.73 19
Run 2 0.60 0.66 0.79 0.77 529.21 16
Run 1 0.54 0.77 0.81 0.74 148.32 5

Run 3 was chosen and evaluated on the test set being the most performant. Note that Run 3 didn't trigger early stopping, so further training would likely be beneficial. Similarly, run 4 did end with early stopping (just barely), but since the dropout layers make learning more difficult, it might also benefit from longer training, perhaps with an adjusted learning rate. This could also prove not to be true since the TinyVGG architecture is quite simple meaning that dropout could just hinder it as it doesn't have enough capacity/parameters to overfit the data.

Test

TinyVGG -- experiment setup 3

Dataset Accuracy Loss
Train 0.84 0.45
Validation 0.81 0.57
Test 0.83 0.52

conf_matrix.png

83% accuracy is decent for such a small model. The current state of the art is Giant AmoebaNet with GPipe, achieving 99% accuracy on CIFAR-10. So in practice, CIFAR-10 is essentially “solved”. However, that model has hundreds of millions of parameters, while this TinyVGG architecture has only around ~154k parameters. (Ref).

The model performs worst on (in order from worst):

  • cats – confused with dogs
  • airplanes – confused with birds and ships
  • horses – confused with dogs and deer
  • dogs – confused with cats

The model gets most confused between cats and dogs, and it's understandable, as sometimes they do look quite alike:

catndog.png

Also, sometimes there can be data quality issues; can you tell that this is a ship?

ship

Improvements

Experiment more:

  • Hidden units
  • Learning rate
  • Batch size

Implementation ideas:

  • Continued training from saved checkpoints of the most promising models.
  • Learning rate scheduling
  • Different optimisers
  • Other model architectures, e.g., miniResNet

Running Custom Experiments

ENV

The env used for developing this project was managed by Conda and is provided in env.lock.yaml. You can also clone the repo on a managed platform like Colab or replicate the dependencies manually. On Colab you need to install torchmetrics separately.

Instructions for Colab:

  • clone the repo
  • Install torchmetrics. In a code cell: !pip install torchmetrics
  • Set the repo as root, so that the imports work. In a code cell: %cd /content/cifar10

Workflow

  • Define your run/experiment. Use the provided .yaml configuration files for reference.
  • Run the train_model.py script with the appropriate configuration file. e.g.
python train_model.py --config ./run_config/run_1.yaml
  • this will save model metrics into the metrics directory and save model checkpoints into the checkpoints directory.
  • you can then compare multiple runs with compare_runs.ipynb, which will read the metrics from the metrics directory and plot the results.
  • choose the best model according to the validation score and evaluate it on the test set with final_results.ipynb.

The data should be automatically downloaded when you run the training script if not present in the data folder. You can also use utils/data_loading.py if you wish to manually load the data for exploration or experimentation. This is the module that the training script calls internally.

About

Tiny custom CNN for CIFAR10 classification.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published