Skip to content

Handle zero-causal PrefixLM FA3 pass as no-op#13

Open
Zane12518 wants to merge 2 commits into
sapientinc:mainfrom
Zane12518:fix-fa3-zero-causal-noop
Open

Handle zero-causal PrefixLM FA3 pass as no-op#13
Zane12518 wants to merge 2 commits into
sapientinc:mainfrom
Zane12518:fix-fa3-zero-causal-noop

Conversation

@Zane12518
Copy link
Copy Markdown
Contributor

Summary

This revisits the zero-causal PrefixLM FA3 boundary from #9 while addressing the hot-path cost concern from #10.

The previous runtime guard initialized dk2/dv2 with zeros_like unconditionally, which adds a K/V-shaped memset to every backward. This version keeps the normal path unchanged:

  • when max_seqlen_causal > 0, allocate dk2/dv2 with empty_like and run the causal FA3 forward/backward as before;
  • when max_seqlen_causal == 0, skip the causal FA3 forward/backward pass and initialize dk2/dv2 to zero because the skipped pass is mathematically a no-op.

Why this may still be useful

The SFT data-layer fix in #11 covers empty-response rows prepared by scripts/prepare_sft_data.py. The more general runtime boundary, however, is final resp_len <= 1 after tokenization/truncation, not only an empty raw response.

For example, a non-empty response can still become zero-causal if the prefix consumes nearly the full context budget:

  • context_size = 4097
  • inst_len = 4096
  • allowed_resp = 1
  • final_resp_len = 1
  • causal_len = final_resp_len - 1 = 0

A data-layer-only fix would need to enforce resp_len > 1 in every V1Dataset-producing path after tokenization/truncation, including the pretraining data_io pipeline. This runtime guard is defensive and keeps the normal FA3 hot path unchanged.

Validation

  • python -m py_compile models/flash_attention_prefixlm_v2.py
  • git diff --check origin/main..HEAD
  • local static guard check: normal causal backward uses empty_like, zero-causal branch uses zeros_like, and forward skips the zero-causal FA3 pass
  • previous single-H100 zero-causal boundary repro: PASS
  • previous 6 x 8 H100 training validation: reached max_steps=2286, crossing the previous failing step 2284

@imoneoi
Copy link
Copy Markdown
Contributor

imoneoi commented May 29, 2026

When the causal response length is zero for a sequence (in which case max_seqlen_causal may not be zero), cu_seqlens_q will contain a zero-length sequence. In this case, will the FA3 op fill dk2/dv2, or will it still contain corrupted values?

@Zane12518
Copy link
Copy Markdown
Contributor Author

Zane12518 commented May 29, 2026

When the causal response length is zero for a sequence (in which case max_seqlen_causal may not be zero), cu_seqlens_q will contain a zero-length sequence. In this case, will the FA3 op fill dk2/dv2, or will it still contain corrupted values?

Good question. I verified this mixed-zero-causal case on H100.

The case I tested was:

  • prefix_lens=[4096,4096,4089,4080]
  • causal_lens=[0,0,7,16]
  • max_seqlen_causal=16
  • max_seqlen_all=4096
  • nheads=20, head_dim=128

I ran the causal FA3 backward twice with identical inputs:

  1. dk/dv initialized with zeros;
  2. dk/dv initialized with a sentinel value.

The final outputs matched exactly:

max_abs_diff_vs_zero_init {'dk': 0.0, 'dv': 0.0}

The K/V gradient spans for the zero-causal sequences were also zero, so the sentinel values were not preserved. This suggests that when max_seqlen_causal > 0, FA3 does initialize dk2/dv2 correctly even if seqused_q contains some zero-length query entries.

The guard in this PR is for the separate all-zero-causal case, where max_seqlen_causal == 0 and the causal FA3 pass has no query work at all.

@Zane12518
Copy link
Copy Markdown
Contributor Author

Follow-up commit 428dfc5 switches the zero-causal backward no-op from materializing an explicit zero dk2/dv2 tensor to returning the bidirectional K/V gradients directly.

The normal path is unchanged in the part that matters for performance:

  • when max_seqlen_causal_int > 0, dk2/dv2 are still allocated with empty_like and the causal FA3 backward runs as before;
  • dk2/dv2 padding is zeroed before adding into dk1/dv1, preserving the previous padding cleanup behavior.

For the all-zero-causal path:

  • the causal FA3 forward/backward remains skipped;
  • the causal pass is mathematically a no-op, so instead of allocating zeros_like(k/v) and returning dk1 + 0, the code now returns dk1/dv1 directly.

This keeps the hot path unchanged and also avoids the extra zero tensor allocation/add in the rare all-zero-causal branch.

Validation run:

python -m py_compile models/flash_attention_prefixlm_v2.py
git diff --check
static_implementation_check=PASS

@Zane12518
Copy link
Copy Markdown
Contributor Author

Follow-up H100 validation after commit 428dfc5:

CUDA_VISIBLE_DEVICES=0 scripts/validate_fa3_zero_causal_single_gpu.sh

Result:

VALIDATION_PASS
shape {'total_seqlen': 4096, 'numseqs': 1, 'max_seqlen_prefix': 4096, 'max_seqlen_causal': 0, 'max_seqlen_all': 4096}

Also re-ran the mixed zero-causal K/V grad check:

PREFIX_LENS="4096,4096,4089,4080" \
CAUSAL_LENS="0,0,7,16" \
NHEADS=20 \
HEAD_DIM=128 \
CUDA_VISIBLE_DEVICES=0 \
scripts/validate_fa3_mixed_zero_causal_kv_grads.sh

Result:

max_abs_diff_vs_zero_init {'dk': 0.0, 'dv': 0.0}
VALIDATION_PASS

Runtime was H100 80GB, torch 2.11.0+cu128, CUDA 12.8. This covers both boundaries:

  • all-zero-causal batch: max_seqlen_causal == 0, no-op path passes;
  • mixed batch: max_seqlen_causal > 0 with some zero-length causal sequences, FA3 causal backward writes K/V grads deterministically.

@imoneoi
Copy link
Copy Markdown
Contributor

imoneoi commented May 30, 2026

LGTM. Left a minor comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants