Skip to content

mrpouyaalavi/CIFAR-10-Image-Classification

Repository files navigation

title CIFAR-10 Image Classification
emoji 🧠
colorFrom blue
colorTo purple
sdk gradio
sdk_version 5.29.0
app_file app.py
pinned false
license mit

Typing SVG

License: MIT Python PyTorch scikit-learn Gradio Hugging Face Jupyter


CIFAR-10 Image Classification β€” From Training to Deployment

An end-to-end deep learning project that designs, trains, and evaluates multiple architectures on the CIFAR-10 benchmark, demonstrating the effectiveness of transfer learning compared with a custom CNN baseline across multiple architectures.

This project is a portfolio-grade machine learning study that goes beyond model training. It includes data augmentation pipelines (RandomCrop, CutOut, MixUp, CutMix), cosine annealing learning rate scheduling, progressive unfreezing, INT8 model quantisation, Grad-CAM interpretability visualisations, CLI inference tools, and a Gradio demo deployed on Hugging Face Spaces β€” all documented in a structured Jupyter notebook.

πŸ““ Explore the Notebook Β Β·Β  πŸš€ Live Demo Β Β·Β  πŸ“Š Key Results



🎯 Motivation & Research Question

How much does a pretrained backbone actually help compared to training from scratch when models are evaluated under controlled conditions?

Deep learning practitioners often default to transfer learning without quantifying its advantage under comparable settings. This project answers that question through a controlled experiment: identical dataset, optimiser family, learning-rate scheduling, epoch budget, and augmentation strategy, with architecture and transfer-learning strategy as the key variables.

The results have direct implications for:

  • Model selection in resource-constrained environments
  • Training efficiency when labelled data is limited
  • Deployment strategy when balancing latency, size, and accuracy


πŸ“Š Key Results & Performance Benchmarks

Metric Custom CNN MobileNetV2 ResNet-18 Winner
Test Accuracy 48.40% 86.91% 87.48% πŸ† ResNet-18
Trainable Params 2,462,282 12,810 5,130 πŸ† ResNet-18
Model Size 9.42 MB 8.76 MB 44.80 MB πŸ† MobileNetV2
CPU Latency (batch 1) 1.38 ms 17.22 ms 9.80 ms πŸ† Custom CNN
Throughput ~724 FPS ~58 FPS ~102 FPS πŸ† Custom CNN

All numbers are measured empirically from the final evaluation checkpoints on the full 10,000-image CIFAR-10 test set (Custom CNN & MobileNetV2 verified 2026-04-10; ResNet-18 retrained and verified 2026-04-18 via cached-features linear probe; Custom CNN latency re-measured 2026-04-11 with 100 trials on Apple silicon CPU, batch size 1).

Key finding: ResNet-18 achieves 87.48% accuracy with just 0.2% of the Custom CNN's trainable parameters β€” a +39.1 percentage-point lift for a 480Γ— reduction in trainable weights. MobileNetV2 lands within a fraction of a point at 86.91% with a different parameter/latency trade-off.

Training Progression β€” Convergence Comparison

Epoch   Custom CNN (Val Acc)     MobileNetV2 (Val Acc)     ResNet-18 (Val Acc)
─────   ────────────────────     ─────────────────────     ───────────────────
  1          21.0%                    85.88%                    84.21%
  2          27.8%                    86.80%                    85.64%
  3          32.4%                    86.91%                    86.27%
  …            …                        …                          …
 15          48.40%                   86.91%                    87.16%
 30           β€”                        β€”                        87.48%

Both transfer-learning models (MobileNetV2 and ResNet-18) reach strong accuracy within 1–3 epochs and plateau quickly, because their frozen backbones already encode powerful ImageNet features. The Custom CNN is still improving across the full 15-epoch budget, highlighting the value of pretrained feature representations.



🧠 Model Architectures

Custom CNN β€” 4-Block Design (Trained From Scratch)

Input (3 Γ— 32 Γ— 32)
  β”‚
  β”œβ”€β”€ Block 1: Conv(3β†’64) Γ—2 β†’ BatchNorm β†’ ReLU β†’ MaxPool β†’ Dropout(0.25)
  β”œβ”€β”€ Block 2: Conv(64β†’128) Γ—2 β†’ BatchNorm β†’ ReLU β†’ MaxPool β†’ Dropout(0.25)
  β”œβ”€β”€ Block 3: Conv(128β†’256) Γ—2 β†’ BatchNorm β†’ ReLU β†’ MaxPool β†’ Dropout(0.25)
  β”œβ”€β”€ Block 4: Conv(256β†’512) β†’ BatchNorm β†’ ReLU β†’ AdaptiveAvgPool
  β”‚
  β”œβ”€β”€ Flatten β†’ Dropout(0.5) β†’ FC(512β†’256) β†’ ReLU
  └── Dropout(0.5) β†’ FC(256β†’10) β†’ Output

