- Introduction
- Input Processing
- Self-Attention Mechanism
- Multi-Head Attention
- Encoder Architecture
- Decoder Architecture
- Training Process
- Advanced Topics: Flash Attention
- Resources
The Transformer architecture, introduced in "Attention is All You Need" (Vaswani et al., 2017), revolutionized natural language processing by using attention mechanisms exclusively, without recurrent or convolutional layers. This guide provides a comprehensive understanding of how transformers work, from basic concepts to implementation details.
The Transformer replaces sequential processing with parallel attention computations, allowing models to:
- Process all tokens simultaneously
- Capture long-range dependencies efficiently
- Scale to very large datasets and model sizes
Before any attention computation can occur, text must be converted to numerical representations:
Input Tokens: ["The", "cat", "sat", "on", "the", "mat"]
↓ (embedding lookup/computation)
Input IDs: [105, 6587, 5475, 3578, 65, 6587]
↓ (embedding matrix: vocab_size × d_model)
Embeddings: 6×512 numerical matrix
Evolution of Embeddings:
- Early approach: Fixed random lookup tables
- Word2Vec/GloVe: Pre-trained semantic embeddings
- Modern approach: Contextual embeddings where the same word gets different representations based on context
Common Initialization Ranges:
- Embedding layers: Normal distribution (mean=0, std=0.02-0.1)
- Linear layers: Xavier/Glorot initialization (±√(6/(fan_in + fan_out)))
- For d_model=512: typically within [-0.1, 0.1]
Since transformers process all tokens in parallel, they need explicit position information. Positional encodings provide this crucial sequential context.
Key Properties:
- Computed once and reused for all sentences
- Deterministic function of position only
- Added directly to token embeddings
Mathematical Formula:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) # for even dimensions
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model)) # for odd dimensions
Where:
pos= position in the sequencei= dimension indexd_model= embedding dimension (e.g., 512)
Why This Works:
- Position 0 always gets the same encoding vector
- Position 1 always gets the same encoding vector
- Content-independent: "The cat sat" and "Hello world" use the same positional encodings
Implementation: Pre-compute positional encoding matrix up to maximum sequence length, then add appropriate encodings based on each token's position.
Self-attention is the core innovation that allows each token to "look at" and relate to all other tokens in the sequence.
Attention(Q,K,V) = softmax(QK^T/√d_k) × V
Common Misconception: Q, K, V are NOT just the input sentence.
Reality: They are learned linear transformations of the input:
Q (Query) = Input × W_Q # 6×512 × 512×512 = 6×512
K (Key) = Input × W_K # 6×512 × 512×512 = 6×512
V (Value) = Input × W_V # 6×512 × 512×512 = 6×512
Weight Matrices:
- W_Q, W_K, W_V: 512×512 matrices (for d_model=512)
- Start with random initialization
- Learn specialized roles through training
-
Compute Attention Scores:
Scores = Q × K^T / √d_k # (6×512) × (512×6) / √512 = (6×6) matrix -
Apply Softmax:
Attention_Weights = softmax(Scores) # Each row sums to 1.0 -
Weight the Values:
Output = Attention_Weights × V # (6×6) × (6×512) = (6×512)
The attention formula forces different roles for Q, K, V:
Q and K interact directly:
- Q·K^T creates attention scores
- Q learns: "What should I ask for?"
- K learns: "What should I respond to?"
- They co-evolve to be compatible
V has a separate role:
- Gets weighted by attention scores
- V learns: "What information should I carry forward?"
- Doesn't participate in the attention score computation
- Permutation Invariant: Without positional encoding, order doesn't matter
- No Parameters: The attention mechanism itself has no learnable parameters (the parameters are in W_Q, W_K, W_V)
- Diagonal Dominance: Words typically attend strongly to themselves
- Masking: Can prevent future tokens from being seen by setting attention scores to -∞
Multi-head attention allows the model to attend to information from different representation subspaces simultaneously.
Instead of using one large attention computation, split into multiple "heads":
- Each head learns different types of relationships
- Some heads might focus on syntax, others on semantics
- Heads operate in parallel and are concatenated
Split the d_model dimension across h heads:
d_model = 512, h = 8 heads
d_k = d_v = d_model / h = 64 per head
Each head processes:
head_i = Attention(Q_i, K_i, V_i)
where Q_i = Q[:, i*d_k:(i+1)*d_k] # slice the matrix
Final output:
MultiHead = Concat(head_1, ..., head_h) × W_O
The encoder processes the entire input sequence to create rich, contextualized representations.
Each encoder layer contains:
-
Self-Attention
- All tokens can attend to all other tokens
- No masking (unlike decoder)
- Captures bidirectional context
-
Feed-Forward Network
- Applied to each position independently
- Typically: d_model → 4×d_model → d_model
- Adds non-linearity and capacity
-
Residual Connections + Layer Normalization
- Around both sub-layers
- Helps with gradient flow and training stability
Input: "The cat sat on the mat"
↓
Token + Positional Embeddings
↓
Layer 1: Self-Attention → Feed-Forward
↓
Layer 2: Self-Attention → Feed-Forward
↓
...
↓
Layer N: Self-Attention → Feed-Forward
↓
Rich contextualized representations
The encoder's job is to thoroughly understand the input, creating representations where each token contains:
- Original word meaning
- Grammatical role
- Relationships to other words
- Full contextual information
The decoder generates output sequences one token at a time, using both encoder output and previously generated tokens.
The decoder has three main components:
-
Masked Self-Attention
- Looks only at previously generated tokens
- Prevents "cheating" by looking at future tokens
- Enforced by setting future positions to -∞ before softmax
-
Cross-Attention (Encoder-Decoder Attention)
- Q (queries) come from decoder's previous layer
- K and V (keys, values) come from encoder output
- Allows decoder to "read" the source sequence
-
Feed-Forward Network
- Same as in encoder
This is the key connection between encoder and decoder:
Q_decoder = Decoder_Hidden × W_Q # from decoder
K_encoder = Encoder_Output × W_K # from encoder
V_encoder = Encoder_Output × W_V # from encoder
Cross_Attention = softmax(Q_decoder × K_encoder^T) × V_encoder
During inference, the decoder works step-by-step:
- Start with
<START>token - Decoder generates probability distribution over vocabulary
- Sample/pick next token
- Feed that token back as input for next position
- Repeat until
<END>token
Input: "Hello world" → Output: "Bonjour monde"
1. Encoder processes "Hello world" once
2. Decoder generates "Bonjour" (attending to encoder output)
3. Decoder generates "monde" (attending to "Bonjour" + encoder output)
4. Continue until complete
Critical Insight: All matrices (W_Q, W_K, W_V) are trained together with the same objective.
Single Training Loop:
Input → [W_Q, W_K, W_V] → Q, K, V → Attention → Output → Loss
↑ ↓
Update ← ← ← ← ← Backpropagation ← ← ← ← ← ← ← ←
Even with identical training:
-
Different Roles in Computation
- W_Q creates queries for Q·K^T dot product
- W_K creates keys for K·Q^T dot product
- W_V creates values for attention weighting
-
Different Gradient Flows
∂Loss/∂W_Q flows through Q·K^T computation ∂Loss/∂W_K flows through K·Q^T computation ∂Loss/∂W_V flows through attention weighting -
Random Initialization
- Different starting points lead to different evolutionary paths
- Symmetry breaking allows specialization
Loss Function: Typically cross-entropy for next token prediction Optimization: Adam optimizer with learning rate scheduling Regularization: Dropout, layer normalization Gradient Flow: Residual connections help gradients reach early layers
Flash Attention is a hardware-aware algorithm for computing attention that significantly improves speed and memory efficiency.
Standard attention computation requires storing large intermediate matrices:
- Attention scores: O(n²) memory for sequence length n
- For long sequences, this becomes prohibitive
- Computes attention in blocks without storing full attention matrix
- Uses "online" algorithms for softmax computation
- Dramatically reduces memory usage while maintaining mathematical equivalence
Traditional softmax can cause numerical instability. Flash Attention uses:
- LogSumExp trick for numerical stability
- Online/streaming computation to avoid storing large matrices
- Block-wise matrix operations
Triton allows writing GPU kernels in Python-like syntax:
# Triton kernel for Flash Attention
@triton.jit
def flash_attention_kernel(
q_ptr, k_ptr, v_ptr, output_ptr,
# ... other parameters
):
# Block-wise attention computation
# Optimized memory access patterns- 2-5x speedup over standard attention on modern GPUs
- Significantly reduced memory footprint
- Enables training on longer sequences
- Official Triton Tutorial: https://github.com/triton-lang/triton
- Tutorial 06: Fused Attention implementation
- SageAttention: Optimized for RTX 4090 (supports your setup)
- "Attention is All You Need" (Vaswani et al., 2017)
- Original transformer paper
- URL: https://papers.neurips.cc/paper/7181-attention-is-all-you-need.pdf
-
Stanford CS224N Lecture Notes (2023)
- "Note 10: Self-Attention & Transformers"
- URL: https://web.stanford.edu/class/cs224n/readings/cs224n-self-attention-transformers-2023_draft.pdf
-
Purdue University Tutorial (Avinash Kak)
- "Transformers: Learning with Purely Attention Based Networks"
- URL: https://engineering.purdue.edu/DeepLearn/pdf-kak/Transformers.pdf
-
Jurafsky & Martin Textbook - Chapter 8
- "Speech and Language Processing"
- URL: https://web.stanford.edu/~jurafsky/slp3/8.pdf
-
Sebastian Raschka's Lecture Notes
- "Self-Attention Mechanism & Transformers"
- URL: https://sebastianraschka.com/pdf/lecture-notes/stat453ss21/L19_seq2seq_rnn-transformers__slides.pdf
- Start with Stanford CS224N for mathematical foundations
- Read the original paper for complete context
- Use Purdue tutorial for implementation details
- Explore Flash Attention resources for optimization
- Transformers are fundamentally about attention: The ability for tokens to selectively focus on other tokens
- Q, K, V specialization emerges naturally: Same training objective, different roles due to mathematical structure
- Parallelization is key: Unlike RNNs, transformers can process entire sequences simultaneously
- Position matters: Positional encodings are crucial since attention is permutation-invariant
- Encoder-decoder separation: Different attention patterns for understanding vs. generation
- Modern optimizations matter: Techniques like Flash Attention make transformers practical at scale
This architecture forms the foundation of modern language models like GPT, BERT, Claude, and many others. Understanding these fundamentals provides the basis for working with and improving transformer-based systems.