Skip to content

Latest commit

 

History

History
410 lines (308 loc) · 13 KB

File metadata and controls

410 lines (308 loc) · 13 KB

Comprehensive Transformer Architecture Guide

Table of Contents

  1. Introduction
  2. Input Processing
  3. Self-Attention Mechanism
  4. Multi-Head Attention
  5. Encoder Architecture
  6. Decoder Architecture
  7. Training Process
  8. Advanced Topics: Flash Attention
  9. Resources

Introduction

The Transformer architecture, introduced in "Attention is All You Need" (Vaswani et al., 2017), revolutionized natural language processing by using attention mechanisms exclusively, without recurrent or convolutional layers. This guide provides a comprehensive understanding of how transformers work, from basic concepts to implementation details.

Key Innovation

The Transformer replaces sequential processing with parallel attention computations, allowing models to:

  • Process all tokens simultaneously
  • Capture long-range dependencies efficiently
  • Scale to very large datasets and model sizes

Input Processing

Token Embeddings

Before any attention computation can occur, text must be converted to numerical representations:

Input Tokens: ["The", "cat", "sat", "on", "the", "mat"]
           ↓ (embedding lookup/computation)  
Input IDs:    [105, 6587, 5475, 3578, 65, 6587]
           ↓ (embedding matrix: vocab_size × d_model)
Embeddings:   6×512 numerical matrix

Evolution of Embeddings:

  • Early approach: Fixed random lookup tables
  • Word2Vec/GloVe: Pre-trained semantic embeddings
  • Modern approach: Contextual embeddings where the same word gets different representations based on context

Common Initialization Ranges:

  • Embedding layers: Normal distribution (mean=0, std=0.02-0.1)
  • Linear layers: Xavier/Glorot initialization (±√(6/(fan_in + fan_out)))
  • For d_model=512: typically within [-0.1, 0.1]

Positional Encoding

Since transformers process all tokens in parallel, they need explicit position information. Positional encodings provide this crucial sequential context.

Key Properties:

  • Computed once and reused for all sentences
  • Deterministic function of position only
  • Added directly to token embeddings

Mathematical Formula:

PE(pos, 2i) = sin(pos / 10000^(2i/d_model))     # for even dimensions
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))   # for odd dimensions

Where:

  • pos = position in the sequence
  • i = dimension index
  • d_model = embedding dimension (e.g., 512)

Why This Works:

  • Position 0 always gets the same encoding vector
  • Position 1 always gets the same encoding vector
  • Content-independent: "The cat sat" and "Hello world" use the same positional encodings

Implementation: Pre-compute positional encoding matrix up to maximum sequence length, then add appropriate encodings based on each token's position.


Self-Attention Mechanism

Self-attention is the core innovation that allows each token to "look at" and relate to all other tokens in the sequence.

The Attention Formula

Attention(Q,K,V) = softmax(QK^T/√d_k) × V

Creating Q, K, V Matrices

Common Misconception: Q, K, V are NOT just the input sentence.

Reality: They are learned linear transformations of the input:

Q (Query) = Input × W_Q    # 6×512 × 512×512 = 6×512
K (Key)   = Input × W_K    # 6×512 × 512×512 = 6×512  
V (Value) = Input × W_V    # 6×512 × 512×512 = 6×512

Weight Matrices:

  • W_Q, W_K, W_V: 512×512 matrices (for d_model=512)
  • Start with random initialization
  • Learn specialized roles through training

Step-by-Step Attention Computation

  1. Compute Attention Scores:

    Scores = Q × K^T / √d_k
    # (6×512) × (512×6) / √512 = (6×6) matrix
    
  2. Apply Softmax:

    Attention_Weights = softmax(Scores)
    # Each row sums to 1.0
    
  3. Weight the Values:

    Output = Attention_Weights × V
    # (6×6) × (6×512) = (6×512)
    

Mathematical Specialization

The attention formula forces different roles for Q, K, V:

Q and K interact directly:

  • Q·K^T creates attention scores
  • Q learns: "What should I ask for?"
  • K learns: "What should I respond to?"
  • They co-evolve to be compatible

V has a separate role:

  • Gets weighted by attention scores
  • V learns: "What information should I carry forward?"
  • Doesn't participate in the attention score computation

Properties of Self-Attention

  1. Permutation Invariant: Without positional encoding, order doesn't matter
  2. No Parameters: The attention mechanism itself has no learnable parameters (the parameters are in W_Q, W_K, W_V)
  3. Diagonal Dominance: Words typically attend strongly to themselves
  4. Masking: Can prevent future tokens from being seen by setting attention scores to -∞

Multi-Head Attention

Multi-head attention allows the model to attend to information from different representation subspaces simultaneously.

Concept

Instead of using one large attention computation, split into multiple "heads":

  • Each head learns different types of relationships
  • Some heads might focus on syntax, others on semantics
  • Heads operate in parallel and are concatenated

Implementation

Split the d_model dimension across h heads:

d_model = 512, h = 8 heads
d_k = d_v = d_model / h = 64 per head

Each head processes:

head_i = Attention(Q_i, K_i, V_i)
where Q_i = Q[:, i*d_k:(i+1)*d_k]  # slice the matrix

Final output:

MultiHead = Concat(head_1, ..., head_h) × W_O

Encoder Architecture

The encoder processes the entire input sequence to create rich, contextualized representations.

Structure

Each encoder layer contains:

  1. Self-Attention

    • All tokens can attend to all other tokens
    • No masking (unlike decoder)
    • Captures bidirectional context
  2. Feed-Forward Network

    • Applied to each position independently
    • Typically: d_model → 4×d_model → d_model
    • Adds non-linearity and capacity
  3. Residual Connections + Layer Normalization

    • Around both sub-layers
    • Helps with gradient flow and training stability

