This project exposes a simple image classification REST API built with FastAPI and PyTorch.
You can upload an image and get back the predicted CIFAR‑10 class label (e.g. cat, dog, truck) along with confidence scores for all 10 classes.
The underlying CNN is trained on CIFAR‑10 in the classifier/classifier.ipynb notebook, and its weights are saved as classifier/model/cifar_net.pth. The API loads this checkpoint once on startup and serves predictions via a /predict endpoint.
- FastAPI backend with automatic interactive docs (Swagger UI).
- CNN model for CIFAR‑10 (
Netinsrc/model.py). - Single
/predictendpoint that:- Accepts an uploaded image file.
- Resizes it to (32 \times 32), normalizes it, and runs it through the CNN.
- Returns:
prediction: the best CIFAR‑10 class name.confidence: probability for the predicted class.probabilities: per‑class probability distribution for all 10 classes.
src/main.py– FastAPI application, request handling, preprocessing,/predictroute.src/model.py– CIFAR‑10Netarchitecture and checkpoint loading logic.classifier/classifier.ipynb– Notebook used to train and exportcifar_net.pth.classifier/model/cifar_net.pth– Trained PyTorch model weights (checkpoint).
- Python 3.9+ (recommended)
pipor another Python package manager
This repository includes a requirements.txt that pins all necessary dependencies (FastAPI, PyTorch, Jupyter tooling, etc.).
It is recommended to use a dedicated virtual environment in the project root:
# From the project root
python -m venv .venv
# Activate the virtual environment (macOS / Linux)
source .venv/bin/activate
# On Windows (PowerShell)
.venv\Scripts\Activate.ps1
# Install all dependencies
pip install --upgrade pip
pip install -r requirements.txtThe API expects a trained CIFAR‑10 model checkpoint at:
- Default path:
classifier/model/cifar_net.pth
You can override this via the CHECKPOINT_PATH environment variable:
export CHECKPOINT_PATH=/absolute/or/relative/path/to/cifar_net.pthIf the checkpoint file is missing, the API will respond with 503 Service Unavailable and a descriptive error message indicating that you should train and save the model in classifier/classifier.ipynb first.
From the project root (the same directory as this README.md), run:
uvicorn src.main:app --reloadUseful options:
--host 0.0.0.0to listen on all interfaces.--port 8000(default) or another port of your choice.
Example:
uvicorn src.main:app --host 0.0.0.0 --port 8000 --reloadOnce running, you can open the interactive docs at:
- Swagger UI:
http://localhost:8000/docs - ReDoc:
http://localhost:8000/redoc
Use the “Try it out” button in Swagger UI to upload an image and see the raw JSON response.
- Method:
POST - Path:
/predict - Content type:
multipart/form-data(file upload) - Field name:
file
curl -X POST "http://localhost:8000/predict" \
-H "accept: application/json" \
-H "Content-Type: multipart/form-data" \
-F "file=@path/to/your/image.jpg"{
"prediction": "cat",
"confidence": 0.8234,
"probabilities": [
{"class": "plane", "probability": 0.0001},
{"class": "car", "probability": 0.0002},
{"class": "bird", "probability": 0.0153},
{"class": "cat", "probability": 0.8234},
{"class": "deer", "probability": 0.0102},
{"class": "dog", "probability": 0.1201},
{"class": "frog", "probability": 0.0050},
{"class": "horse", "probability": 0.0123},
{"class": "ship", "probability": 0.0075},
{"class": "truck", "probability": 0.0059}
]
}400 Bad Request– Non‑image file, empty file, or unreadable upload.422 Unprocessable Entity– Unexpected error while running inference.503 Service Unavailable– Model checkpoint file missing or not loadable.
To retrain or fine‑tune the CIFAR‑10 model:
- Open
classifier/classifier.ipynbin Jupyter or VS Code. - Run the notebook cells to:
- Download and prepare CIFAR‑10.
- Define and train the CNN architecture that matches
src/model.py::Net. - Save the trained weights to
classifier/model/cifar_net.pth.
- Restart the FastAPI server so it reloads the updated checkpoint.
- The model is loaded lazily on first request and cached for subsequent calls.
- Preprocessing in
src/main.py::preprocess_imagealigns with the training pipeline:- Convert to RGB.
- Resize to (32 \times 32).
- Convert to tensor and normalize with mean/std ((0.5, 0.5, 0.5)).
Add your preferred license here (e.g. MIT, Apache‑2.0).