Skip to content

fix: cast eagle_acts to draft dtype before send to avoid bf16/fp16 bit reinterpretation#18

Open
WLLEGit wants to merge 1 commit into
tanishqkumar:mainfrom
WLLEGit:fix-eagle3-dtype-cast
Open

fix: cast eagle_acts to draft dtype before send to avoid bf16/fp16 bit reinterpretation#18
WLLEGit wants to merge 1 commit into
tanishqkumar:mainfrom
WLLEGit:fix-eagle3-dtype-cast

Conversation

@WLLEGit

@WLLEGit WLLEGit commented May 21, 2026

Copy link
Copy Markdown

Summary

Fix dtype mismatch in async draft prefill: eagle_acts is sent from the target (bf16) without casting, then received into the draft's buffer (fp16). dist.recv just copies bytes, corrupting the conditioning hidden states fed into the draft's prefill. This poisons every entry of the draft KV cache and degrades acceptance for the entire generation.

The speculate path at speculator_async.py:175 already casts via recovery_activations.to(self.draft_dtype). The prefill path at speculator_async.py:89 was missing the same cast.

Fix

-            dist.send(eagle_acts, dst=self.draft_runner_rank, group=self.async_pg)
+            dist.send(eagle_acts.to(self.draft_dtype), dst=self.draft_runner_rank, group=self.async_pg)

Diagnosis

Reproduced on Llama-3.1-8B (bf16) + yuhuili EAGLE3 (fp16), humaneval, K=5, B=8 async + JIT backup.

Layer trace at draft prefill last position (213), reading conditioning into fc:

Tensor SSD (buggy) Oracle
input tok_emb norm 0.4341 0.4341
input cond (= fc(eagle_acts[212])) norm 479.03 31.38
attn_out norm at first JIT step 86.64 114.16

The conditioning passed into fc matched between SSD and the oracle (verified via the dumped target_recovery_activations), but the conditioning received by the draft for prefill positions was numerically off by ~15×. Independently computing fc(eagle_acts[212]) in fp32 confirms the oracle's value (31.38) is correct.

Impact

humaneval, 8 prompts × 256 output tokens, K=5:

Metric Before After vLLM EAGLE3
avg tokens / step 2.28 3.94 3.43
pos0 accept (JIT / cache-miss) 0.604 0.761 0.764
pos0 accept (all) 0.610 0.832 0.764
pos4 accept (all) 0.050 0.387 0.278
cache hit rate 0.51 0.71

JIT-path pos0 accept now matches vLLM almost exactly (0.761 vs 0.764), confirming the per-step computation is correct once the conditioning is clean.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant