TinyVGG-inspired CNN architecture for solving CIFAR10.
Baseline Model: TinyVGG with data normalisation and early stopping, trained in mini-batches with Adam optimizer (run 1). Model code. The dataset is balanced, so accuracy is used as primary training metric.
Additional experiments:
Validation results are compared across runs, and the best model is selected for evaluation on the test set. See Running Custom Experiments for instructions to reproduce these results.
Screenshot from the simple dashboard for run comparison in compare_runs.ipynb:
The experiment/run results are provided in the metrics directory as .json files , and can be used to run the dashboard.
Validation results of the best model for each run:
| Run | Train Loss | Val Loss | Train Accuracy | Val Accuracy | Train Time (s) | Epochs |
|---|---|---|---|---|---|---|
| Run 3 | 0.45 | 0.57 | 0.84 | 0.81 | 671.11 | 20 |
| Run 4 | 0.71 | 0.60 | 0.75 | 0.79 | 670.73 | 19 |
| Run 2 | 0.60 | 0.66 | 0.79 | 0.77 | 529.21 | 16 |
| Run 1 | 0.54 | 0.77 | 0.81 | 0.74 | 148.32 | 5 |
Run 3 was chosen and evaluated on the test set being the most performant. Note that Run 3 didn't trigger early stopping, so further training would likely be beneficial. Similarly, run 4 did end with early stopping (just barely), but since the dropout layers make learning more difficult, it might also benefit from longer training, perhaps with an adjusted learning rate. This could also prove not to be true since the TinyVGG architecture is quite simple meaning that dropout could just hinder it as it doesn't have enough capacity/parameters to overfit the data.
TinyVGG -- experiment setup 3
| Dataset | Accuracy | Loss |
|---|---|---|
| Train | 0.84 | 0.45 |
| Validation | 0.81 | 0.57 |
| Test | 0.83 | 0.52 |
83% accuracy is decent for such a small model. The current state of the art is Giant AmoebaNet with GPipe, achieving 99% accuracy on CIFAR-10. So in practice, CIFAR-10 is essentially “solved”. However, that model has hundreds of millions of parameters, while this TinyVGG architecture has only around ~154k parameters. (Ref).
The model performs worst on (in order from worst):
- cats – confused with dogs
- airplanes – confused with birds and ships
- horses – confused with dogs and deer
- dogs – confused with cats
The model gets most confused between cats and dogs, and it's understandable, as sometimes they do look quite alike:
Also, sometimes there can be data quality issues; can you tell that this is a ship?
Experiment more:
- Hidden units
- Learning rate
- Batch size
Implementation ideas:
- Continued training from saved checkpoints of the most promising models.
- Learning rate scheduling
- Different optimisers
- Other model architectures, e.g., miniResNet
The env used for developing this project was managed by Conda and is provided in env.lock.yaml. You can also clone the repo on a managed platform like Colab or replicate the dependencies manually. On Colab you need to install torchmetrics separately.
Instructions for Colab:
- clone the repo
- Install
torchmetrics. In a code cell:!pip install torchmetrics - Set the repo as root, so that the imports work. In a code cell:
%cd /content/cifar10
- Define your run/experiment. Use the provided
.yamlconfiguration files for reference. - Run the
train_model.pyscript with the appropriate configuration file. e.g.
python train_model.py --config ./run_config/run_1.yaml- this will save model metrics into the
metricsdirectory and save model checkpoints into thecheckpointsdirectory. - you can then compare multiple runs with
compare_runs.ipynb, which will read the metrics from themetricsdirectory and plot the results. - choose the best model according to the validation score and evaluate it on the test set with
final_results.ipynb.
The data should be automatically downloaded when you run the training script if not present in the data folder. You can also use utils/data_loading.py if you wish to manually load the data for exploration or experimentation. This is the module that the training script calls internally.


