Skip to content

feat(turboquant): batched KV-cache compression (single + batch), no worse than single#1547

Merged
jundot merged 10 commits into
jundot:mainfrom
popfido:feat/turboquant-batch-mode
May 31, 2026
Merged

feat(turboquant): batched KV-cache compression (single + batch), no worse than single#1547
jundot merged 10 commits into
jundot:mainfrom
popfido:feat/turboquant-batch-mode

Conversation

@popfido
Copy link
Copy Markdown
Contributor

@popfido popfido commented May 30, 2026

Summary

Makes TurboQuant KV-cache compression actually work in oMLX — in both single-sequence
and continuous-batching (B>1) decode
— for the first time. TQ decode was previously
dead code (the conversion was reverted in #771 and never re-enabled), so
turboquant_kv_enabled silently ran fp16. Wiring it up exposed and fixed three real bugs
(two upstream in mlx-vlm — now both merged — and one in oMLX), and the result is proven
no worse than single mode on every axis while saving ~69% KV memory.

What changed (runtime)

Upstream dependency (mlx-vlm) — both fixes now merged

Wiring TQ decode exposed two mlx-vlm kernel bugs (both ~140% error, latent because TQ
decode never ran), both now fixed upstream:

  1. Bug 1 — fused single-token quantize kernel.
  2. Bug 2 — RHT-incompatible L=1 value kernels (masked decode): turboquant: guard L=1 value kernels behind not use_rht (fix masked decode under RHT) Blaizzy/mlx-vlm#1244, merged.

This PR pins mlx-vlm 6f60ee4 (which contains both fixes), so oMLX needs no workaround
B>1 masked decode relies on the upstream fix directly.

Proof: batch is no worse than single

Measured on Llama-3.2-1B-Instruct-4bit, TQ 4-bit:

axis single (B=1) batch (B>1) verdict
quality — teacher-forced agreement vs fp16¹ 96.2% 95.7% ✅ ≈ equal (symmetric near-ties)
speed — decode throughput 35.9 tok/s 53.2 tok/s ✅ 1.5× faster
memory — KV bytes/token 0.31× fp16 0.31× fp16 ✅ equal

¹ Same forced context (no greedy cascade), n=210. At the rare single-vs-batch disagreements
(1.4%) it's a coin-flip which matches fp16 (2 vs 1) — batch is not systematically worse.

Scope / impact validation

  • Runtime blast radius is 3 files; scheduler.py changes are TQ-gated (no-op for non-TQ).
  • Full default unit suite on the pin: 4754 passed (the only failures are a pre-existing
    stale dflash-mlx dependency, unrelated); mlx-vlm-dependent modules re-validated on 6f60ee4.

Tests

TQ tests are grouped under a turboquant marker — run with pytest -m turboquant (30
tests across test_turboquant.py, test_turboquant_batch_memory.py, test_turboquant_ssd.py;
the model-loading ones are also slow).

Notes for the merger

  • uv.lock is gitignored; regenerate with uv lock in a controlled env.

popfido added 10 commits May 30, 2026 22:35
Re-enable TurboQuant KV under continuous batching by quantizing the
completed fp16 prefill cache once (post-prefill), instead of the jundot#717
on-the-fly-during-prefill conversion that corrupted hidden states and
was reverted in jundot#771.

- Add Scheduler._turboquant_eligible() gate: only dense KVCache (and
  CacheList of KVCache for VLM) is converted. Chunked/rotating caches
  (Llama-4, sliding-window) stay fp16 — closes the jundot#771 SIGABRT class.
- Call _apply_turboquant_kv_convert() at the end of _do_external_prefill
  (after boundary snapshots, so paged-SSD format stays fp16). The per
  request TurboQuantKVCache is turned into a BatchTurboQuantKVCache by
  mlx-lm _merge_caches() at insert() time via the existing merge patch.
- Empty/short-prompt path converts too (empty TQ for fresh, from_cache
  for restored).
- Tests: eligibility gate across cache types; from_cache -> merge ->
  decode_attention batch path (offset tracking + real attention shape).
Wiring up TurboQuant decode (prev. commit) exposed that mlx-vlm f96138e's
TurboQuant decode path produces garbage even at 8-bit. TQ decode was never
actually exercised before (conversion was dead code since jundot#771), so these
upstream bugs were dormant. Two distinct kernels are broken:

1. Fused single-token quantize (_try_fused_kv_quantize, used only when
   keys.shape[-2] == 1, i.e. every decode step): ~140% reconstruction error
   at all bit depths; the non-fused quantize() path (T>=2 prefill) is fine.
   Fix: _fix_decode_single_token_quantize() forces the non-fused path.
   -> single-seq TQ output now matches fp16.

2. Masked decode_attention path (taken whenever an array mask is passed,
   i.e. all B>1 continuous-batching decode for per-request left-padding):
   ~140% error. Fix: route array-mask decode through dequantize + standard
   SDPA, the same approach mlx-vlm uses for its own BatchTurboQuantKVCache.
   B=1 keeps mask=None/causal and the correct fused kernel.
   -> batched TQ output is now coherent.

Both verified numerically (err 140% -> ~1% at 8-bit) and end-to-end on
Llama-3.2-1B-Instruct-4bit. Regression tests added for both paths.

NOTE: both are upstream mlx-vlm bugs and should be reported there.
Adds tests/test_turboquant_batch_memory.py: compares batched TurboQuant to
single-seq on a real model across the three axes requested —

- occupancy: KV bytes/token TQ vs fp16 (~0.31x at 4-bit), batch vs single,
  left-padding waste (analytical, measured at a cache_step-aligned length so
  over-allocation slack cancels), plus long-context savings projection.
- accuracy: concurrent B>1 TQ vs single-seq TQ token match + coherence gate.
- peak memory: live peak for single/batch x fp16/TQ (with the honest caveat
  that at short context the model weights dominate; B>1 TQ dequantizes the
  batch KV per step so peak is not below batch fp16).

Model-gated (skips if not cached); writes tq_batch_memory.md report artifact
(gitignored). Also notes in the attention patch that Bug #1 is fixed on
mlx-vlm main while Bug jundot#2 (masked decode) is not — the latter is the planned
upstream PR.
Validates TurboQuant + paged-SSD now that TQ decode actually engages. Key
finding: prefill boundary snapshots are stored fp16 and re-quantized
deterministically on a cache hit, so there is NO double-quant (TQ->fp16->TQ)
drift — the bespoke __turboquant_v2__ SSD path is not even exercised by the
common flow (verified via logs).

- single-request: cache hit reproduces fresh exactly (bit-identical) at 4-bit.
- batch, 8-bit: hit reproduces fresh exactly -> structural round-trip is sound.
- batch, 4-bit: hit may differ by a few tokens where quantization tips a greedy
  near-tie (fp16+SSD is exact; single TQ is exact), output stays coherent. This
  residual divergence is the same B>1 dequant sensitivity as Bug 2 and resolves
  when the upstream masked-decode kernel is fixed.

3 model-gated tests; all skip when the model is not cached.
…rtifacts

The L=1 value kernels (_metal_mse_weighted_sum,
_metal_mse_weighted_sum_sum_from_scores) undo the codec rotation with
matmul(.,rotation) but ignore use_rht (RHT), so they corrupt the masked
decode path (~140% err). weighted_sum / weighted_sum_stats_from_scores lack
the 'if not self.use_rht' guard that weighted_sum_from_scores has. Patch +
PR description for upstream Blaizzy/mlx-vlm (against main).
…n oMLX

Forward the mlx-vlm pin f96138e -> fea81522 (main), which fixes the fused
single-token quantize decode kernel (Bug 1) upstream — so the oMLX
_fix_decode_single_token_quantize workaround is dropped.

Bug 2 (the RHT-incompatible L=1 value kernels corrupting the masked decode
path) is still unmerged upstream, so carry it as an oMLX monkey-patch
(_fix_masked_decode_rht: disable those kernels -> correct einsum fallback).
With the masked path now correct, route B>1 continuous-batching decode
through decode_attention instead of the dequantize+SDPA workaround — no
per-step batch dequantize, and it resolves the batch-4-bit SSD fresh-vs-hit
divergence ([False] -> [True] at 4-bit; verified).

uv.lock is gitignored; regenerate it ('uv lock') and run the full suite in a
controlled env before release. Tests updated for the new routing; 29 TQ tests
pass on HEAD.
BatchTurboQuantKVCache.make_mask hand-rolled a causal term that compared each
request's sequence length (offset) against the column index, then ANDed the
left_padding term — which masked out the valid left-padded tokens. Left-padded
requests in a ragged batch attended to ~nothing and decoded garbage, making
batch mode worse than single (fp16's BatchKVCache was unaffected).

Delegate to mlx-lm's create_causal_mask(N, offset=phys, left_padding=...),
exactly like BatchKVCache, so the masks are identical. After the fix:
- ragged-batch token-match to single-seq: 25% -> 71% (== same-length batch);
- teacher-forced top-1 agreement single-vs-batch (left-padded member): 12/12,
  i.e. batch is computationally equivalent to single; residual greedy token
  divergence is cascade noise, not quality loss.

This is an oMLX-only bug (separate from the mlx-vlm RHT PR).
Register a 'turboquant' pytest marker and apply it to the three TQ test
files (test_turboquant.py, test_turboquant_batch_memory.py,
test_turboquant_ssd.py) so the whole suite runs with 'pytest -m turboquant'.
The model-loading files are also marked 'slow' (deselected by default).
…pstream)

Blaizzy/mlx-vlm#1244 (the RHT masked-decode 'not use_rht' guard) is merged.
Bump the pin fea81522 -> 6f60ee4 (includes the merge) and delete the interim
_fix_masked_decode_rht monkey-patch — B>1 masked decode now relies on the
upstream fix. Removed the docs/upstream PR artifacts (PR is merged). Verified:
masked decode 1.2% with no patch; 26 TQ tests pass; single/batch coherent.
@popfido
Copy link
Copy Markdown
Contributor Author

popfido commented May 31, 2026

Update: the mlx-vlm RHT masked-decode fix (Blaizzy/mlx-vlm#1244) is merged. Bumped the pin fea81522 → 6f60ee4 (contains both TurboQuant decode fixes) and removed the interim _fix_masked_decode_rht monkey-patch + the docs/upstream/ artifacts — B>1 masked decode now relies on the upstream fix directly. Re-verified: masked decode 1.2% with no patch, 26 TQ tests + 86 mlx-vlm-dependent tests pass, single/batch coherent.

@jundot
Copy link
Copy Markdown
Owner

jundot commented May 31, 2026

Thanks, this is useful and the main TQ batch path looks good. I found one narrow mixed empty/non-empty batch crash in BatchTurboQuantKVCache.merge(): an empty TQ row is skipped while the batch metadata still counts it, so the next decode append sees B=2 inputs with B=1 quantized state. It is straightforward enough that I'll merge this and fold the fix into an immediate follow-up on main.

@jundot jundot merged commit ffb48b7 into jundot:main May 31, 2026
@jundot
Copy link
Copy Markdown
Owner

jundot commented May 31, 2026

One follow-up note: this also bumps mlx-vlm from fea81522 to 6f60ee4, so I'll do a separate pass over the mlx-vlm-related monkey patches on main. I'll check whether any local patches are now covered upstream by the new pin and remove or adjust them in a follow-up commit where appropriate.

@popfido
Copy link
Copy Markdown
Contributor Author

popfido commented May 31, 2026

Thanks, this is useful and the main TQ batch path looks good. I found one narrow mixed empty/non-empty batch crash in BatchTurboQuantKVCache.merge(): an empty TQ row is skipped while the batch metadata still counts it, so the next decode append sees B=2 inputs with B=1 quantized state. It is straightforward enough that I'll merge this and fold the fix into an immediate follow-up on main.

That‘s really helpful for available TQ for batch mode inference. Glad to talk about any edge case I've ever missed.

@popfido
Copy link
Copy Markdown
Contributor Author

popfido commented May 31, 2026

One follow-up note: this also bumps mlx-vlm from fea81522 to 6f60ee4, so I'll do a separate pass over the mlx-vlm-related monkey patches on main. I'll check whether any local patches are now covered upstream by the new pin and remove or adjust them in a follow-up commit where appropriate.

That's also what I mean. Next we'll need to wait for updated mlx-vlm version that including HEAD fixes so that we can bump to a released stable version for mlx-vlm dependency. I'll take a cautious look on it. I'm also raising performance optimization for TQ batch in mlx-vlm, which may benefit if I can catch up with recent release of mlx-vlm, so that the TQ batch mode in oMLX would enjoy better performance than single seq model far more than just 1.5x.

@popfido popfido deleted the feat/turboquant-batch-mode branch June 1, 2026 07:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants