Skip to content

csthesis-g8/capsule-network

Repository files navigation

Capsule Network (CapsNet) Implementation

implementation of Capsule Networks using PyTorch.

What is Capsule Network?

A Capsule Network (CapsNet) is a type of neural network architecture that uses capsules to represent and process visual information. Capsules are a way of representing visual features in a more compact and hierarchical manner than traditional neural networks. It was first introduced by Geoffrey Hinton in 2011 and later popularized by the Capsule Network paper in 2017.

Project Directory

.
├── capsnet
│   ├── __init__.py
│   ├── capsnet.py      # The PyTorch CapsNet Model (Algorithm)
├── data                # Directory for downloaded MNIST data
├── main.py             # Streamlit App for evaluation and demo
├── train.py            # Script to train the model
└── install_dependencies.py # Setup script

Data

The project now uses torchvision.datasets.MNIST to automatically download and load the MNIST dataset.

  • The data will be downloaded to the data/ directory.
  • No manual CSV download is required.

Pseudocode: How CapsNet Works

Training Logic

for each epoch:
    for each batch of images and labels:
        1. Pass images through CapsNet -> Get Output Capsules
           (Conv1 -> PrimaryCaps -> DigitCaps with Routing)
        
        2. Calculate Loss (Margin Loss)
           - If digit is present: encourage long capsule vector (length -> 1)
           - If digit is absent: encourage short capsule vector (length -> 0)
        
        3. Backpropagation
           - Calculate gradients
           - Update weights (learn)

Inference Logic (Prediction)

1. Load trained model.
2. Pass a test image through CapsNet.
3. Get 10 Output Vectors (one for each digit 0-9).
4. Calculate length (norm) of each vector.
   - Length represents the probability that the digit exists.
5. The digit with the longest vector is the prediction.

CapsNet Model Components

1. Squash Function

Goal: Normalize a vector so its length is between 0 and 1, without changing its direction.

  • Why? In CapsNet, the length of the vector is the probability. Standard normalization destroys length info. Squash keeps it safe.
  • Simple Term: "Make short vectors almost 0, and long vectors almost 1."

2. Primary Caps

Goal: Convert standard image features (from Conv1) into Capsules (vectors).

  • How?
    1. Run a standard Convolution.
    2. Reshape the output so that instead of just "pixels", we have "vectors" (groups of neurons).
  • Simple Term: "Bundle groups of neurons together to form the first set of capsules."

3. Digit Caps (with Dynamic Routing)

Goal: The final layer that decides "Which digit is this?".

  • Dynamic Routing (The "Agreement" Algorithm):
    • Lower capsules (Primary) try to predict the output of higher capsules (Digit).
    • If a Primary Capsule's prediction agrees with the Digit Capsule's actual state, the connection gets stronger.
    • This happens in a loop (usually 3 times).
  • Simple Term: "Capsules vote on what the object is. If they agree, the vote counts more."

4. CapsNet (The Container)

Goal: Put it all together.

  • Structure: Image -> Conv1 -> PrimaryCaps -> DigitCaps -> Prediction.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages