turboquant: guard L=1 value kernels behind not use_rht (fix masked decode under RHT)#1244
Merged
Merged
Conversation
The L=1 value-reconstruction Metal kernels (_metal_mse_weighted_sum, _metal_mse_weighted_sum_sum_from_scores) undo the codec rotation with matmul(weighted_rot, rotation), which is only the inverse for a plain rotation. _TurboQuantMSECodec defaults to use_rht=True (randomized Hadamard transform) whose inverse is _rht_inverse(.; signs), so under RHT these kernels return uncorrelated output (~140% reconstruction error at every bit depth). weighted_sum_from_scores already guards its kernel with `if not self.use_rht`; weighted_sum and weighted_sum_stats_from_scores did not, corrupting the slow (masked / array-mask) single-query decode path used by continuous-batching decode with per-request left-padding. Add the same guard so RHT falls back to the correct einsum + _rotate_inverse path. Verified 140% -> ~1%.
Blaizzy
approved these changes
May 30, 2026
Owner
Blaizzy
left a comment
There was a problem hiding this comment.
LGTM!
Thanks, will merge for now but will revisit it later
popfido
added a commit
to popfido/omlx
that referenced
this pull request
May 31, 2026
…pstream) Blaizzy/mlx-vlm#1244 (the RHT masked-decode 'not use_rht' guard) is merged. Bump the pin fea81522 -> 6f60ee4 (includes the merge) and delete the interim _fix_masked_decode_rht monkey-patch — B>1 masked decode now relies on the upstream fix. Removed the docs/upstream PR artifacts (PR is merged). Verified: masked decode 1.2% with no patch; 26 TQ tests pass; single/batch coherent.
jundot
pushed a commit
to jundot/omlx
that referenced
this pull request
May 31, 2026
…orse than single (#1547) * feat(turboquant): wire batched KV conversion (Phase 1) Re-enable TurboQuant KV under continuous batching by quantizing the completed fp16 prefill cache once (post-prefill), instead of the #717 on-the-fly-during-prefill conversion that corrupted hidden states and was reverted in #771. - Add Scheduler._turboquant_eligible() gate: only dense KVCache (and CacheList of KVCache for VLM) is converted. Chunked/rotating caches (Llama-4, sliding-window) stay fp16 — closes the #771 SIGABRT class. - Call _apply_turboquant_kv_convert() at the end of _do_external_prefill (after boundary snapshots, so paged-SSD format stays fp16). The per request TurboQuantKVCache is turned into a BatchTurboQuantKVCache by mlx-lm _merge_caches() at insert() time via the existing merge patch. - Empty/short-prompt path converts too (empty TQ for fresh, from_cache for restored). - Tests: eligibility gate across cache types; from_cache -> merge -> decode_attention batch path (offset tracking + real attention shape). * fix(turboquant): work around two broken mlx-vlm f96138e decode kernels Wiring up TurboQuant decode (prev. commit) exposed that mlx-vlm f96138e's TurboQuant decode path produces garbage even at 8-bit. TQ decode was never actually exercised before (conversion was dead code since #771), so these upstream bugs were dormant. Two distinct kernels are broken: 1. Fused single-token quantize (_try_fused_kv_quantize, used only when keys.shape[-2] == 1, i.e. every decode step): ~140% reconstruction error at all bit depths; the non-fused quantize() path (T>=2 prefill) is fine. Fix: _fix_decode_single_token_quantize() forces the non-fused path. -> single-seq TQ output now matches fp16. 2. Masked decode_attention path (taken whenever an array mask is passed, i.e. all B>1 continuous-batching decode for per-request left-padding): ~140% error. Fix: route array-mask decode through dequantize + standard SDPA, the same approach mlx-vlm uses for its own BatchTurboQuantKVCache. B=1 keeps mask=None/causal and the correct fused kernel. -> batched TQ output is now coherent. Both verified numerically (err 140% -> ~1% at 8-bit) and end-to-end on Llama-3.2-1B-Instruct-4bit. Regression tests added for both paths. NOTE: both are upstream mlx-vlm bugs and should be reported there. * test(turboquant): batched accuracy + memory/occupancy harness (Phase 2) Adds tests/test_turboquant_batch_memory.py: compares batched TurboQuant to single-seq on a real model across the three axes requested — - occupancy: KV bytes/token TQ vs fp16 (~0.31x at 4-bit), batch vs single, left-padding waste (analytical, measured at a cache_step-aligned length so over-allocation slack cancels), plus long-context savings projection. - accuracy: concurrent B>1 TQ vs single-seq TQ token match + coherence gate. - peak memory: live peak for single/batch x fp16/TQ (with the honest caveat that at short context the model weights dominate; B>1 TQ dequantizes the batch KV per step so peak is not below batch fp16). Model-gated (skips if not cached); writes tq_batch_memory.md report artifact (gitignored). Also notes in the attention patch that Bug #1 is fixed on mlx-vlm main while Bug #2 (masked decode) is not — the latter is the planned upstream PR. * test(turboquant): SSD prefix-cache round-trip, single + batch (Phase 3) Validates TurboQuant + paged-SSD now that TQ decode actually engages. Key finding: prefill boundary snapshots are stored fp16 and re-quantized deterministically on a cache hit, so there is NO double-quant (TQ->fp16->TQ) drift — the bespoke __turboquant_v2__ SSD path is not even exercised by the common flow (verified via logs). - single-request: cache hit reproduces fresh exactly (bit-identical) at 4-bit. - batch, 8-bit: hit reproduces fresh exactly -> structural round-trip is sound. - batch, 4-bit: hit may differ by a few tokens where quantization tips a greedy near-tie (fp16+SSD is exact; single TQ is exact), output stays coherent. This residual divergence is the same B>1 dequant sensitivity as Bug 2 and resolves when the upstream masked-decode kernel is fixed. 3 model-gated tests; all skip when the model is not cached. * docs(upstream): mlx-vlm TurboQuant RHT masked-decode fix (Bug 2) PR artifacts The L=1 value kernels (_metal_mse_weighted_sum, _metal_mse_weighted_sum_sum_from_scores) undo the codec rotation with matmul(.,rotation) but ignore use_rht (RHT), so they corrupt the masked decode path (~140% err). weighted_sum / weighted_sum_stats_from_scores lack the 'if not self.use_rht' guard that weighted_sum_from_scores has. Patch + PR description for upstream Blaizzy/mlx-vlm (against main). * fix(turboquant): forward mlx-vlm to HEAD + land Bug-2 masked-decode in oMLX Forward the mlx-vlm pin f96138e -> fea81522 (main), which fixes the fused single-token quantize decode kernel (Bug 1) upstream — so the oMLX _fix_decode_single_token_quantize workaround is dropped. Bug 2 (the RHT-incompatible L=1 value kernels corrupting the masked decode path) is still unmerged upstream, so carry it as an oMLX monkey-patch (_fix_masked_decode_rht: disable those kernels -> correct einsum fallback). With the masked path now correct, route B>1 continuous-batching decode through decode_attention instead of the dequantize+SDPA workaround — no per-step batch dequantize, and it resolves the batch-4-bit SSD fresh-vs-hit divergence ([False] -> [True] at 4-bit; verified). uv.lock is gitignored; regenerate it ('uv lock') and run the full suite in a controlled env before release. Tests updated for the new routing; 29 TQ tests pass on HEAD. * fix(turboquant): correct B>1 make_mask for left-padded batches BatchTurboQuantKVCache.make_mask hand-rolled a causal term that compared each request's sequence length (offset) against the column index, then ANDed the left_padding term — which masked out the valid left-padded tokens. Left-padded requests in a ragged batch attended to ~nothing and decoded garbage, making batch mode worse than single (fp16's BatchKVCache was unaffected). Delegate to mlx-lm's create_causal_mask(N, offset=phys, left_padding=...), exactly like BatchKVCache, so the masks are identical. After the fix: - ragged-batch token-match to single-seq: 25% -> 71% (== same-length batch); - teacher-forced top-1 agreement single-vs-batch (left-padded member): 12/12, i.e. batch is computationally equivalent to single; residual greedy token divergence is cascade noise, not quality loss. This is an oMLX-only bug (separate from the mlx-vlm RHT PR). * test(turboquant): group TQ tests under a 'turboquant' marker Register a 'turboquant' pytest marker and apply it to the three TQ test files (test_turboquant.py, test_turboquant_batch_memory.py, test_turboquant_ssd.py) so the whole suite runs with 'pytest -m turboquant'. The model-loading files are also marked 'slow' (deselected by default). * test(turboquant): split semicolon statements in mask test (E702) * chore(turboquant): drop RHT monkey-patch; pin merged mlx-vlm (Bug 2 upstream) Blaizzy/mlx-vlm#1244 (the RHT masked-decode 'not use_rht' guard) is merged. Bump the pin fea81522 -> 6f60ee4 (includes the merge) and delete the interim _fix_masked_decode_rht monkey-patch — B>1 masked decode now relies on the upstream fix. Removed the docs/upstream PR artifacts (PR is merged). Verified: masked decode 1.2% with no patch; 26 TQ tests pass; single/batch coherent.
Blaizzy
added a commit
that referenced
this pull request
Jun 1, 2026
…nder RHT) (#1252) The single-token value kernels (_metal_mse_weighted_sum and friends) return the weighted value sum in the codec's rotated space and undo it with a hard-coded matmul(weighted_rot, rotation). That inverse is only correct for the dense-rotation codec, so #1244 disabled these kernels whenever the codec uses the Randomized Hadamard Transform (use_rht) — every RHT decode fell back to the slower einsum path. Pass the codec's RHT signs into the wrappers and apply the matching inverse (_rht_inverse when signs are set, else the dense matmul) via a small _value_rotate_inverse helper, mirroring _TurboQuantMSECodec._rotate_inverse and the already-correct fused-decode path. The kernel computes weighted_rot correctly regardless of rotation type; only the post-kernel inverse needed fixing. The not-self.use_rht guards are dropped so RHT decode takes the kernel again. Prod-mode fused-decode call sites pass no signs (default None) and keep the exact dense matmul — byte-for-byte unchanged. Verified: RHT weighted_sum / weighted_sum_stats_from_scores match the einsum fallback and the dequantize ground truth to <1e-4 / <1e-3 across dims {64,128,256} x bits {2,3,4,8} x repeats {1,4}; the L=1 value reconstruction runs 3.5-13.7x faster than the einsum fallback (the gap grows with context length). Co-authored-by: Prince Canuma <prince.gdt@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
_TurboQuantMSECodec.weighted_sumandweighted_sum_stats_from_scorescall theL=1 value-reconstruction Metal kernels (
_metal_mse_weighted_sum,_metal_mse_weighted_sum_sum_from_scores) without theif not self.use_rhtguard that the sibling
weighted_sum_from_scoresalready has.Those kernels finish with
matmul(weighted_rot, rotation)to undo the codecrotation — correct only for a plain rotation.
_TurboQuantMSECodecdefaults touse_rht=True(randomized Hadamard transform), whose inverse is_rht_inverse(.; signs), notmatmul(.; rotation). So under RHT these kernelsapply the wrong inverse transform and return essentially uncorrelated output
(~140% reconstruction error at every bit depth, 2–8).
Impact
L=1)decode_attentionpath is corrupt._fused_mse_decode_kernelwhenmaskisNone/"causal"); it only surfaceswhen an array mask forces the slow path — e.g. continuous-batching decode
with per-request left-padding (
B > 1), which then produces garbage.Fix
Add
not self.use_rht andto the two L=1 guards, mirroring the existingweighted_sum_from_scores. Under RHT this takes the correcteinsum +
_rotate_inversefallback; with a plain rotation (use_rht=False) thekernels still run.
Verification
End-to-end on
mlx-community/Llama-3.2-1B-Instruct-4bit, continuous-batchingdecode (
B>1, left-padded) produces coherent output after the fix; before it isgarbage.
Note
Conservative fix (fall back to the correct math) matching the existing
weighted_sum_from_scoresbehavior. A deeper fix would teach the kernels theRHT inverse so the RHT path could keep using the fast kernels.