Skip to content

willymromero/image-prediction-api

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Image Classifier API (CIFAR‑10)

This project exposes a simple image classification REST API built with FastAPI and PyTorch.
You can upload an image and get back the predicted CIFAR‑10 class label (e.g. cat, dog, truck) along with confidence scores for all 10 classes.

The underlying CNN is trained on CIFAR‑10 in the classifier/classifier.ipynb notebook, and its weights are saved as classifier/model/cifar_net.pth. The API loads this checkpoint once on startup and serves predictions via a /predict endpoint.


Features

  • FastAPI backend with automatic interactive docs (Swagger UI).
  • CNN model for CIFAR‑10 (Net in src/model.py).
  • Single /predict endpoint that:
    • Accepts an uploaded image file.
    • Resizes it to (32 \times 32), normalizes it, and runs it through the CNN.
    • Returns:
      • prediction: the best CIFAR‑10 class name.
      • confidence: probability for the predicted class.
      • probabilities: per‑class probability distribution for all 10 classes.

Project Structure

  • src/main.py – FastAPI application, request handling, preprocessing, /predict route.
  • src/model.py – CIFAR‑10 Net architecture and checkpoint loading logic.
  • classifier/classifier.ipynb – Notebook used to train and export cifar_net.pth.
  • classifier/model/cifar_net.pth – Trained PyTorch model weights (checkpoint).

Prerequisites

  • Python 3.9+ (recommended)
  • pip or another Python package manager

This repository includes a requirements.txt that pins all necessary dependencies (FastAPI, PyTorch, Jupyter tooling, etc.).
It is recommended to use a dedicated virtual environment in the project root:

# From the project root
python -m venv .venv

# Activate the virtual environment (macOS / Linux)
source .venv/bin/activate

# On Windows (PowerShell)
.venv\Scripts\Activate.ps1

# Install all dependencies
pip install --upgrade pip
pip install -r requirements.txt

Model Checkpoint

The API expects a trained CIFAR‑10 model checkpoint at:

  • Default path: classifier/model/cifar_net.pth

You can override this via the CHECKPOINT_PATH environment variable:

export CHECKPOINT_PATH=/absolute/or/relative/path/to/cifar_net.pth

If the checkpoint file is missing, the API will respond with 503 Service Unavailable and a descriptive error message indicating that you should train and save the model in classifier/classifier.ipynb first.


Running the API

From the project root (the same directory as this README.md), run:

uvicorn src.main:app --reload

Useful options:

  • --host 0.0.0.0 to listen on all interfaces.
  • --port 8000 (default) or another port of your choice.

Example:

uvicorn src.main:app --host 0.0.0.0 --port 8000 --reload

Once running, you can open the interactive docs at:

  • Swagger UI: http://localhost:8000/docs
  • ReDoc: http://localhost:8000/redoc

Use the “Try it out” button in Swagger UI to upload an image and see the raw JSON response.


/predict Endpoint

  • Method: POST
  • Path: /predict
  • Content type: multipart/form-data (file upload)
  • Field name: file

Request example (cURL)

curl -X POST "http://localhost:8000/predict" \
  -H "accept: application/json" \
  -H "Content-Type: multipart/form-data" \
  -F "file=@path/to/your/image.jpg"

Successful response example

{
  "prediction": "cat",
  "confidence": 0.8234,
  "probabilities": [
    {"class": "plane", "probability": 0.0001},
    {"class": "car", "probability": 0.0002},
    {"class": "bird", "probability": 0.0153},
    {"class": "cat", "probability": 0.8234},
    {"class": "deer", "probability": 0.0102},
    {"class": "dog", "probability": 0.1201},
    {"class": "frog", "probability": 0.0050},
    {"class": "horse", "probability": 0.0123},
    {"class": "ship", "probability": 0.0075},
    {"class": "truck", "probability": 0.0059}
  ]
}

Error responses

  • 400 Bad Request – Non‑image file, empty file, or unreadable upload.
  • 422 Unprocessable Entity – Unexpected error while running inference.
  • 503 Service Unavailable – Model checkpoint file missing or not loadable.

Training the Model (Optional)

To retrain or fine‑tune the CIFAR‑10 model:

  1. Open classifier/classifier.ipynb in Jupyter or VS Code.
  2. Run the notebook cells to:
    • Download and prepare CIFAR‑10.
    • Define and train the CNN architecture that matches src/model.py::Net.
    • Save the trained weights to classifier/model/cifar_net.pth.
  3. Restart the FastAPI server so it reloads the updated checkpoint.

Development Notes

  • The model is loaded lazily on first request and cached for subsequent calls.
  • Preprocessing in src/main.py::preprocess_image aligns with the training pipeline:
    • Convert to RGB.
    • Resize to (32 \times 32).
    • Convert to tensor and normalize with mean/std ((0.5, 0.5, 0.5)).

License

Add your preferred license here (e.g. MIT, Apache‑2.0).

About

FastAPI PyTorch image prediction REST API

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors