feat(turboquant): batched KV-cache compression (single + batch), no worse than single#1
Closed
popfido wants to merge 9 commits into
Closed
feat(turboquant): batched KV-cache compression (single + batch), no worse than single#1popfido wants to merge 9 commits into
popfido wants to merge 9 commits into
Conversation
Re-enable TurboQuant KV under continuous batching by quantizing the completed fp16 prefill cache once (post-prefill), instead of the jundot#717 on-the-fly-during-prefill conversion that corrupted hidden states and was reverted in jundot#771. - Add Scheduler._turboquant_eligible() gate: only dense KVCache (and CacheList of KVCache for VLM) is converted. Chunked/rotating caches (Llama-4, sliding-window) stay fp16 — closes the jundot#771 SIGABRT class. - Call _apply_turboquant_kv_convert() at the end of _do_external_prefill (after boundary snapshots, so paged-SSD format stays fp16). The per request TurboQuantKVCache is turned into a BatchTurboQuantKVCache by mlx-lm _merge_caches() at insert() time via the existing merge patch. - Empty/short-prompt path converts too (empty TQ for fresh, from_cache for restored). - Tests: eligibility gate across cache types; from_cache -> merge -> decode_attention batch path (offset tracking + real attention shape).
Wiring up TurboQuant decode (prev. commit) exposed that mlx-vlm f96138e's TurboQuant decode path produces garbage even at 8-bit. TQ decode was never actually exercised before (conversion was dead code since jundot#771), so these upstream bugs were dormant. Two distinct kernels are broken: 1. Fused single-token quantize (_try_fused_kv_quantize, used only when keys.shape[-2] == 1, i.e. every decode step): ~140% reconstruction error at all bit depths; the non-fused quantize() path (T>=2 prefill) is fine. Fix: _fix_decode_single_token_quantize() forces the non-fused path. -> single-seq TQ output now matches fp16. 2. Masked decode_attention path (taken whenever an array mask is passed, i.e. all B>1 continuous-batching decode for per-request left-padding): ~140% error. Fix: route array-mask decode through dequantize + standard SDPA, the same approach mlx-vlm uses for its own BatchTurboQuantKVCache. B=1 keeps mask=None/causal and the correct fused kernel. -> batched TQ output is now coherent. Both verified numerically (err 140% -> ~1% at 8-bit) and end-to-end on Llama-3.2-1B-Instruct-4bit. Regression tests added for both paths. NOTE: both are upstream mlx-vlm bugs and should be reported there.
Adds tests/test_turboquant_batch_memory.py: compares batched TurboQuant to single-seq on a real model across the three axes requested — - occupancy: KV bytes/token TQ vs fp16 (~0.31x at 4-bit), batch vs single, left-padding waste (analytical, measured at a cache_step-aligned length so over-allocation slack cancels), plus long-context savings projection. - accuracy: concurrent B>1 TQ vs single-seq TQ token match + coherence gate. - peak memory: live peak for single/batch x fp16/TQ (with the honest caveat that at short context the model weights dominate; B>1 TQ dequantizes the batch KV per step so peak is not below batch fp16). Model-gated (skips if not cached); writes tq_batch_memory.md report artifact (gitignored). Also notes in the attention patch that Bug #1 is fixed on mlx-vlm main while Bug jundot#2 (masked decode) is not — the latter is the planned upstream PR.
Validates TurboQuant + paged-SSD now that TQ decode actually engages. Key finding: prefill boundary snapshots are stored fp16 and re-quantized deterministically on a cache hit, so there is NO double-quant (TQ->fp16->TQ) drift — the bespoke __turboquant_v2__ SSD path is not even exercised by the common flow (verified via logs). - single-request: cache hit reproduces fresh exactly (bit-identical) at 4-bit. - batch, 8-bit: hit reproduces fresh exactly -> structural round-trip is sound. - batch, 4-bit: hit may differ by a few tokens where quantization tips a greedy near-tie (fp16+SSD is exact; single TQ is exact), output stays coherent. This residual divergence is the same B>1 dequant sensitivity as Bug 2 and resolves when the upstream masked-decode kernel is fixed. 3 model-gated tests; all skip when the model is not cached.
…rtifacts The L=1 value kernels (_metal_mse_weighted_sum, _metal_mse_weighted_sum_sum_from_scores) undo the codec rotation with matmul(.,rotation) but ignore use_rht (RHT), so they corrupt the masked decode path (~140% err). weighted_sum / weighted_sum_stats_from_scores lack the 'if not self.use_rht' guard that weighted_sum_from_scores has. Patch + PR description for upstream Blaizzy/mlx-vlm (against main).
…n oMLX
Forward the mlx-vlm pin f96138e -> fea81522 (main), which fixes the fused
single-token quantize decode kernel (Bug 1) upstream — so the oMLX
_fix_decode_single_token_quantize workaround is dropped.
Bug 2 (the RHT-incompatible L=1 value kernels corrupting the masked decode
path) is still unmerged upstream, so carry it as an oMLX monkey-patch
(_fix_masked_decode_rht: disable those kernels -> correct einsum fallback).
With the masked path now correct, route B>1 continuous-batching decode
through decode_attention instead of the dequantize+SDPA workaround — no
per-step batch dequantize, and it resolves the batch-4-bit SSD fresh-vs-hit
divergence ([False] -> [True] at 4-bit; verified).
uv.lock is gitignored; regenerate it ('uv lock') and run the full suite in a
controlled env before release. Tests updated for the new routing; 29 TQ tests
pass on HEAD.
BatchTurboQuantKVCache.make_mask hand-rolled a causal term that compared each request's sequence length (offset) against the column index, then ANDed the left_padding term — which masked out the valid left-padded tokens. Left-padded requests in a ragged batch attended to ~nothing and decoded garbage, making batch mode worse than single (fp16's BatchKVCache was unaffected). Delegate to mlx-lm's create_causal_mask(N, offset=phys, left_padding=...), exactly like BatchKVCache, so the masks are identical. After the fix: - ragged-batch token-match to single-seq: 25% -> 71% (== same-length batch); - teacher-forced top-1 agreement single-vs-batch (left-padded member): 12/12, i.e. batch is computationally equivalent to single; residual greedy token divergence is cascade noise, not quality loss. This is an oMLX-only bug (separate from the mlx-vlm RHT PR).
Register a 'turboquant' pytest marker and apply it to the three TQ test files (test_turboquant.py, test_turboquant_batch_memory.py, test_turboquant_ssd.py) so the whole suite runs with 'pytest -m turboquant'. The model-loading files are also marked 'slow' (deselected by default).
de5ece7 to
e7d06d5
Compare
Owner
Author
|
Rebased onto current upstream and retargeting the PR to jundot/omlx. |
popfido
added a commit
that referenced
this pull request
Jun 1, 2026
…orse than single (jundot#1547) * feat(turboquant): wire batched KV conversion (Phase 1) Re-enable TurboQuant KV under continuous batching by quantizing the completed fp16 prefill cache once (post-prefill), instead of the jundot#717 on-the-fly-during-prefill conversion that corrupted hidden states and was reverted in jundot#771. - Add Scheduler._turboquant_eligible() gate: only dense KVCache (and CacheList of KVCache for VLM) is converted. Chunked/rotating caches (Llama-4, sliding-window) stay fp16 — closes the jundot#771 SIGABRT class. - Call _apply_turboquant_kv_convert() at the end of _do_external_prefill (after boundary snapshots, so paged-SSD format stays fp16). The per request TurboQuantKVCache is turned into a BatchTurboQuantKVCache by mlx-lm _merge_caches() at insert() time via the existing merge patch. - Empty/short-prompt path converts too (empty TQ for fresh, from_cache for restored). - Tests: eligibility gate across cache types; from_cache -> merge -> decode_attention batch path (offset tracking + real attention shape). * fix(turboquant): work around two broken mlx-vlm f96138e decode kernels Wiring up TurboQuant decode (prev. commit) exposed that mlx-vlm f96138e's TurboQuant decode path produces garbage even at 8-bit. TQ decode was never actually exercised before (conversion was dead code since jundot#771), so these upstream bugs were dormant. Two distinct kernels are broken: 1. Fused single-token quantize (_try_fused_kv_quantize, used only when keys.shape[-2] == 1, i.e. every decode step): ~140% reconstruction error at all bit depths; the non-fused quantize() path (T>=2 prefill) is fine. Fix: _fix_decode_single_token_quantize() forces the non-fused path. -> single-seq TQ output now matches fp16. 2. Masked decode_attention path (taken whenever an array mask is passed, i.e. all B>1 continuous-batching decode for per-request left-padding): ~140% error. Fix: route array-mask decode through dequantize + standard SDPA, the same approach mlx-vlm uses for its own BatchTurboQuantKVCache. B=1 keeps mask=None/causal and the correct fused kernel. -> batched TQ output is now coherent. Both verified numerically (err 140% -> ~1% at 8-bit) and end-to-end on Llama-3.2-1B-Instruct-4bit. Regression tests added for both paths. NOTE: both are upstream mlx-vlm bugs and should be reported there. * test(turboquant): batched accuracy + memory/occupancy harness (Phase 2) Adds tests/test_turboquant_batch_memory.py: compares batched TurboQuant to single-seq on a real model across the three axes requested — - occupancy: KV bytes/token TQ vs fp16 (~0.31x at 4-bit), batch vs single, left-padding waste (analytical, measured at a cache_step-aligned length so over-allocation slack cancels), plus long-context savings projection. - accuracy: concurrent B>1 TQ vs single-seq TQ token match + coherence gate. - peak memory: live peak for single/batch x fp16/TQ (with the honest caveat that at short context the model weights dominate; B>1 TQ dequantizes the batch KV per step so peak is not below batch fp16). Model-gated (skips if not cached); writes tq_batch_memory.md report artifact (gitignored). Also notes in the attention patch that Bug #1 is fixed on mlx-vlm main while Bug jundot#2 (masked decode) is not — the latter is the planned upstream PR. * test(turboquant): SSD prefix-cache round-trip, single + batch (Phase 3) Validates TurboQuant + paged-SSD now that TQ decode actually engages. Key finding: prefill boundary snapshots are stored fp16 and re-quantized deterministically on a cache hit, so there is NO double-quant (TQ->fp16->TQ) drift — the bespoke __turboquant_v2__ SSD path is not even exercised by the common flow (verified via logs). - single-request: cache hit reproduces fresh exactly (bit-identical) at 4-bit. - batch, 8-bit: hit reproduces fresh exactly -> structural round-trip is sound. - batch, 4-bit: hit may differ by a few tokens where quantization tips a greedy near-tie (fp16+SSD is exact; single TQ is exact), output stays coherent. This residual divergence is the same B>1 dequant sensitivity as Bug 2 and resolves when the upstream masked-decode kernel is fixed. 3 model-gated tests; all skip when the model is not cached. * docs(upstream): mlx-vlm TurboQuant RHT masked-decode fix (Bug 2) PR artifacts The L=1 value kernels (_metal_mse_weighted_sum, _metal_mse_weighted_sum_sum_from_scores) undo the codec rotation with matmul(.,rotation) but ignore use_rht (RHT), so they corrupt the masked decode path (~140% err). weighted_sum / weighted_sum_stats_from_scores lack the 'if not self.use_rht' guard that weighted_sum_from_scores has. Patch + PR description for upstream Blaizzy/mlx-vlm (against main). * fix(turboquant): forward mlx-vlm to HEAD + land Bug-2 masked-decode in oMLX Forward the mlx-vlm pin f96138e -> fea81522 (main), which fixes the fused single-token quantize decode kernel (Bug 1) upstream — so the oMLX _fix_decode_single_token_quantize workaround is dropped. Bug 2 (the RHT-incompatible L=1 value kernels corrupting the masked decode path) is still unmerged upstream, so carry it as an oMLX monkey-patch (_fix_masked_decode_rht: disable those kernels -> correct einsum fallback). With the masked path now correct, route B>1 continuous-batching decode through decode_attention instead of the dequantize+SDPA workaround — no per-step batch dequantize, and it resolves the batch-4-bit SSD fresh-vs-hit divergence ([False] -> [True] at 4-bit; verified). uv.lock is gitignored; regenerate it ('uv lock') and run the full suite in a controlled env before release. Tests updated for the new routing; 29 TQ tests pass on HEAD. * fix(turboquant): correct B>1 make_mask for left-padded batches BatchTurboQuantKVCache.make_mask hand-rolled a causal term that compared each request's sequence length (offset) against the column index, then ANDed the left_padding term — which masked out the valid left-padded tokens. Left-padded requests in a ragged batch attended to ~nothing and decoded garbage, making batch mode worse than single (fp16's BatchKVCache was unaffected). Delegate to mlx-lm's create_causal_mask(N, offset=phys, left_padding=...), exactly like BatchKVCache, so the masks are identical. After the fix: - ragged-batch token-match to single-seq: 25% -> 71% (== same-length batch); - teacher-forced top-1 agreement single-vs-batch (left-padded member): 12/12, i.e. batch is computationally equivalent to single; residual greedy token divergence is cascade noise, not quality loss. This is an oMLX-only bug (separate from the mlx-vlm RHT PR). * test(turboquant): group TQ tests under a 'turboquant' marker Register a 'turboquant' pytest marker and apply it to the three TQ test files (test_turboquant.py, test_turboquant_batch_memory.py, test_turboquant_ssd.py) so the whole suite runs with 'pytest -m turboquant'. The model-loading files are also marked 'slow' (deselected by default). * test(turboquant): split semicolon statements in mask test (E702) * chore(turboquant): drop RHT monkey-patch; pin merged mlx-vlm (Bug 2 upstream) Blaizzy/mlx-vlm#1244 (the RHT masked-decode 'not use_rht' guard) is merged. Bump the pin fea81522 -> 6f60ee4 (includes the merge) and delete the interim _fix_masked_decode_rht monkey-patch — B>1 masked decode now relies on the upstream fix. Removed the docs/upstream PR artifacts (PR is merged). Verified: masked decode 1.2% with no patch; 26 TQ tests pass; single/batch coherent.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Makes TurboQuant KV-cache compression actually work in oMLX — in both single-sequence
and continuous-batching (B>1) decode — for the first time. TQ decode was previously
dead code (the conversion was reverted in jundot#771 and never re-enabled), so
turboquant_kv_enabledsilently ran fp16. Wiring it up exposed and fixed three real bugs(two upstream in mlx-vlm, one in oMLX), and the result is proven no worse than single
mode on every axis while saving ~69% KV memory.
What changed (runtime)
omlx/scheduler.py— quantize the completed fp16 prefill cache once, post-prefill(not on-the-fly, which corrupted hidden states in fix: wire up TurboQuant KV cache conversion + skip sensitive last layer (#661) jundot/omlx#717/0.3.5-rc1: scheduler._do_external_prefill applies TurboQuantKV during prefill → garbage output & SIGABRT on quantized-KV models jundot/omlx#771), gated to dense-KVCache
models via a new
_turboquant_eligible()(chunked/rotating models stay fp16 — no crash).Fully TQ-gated: for non-TQ models these paths are a no-op (zero behavioural change).
omlx/patches/turboquant_attention.py— route B>1 decode through the quantizeddecode_attention, plus a monkey-patch (_fix_masked_decode_rht) for the unmergedmlx-vlm Bug 2 (below).
omlx/turboquant_kv.py— fixBatchTurboQuantKVCache.make_maskfor left-paddedbatches (it was masking out valid tokens, so left-padded requests decoded garbage); now
delegates to mlx-lm
create_causal_mask, identical toBatchKVCache.Upstream dependency (mlx-vlm)
Wiring TQ decode exposed two mlx-vlm kernel bugs (both ~140% error, latent because TQ
decode never ran):
main; this PRforwards the pin
f96138e → fea81522to pick it up (the full default test suitepasses on it — see below).
turboquant: guard L=1 value kernels behind
not use_rht(fix masked decode under RHT) Blaizzy/mlx-vlm#1244. Carried here as the_fix_masked_decode_rhtmonkey-patchuntil it merges; delete the patch + bump the pin once it lands. Patch + write-up in
docs/upstream/.Proof: batch is no worse than single
Measured on
Llama-3.2-1B-Instruct-4bit, TQ 4-bit:¹ Same context, no greedy cascade. Raw greedy token-match is lower but that's cascade
noise, not quality (each step's prediction is identical — the 97%). Before the make_mask
fix, ragged-batch was genuinely broken (garbage).
Scope / impact validation
scheduler.pychanges are TQ-gated (no-op for non-TQ).failures are
tests/test_dflash_engine.pyfrom a pre-existing staledflash-mlxdependency (
ModuleNotFoundError: dflash_mlx.runtime.config), unrelated to this change.Tests
TQ tests are grouped under a
turboquantmarker — run withpytest -m turboquant(30tests across
test_turboquant.py,test_turboquant_batch_memory.py,test_turboquant_ssd.py;the model-loading ones are also
slow).Notes for the merger
uv.lockis gitignored; regenerate withuv lockin a controlled env.not use_rht(fix masked decode under RHT) Blaizzy/mlx-vlm#1244 merges: drop_fix_masked_decode_rht, bump the pin.