Process Flow

Input: "The cat sat on the mat"
↓
Token + Positional Embeddings
↓
Layer 1: Self-Attention → Feed-Forward
↓
Layer 2: Self-Attention → Feed-Forward
↓
...
↓
Layer N: Self-Attention → Feed-Forward
↓
Rich contextualized representations

Purpose

The encoder's job is to thoroughly understand the input, creating representations where each token contains:

  • Original word meaning
  • Grammatical role
  • Relationships to other words
  • Full contextual information

Decoder Architecture

The decoder generates output sequences one token at a time, using both encoder output and previously generated tokens.

Key Differences from Encoder

The decoder has three main components:

  1. Masked Self-Attention

    • Looks only at previously generated tokens
    • Prevents "cheating" by looking at future tokens
    • Enforced by setting future positions to -∞ before softmax
  2. Cross-Attention (Encoder-Decoder Attention)

    • Q (queries) come from decoder's previous layer
    • K and V (keys, values) come from encoder output
    • Allows decoder to "read" the source sequence
  3. Feed-Forward Network

    • Same as in encoder

Cross-Attention Mechanics

This is the key connection between encoder and decoder:

Q_decoder = Decoder_Hidden × W_Q     # from decoder
K_encoder = Encoder_Output × W_K     # from encoder  
V_encoder = Encoder_Output × W_V     # from encoder

Cross_Attention = softmax(Q_decoder × K_encoder^T) × V_encoder

Autoregressive Generation

During inference, the decoder works step-by-step:

  1. Start with <START> token
  2. Decoder generates probability distribution over vocabulary
  3. Sample/pick next token
  4. Feed that token back as input for next position
  5. Repeat until <END> token

Example: Translation Process

Input: "Hello world" → Output: "Bonjour monde"

1. Encoder processes "Hello world" once
2. Decoder generates "Bonjour" (attending to encoder output)
3. Decoder generates "monde" (attending to "Bonjour" + encoder output)
4. Continue until complete

Training Process

Unified Training Objective

Critical Insight: All matrices (W_Q, W_K, W_V) are trained together with the same objective.

Single Training Loop:
Input → [W_Q, W_K, W_V] → Q, K, V → Attention → Output → Loss
                ↑                                            ↓
              Update ← ← ← ← ← Backpropagation ← ← ← ← ← ← ← ←

Why Do They Specialize?

Even with identical training:

  1. Different Roles in Computation

    • W_Q creates queries for Q·K^T dot product
    • W_K creates keys for K·Q^T dot product
    • W_V creates values for attention weighting
  2. Different Gradient Flows

    ∂Loss/∂W_Q flows through Q·K^T computation
    ∂Loss/∂W_K flows through K·Q^T computation
    ∂Loss/∂W_V flows through attention weighting
    
  3. Random Initialization

    • Different starting points lead to different evolutionary paths
    • Symmetry breaking allows specialization

Training Details

Loss Function: Typically cross-entropy for next token prediction Optimization: Adam optimizer with learning rate scheduling Regularization: Dropout, layer normalization Gradient Flow: Residual connections help gradients reach early layers


Advanced Topics: Flash Attention

Flash Attention is a hardware-aware algorithm for computing attention that significantly improves speed and memory efficiency.

The Problem

Standard attention computation requires storing large intermediate matrices:

  • Attention scores: O(n²) memory for sequence length n
  • For long sequences, this becomes prohibitive

Flash Attention Solution

  • Computes attention in blocks without storing full attention matrix
  • Uses "online" algorithms for softmax computation
  • Dramatically reduces memory usage while maintaining mathematical equivalence

Safe Softmax and Online Softmax

Traditional softmax can cause numerical instability. Flash Attention uses:

  • LogSumExp trick for numerical stability
  • Online/streaming computation to avoid storing large matrices
  • Block-wise matrix operations

Implementation with Triton

Triton allows writing GPU kernels in Python-like syntax:

# Triton kernel for Flash Attention
@triton.jit
def flash_attention_kernel(
    q_ptr, k_ptr, v_ptr, output_ptr,
    # ... other parameters
):
    # Block-wise attention computation
    # Optimized memory access patterns

Performance Benefits

  • 2-5x speedup over standard attention on modern GPUs
  • Significantly reduced memory footprint
  • Enables training on longer sequences

Resources for Implementation


Resources

Essential Papers

  1. "Attention is All You Need" (Vaswani et al., 2017)

Academic Tutorials

  1. Stanford CS224N Lecture Notes (2023)

  2. Purdue University Tutorial (Avinash Kak)

  3. Jurafsky & Martin Textbook - Chapter 8

  4. Sebastian Raschka's Lecture Notes

Recommended Reading Order

  1. Start with Stanford CS224N for mathematical foundations
  2. Read the original paper for complete context
  3. Use Purdue tutorial for implementation details
  4. Explore Flash Attention resources for optimization

Key Takeaways

  1. Transformers are fundamentally about attention: The ability for tokens to selectively focus on other tokens
  2. Q, K, V specialization emerges naturally: Same training objective, different roles due to mathematical structure
  3. Parallelization is key: Unlike RNNs, transformers can process entire sequences simultaneously
  4. Position matters: Positional encodings are crucial since attention is permutation-invariant
  5. Encoder-decoder separation: Different attention patterns for understanding vs. generation
  6. Modern optimizations matter: Techniques like Flash Attention make transformers practical at scale

This architecture forms the foundation of modern language models like GPT, BERT, Claude, and many others. Understanding these fundamentals provides the basis for working with and improving transformer-based systems.