Training-free sparse attention that dynamically allocates per-head budgets based on attention entropy. Uses FA4 block-sparse kernels on Blackwell GPUs. Two modes: online (compute masks on the fly) and cached (pre-computed masks for zero overhead).
pip install -e .
pip install -e third_party/fa4Requires torch >= 2.5, NVIDIA B200/Blackwell GPU, FA4 (vendored in third_party/fa4/).
Computes importance scores and entropy per layer per step. Simple to use but adds overhead from mask generation.
from esa.processor import sparse_processor_mode, _make_step_callback
from esa.models import load_pipeline, generate_video
pipe = load_pipeline("wan-t2v-1.3b")
with sparse_processor_mode(pipe, mean_budget=0.5) as state:
cb = _make_step_callback(state)
video = generate_video(pipe, "A cat walking on grass", "wan-t2v-1.3b",
callback_on_step_end=cb)Pre-compute block-sparse masks offline from a small set of calibration prompts, then reuse them at inference with zero mask-generation overhead. The masks capture which K-blocks matter for each Q-block at each layer and diffusion step — averaged across prompts so they generalize to unseen inputs.
Why pre-compute? Online importance scoring adds ~1.8ms per layer (matmul + softmax + entropy + topk), which dominates the 0.4ms dense attention at lower resolutions. Pre-computed masks eliminate this entirely — the processor just indexes into a tensor.
Step 1: Calibrate (offline, once per model + resolution)
Run a few prompts through the dense pipeline while recording importance scores at every layer and step. The scores are averaged across prompts, then converted to block-sparse masks for each requested budget.
python scripts/calibrate.py \
--model wan-t2v-1.3b \
--height 720 --width 1280 \
--num-prompts 5 \
--budgets 0.7,0.5,0.3 \
--output masks/wan-1.3b-720p.ptThis produces a .pt file containing:
- Averaged importance
[num_layers, num_steps, H, nq, nk]— the raw block importance scores averaged over all calibration prompts. Stored in float16. - Pre-built masks per budget —
full_block_cntandfull_block_idxtensors ready for FA4, so loading is instant.
The masks are tied to a specific resolution (nq/nk depend on token count) and number of diffusion steps (each step has its own mask). As few as 2 calibration prompts work well — in our stress test, masks from 2 prompts achieved SSIM 0.92 on 5 unseen prompts.
Step 2: Generate with cached masks
from esa.cached_processor import cached_processor_mode
from esa.processor import _make_step_callback
from esa.models import load_pipeline, generate_video
pipe = load_pipeline("wan-t2v-1.3b")
with cached_processor_mode(pipe, "masks/wan-1.3b-720p.pt", budget=0.5) as state:
cb = _make_step_callback(state)
video = generate_video(pipe, "A cat walking on grass", "wan-t2v-1.3b",
height=720, width=1280, callback_on_step_end=cb)At inference, the processor does: tensor index lookup → pad Q/K/V → FA4 block-sparse call. No importance scoring, no entropy, no topk. If you request a budget that wasn't pre-computed during calibration, it will be derived on the fly from the stored importance tensor (once at load time, not per-call).
| Metric | Value |
|---|---|
| Speedup | 1.25x (with warmup) / 1.33x (layer-0 only) |
| PSNR | 30.80 dB |
| SSIM | 0.9354 |
| LPIPS | 0.0335 |
| Budget | Speedup | PSNR | SSIM | LPIPS |
|---|---|---|---|---|
| b=0.7 | 1.06x | 33.38 | 0.9577 | 0.0241 |
| b=0.5 | 1.15x | 29.15 | 0.9306 | 0.0460 |
| b=0.3 | 1.24x | 25.22 | 0.8781 | 0.1056 |
Speedup scales with resolution (attention fraction grows quadratically with token count).
# 480p (default resolution)
python scripts/calibrate.py --model wan-t2v-1.3b --num-prompts 5 \
--budgets 0.7,0.5,0.3 --output masks/wan-1.3b.pt
# 720p
python scripts/calibrate.py --model wan-t2v-1.3b --num-prompts 5 \
--height 720 --width 1280 --budgets 0.5 --output masks/wan-1.3b-720p.pt# Online mode
python scripts/benchmark.py --model wan-t2v-1.3b --steps 5 --budgets 0.7,0.5,0.3
# Cached mode
python scripts/benchmark_cached.py --model wan-t2v-1.3b \
--masks masks/wan-1.3b.pt --budgets 0.7,0.5,0.3 --steps 5python scripts/generate.py --model wan-t2v-1.3b --method dense \
--output-dir results/dense
python scripts/generate.py --model wan-t2v-1.3b --method entropy --budget 0.5 \
--output-dir results/sparse-0.5python scripts/evaluate.py \
--reference results/dense \
--generated results/sparse-0.5 \
--output results/sparse-0.5/metrics.jsonpython experiments/run_experiment.py \
--model wan-t2v-1.3b --height 720 --width 1280 --budget 0.5 \
--tag 1.3b-720p-b05
# Or launch multiple experiments via srun
bash experiments/launch_all.shpython scripts/export_mp4.py results/denseEach experiment (calibrate + benchmark + quality) is self-contained and needs a single GPU. Use srun for interactive jobs or sbatch for batch submission. Multiple experiments can run in parallel on different GPUs.
Single experiment with srun:
srun --gres=gpu:1 --job-name=esa-1.3b-720p \
python experiments/run_experiment.py \
--model wan-t2v-1.3b --height 720 --width 1280 --budget 0.5 \
--tag 1.3b-720p-b05Multiple experiments in parallel with srun:
# Each srun grabs its own GPU — run these from the same shell
srun --gres=gpu:1 --job-name=esa-1.3b-720p-b07 \
python experiments/run_experiment.py \
--model wan-t2v-1.3b --height 720 --width 1280 --budget 0.7 \
--tag 1.3b-720p-b07 > experiments/logs/1.3b-720p-b07.log 2>&1 &
srun --gres=gpu:1 --job-name=esa-14b-480p-b05 \
python experiments/run_experiment.py \
--model wan-t2v-14b --budget 0.5 --cpu-offload \
--tag 14b-480p-b05 > experiments/logs/14b-480p-b05.log 2>&1 &
# Monitor
tail -f experiments/logs/*.log
squeue -u $USEROr use the provided launcher script which submits all experiments at once:
bash experiments/launch_all.shBatch submission with sbatch:
Create a script (e.g., experiments/job.sbatch):
#!/bin/bash
#SBATCH --job-name=esa-experiment
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-task=8
#SBATCH --mem=64G
#SBATCH --time=6:00:00
#SBATCH --output=experiments/logs/%x-%j.log
python experiments/run_experiment.py \
--model wan-t2v-1.3b --height 720 --width 1280 --budget 0.5 \
--tag 1.3b-720p-b05Then submit multiple jobs with different parameters:
for budget in 0.7 0.5 0.3; do
sbatch --job-name=esa-b${budget} experiments/job.sbatch \
--export=ALL,BUDGET=$budget
doneNotes:
- The 14B model needs
--cpu-offloadon a single GPU, or request multiple GPUs for tensor parallelism. - Calibration is the slowest stage (~10 min/prompt at 480p, ~50 min/prompt at 720p for Wan-1.3B). Use
--skip-calibrateto reuse existing masks. - Results go to
results/<tag>/withspeed.jsonandquality.json.
| Parameter | Default | Description |
|---|---|---|
mean_budget |
0.5 | Fraction of K-blocks to keep (0.0-1.0). Lower = faster, more lossy |
warmup_frac |
0.2 | First N% of diffusion steps use dense attention |
dense_first_layer |
True | Layer 0 always uses dense attention |
Q_BLOCK |
256 | Query block size (matches FA4 SM100 q_stage=2) |
K_BLOCK |
128 | Key block size (FA4 tile_n) |
-
Importance scoring: For each Q-block (256 tokens), sample 4 query tokens and score against all K-blocks (128 tokens) via batched matmul + softmax + block-sum.
-
Entropy-adaptive budgets: Per-head normalized Shannon entropy determines budget allocation. High-entropy heads (dispersed attention) get larger budgets; focused heads get smaller. Mean-preserving across heads.
-
Block selection: Batched top-k on importance with per-head variable k. No Python loops over heads.
-
FA4 execution: Selected blocks passed directly to FA4's
flash_attn_funcasBlockSparseTensorsTorchin native BSHD layout.
Warmup: First layer and first 20% of diffusion steps use dense attention for stability.
| Name | Pipeline | Heads | Head Dim | Default Resolution |
|---|---|---|---|---|
wan-t2v-1.3b |
WanPipeline | 12 | 128 | 81f @ 480x832 |
wan-t2v-14b |
WanPipeline | 40 | 128 | 81f @ 480x832 |
cogvideox-5b |
CogVideoXPipeline | 48 | 64 | 49f @ 480x720 |
hunyuan-video |
HunyuanVideoPipeline | 24 | 128 | 129f @ 544x960 |
esa/
processor.py Online sparse attention processor (FA4)
cached_processor.py Cached sparse processor (zero mask overhead)
models.py Model loading via diffusers
metrics.py PSNR, SSIM, LPIPS
scripts/
calibrate.py Offline mask calibration
benchmark.py Speed benchmark (online mode)
benchmark_cached.py Speed benchmark (cached mode)
generate.py Video generation
evaluate.py Quality metrics
export_mp4.py Tensor to MP4
experiments/
run_experiment.py Full pipeline: calibrate + bench + quality
launch_all.sh Launch experiments via srun
sparse_attention_processor.py AR pipeline processor (14B)
third_party/fa4/ Vendored FlashAttention-4