Skip to content

tgautam23/Transformer_Memory_Optimization

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

LLM Memory Techniques — Hands‑on Homework (CPU only)

This repo is an educational lab to implement and compare:

  1. FlashAttention-style tiling + online softmax (CPU pedagogical version)
  2. MQA (shared K/V) and GQA (grouped queries)
  3. Activation Checkpointing (recompute tradeoffs)

You’ll fill in TODOs and run small experiments that log: accuracy parity (on toy tasks), runtime, and peak memory.

Why these exercises?

  • Baseline vs. Optimized: You’ll first implement a simple, obvious baseline to internalize the cost. Then you’ll implement the optimized version and measure gains. This mirrors real systems work.
  • CPU-scaled: We use small tensors and the Python stdlib to profile (no CUDA), so you can iterate quickly on a laptop.
  • Metrics-first: Each script reports wall time and process peak RSS, so you see the effect of each technique.

Setup

python -m venv .venv && source .venv/bin/activate  # Windows: .venv\Scripts\activate
pip install -r requirements.txt
pip install -e .

Running the labs

# 1) Flash vs Baseline (tiling + online softmax)
python scripts/01_flash_vs_baseline.py --seq 1024 --d 64 --heads 4

# 2) MQA / GQA vs MHA
python scripts/02_mqa_gqa_compare.py --seq 1024 --d 64 --heads 8 --gqa-groups 4

# 3) Activation checkpointing
python scripts/03_checkpointing_compare.py --layers 6 --seq 512 --d 128 --batch 4

Outputs show time and peak memory. Start with small sizes and scale until your laptop struggles.

What you’ll hand‑check

  • Correctness: Max absolute diff (optimized vs baseline) should be small (≈1e‑5 to 1e‑6) on fp32.
  • Memory: Optimized variants should reduce peak RSS.
  • Runtime: Should be comparable or better at larger sizes even on CPU (won’t match GPU papers, but trend should be visible).

Tips

  • Prefer torch.no_grad() for inference benchmarks.
  • Pin random seeds for reproducibility.
  • Use pytest -q to check your work.

Suggested 5–10 hour path

  1. Hour 1–2: Implement baseline_attention.py and write the reference tests.
  2. Hour 2–4: Implement flash_attention_cpu.py (non-causal first), pass tests, run 01_flash_vs_baseline.py and scale --seq.
  3. Hour 4–6: Implement mqa.py and gqa.py, then 02_mqa_gqa_compare.py; try different --heads and --gqa-groups.
  4. Hour 6–8: Implement CheckpointedBlock and wire it into 03_checkpointing_compare.py. Note memory/runtime changes.
  5. Hour 8–10 (optional): Add a causal mask path, and plug flash-style attention into the tiny transformer.

Stretch ideas (optional)

  • Add a causal mask fast path to flash_attention_cpu (stop blocks after t).
  • Try BF16/FP16 on CPU (if supported) to see numerical effects.
  • Log attention I/O bytes approximations per method for deeper insight.

About

This is a repo that aims to teach-by-doing the following memory optimization techniques: flash attention, MQA, GQA, Activation Checkpointing

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages