implementation of Capsule Networks using PyTorch.
A Capsule Network (CapsNet) is a type of neural network architecture that uses capsules to represent and process visual information. Capsules are a way of representing visual features in a more compact and hierarchical manner than traditional neural networks. It was first introduced by Geoffrey Hinton in 2011 and later popularized by the Capsule Network paper in 2017.
.
├── capsnet
│ ├── __init__.py
│ ├── capsnet.py # The PyTorch CapsNet Model (Algorithm)
├── data # Directory for downloaded MNIST data
├── main.py # Streamlit App for evaluation and demo
├── train.py # Script to train the model
└── install_dependencies.py # Setup script
The project now uses torchvision.datasets.MNIST to automatically download and load the MNIST dataset.
- The data will be downloaded to the
data/directory. - No manual CSV download is required.
for each epoch:
for each batch of images and labels:
1. Pass images through CapsNet -> Get Output Capsules
(Conv1 -> PrimaryCaps -> DigitCaps with Routing)
2. Calculate Loss (Margin Loss)
- If digit is present: encourage long capsule vector (length -> 1)
- If digit is absent: encourage short capsule vector (length -> 0)
3. Backpropagation
- Calculate gradients
- Update weights (learn)1. Load trained model.
2. Pass a test image through CapsNet.
3. Get 10 Output Vectors (one for each digit 0-9).
4. Calculate length (norm) of each vector.
- Length represents the probability that the digit exists.
5. The digit with the longest vector is the prediction.Goal: Normalize a vector so its length is between 0 and 1, without changing its direction.
- Why? In CapsNet, the length of the vector is the probability. Standard normalization destroys length info. Squash keeps it safe.
- Simple Term: "Make short vectors almost 0, and long vectors almost 1."
Goal: Convert standard image features (from Conv1) into Capsules (vectors).
- How?
- Run a standard Convolution.
- Reshape the output so that instead of just "pixels", we have "vectors" (groups of neurons).
- Simple Term: "Bundle groups of neurons together to form the first set of capsules."
Goal: The final layer that decides "Which digit is this?".
- Dynamic Routing (The "Agreement" Algorithm):
- Lower capsules (Primary) try to predict the output of higher capsules (Digit).
- If a Primary Capsule's prediction agrees with the Digit Capsule's actual state, the connection gets stronger.
- This happens in a loop (usually 3 times).
- Simple Term: "Capsules vote on what the object is. If they agree, the vote counts more."
Goal: Put it all together.
- Structure:
Image->Conv1->PrimaryCaps->DigitCaps->Prediction.