Design decisions:

  • Kaiming He initialisation for stable gradient flow
  • Dual convolutions per block before downsampling
  • Global Average Pooling to reduce classifier size
  • Aggressive dropout for regularisation

MobileNetV2 β€” Transfer Learning

Pretrained MobileNetV2 backbone (frozen)
  β”‚
  └── Classifier Head: Dropout(0.2) β†’ Linear(1280 β†’ 10)

Strategy: Freeze the feature extractor and train a lightweight classification head, then apply progressive unfreezing for stronger adaptation.

ResNet-18 β€” Transfer Learning

Pretrained ResNet-18 backbone (frozen)
  β”‚
  └── FC Head: Linear(512 β†’ 10)

Strategy: Replace and train only the final fully connected layer. This model achieves strong accuracy with the fewest trainable parameters among the deployed models.



πŸ” Error Analysis & Confusion Patterns

Both models consistently confuse visually similar classes, but MobileNetV2 makes significantly fewer mistakes:

Confusion Pair Custom CNN Errors MobileNetV2 Errors Error Reduction Root Cause
🚚 Truck ↔ πŸš— Automobile 432 97 78% Similar vehicle structure at 32Γ—32
🚒 Ship ↔ ✈️ Airplane 375 83 78% Shared background cues
🐱 Cat ↔ πŸ• Dog 333 243 27% Fine-grained mammal similarity
🐴 Horse ↔ πŸ• Dog 293 68 77% Quadruped shape overlap
🐦 Bird ↔ 🦌 Deer 180 78 57% Challenging low-resolution silhouettes

Measured empirically from confusion matrices on the full CIFAR-10 test set.



✨ Key Features

╔══════════════════════════════════════════════════════════════════════════════╗
β•‘  🧠  3 deployed models in the live demo Β· 5 architectures explored in notebookβ•‘
β•‘  πŸ“ˆ  Full training pipeline with cosine annealing and progressive unfreezing β•‘
β•‘  🎲  Advanced augmentation: RandomCrop, CutOut, MixUp, CutMix                β•‘
β•‘  πŸ”¬  Grad-CAM interpretability for visual model explanations                 β•‘
β•‘  ⚑  INT8 dynamic quantisation experiments for deployment analysis            β•‘
β•‘  πŸ“Š  Confusion matrices, training curves, and efficiency benchmarks          β•‘
β•‘  πŸ–₯️  Gradio demo on HF Spaces with interactive image classification          β•‘
β•‘  πŸ› οΈ  CLI inference tools for single image, batch, and test-set evaluation    β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•


πŸ—οΈ Technical Architecture & Training Configuration

Runtime Stack

Layer Technology
Deep Learning PyTorch 2.0+
Pretrained Models torchvision (MobileNetV2, ResNet-18, EfficientNet-B0)
Dataset CIFAR-10 (60K images, 10 classes)
Evaluation scikit-learn (classification reports, confusion matrices)
Visualization Matplotlib, Seaborn
Interpretability Grad-CAM with PyTorch hooks
Demo App Gradio on Hugging Face Spaces
Environment Jupyter Notebook, Python 3.11
Hardware Auto-detected: CUDA / Apple Silicon MPS / CPU

Training Hyperparameters

The core training budget was kept consistent across experiments, with architecture-specific adaptations where required:

Optimiser      : Adam
Learning Rate  : 0.001 (with Cosine Annealing decay)
Weight Decay   : 1e-4
Batch Size     : 128
Epochs         : 15
Loss Function  : CrossEntropyLoss
Training Set   : 50,000 images
Test Set       : 10,000 images
Augmentation   : RandomCrop(32,4), HFlip, CutOut(16), MixUp, CutMix
Random Seed    : 42


πŸ““ Notebook Walkthrough β€” 14-Section ML Pipeline

# Section Description
1 Environment & Config Seed setup, device detection, hyperparameter configuration
2 Data Preparation & Augmentation Dataset loading and augmentation pipeline
3 MixUp & CutMix Batch-level augmentation experiments
4 Model Architectures Custom CNN, MobileNetV2, ResNet-18, EfficientNet-B0, Vision Transformer
5 Training Pipeline Unified training loop with cosine annealing and AMP support
6 Train All Models Controlled comparisons across architectures
7 Progressive Unfreezing MobileNetV2 fine-tuning schedule
8 Test Set Evaluation Full test-set accuracy and class-level metrics
9 Confusion Matrices Side-by-side error analysis
10 Training Curves Loss, accuracy, and LR schedule visualisation
11 Error Analysis Misclassification deep-dive
12 Efficiency Benchmarks Parameters, size, latency, and throughput
13 Model Quantization INT8 quantisation experiments
14 Save Artifacts Export config, results, and metadata


πŸš€ Getting Started

Prerequisites

  • Python 3.11 (as specified in runtime.txt)
  • pip or conda
  • GPU recommended, but not required

Installation

git clone https://github.com/mrpouyaalavi/CIFAR-10-Image-Classification.git
cd CIFAR-10-Image-Classification

python -m venv .venv
source .venv/bin/activate        # macOS / Linux
# .venv\Scripts\activate         # Windows

pip install -r requirements.txt

Run the Notebook

jupyter notebook "cifar10 image classification.ipynb"

The CIFAR-10 dataset is downloaded automatically on first run via torchvision.datasets.



πŸ’» Usage

CLI Inference

python predict.py --test-samples 10 --model all
python predict.py --image path/to/image.png --model mobilenet
python predict.py --image-dir path/to/images/ --model all --save results/predictions.png

Grad-CAM Visualisations

python gradcam.py --model all --num-images 6
python gradcam.py --model all --image-index 0 42 100 --save results/gradcam/

Gradio Demo App

python app.py

Upload an image or select an example to compare deployed models with live predictions and confidence rankings.

The live demo is hosted on Hugging Face Spaces.



πŸ“ Project Structure

CIFAR-10-Image-Classification/
β”‚
β”œβ”€β”€ app.py                                # Gradio demo (HF Spaces entry point)
β”œβ”€β”€ model_utils.py                        # Shared model architectures & inference
β”œβ”€β”€ benchmark_data.py                     # Canonical benchmark metrics
β”œβ”€β”€ predict.py                            # CLI inference tools
β”œβ”€β”€ gradcam.py                            # Grad-CAM visualisations
β”‚
β”œβ”€β”€ cifar10 image classification.ipynb    # Main notebook
β”‚
β”œβ”€β”€ scripts/                              # Retraining & measurement scripts
β”‚   β”œβ”€β”€ retrain_custom_cnn.py             #   Custom CNN training run
β”‚   β”œβ”€β”€ retrain_mobilenetv2.py            #   MobileNetV2 frozen-backbone training
β”‚   β”œβ”€β”€ retrain_resnet18.py               #   ResNet-18 frozen-backbone training
β”‚   └── measure_model.py                  #   Accuracy, latency & confusion pairs
β”‚
β”œβ”€β”€ tests/                                # Pytest unit & integration tests
β”‚   β”œβ”€β”€ conftest.py                       #   Shared fixtures
β”‚   β”œβ”€β”€ test_models.py                    #   Architecture smoke tests
β”‚   β”œβ”€β”€ test_inference.py                 #   predict() contract tests
β”‚   β”œβ”€β”€ test_preprocessing.py             #   Transform pipeline tests
β”‚   β”œβ”€β”€ test_gradcam.py                   #   Grad-CAM hook tests
β”‚   β”œβ”€β”€ test_benchmark_data.py            #   Metric consistency tests
β”‚   β”œβ”€β”€ test_checkpoint_remap.py          #   Checkpoint key migration tests
β”‚   └── test_device.py                    #   Device selection tests
β”‚
β”œβ”€β”€ results/                              # Training results & analysis
β”œβ”€β”€ artifacts/                            # Exported configs and run metadata
β”œβ”€β”€ examples/                             # Example images for the live demo
β”œβ”€β”€ data/                                 # CIFAR-10 dataset (auto-downloaded)
β”‚
β”œβ”€β”€ requirements.txt                      # Gradio / HF Spaces dependencies
β”œβ”€β”€ requirements-dev.txt                  # Development dependencies
β”œβ”€β”€ runtime.txt                           # Python version pin for HF Spaces
β”œβ”€β”€ pytest.ini                            # Pytest configuration
β”œβ”€β”€ LICENSE                               # MIT License
└── .gitignore                            # Git ignore rules

Model weights are hosted on the Hugging Face Hub and downloaded automatically at runtime.



πŸ’‘ Key Takeaways

  1. Transfer learning is highly efficient β€” MobileNetV2 strongly outperforms a custom CNN while training only a tiny fraction of the parameters.
  2. Pretrained features transfer well β€” even from ImageNet-scale pretraining to CIFAR-10.
  3. More trainable parameters do not guarantee better results under a fixed training budget.
  4. Data efficiency matters β€” transfer learning reaches strong performance within a very small number of epochs.
  5. Speed vs. accuracy trade-offs remain real β€” the custom CNN is ~10Γ— faster on CPU, while ResNet-18 (87.48%) and MobileNetV2 (86.91%) trade a little latency for ~+39 pp of accuracy.


πŸ“œ License

Released under the MIT License. See LICENSE for details.



> ping --author

> Target     : Pouya Alavi Naeini β€” AI & Full-Stack Developer
> University : Macquarie University, Sydney, NSW
> Major      : B.IT β€” Artificial Intelligence & Web/App Development
> Status     : [●] ONLINE β€” open to grad & junior opportunities

Live Demo LinkedIn GitHub Email


Built with PyTorch & Gradio Β· Deployed on Hugging Face Spaces Β· Designed for learning, research, and demonstration

About

CIFAR-10 image classification project comparing a custom CNN, MobileNetV2, and ResNet-18 with Grad-CAM interpretability, CLI inference tools, and a live Gradio demo deployed on Hugging Face Spaces.

Topics

Resources

License

Stars

Watchers

Forks

Contributors