Handle zero-causal PrefixLM FA3 pass as no-op#13
Conversation
|
When the causal response length is zero for a sequence (in which case |
Good question. I verified this mixed-zero-causal case on H100. The case I tested was:
I ran the causal FA3 backward twice with identical inputs:
The final outputs matched exactly: The K/V gradient spans for the zero-causal sequences were also zero, so the sentinel values were not preserved. This suggests that when The guard in this PR is for the separate all-zero-causal case, where |
|
Follow-up commit The normal path is unchanged in the part that matters for performance:
For the all-zero-causal path:
This keeps the hot path unchanged and also avoids the extra zero tensor allocation/add in the rare all-zero-causal branch. Validation run: |
|
Follow-up H100 validation after commit Result: Also re-ran the mixed zero-causal K/V grad check: Result: Runtime was H100 80GB, torch
|
|
LGTM. Left a minor comment. |
Summary
This revisits the zero-causal PrefixLM FA3 boundary from #9 while addressing the hot-path cost concern from #10.
The previous runtime guard initialized
dk2/dv2withzeros_likeunconditionally, which adds a K/V-shaped memset to every backward. This version keeps the normal path unchanged:max_seqlen_causal > 0, allocatedk2/dv2withempty_likeand run the causal FA3 forward/backward as before;max_seqlen_causal == 0, skip the causal FA3 forward/backward pass and initializedk2/dv2to zero because the skipped pass is mathematically a no-op.Why this may still be useful
The SFT data-layer fix in #11 covers empty-response rows prepared by
scripts/prepare_sft_data.py. The more general runtime boundary, however, isfinal resp_len <= 1after tokenization/truncation, not only an empty raw response.For example, a non-empty response can still become zero-causal if the prefix consumes nearly the full context budget:
context_size = 4097inst_len = 4096allowed_resp = 1final_resp_len = 1causal_len = final_resp_len - 1 = 0A data-layer-only fix would need to enforce
resp_len > 1in every V1Dataset-producing path after tokenization/truncation, including the pretrainingdata_iopipeline. This runtime guard is defensive and keeps the normal FA3 hot path unchanged.Validation
python -m py_compile models/flash_attention_prefixlm_v2.pygit diff --check origin/main..HEADempty_like, zero-causal branch useszeros_like, and forward skips the zero-causal FA3 passmax_steps=2286, crossing the previous failing step 2284