diff --git a/README.md b/README.md
index 72a984b9..a9229513 100644
--- a/README.md
+++ b/README.md
@@ -1,85 +1,107 @@
-#
TorchLens
+
+
+
+ Extract and visualize activations and computational graphs from any PyTorch model — in one line. +
+ + + + + ++ Paper • + CoLab Tutorial • + Model Gallery • + Metadata Overview +
+ +--- ## Overview -*TorchLens* is a package for doing exactly two things: +**TorchLens** is a Python library for doing exactly two things: -1) Easily extracting the activations from every single intermediate operation in a PyTorch model—no - modifications needed—in one line of code. "Every operation" means every operation; "one line" means one line. -2) Understanding the model's computational structure via an intuitive automatic visualization and extensive - metadata ([partial list here](https://static-content.springer.com/esm/art%3A10.1038%2Fs41598-023-40807-0/MediaObjects/41598_2023_40807_MOESM1_ESM.pdf)) - about the network's computational graph. +1. Extract **all intermediate activations** from any PyTorch model with **one line of code**. +2. Visualize the model's **computational graph** and get extensive metadata. -Here it is in action for a very simple recurrent model; as you can see, you just define the model like normal and pass -it in, and *TorchLens* returns a full log of the forward pass along with a visualization: +- #### No model modification required. +- #### Works for any PyTorch model. (Tested on over 700 different models) + +--- + +## Example: +```python +import torchlens as tl + +model_history = tl.log_forward_pass(pytorch_model, example_input, vis_opt='rolled') +``` +This one line performs a forward pass, creates and displays the graph. + +It returns a ModelHistory object containing the intermediate layer activations and accompanying metadata. + + +--- + +## Installation +> Requires PyTorch version **1.8.0 or higher** + +- #### Install Graphviz (required for visualization) + `sudo apt install graphviz` + +- #### Install TorchLens + `pip install torchlens` + + +--- +## Detailed examples + +#### Simple Recurrent Model ```python class SimpleRecurrent(nn.Module): def __init__(self): super().__init__() - self.fc = nn.Linear(in_features=5, out_features=5) + self.fc = nn.Linear(5, 5) def forward(self, x): - for r in range(4): + for _ in range(4): x = self.fc(x) x = x + 1 x = x * 2 return x - simple_recurrent = SimpleRecurrent() -model_history = tl.log_forward_pass(simple_recurrent, x, - layers_to_save='all', - vis_opt='rolled') -print(model_history['linear_1_1:2'].tensor_contents) # second pass of first linear layer - -''' -tensor([[-0.0690, -1.3957, -0.3231, -0.1980, 0.7197], - [-0.1083, -1.5051, -0.2570, -0.2024, 0.8248], - [ 0.1031, -1.4315, -0.5999, -0.4017, 0.7580], - [-0.0396, -1.3813, -0.3523, -0.2008, 0.6654], - [ 0.0980, -1.4073, -0.5934, -0.3866, 0.7371], - [-0.1106, -1.2909, -0.3393, -0.2439, 0.7345]]) -''' +model_history = tl.log_forward_pass(simple_recurrent, x, layers_to_save='all', vis_opt='rolled') ``` +
+
+
-And here it is for a very complex transformer model ([swin_v2_b](https://arxiv.org/abs/2103.14030)) with 1932 operations
-in its forward pass; you can grab the saved outputs of every last one:
-
-
-
-The goal of *TorchLens* is to do this for any PyTorch model whatsoever. You can see a bunch of example model
-visualizations in this [model menagerie](https://drive.google.com/drive/u/0/folders/1BsM6WPf3eB79-CRNgZejMxjg38rN6VCb).
-
-## Installation
-
-To install *TorchLens*, first install graphviz if you haven't already (required to generate the network visualizations),
-and then install *TorchLens* using pip:
-
-```bash
-sudo apt install graphviz
-pip install torchlens
-```
+#### Complex Transformer Model
-*TorchLens* is compatible with versions 1.8.0+ of PyTorch.
+TorchLens also works with large models like [Swin V2](https://arxiv.org/abs/2103.14030), with 1932 operations.
+
+
+
+
+
-
-You can pull out information about a given layer, including its activations and helpful metadata, by indexing
-the ModelHistory object in any of these equivalent ways:
-1) the name of a layer (with the convention that 'conv2d_3_7' is the 3rd convolutional layer, and the 7th layer overall)
-2) the name of a module (e.g., 'features' or 'classifier.3') for which that layer is an output, or
-3) the ordinal position of the layer (e.g., 2 for the 2nd layer, -5 for the fifth-to-last; inputs and outputs count as
- layers here).
+### Accessing Layers
-To quickly figure out these names, you can look at the graph visualization, or at the output of printing the
-ModelHistory object (both shown above). Here are some examples of how to pull out information about a
-particular layer, and also how to pull out the actual activations from that layer:
+You can access any layer’s metadata and activations using several methods:
+#### By name:
```python
-print(model_history['conv2d_3_7']) # pulling out layer by its name
-# The following commented lines pull out the same layer:
-# model_history['conv2d_3'] you can omit the second number (since strictly speaking it's redundant)
-# model_history['conv2d_3_7:1'] colon indicates the pass of a layer (here just one)
-# model_history['features.6'] can grab a layer by the module for which it is an output
-# model_history[7] the 7th layer overall
-# model_history[-17] the 17th-to-last layer
+print(model_history['conv2d_3_7'])
'''
Layer conv2d_3_7, operation 8/24:
Output tensor: shape=(1, 384, 13, 13), dype=torch.float32, size=253.5 KB
@@ -173,8 +142,17 @@ Layer conv2d_3_7, operation 8/24:
Output of bottom-level module: features.6
Lookup keys: -17, 7, conv2d_3_7, conv2d_3_7:1, features.6, features.6:1
'''
-
-# You can pull out the actual output activations from a layer with the tensor_contents field:
+```
+#### By module
+```python
+print(model_history['features.6'])
+```
+#### By index
+```python
+print(model_history[7])
+```
+#### Activations
+```python
print(model_history['conv2d_3_7'].tensor_contents)
'''
tensor([[[[-0.0867, -0.0787, -0.0817, ..., -0.0820, -0.0655, -0.0195],
@@ -187,56 +165,42 @@ tensor([[[[-0.0867, -0.0787, -0.0817, ..., -0.0820, -0.0655, -0.0195],
'''
```
-If you do not wish to save the activations for all layers (e.g., to save memory), you can specify which layers to save
-with the `layers_to_save` argument when calling `log_forward_pass`; you can either indicate layers in the same way
-as indexing them above, or by passing in a desired substring for filtering the layers (e.g., 'conv'
-will pull out all conv layers):
+To save only specific layers:
```python
-# Pull out conv2d_3_7, the output of the 'features' module, the fifth-to-last layer, and all linear (i.e., fc) layers:
-model_history = tl.log_forward_pass(alexnet, x, vis_opt='unrolled',
- layers_to_save=['conv2d_3_7', 'features', -5, 'linear'])
+model_history = tl.log_forward_pass(alexnet, x, vis_opt='unrolled', layers_to_save=['conv2d_3_7', 'features', -5, 'linear'])
print(model_history.layer_labels)
'''
['conv2d_3_7', 'maxpool2d_3_13', 'linear_1_17', 'dropout_2_19', 'linear_2_20', 'linear_3_22']
'''
```
-The main function of *TorchLens* is `log_forward_pass`; the remaining functions are:
+---
-1) `get_model_metadata`, to retrieve all model metadata without saving any activations (e.g., to figure out which
- layers you wish to save; note that this is the same as calling `log_forward_pass` with `layers_to_save=None`)
-2) `show_model_graph`, which visualizes the model graph without saving any activations
-3) `validate_model_activations`, which runs a procedure to check that the activations are correct: specifically,
- it runs a forward pass and saves all intermediate activations, re-runs the forward pass from each intermediate
- layer, and checks that the resulting output matches the ground-truth output. It also checks that swapping in
- random nonsense activations instead of the saved activations generates the wrong output. **If this function ever
- returns False (i.e., the saved activations are wrong), please contact me via email (johnmarkedwardtaylor@gmail.com)
- or on this GitHub page with a description of the problem, and I will update TorchLens to fix the problem.**
+## API Reference
-And that's it. *TorchLens* remains in active development, and the goal is for it to work with any PyTorch model
-whatosever without exception. As of the time of this writing, it has been tested with over 700
-image, video, auditory, multimodal, and language models, including feedforward, recurrent, transformer,
-and graph neural networks.
+- `log_forward_pass(model, input, ...)`: main function. Logs activations and metadata.
+- `get_model_metadata(model, input)`: get only the metadata (no activations).
+- `show_model_graph(model, input)`: render visual graph without saving tensors.
+- `validate_model_activations(...)`: verify saved activations are accurate.
-## Miscellaneous Features
+---
-- You can visualize models at different levels of nesting depth using the `vis_nesting_depth` argument
- to `log_forward_pass`; for example, here you can see one of GoogLeNet's "inception" modules at different levels of
- nesting depth:
+## Features
-
+#### Visualize modules at different nesting levels:
+Use the `vis_nesting_depth` argument in the `log_forward_pass` function.
+
+
+
+
+
+