feat(turboquant): batched KV-cache compression (single + batch), no worse than single#1547
Conversation
Re-enable TurboQuant KV under continuous batching by quantizing the completed fp16 prefill cache once (post-prefill), instead of the jundot#717 on-the-fly-during-prefill conversion that corrupted hidden states and was reverted in jundot#771. - Add Scheduler._turboquant_eligible() gate: only dense KVCache (and CacheList of KVCache for VLM) is converted. Chunked/rotating caches (Llama-4, sliding-window) stay fp16 — closes the jundot#771 SIGABRT class. - Call _apply_turboquant_kv_convert() at the end of _do_external_prefill (after boundary snapshots, so paged-SSD format stays fp16). The per request TurboQuantKVCache is turned into a BatchTurboQuantKVCache by mlx-lm _merge_caches() at insert() time via the existing merge patch. - Empty/short-prompt path converts too (empty TQ for fresh, from_cache for restored). - Tests: eligibility gate across cache types; from_cache -> merge -> decode_attention batch path (offset tracking + real attention shape).
Wiring up TurboQuant decode (prev. commit) exposed that mlx-vlm f96138e's TurboQuant decode path produces garbage even at 8-bit. TQ decode was never actually exercised before (conversion was dead code since jundot#771), so these upstream bugs were dormant. Two distinct kernels are broken: 1. Fused single-token quantize (_try_fused_kv_quantize, used only when keys.shape[-2] == 1, i.e. every decode step): ~140% reconstruction error at all bit depths; the non-fused quantize() path (T>=2 prefill) is fine. Fix: _fix_decode_single_token_quantize() forces the non-fused path. -> single-seq TQ output now matches fp16. 2. Masked decode_attention path (taken whenever an array mask is passed, i.e. all B>1 continuous-batching decode for per-request left-padding): ~140% error. Fix: route array-mask decode through dequantize + standard SDPA, the same approach mlx-vlm uses for its own BatchTurboQuantKVCache. B=1 keeps mask=None/causal and the correct fused kernel. -> batched TQ output is now coherent. Both verified numerically (err 140% -> ~1% at 8-bit) and end-to-end on Llama-3.2-1B-Instruct-4bit. Regression tests added for both paths. NOTE: both are upstream mlx-vlm bugs and should be reported there.
Adds tests/test_turboquant_batch_memory.py: compares batched TurboQuant to single-seq on a real model across the three axes requested — - occupancy: KV bytes/token TQ vs fp16 (~0.31x at 4-bit), batch vs single, left-padding waste (analytical, measured at a cache_step-aligned length so over-allocation slack cancels), plus long-context savings projection. - accuracy: concurrent B>1 TQ vs single-seq TQ token match + coherence gate. - peak memory: live peak for single/batch x fp16/TQ (with the honest caveat that at short context the model weights dominate; B>1 TQ dequantizes the batch KV per step so peak is not below batch fp16). Model-gated (skips if not cached); writes tq_batch_memory.md report artifact (gitignored). Also notes in the attention patch that Bug #1 is fixed on mlx-vlm main while Bug jundot#2 (masked decode) is not — the latter is the planned upstream PR.
Validates TurboQuant + paged-SSD now that TQ decode actually engages. Key finding: prefill boundary snapshots are stored fp16 and re-quantized deterministically on a cache hit, so there is NO double-quant (TQ->fp16->TQ) drift — the bespoke __turboquant_v2__ SSD path is not even exercised by the common flow (verified via logs). - single-request: cache hit reproduces fresh exactly (bit-identical) at 4-bit. - batch, 8-bit: hit reproduces fresh exactly -> structural round-trip is sound. - batch, 4-bit: hit may differ by a few tokens where quantization tips a greedy near-tie (fp16+SSD is exact; single TQ is exact), output stays coherent. This residual divergence is the same B>1 dequant sensitivity as Bug 2 and resolves when the upstream masked-decode kernel is fixed. 3 model-gated tests; all skip when the model is not cached.
…rtifacts The L=1 value kernels (_metal_mse_weighted_sum, _metal_mse_weighted_sum_sum_from_scores) undo the codec rotation with matmul(.,rotation) but ignore use_rht (RHT), so they corrupt the masked decode path (~140% err). weighted_sum / weighted_sum_stats_from_scores lack the 'if not self.use_rht' guard that weighted_sum_from_scores has. Patch + PR description for upstream Blaizzy/mlx-vlm (against main).
…n oMLX
Forward the mlx-vlm pin f96138e -> fea81522 (main), which fixes the fused
single-token quantize decode kernel (Bug 1) upstream — so the oMLX
_fix_decode_single_token_quantize workaround is dropped.
Bug 2 (the RHT-incompatible L=1 value kernels corrupting the masked decode
path) is still unmerged upstream, so carry it as an oMLX monkey-patch
(_fix_masked_decode_rht: disable those kernels -> correct einsum fallback).
With the masked path now correct, route B>1 continuous-batching decode
through decode_attention instead of the dequantize+SDPA workaround — no
per-step batch dequantize, and it resolves the batch-4-bit SSD fresh-vs-hit
divergence ([False] -> [True] at 4-bit; verified).
uv.lock is gitignored; regenerate it ('uv lock') and run the full suite in a
controlled env before release. Tests updated for the new routing; 29 TQ tests
pass on HEAD.
BatchTurboQuantKVCache.make_mask hand-rolled a causal term that compared each request's sequence length (offset) against the column index, then ANDed the left_padding term — which masked out the valid left-padded tokens. Left-padded requests in a ragged batch attended to ~nothing and decoded garbage, making batch mode worse than single (fp16's BatchKVCache was unaffected). Delegate to mlx-lm's create_causal_mask(N, offset=phys, left_padding=...), exactly like BatchKVCache, so the masks are identical. After the fix: - ragged-batch token-match to single-seq: 25% -> 71% (== same-length batch); - teacher-forced top-1 agreement single-vs-batch (left-padded member): 12/12, i.e. batch is computationally equivalent to single; residual greedy token divergence is cascade noise, not quality loss. This is an oMLX-only bug (separate from the mlx-vlm RHT PR).
Register a 'turboquant' pytest marker and apply it to the three TQ test files (test_turboquant.py, test_turboquant_batch_memory.py, test_turboquant_ssd.py) so the whole suite runs with 'pytest -m turboquant'. The model-loading files are also marked 'slow' (deselected by default).
…pstream) Blaizzy/mlx-vlm#1244 (the RHT masked-decode 'not use_rht' guard) is merged. Bump the pin fea81522 -> 6f60ee4 (includes the merge) and delete the interim _fix_masked_decode_rht monkey-patch — B>1 masked decode now relies on the upstream fix. Removed the docs/upstream PR artifacts (PR is merged). Verified: masked decode 1.2% with no patch; 26 TQ tests pass; single/batch coherent.
|
Update: the mlx-vlm RHT masked-decode fix (Blaizzy/mlx-vlm#1244) is merged. Bumped the pin |
|
Thanks, this is useful and the main TQ batch path looks good. I found one narrow mixed empty/non-empty batch crash in BatchTurboQuantKVCache.merge(): an empty TQ row is skipped while the batch metadata still counts it, so the next decode append sees B=2 inputs with B=1 quantized state. It is straightforward enough that I'll merge this and fold the fix into an immediate follow-up on main. |
|
One follow-up note: this also bumps mlx-vlm from fea81522 to 6f60ee4, so I'll do a separate pass over the mlx-vlm-related monkey patches on main. I'll check whether any local patches are now covered upstream by the new pin and remove or adjust them in a follow-up commit where appropriate. |
That‘s really helpful for available TQ for batch mode inference. Glad to talk about any edge case I've ever missed. |
That's also what I mean. Next we'll need to wait for updated mlx-vlm version that including HEAD fixes so that we can bump to a released stable version for mlx-vlm dependency. I'll take a cautious look on it. I'm also raising performance optimization for TQ batch in mlx-vlm, which may benefit if I can catch up with recent release of mlx-vlm, so that the TQ batch mode in oMLX would enjoy better performance than single seq model far more than just 1.5x. |
Summary
Makes TurboQuant KV-cache compression actually work in oMLX — in both single-sequence
and continuous-batching (B>1) decode — for the first time. TQ decode was previously
dead code (the conversion was reverted in #771 and never re-enabled), so
turboquant_kv_enabledsilently ran fp16. Wiring it up exposed and fixed three real bugs(two upstream in mlx-vlm — now both merged — and one in oMLX), and the result is proven
no worse than single mode on every axis while saving ~69% KV memory.
What changed (runtime)
omlx/scheduler.py— quantize the completed fp16 prefill cache once, post-prefill(not on-the-fly, which corrupted hidden states in fix: wire up TurboQuant KV cache conversion + skip sensitive last layer (#661) #717/0.3.5-rc1: scheduler._do_external_prefill applies TurboQuantKV during prefill → garbage output & SIGABRT on quantized-KV models #771), gated to dense-KVCache
models via a new
_turboquant_eligible()(chunked/rotating models stay fp16 — no crash).Fully TQ-gated: for non-TQ models these paths are a no-op (zero behavioural change).
omlx/patches/turboquant_attention.py— route B>1 decode through the quantizeddecode_attention.omlx/turboquant_kv.py— fixBatchTurboQuantKVCache.make_maskfor left-paddedbatches (it was masking out valid tokens, so left-padded requests decoded garbage); now
delegates to mlx-lm
create_causal_mask, identical toBatchKVCache.Upstream dependency (mlx-vlm) — both fixes now merged
Wiring TQ decode exposed two mlx-vlm kernel bugs (both ~140% error, latent because TQ
decode never ran), both now fixed upstream:
not use_rht(fix masked decode under RHT) Blaizzy/mlx-vlm#1244, merged.This PR pins mlx-vlm
6f60ee4(which contains both fixes), so oMLX needs no workaround —B>1 masked decode relies on the upstream fix directly.
Proof: batch is no worse than single
Measured on
Llama-3.2-1B-Instruct-4bit, TQ 4-bit:¹ Same forced context (no greedy cascade), n=210. At the rare single-vs-batch disagreements
(1.4%) it's a coin-flip which matches fp16 (2 vs 1) — batch is not systematically worse.
Scope / impact validation
scheduler.pychanges are TQ-gated (no-op for non-TQ).stale
dflash-mlxdependency, unrelated); mlx-vlm-dependent modules re-validated on6f60ee4.Tests
TQ tests are grouped under a
turboquantmarker — run withpytest -m turboquant(30tests across
test_turboquant.py,test_turboquant_batch_memory.py,test_turboquant_ssd.py;the model-loading ones are also
slow).Notes for the merger
uv.lockis gitignored; regenerate withuv lockin a controlled env.