This repo is an educational lab to implement and compare:
- FlashAttention-style tiling + online softmax (CPU pedagogical version)
- MQA (shared K/V) and GQA (grouped queries)
- Activation Checkpointing (recompute tradeoffs)
You’ll fill in TODOs and run small experiments that log: accuracy parity (on toy tasks), runtime, and peak memory.
- Baseline vs. Optimized: You’ll first implement a simple, obvious baseline to internalize the cost. Then you’ll implement the optimized version and measure gains. This mirrors real systems work.
- CPU-scaled: We use small tensors and the Python stdlib to profile (no CUDA), so you can iterate quickly on a laptop.
- Metrics-first: Each script reports wall time and process peak RSS, so you see the effect of each technique.
python -m venv .venv && source .venv/bin/activate # Windows: .venv\Scripts\activate
pip install -r requirements.txt
pip install -e .# 1) Flash vs Baseline (tiling + online softmax)
python scripts/01_flash_vs_baseline.py --seq 1024 --d 64 --heads 4
# 2) MQA / GQA vs MHA
python scripts/02_mqa_gqa_compare.py --seq 1024 --d 64 --heads 8 --gqa-groups 4
# 3) Activation checkpointing
python scripts/03_checkpointing_compare.py --layers 6 --seq 512 --d 128 --batch 4Outputs show time and peak memory. Start with small sizes and scale until your laptop struggles.
- Correctness: Max absolute diff (optimized vs baseline) should be small (≈1e‑5 to 1e‑6) on fp32.
- Memory: Optimized variants should reduce peak RSS.
- Runtime: Should be comparable or better at larger sizes even on CPU (won’t match GPU papers, but trend should be visible).
- Prefer
torch.no_grad()for inference benchmarks. - Pin random seeds for reproducibility.
- Use
pytest -qto check your work.
- Hour 1–2: Implement
baseline_attention.pyand write the reference tests. - Hour 2–4: Implement
flash_attention_cpu.py(non-causal first), pass tests, run01_flash_vs_baseline.pyand scale--seq. - Hour 4–6: Implement
mqa.pyandgqa.py, then02_mqa_gqa_compare.py; try different--headsand--gqa-groups. - Hour 6–8: Implement
CheckpointedBlockand wire it into03_checkpointing_compare.py. Note memory/runtime changes. - Hour 8–10 (optional): Add a causal mask path, and plug flash-style attention into the tiny transformer.
- Add a causal mask fast path to
flash_attention_cpu(stop blocks aftert). - Try BF16/FP16 on CPU (if supported) to see numerical effects.
- Log attention I/O bytes approximations per method for deeper insight.