Fix Qwen MTP batched target-verify drift#1210
Merged
Merged
Conversation
Use uniform deferred verification for non-greedy batched MTP so target sampling consumes RNG in the same lockstep order as no-drafter batches. Keep ragged acceptance enabled for greedy decoding, where argmax has no RNG-order drift and preserves the faster batch path.
Add a positioned target sampler so no-drafter and MTP consume deterministic per-position target draws instead of relying on global RNG order. This keeps sampled batched decoding exact while allowing Qwen MTP to use the ragged acceptance path.
# Conflicts: # mlx_vlm/generate.py # mlx_vlm/models/qwen3_5/language.py # mlx_vlm/server/generation.py # mlx_vlm/tests/test_generate.py # mlx_vlm/tests/test_server.py # mlx_vlm/tests/test_speculative.py
# Conflicts: # mlx_vlm/generate/ar.py
lucasnewman
reviewed
Jun 1, 2026
| return mx.random.categorical(logprobs * (1 / args.temperature)) | ||
|
|
||
| return sampler | ||
| return _PositionedTargetSampler( |
Collaborator
There was a problem hiding this comment.
Should this only be used when there's a draft model? I'm wondering if the vmap-based sampler might be a bit slower...
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes Qwen3.5/Qwen3.6 MTP batch drift by making target-verify paths match singleton numerics for left-padded and mixed-length batches.
Changes include:
Root Cause
The remaining batch drift came from target-verify fallback projections that were only split by time, not by row. Small GDN projections fell back to batched GEMM across rows, which changed recurrent GDN state numerics for mixed batches. Left-padded batched prefill also needed singleton-equivalent position handling so no-drafter and MTP prefill states stay aligned.
Validation
Focused tests:
Result:
4 passed.Qwen3.6-35B-A3B AIME 2026 ids 1-4,
max_tokens=256,temperature=0,seed=42, thinking enabled:MTP batch-4 is
2.81xfaster than no-drafter batch-4 for this short run, and1.33xfaster than sequential singleton MTP.Qwen3.5 9B 5-bit Temperature Sweep
AIME 2026 prompts,
max_tokens=2048,seed=42, thinking enabled. All runs below were token-identical vs their no-drafter reference.Batch 4, first 4 prompts.
Beforeis the uniform sampled-walk fallback;Afteris the positioned ragged sampled path.Current positioned ragged sampled path at additional batch sizes: