Skip to content

DecartAI/entropy-sparse-attn

Repository files navigation

ESA: Entropy-based Sparse Attention for Video Generation

Training-free sparse attention that dynamically allocates per-head budgets based on attention entropy. Uses FA4 block-sparse kernels on Blackwell GPUs. Two modes: online (compute masks on the fly) and cached (pre-computed masks for zero overhead).

Setup

pip install -e .
pip install -e third_party/fa4

Requires torch >= 2.5, NVIDIA B200/Blackwell GPU, FA4 (vendored in third_party/fa4/).

Quick Start

Online mode (no calibration needed)

Computes importance scores and entropy per layer per step. Simple to use but adds overhead from mask generation.

from esa.processor import sparse_processor_mode, _make_step_callback
from esa.models import load_pipeline, generate_video

pipe = load_pipeline("wan-t2v-1.3b")

with sparse_processor_mode(pipe, mean_budget=0.5) as state:
    cb = _make_step_callback(state)
    video = generate_video(pipe, "A cat walking on grass", "wan-t2v-1.3b",
                           callback_on_step_end=cb)

Cached mode (recommended)

Pre-compute block-sparse masks offline from a small set of calibration prompts, then reuse them at inference with zero mask-generation overhead. The masks capture which K-blocks matter for each Q-block at each layer and diffusion step — averaged across prompts so they generalize to unseen inputs.

Why pre-compute? Online importance scoring adds ~1.8ms per layer (matmul + softmax + entropy + topk), which dominates the 0.4ms dense attention at lower resolutions. Pre-computed masks eliminate this entirely — the processor just indexes into a tensor.

Step 1: Calibrate (offline, once per model + resolution)

Run a few prompts through the dense pipeline while recording importance scores at every layer and step. The scores are averaged across prompts, then converted to block-sparse masks for each requested budget.

python scripts/calibrate.py \
    --model wan-t2v-1.3b \
    --height 720 --width 1280 \
    --num-prompts 5 \
    --budgets 0.7,0.5,0.3 \
    --output masks/wan-1.3b-720p.pt

This produces a .pt file containing:

  • Averaged importance [num_layers, num_steps, H, nq, nk] — the raw block importance scores averaged over all calibration prompts. Stored in float16.
  • Pre-built masks per budget — full_block_cnt and full_block_idx tensors ready for FA4, so loading is instant.

The masks are tied to a specific resolution (nq/nk depend on token count) and number of diffusion steps (each step has its own mask). As few as 2 calibration prompts work well — in our stress test, masks from 2 prompts achieved SSIM 0.92 on 5 unseen prompts.

Step 2: Generate with cached masks

from esa.cached_processor import cached_processor_mode
from esa.processor import _make_step_callback
from esa.models import load_pipeline, generate_video

pipe = load_pipeline("wan-t2v-1.3b")

with cached_processor_mode(pipe, "masks/wan-1.3b-720p.pt", budget=0.5) as state:
    cb = _make_step_callback(state)
    video = generate_video(pipe, "A cat walking on grass", "wan-t2v-1.3b",
                           height=720, width=1280, callback_on_step_end=cb)

At inference, the processor does: tensor index lookup → pad Q/K/V → FA4 block-sparse call. No importance scoring, no entropy, no topk. If you request a budget that wasn't pre-computed during calibration, it will be derived on the fly from the stored importance tensor (once at load time, not per-call).

Benchmark Results

Wan2.1-1.3B, 81 frames @ 720p, b=0.5

Metric Value
Speedup 1.25x (with warmup) / 1.33x (layer-0 only)
PSNR 30.80 dB
SSIM 0.9354
LPIPS 0.0335

Wan2.1-1.3B, 81 frames @ 480x832

Budget Speedup PSNR SSIM LPIPS
b=0.7 1.06x 33.38 0.9577 0.0241
b=0.5 1.15x 29.15 0.9306 0.0460
b=0.3 1.24x 25.22 0.8781 0.1056

Speedup scales with resolution (attention fraction grows quadratically with token count).

Scripts

Calibrate masks (offline)

# 480p (default resolution)
python scripts/calibrate.py --model wan-t2v-1.3b --num-prompts 5 \
    --budgets 0.7,0.5,0.3 --output masks/wan-1.3b.pt

# 720p
python scripts/calibrate.py --model wan-t2v-1.3b --num-prompts 5 \
    --height 720 --width 1280 --budgets 0.5 --output masks/wan-1.3b-720p.pt

Speed benchmark

# Online mode
python scripts/benchmark.py --model wan-t2v-1.3b --steps 5 --budgets 0.7,0.5,0.3

# Cached mode
python scripts/benchmark_cached.py --model wan-t2v-1.3b \
    --masks masks/wan-1.3b.pt --budgets 0.7,0.5,0.3 --steps 5

Generate videos

python scripts/generate.py --model wan-t2v-1.3b --method dense \
    --output-dir results/dense

python scripts/generate.py --model wan-t2v-1.3b --method entropy --budget 0.5 \
    --output-dir results/sparse-0.5

Evaluate quality

python scripts/evaluate.py \
    --reference results/dense \
    --generated results/sparse-0.5 \
    --output results/sparse-0.5/metrics.json

Run full experiment (calibrate + speed + quality)

python experiments/run_experiment.py \
    --model wan-t2v-1.3b --height 720 --width 1280 --budget 0.5 \
    --tag 1.3b-720p-b05

# Or launch multiple experiments via srun
bash experiments/launch_all.sh

Export to MP4

python scripts/export_mp4.py results/dense

Running experiments on a Slurm cluster

Each experiment (calibrate + benchmark + quality) is self-contained and needs a single GPU. Use srun for interactive jobs or sbatch for batch submission. Multiple experiments can run in parallel on different GPUs.

Single experiment with srun:

srun --gres=gpu:1 --job-name=esa-1.3b-720p \
    python experiments/run_experiment.py \
    --model wan-t2v-1.3b --height 720 --width 1280 --budget 0.5 \
    --tag 1.3b-720p-b05

Multiple experiments in parallel with srun:

# Each srun grabs its own GPU — run these from the same shell
srun --gres=gpu:1 --job-name=esa-1.3b-720p-b07 \
    python experiments/run_experiment.py \
    --model wan-t2v-1.3b --height 720 --width 1280 --budget 0.7 \
    --tag 1.3b-720p-b07 > experiments/logs/1.3b-720p-b07.log 2>&1 &

srun --gres=gpu:1 --job-name=esa-14b-480p-b05 \
    python experiments/run_experiment.py \
    --model wan-t2v-14b --budget 0.5 --cpu-offload \
    --tag 14b-480p-b05 > experiments/logs/14b-480p-b05.log 2>&1 &

# Monitor
tail -f experiments/logs/*.log
squeue -u $USER

Or use the provided launcher script which submits all experiments at once:

bash experiments/launch_all.sh

Batch submission with sbatch:

Create a script (e.g., experiments/job.sbatch):

#!/bin/bash
#SBATCH --job-name=esa-experiment
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-task=8
#SBATCH --mem=64G
#SBATCH --time=6:00:00
#SBATCH --output=experiments/logs/%x-%j.log

python experiments/run_experiment.py \
    --model wan-t2v-1.3b --height 720 --width 1280 --budget 0.5 \
    --tag 1.3b-720p-b05

Then submit multiple jobs with different parameters:

for budget in 0.7 0.5 0.3; do
    sbatch --job-name=esa-b${budget} experiments/job.sbatch \
        --export=ALL,BUDGET=$budget
done

Notes:

  • The 14B model needs --cpu-offload on a single GPU, or request multiple GPUs for tensor parallelism.
  • Calibration is the slowest stage (~10 min/prompt at 480p, ~50 min/prompt at 720p for Wan-1.3B). Use --skip-calibrate to reuse existing masks.
  • Results go to results/<tag>/ with speed.json and quality.json.

Configuration

Parameter Default Description
mean_budget 0.5 Fraction of K-blocks to keep (0.0-1.0). Lower = faster, more lossy
warmup_frac 0.2 First N% of diffusion steps use dense attention
dense_first_layer True Layer 0 always uses dense attention
Q_BLOCK 256 Query block size (matches FA4 SM100 q_stage=2)
K_BLOCK 128 Key block size (FA4 tile_n)

How It Works

  1. Importance scoring: For each Q-block (256 tokens), sample 4 query tokens and score against all K-blocks (128 tokens) via batched matmul + softmax + block-sum.

  2. Entropy-adaptive budgets: Per-head normalized Shannon entropy determines budget allocation. High-entropy heads (dispersed attention) get larger budgets; focused heads get smaller. Mean-preserving across heads.

  3. Block selection: Batched top-k on importance with per-head variable k. No Python loops over heads.

  4. FA4 execution: Selected blocks passed directly to FA4's flash_attn_func as BlockSparseTensorsTorch in native BSHD layout.

Warmup: First layer and first 20% of diffusion steps use dense attention for stability.

Supported Models

Name Pipeline Heads Head Dim Default Resolution
wan-t2v-1.3b WanPipeline 12 128 81f @ 480x832
wan-t2v-14b WanPipeline 40 128 81f @ 480x832
cogvideox-5b CogVideoXPipeline 48 64 49f @ 480x720
hunyuan-video HunyuanVideoPipeline 24 128 129f @ 544x960

Project Structure

esa/
  processor.py          Online sparse attention processor (FA4)
  cached_processor.py   Cached sparse processor (zero mask overhead)
  models.py             Model loading via diffusers
  metrics.py            PSNR, SSIM, LPIPS

scripts/
  calibrate.py          Offline mask calibration
  benchmark.py          Speed benchmark (online mode)
  benchmark_cached.py   Speed benchmark (cached mode)
  generate.py           Video generation
  evaluate.py           Quality metrics
  export_mp4.py         Tensor to MP4

experiments/
  run_experiment.py     Full pipeline: calibrate + bench + quality
  launch_all.sh         Launch experiments via srun

sparse_attention_processor.py   AR pipeline processor (14B)
third_party/fa4/                Vendored FlashAttention-4

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors