Skip to content

Fix Qwen MTP batched target-verify drift#1210

Merged
Blaizzy merged 29 commits into
mainfrom
pc/qwen-mtp-batch-drift
Jun 1, 2026
Merged

Fix Qwen MTP batched target-verify drift#1210
Blaizzy merged 29 commits into
mainfrom
pc/qwen-mtp-batch-drift

Conversation

@Blaizzy
Copy link
Copy Markdown
Owner

@Blaizzy Blaizzy commented May 21, 2026

Summary

Fixes Qwen3.5/Qwen3.6 MTP batch drift by making target-verify paths match singleton numerics for left-padded and mixed-length batches.

Changes include:

  • preserve left-padding position offsets during Qwen batched prefill
  • make Qwen target-verify fallback projections row/time singleton exact when the dense helper cannot be used
  • pass the sliced attention mask through target-verify attention chunks
  • extend speculative tests to cover multi-row target-verify projections

Root Cause

The remaining batch drift came from target-verify fallback projections that were only split by time, not by row. Small GDN projections fell back to batched GEMM across rows, which changed recurrent GDN state numerics for mixed batches. Left-padded batched prefill also needed singleton-equivalent position handling so no-drafter and MTP prefill states stay aligned.

Validation

Focused tests:

PYTHONPATH=/tmp/codex-mlx-lm-target:. pytest \
  mlx_vlm/tests/test_speculative.py::test_qwen_target_verify_linear_matches_singleton_dense_gemv \
  mlx_vlm/tests/test_speculative.py::test_qwen_target_verify_small_projection_matches_singleton_dense_gemv \
  mlx_vlm/tests/test_speculative.py::test_qwen_target_verify_gated_norm_matches_singleton_path \
  mlx_vlm/tests/test_speculative.py::test_qwen_gdn_verify_conv_matches_singleton_windows -q

Result: 4 passed.

Qwen3.6-35B-A3B AIME 2026 ids 1-4, max_tokens=256, temperature=0, seed=42, thinking enabled:

Mode Batch Wall Tokens Tok/s Exactness
No drafter singleton x4 27.60s 1024 37.10 reference
No drafter 4 25.34s 1024 40.42 exact vs singleton
MTP singleton x4 11.97s 1024 85.52 reference
MTP 4 9.02s 1024 113.48 exact vs singleton

MTP batch-4 is 2.81x faster than no-drafter batch-4 for this short run, and 1.33x faster than sequential singleton MTP.

Qwen3.5 9B 5-bit Temperature Sweep

AIME 2026 prompts, max_tokens=2048, seed=42, thinking enabled. All runs below were token-identical vs their no-drafter reference.

Batch 4, first 4 prompts. Before is the uniform sampled-walk fallback; After is the positioned ragged sampled path.

Temp No-drafter tok/s Before Match Before MTP tok/s Before Speedup Before Accept After Match After MTP tok/s After Speedup After Accept MTP tok/s Δ
0.0 50.83 4/4 122.37 2.38x 2.72 4/4 121.01 2.38x 2.72 -1.1%
0.2 50.69 4/4 105.43 2.04x 2.21 4/4 118.85 2.34x 2.72 +12.7%
0.6 50.45 4/4 101.52 1.97x 2.12 4/4 114.04 2.26x 2.60 +12.3%
1.0 50.77 4/4 96.70 1.88x 1.98 4/4 103.09 2.03x 2.34 +6.6%

Current positioned ragged sampled path at additional batch sizes:

Batch Temp No-drafter tok/s MTP tok/s Speedup Match Accept Rounds
2 0.0 55.88 100.64 1.80x 2/2 2.75 754
2 0.2 55.73 97.14 1.74x 2/2 2.75 752
2 0.6 56.57 88.35 1.56x 2/2 2.62 820
2 1.0 56.25 85.95 1.53x 2/2 2.52 820
8 0.0 52.18 135.93 2.60x 8/8 2.76 770
8 0.2 52.59 134.35 2.55x 8/8 2.76 775
8 0.6 52.32 126.78 2.42x 8/8 2.63 829
8 1.0 52.66 118.08 2.24x 8/8 2.34 935

Blaizzy added 19 commits May 21, 2026 07:37
Use uniform deferred verification for non-greedy batched MTP so target sampling consumes RNG in the same lockstep order as no-drafter batches. Keep ragged acceptance enabled for greedy decoding, where argmax has no RNG-order drift and preserves the faster batch path.
Add a positioned target sampler so no-drafter and MTP consume deterministic per-position target draws instead of relying on global RNG order. This keeps sampled batched decoding exact while allowing Qwen MTP to use the ragged acceptance path.
Blaizzy added 5 commits May 27, 2026 09:39
# Conflicts:
#	mlx_vlm/generate.py
#	mlx_vlm/models/qwen3_5/language.py
#	mlx_vlm/server/generation.py
#	mlx_vlm/tests/test_generate.py
#	mlx_vlm/tests/test_server.py
#	mlx_vlm/tests/test_speculative.py
@Blaizzy Blaizzy marked this pull request as ready for review June 1, 2026 19:35
return mx.random.categorical(logprobs * (1 / args.temperature))

return sampler
return _PositionedTargetSampler(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this only be used when there's a draft model? I'm wondering if the vmap-based sampler might be a bit slower...

Copy link
Copy Markdown
Collaborator

@lucasnewman lucasnewman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Blaizzy Blaizzy merged commit eb7537b into main Jun 1, 2026
1 check passed
@Blaizzy Blaizzy deleted the pc/qwen-mtp-batch-drift branch June 1, 2026 19:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TypeError: _build_replacement_call got an unexpected keyword argument 'target_verify' in Qwen3.5/3.6 MTP models — PR #1210 does not resolve

2 participants