Skip to content

adeelahmad/mlx-lm-lens

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

mlx-lm-lens

Mechanistic interpretability CLI for transformer models on Apple Silicon.

Analyze per-layer predictions, monitor activation drift against reference models, compare fine-tuned vs base models, and discover model circuits—all without a GPU.

Author: Adeel Ahmad | License: Apache 2.0

Features

  • Logit Lens — Stream token generation with per-layer top-K predictions at each position
  • Drift Correction — Measure and correct activation angular drift against reference models (base, adapters, quantized variants)
  • Rich TUI — Live terminal UI with colored streaming, scrollable output, markdown support, real-time drift metrics
  • Reference Model Comparison — Compare fine-tuned/adapted models against base models in activation space
  • LoRA & Quantization Support — Transparent handling of 4-bit, 8-bit, and LoRA adapters via mlx-lm
  • Circuit Discovery — Ablate layers to trace information flow
  • JSON Logging — Export per-layer angles, predictions, and timing for analysis

Why It Matters

Modern language models are black boxes. mlx-lm-lens provides:

  • Mechanistic interpretability on consumer hardware — M1/M2/M3 chips via MLX, no GPU needed
  • Understand fine-tuning effects — See how activation space changes after fine-tuning or adapter application
  • Debug model behavior — Track prediction confidence changes across layers, spot layer collapse
  • Reference-based drift detection — Quantify how adapted models diverge from base models, apply geometric corrections

Installation

Requirements: Python ≥3.10, Apple Silicon (M1/M2/M3)

# Clone and install from source
git clone https://github.com/adeelahmad/mlx-lm-lens.git
cd mlx-lm-lens
pip install -e .

# Or install directly from GitHub
pip install git+https://github.com/adeelahmad/mlx-lm-lens.git@main#egg=mlx-lm-lens

Verify installation:

mlx-lm-lens --help

Quick Start

Generate with logit lens

mlx-lm-lens logit-lens generate \
  --model /path/to/model \
  --prompt "The future of AI is" \
  --max-tokens 50 \
  --top-k 5

Output: Live TUI with colored generation stream, per-layer predictions, agreement metrics.

Compare models with drift correction

mlx-lm-lens logit-lens generate \
  --model /path/to/finetuned-model \
  --reference-model /path/to/base-model \
  --drift-correction \
  --drift-threshold 0.3 \
  --drift-log /tmp/drift.jsonl \
  --prompt "Explain quantum mechanics" \
  --max-tokens 100 \
  --no-tui

Measures per-layer activation angles against reference baselines. Logs angles to JSONL for analysis.

Analyze activations

mlx-lm-lens activations \
  --model /path/to/model \
  --prompt "Your prompt" \
  --metrics cosine,cka,mad
Layer  Top-1 Token  Confidence  Flip?  Top-2           Top-3
─────  ────────────  ──────────  ─────  ──────────────  ──────────────
0      " the"         0.23       No     " a"            " is"
1      " answer"      0.31       No     " that"         " is"
...
15     "4"            0.87       Yes    " four"         " 4"

Activation Analysis: Compare Model States

Measure how similar hidden states are between two models across layers.

mlx-lm-lens activations \
  --model /path/to/model1 \
  --reference-model /path/to/model2 \
  --prompt "Hello world" \
  --metrics cosine,cka,procrustes

Output:

Layer  Cosine Sim  CKA     Procrustes  MAD
─────  ──────────  ──────  ──────────  ──────
0      0.94        0.91    0.12        0.08
1      0.89        0.84    0.18        0.11
...
15     0.42        0.31    0.58        0.45

Circuit Discovery: Find Important Layers

Ablate each layer and measure its impact on the final prediction.

mlx-lm-lens circuit ablate \
  --model /path/to/model \
  --prompt "The capital of France is" \
  --method zero

Output:

Layer  Importance (KL)  Rank  Impact
─────  ────────────────  ────  ──────────────────────
15     2.34              1     High
14     1.89              2     High
13     0.92              3     Medium
...

CLI Reference

logit-lens generate — Stream Generation with Per-Layer Predictions

Generate tokens and log what each layer predicts at every step.

mlx-lm-lens logit-lens generate [OPTIONS]

Required arguments:

-m, --model TEXT                 Path to MLX model [required]
-p, --prompt TEXT                Initial prompt [required]

Generation options:

-n, --max-tokens INTEGER         Tokens to generate (default: 100, max: 1000)
-l, --log-from-token INTEGER     Start logging from token N (default: 0)
-k, --top-k INTEGER              Top predictions per layer (default: 5, min: 1)
-t, --temperature FLOAT          Sampling temperature (default: 1.0, min: 0.01)
--top-p FLOAT                    Nucleus sampling cutoff (default: None, range: 0-1)
--sampling-method TEXT           greedy | top_k | nucleus (default: greedy)
-s, --seed INTEGER               Random seed for reproducibility
--include-prompt                 Include prompt tokens in logging

Template & formatting:

--chat-template / --no-chat-template   Apply chat template (default: True)
-f, --format TEXT                Output format: table | json | csv (default: table)
-o, --output TEXT                Save results to JSON file

TUI & display:

--show-progress / --no-progress  Show progress bar (default: True)
--no-tui                         Disable interactive TUI, print to console
--wordcloud                      Show word frequency cloud after generation
--wordcloud-out TEXT             Save wordcloud to file

Drift correction options:

--drift-correction               Enable per-layer drift correction
--drift-threshold FLOAT          Angle threshold in degrees (default: 0.3, range: 0.001-179.9)
--drift-baseline-tokens INT      Tokens for baseline accumulation (default: 256, min: 1)
--drift-log TEXT                 JSONL file for per-layer angle log
--reference-model TEXT           Reference model path for drift baseline
--reference-adapter TEXT         Reference adapter path

Stop sequences:

--stop TEXT                      Repeatable: stop generation on this string

Example:

mlx-lm-lens logit-lens generate \
  --model Qwen3-0.6B \
  --prompt "The future of AI" \
  --max-tokens 100 \
  --top-k 5 \
  --temperature 0.7 \
  --sampling-method nucleus \
  --top-p 0.9

# With drift correction from reference model
mlx-lm-lens logit-lens generate \
  --model finetuned-model \
  --reference-model base-model \
  --prompt "Explain quantum mechanics" \
  --max-tokens 100 \
  --drift-correction \
  --drift-threshold 0.3 \
  --drift-log /tmp/drift.jsonl \
  --no-tui

# With adapter and output
mlx-lm-lens logit-lens generate \
  --model base-model \
  --prompt "Hello" \
  --max-tokens 50 \
  --output results.json \
  --format json

activations — Analyze Layer Representations

Compare hidden states between models using similarity metrics.

mlx-lm-lens activations [OPTIONS]

Required arguments:

-m, --model TEXT                 Primary model path [required]
-p, --prompt TEXT                Input text [required]
--metrics TEXT                   Comma-separated metrics [required]
                                 Available: cosine,cka,procrustes,grassmannian,mad,
                                           effective-dim,energy-kl,rsa

Comparison options:

--reference-model TEXT           Reference model path for comparison
--adapter-path TEXT              LoRA adapter path for primary model
--reference-adapter TEXT         Reference adapter path
--compare-base                   Compare base vs adapted (requires adapter)

Batch processing:

--batch-prompts TEXT             JSONL file with multiple prompts (one per line)

Output options:

-f, --format TEXT                table | json | csv (default: table)
--output-dir TEXT                Directory for output files
--output TEXT                    Single output file

Performance:

--verbose                        Enable debug logging

Example:

# Single prompt, multiple metrics
mlx-lm-lens activations \
  --model model1 \
  --prompt "Hello world" \
  --metrics cosine,cka,mad

# Compare base vs fine-tuned
mlx-lm-lens activations \
  --model finetuned-model \
  --reference-model base-model \
  --prompt "Your prompt here" \
  --metrics cosine,cka,procrustes \
  --output-dir results/

# Batch analysis with adapter
mlx-lm-lens activations \
  --model base-model \
  --adapter-path my-adapter.safetensors \
  --batch-prompts prompts.jsonl \
  --metrics effective-dim,rsa \
  --format json \
  --output results.json

Available metrics:

  • cosine — Cosine similarity (0=orthogonal, 1=identical)
  • cka — Centered Kernel Alignment (0=uncorrelated, 1=identical)
  • procrustes — Orthogonal transformation distance
  • grassmannian — Principal angles between subspaces
  • mad — Mean Absolute Deviation
  • effective-dim — Effective dimensionality (1.0=full rank)
  • energy-kl — Energy-based KL divergence
  • rsa — Representational Similarity Analysis

circuit ablate — Find Important Layers

Ablate each layer and measure impact on predictions.

mlx-lm-lens circuit ablate [OPTIONS]

Required arguments:

--model TEXT                     Model path [required]
-p, --prompt TEXT                Input text [required]
--method TEXT                    Ablation method [required]:
                                 zero | mean | noise | knockout

Ablation options:

--target-token TEXT              Specific token to measure (default: last token)
--layer-range TEXT               Layer range, e.g., "5-15" (default: all)
--batch-size INTEGER             Batch size for ablation (default: 1)

Output options:

-f, --format TEXT                table | json | csv (default: table)
--output-dir TEXT                Output directory for results

Example:

mlx-lm-lens circuit ablate \
  --model my-model \
  --prompt "The capital of France is" \
  --method zero \
  --format json \
  --output-dir ablation_results/

Ablation methods:

  • zero — Replace layer output with zeros
  • mean — Replace with mean activation across all positions
  • noise — Replace with Gaussian noise (0 mean, unit variance)
  • knockout — Zero out residual stream at layer input

circuit patch — Activation Patching

Patch activations from reference model into main model.

mlx-lm-lens circuit patch [OPTIONS]

Required arguments:

--model TEXT                     Model path [required]
--reference TEXT                 Reference model path [required]
-p, --prompt TEXT                Input text [required]

Patching options:

--patch-layer INTEGER            Specific layer to patch (default: sweep all)
--start-layer INTEGER            Start layer for sweep (default: 0)
--end-layer INTEGER              End layer for sweep

Output options:

-f, --format TEXT                table | json | csv (default: table)
--output TEXT                    Output file

Example:

mlx-lm-lens circuit patch \
  --model my-model \
  --reference base-model \
  --prompt "Hello world" \
  --format json \
  --output patch_results.json

circuit decompose — Residual Stream Decomposition

Decompose residual stream at a specific position.

mlx-lm-lens circuit decompose [OPTIONS]

Required arguments:

--model TEXT                     Model path [required]
-p, --prompt TEXT                Input text [required]

Options:

--position INTEGER               Token position (-1 = last, default: -1)
--show-contributions             Show layer-wise contributions
-f, --format TEXT                table | json | csv (default: table)
--output TEXT                    Output file

Example:

mlx-lm-lens circuit decompose \
  --model my-model \
  --prompt "Explain AI" \
  --position -1 \
  --show-contributions \
  --format json

circuit angles — Weight Angle Analysis

Compute weight angles between adapter versions.

mlx-lm-lens circuit angles [OPTIONS]

Required arguments:

--base-adapter TEXT              Base adapter path [required]
--current-adapter TEXT           Current adapter path [required]

Options:

--per-layer                      Show per-layer angles (default: show summary)
-f, --format TEXT                table | json | csv (default: table)
--output TEXT                    Output file

Example:

mlx-lm-lens circuit angles \
  --base-adapter v1.safetensors \
  --current-adapter v2.safetensors \
  --per-layer \
  --format json

Global Options

Available on all commands:

--verbose                        Enable debug logging
--help                           Show command help

Examples

Use Case 1: Understanding Qwen 0.6B Predictions

Scenario: You want to see which layer in Qwen3-0.6B decides to output "Paris" when asked "The capital of France is..."

mlx-lm-lens logit-lens \
  --model /path/to/qwen3-0.6b \
  --prompt "The capital of France is" \
  --top-k 3

What to expect: Early layers (0-5) might predict generic tokens like "the" or "a". Middle layers (6-12) start hinting at geography. By layer 15, "Paris" dominates with >90% confidence.

Insight: If "Paris" appears early and stays high, the model is confident and hierarchical. If it flips multiple times, the model is uncertain and corrects itself layer-to-layer.

Use Case 2: Comparing Base vs Fine-Tuned Models

Scenario: You fine-tuned a model on math problems. Where do the activations diverge most?

# First, compute activation analysis on the fine-tuned version
mlx-lm-lens activations \
  --model /path/to/finetuned \
  --reference-model /path/to/base \
  --prompt "What is the derivative of x^2?" \
  --metrics cosine,cka,procrustes \
  --output results/

# Look at the CSV to find layers with largest divergence
cat results/activations.csv | sort -t',' -k3 -n

What to expect: Early layers (embeddings, first 5 layers) show high similarity (cosine sim >0.9). Middle layers diverge (0.7-0.8). Head layers might collapse to near-zero similarity if fine-tuning specialized them.

Insight: Large drops in CKA/Procrustes indicate layers that the fine-tuning process restructured most. These are likely domain-specific feature detectors.

Use Case 3: Finding Identity Layers via Ablation

Scenario: Your model sometimes fails on identity tasks (copying the input). Which layers are responsible?

mlx-lm-lens circuit ablate \
  --model /path/to/model \
  --prompt "repeat: hello -> hello" \
  --method zero \
  --output results/

# Interpret: high importance = ablating it hurts identity task

What to expect: Layers 8-12 likely show high importance (KL divergence >1.0 when ablated). Removing these layers breaks the identity behavior.

Insight: These layers contain the "copy" circuit. Visualize their attention patterns to understand the mechanism.

Architecture

mlx-lm-lens is organized into modular, testable components:

  • core/ — Model loading, activation capture, logit projection (low-level APIs)
  • metrics/ — 9 pluggable similarity metrics with unified interface
  • ablations/ — 4 ablation strategies for circuit discovery
  • formatters/ — Output to table, JSON, CSV
  • cli/ — User-facing commands (typer-based)
  • shaman/ — Hypothesis validators (H1-H15 from SHAMAN framework)
  • circuit/ — Patching and residual stream decomposition

Data flows: Model → Activation Capture → Metrics/Ablations → Formatters → User.

All components are dependency-injected via config dataclasses (no global state).

Troubleshooting

"CUDA not available" or MLX import fails

Cause: MLX only runs on Apple Silicon (M1/M2/M3). Linux/Windows users need Docker with Rosetta emulation or similar. Fix: Ensure you're on a Mac with Apple Silicon. Check: python -c "import mlx.core; print(mlx.core.default_device())"

"scipy module not found" — RSA metric fails

Cause: RSA requires scipy but it's optional. Fix: Install dev dependencies: pip install -e .[dev] or manually: pip install scipy

Memory errors on large models

Cause: Quantized models (4-bit/8-bit) still use memory for activations. Typical 70B model needs 16+ GB. Fix:

  • Use smaller models for exploration (0.6B-7B)
  • Use --all-positions sparingly (captures all token positions, not just last)
  • Run on a Mac with more memory (M3 Max with 96GB)

"Top prediction never flips" in logit lens

Cause: Either the model is very confident, or the prompt is ambiguous. Fix: Try prompts with more open-ended predictions (e.g., story continuations instead of facts).

"Reference model not loaded" error

Cause: Activation analysis requires both models to be the same architecture (same layer count, hidden size). Fix: Ensure both model paths are correct and compatible. Use mlx-lm-lens activations --model M1 --reference-model M2 --prompt "x" to verify both load.

Contributing

We welcome contributions! See CONTRIBUTING.md for development setup, testing, and style guidelines.

Key principles:

  • Every change must maintain ≥90% test coverage
  • New features require tests and documentation
  • All code passes strict ruff linter (zero warnings)
  • Keep files under 200 lines (forces modularity)

License

Apache License 2.0. See LICENSE for details.


Built for researchers and practitioners who want to understand transformer internals without a GPU farm.

Questions? Open an issue or check the examples/ directory for more use cases.

About

Mechanistic interpretability CLI for transformer models on Apple Silicon. Analyze per-layer predictions, monitor activation drift, compare models, discover circuits. MLX-based, no GPU needed.

Topics

Resources

License

Code of conduct

Contributing

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages