Skip to content

turboquant: guard L=1 value kernels behind not use_rht (fix masked decode under RHT)#1244

Merged
Blaizzy merged 2 commits into
Blaizzy:mainfrom
popfido:fix/turboquant-rht-masked-decode
May 30, 2026
Merged

turboquant: guard L=1 value kernels behind not use_rht (fix masked decode under RHT)#1244
Blaizzy merged 2 commits into
Blaizzy:mainfrom
popfido:fix/turboquant-rht-masked-decode

Conversation

@popfido
Copy link
Copy Markdown
Contributor

@popfido popfido commented May 30, 2026

Summary

_TurboQuantMSECodec.weighted_sum and weighted_sum_stats_from_scores call the
L=1 value-reconstruction Metal kernels (_metal_mse_weighted_sum,
_metal_mse_weighted_sum_sum_from_scores) without the if not self.use_rht
guard that the sibling weighted_sum_from_scores already has.

Those kernels finish with matmul(weighted_rot, rotation) to undo the codec
rotation — correct only for a plain rotation. _TurboQuantMSECodec defaults to
use_rht=True (randomized Hadamard transform), whose inverse is
_rht_inverse(.; signs), not matmul(.; rotation). So under RHT these kernels
apply the wrong inverse transform and return essentially uncorrelated output
(~140% reconstruction error at every bit depth, 2–8).

Impact

  • The slow / masked single-query (L=1) decode_attention path is corrupt.
  • It's latent for the common decode path (which uses the fused
    _fused_mse_decode_kernel when mask is None/"causal"); it only surfaces
    when an array mask forces the slow path — e.g. continuous-batching decode
    with per-request left-padding (B > 1), which then produces garbage.

Fix

Add not self.use_rht and to the two L=1 guards, mirroring the existing
weighted_sum_from_scores. Under RHT this takes the correct
einsum + _rotate_inverse fallback; with a plain rotation (use_rht=False) the
kernels still run.

Verification

# _TurboQuantMSECodec, 8-bit, single-query decode through the masked path
array-mask decode error:  before = 140.0%   after = 1.2%

End-to-end on mlx-community/Llama-3.2-1B-Instruct-4bit, continuous-batching
decode (B>1, left-padded) produces coherent output after the fix; before it is
garbage.

Note

Conservative fix (fall back to the correct math) matching the existing
weighted_sum_from_scores behavior. A deeper fix would teach the kernels the
RHT inverse so the RHT path could keep using the fast kernels.

The L=1 value-reconstruction Metal kernels (_metal_mse_weighted_sum,
_metal_mse_weighted_sum_sum_from_scores) undo the codec rotation with
matmul(weighted_rot, rotation), which is only the inverse for a plain
rotation. _TurboQuantMSECodec defaults to use_rht=True (randomized Hadamard
transform) whose inverse is _rht_inverse(.; signs), so under RHT these kernels
return uncorrelated output (~140% reconstruction error at every bit depth).

weighted_sum_from_scores already guards its kernel with `if not self.use_rht`;
weighted_sum and weighted_sum_stats_from_scores did not, corrupting the slow
(masked / array-mask) single-query decode path used by continuous-batching
decode with per-request left-padding. Add the same guard so RHT falls back to
the correct einsum + _rotate_inverse path. Verified 140% -> ~1%.
Copy link
Copy Markdown
Owner

@Blaizzy Blaizzy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
Thanks, will merge for now but will revisit it later

@Blaizzy Blaizzy merged commit 3ef17cb into Blaizzy:main May 30, 2026
1 check passed
popfido added a commit to popfido/omlx that referenced this pull request May 31, 2026
…pstream)

Blaizzy/mlx-vlm#1244 (the RHT masked-decode 'not use_rht' guard) is merged.
Bump the pin fea81522 -> 6f60ee4 (includes the merge) and delete the interim
_fix_masked_decode_rht monkey-patch — B>1 masked decode now relies on the
upstream fix. Removed the docs/upstream PR artifacts (PR is merged). Verified:
masked decode 1.2% with no patch; 26 TQ tests pass; single/batch coherent.
jundot pushed a commit to jundot/omlx that referenced this pull request May 31, 2026
…orse than single (#1547)

* feat(turboquant): wire batched KV conversion (Phase 1)

Re-enable TurboQuant KV under continuous batching by quantizing the
completed fp16 prefill cache once (post-prefill), instead of the #717
on-the-fly-during-prefill conversion that corrupted hidden states and
was reverted in #771.

- Add Scheduler._turboquant_eligible() gate: only dense KVCache (and
  CacheList of KVCache for VLM) is converted. Chunked/rotating caches
  (Llama-4, sliding-window) stay fp16 — closes the #771 SIGABRT class.
- Call _apply_turboquant_kv_convert() at the end of _do_external_prefill
  (after boundary snapshots, so paged-SSD format stays fp16). The per
  request TurboQuantKVCache is turned into a BatchTurboQuantKVCache by
  mlx-lm _merge_caches() at insert() time via the existing merge patch.
- Empty/short-prompt path converts too (empty TQ for fresh, from_cache
  for restored).
- Tests: eligibility gate across cache types; from_cache -> merge ->
  decode_attention batch path (offset tracking + real attention shape).

* fix(turboquant): work around two broken mlx-vlm f96138e decode kernels

Wiring up TurboQuant decode (prev. commit) exposed that mlx-vlm f96138e's
TurboQuant decode path produces garbage even at 8-bit. TQ decode was never
actually exercised before (conversion was dead code since #771), so these
upstream bugs were dormant. Two distinct kernels are broken:

1. Fused single-token quantize (_try_fused_kv_quantize, used only when
   keys.shape[-2] == 1, i.e. every decode step): ~140% reconstruction error
   at all bit depths; the non-fused quantize() path (T>=2 prefill) is fine.
   Fix: _fix_decode_single_token_quantize() forces the non-fused path.
   -> single-seq TQ output now matches fp16.

2. Masked decode_attention path (taken whenever an array mask is passed,
   i.e. all B>1 continuous-batching decode for per-request left-padding):
   ~140% error. Fix: route array-mask decode through dequantize + standard
   SDPA, the same approach mlx-vlm uses for its own BatchTurboQuantKVCache.
   B=1 keeps mask=None/causal and the correct fused kernel.
   -> batched TQ output is now coherent.

Both verified numerically (err 140% -> ~1% at 8-bit) and end-to-end on
Llama-3.2-1B-Instruct-4bit. Regression tests added for both paths.

NOTE: both are upstream mlx-vlm bugs and should be reported there.

* test(turboquant): batched accuracy + memory/occupancy harness (Phase 2)

Adds tests/test_turboquant_batch_memory.py: compares batched TurboQuant to
single-seq on a real model across the three axes requested —

- occupancy: KV bytes/token TQ vs fp16 (~0.31x at 4-bit), batch vs single,
  left-padding waste (analytical, measured at a cache_step-aligned length so
  over-allocation slack cancels), plus long-context savings projection.
- accuracy: concurrent B>1 TQ vs single-seq TQ token match + coherence gate.
- peak memory: live peak for single/batch x fp16/TQ (with the honest caveat
  that at short context the model weights dominate; B>1 TQ dequantizes the
  batch KV per step so peak is not below batch fp16).

Model-gated (skips if not cached); writes tq_batch_memory.md report artifact
(gitignored). Also notes in the attention patch that Bug #1 is fixed on
mlx-vlm main while Bug #2 (masked decode) is not — the latter is the planned
upstream PR.

* test(turboquant): SSD prefix-cache round-trip, single + batch (Phase 3)

Validates TurboQuant + paged-SSD now that TQ decode actually engages. Key
finding: prefill boundary snapshots are stored fp16 and re-quantized
deterministically on a cache hit, so there is NO double-quant (TQ->fp16->TQ)
drift — the bespoke __turboquant_v2__ SSD path is not even exercised by the
common flow (verified via logs).

- single-request: cache hit reproduces fresh exactly (bit-identical) at 4-bit.
- batch, 8-bit: hit reproduces fresh exactly -> structural round-trip is sound.
- batch, 4-bit: hit may differ by a few tokens where quantization tips a greedy
  near-tie (fp16+SSD is exact; single TQ is exact), output stays coherent. This
  residual divergence is the same B>1 dequant sensitivity as Bug 2 and resolves
  when the upstream masked-decode kernel is fixed.

3 model-gated tests; all skip when the model is not cached.

* docs(upstream): mlx-vlm TurboQuant RHT masked-decode fix (Bug 2) PR artifacts

The L=1 value kernels (_metal_mse_weighted_sum,
_metal_mse_weighted_sum_sum_from_scores) undo the codec rotation with
matmul(.,rotation) but ignore use_rht (RHT), so they corrupt the masked
decode path (~140% err). weighted_sum / weighted_sum_stats_from_scores lack
the 'if not self.use_rht' guard that weighted_sum_from_scores has. Patch +
PR description for upstream Blaizzy/mlx-vlm (against main).

* fix(turboquant): forward mlx-vlm to HEAD + land Bug-2 masked-decode in oMLX

Forward the mlx-vlm pin f96138e -> fea81522 (main), which fixes the fused
single-token quantize decode kernel (Bug 1) upstream — so the oMLX
_fix_decode_single_token_quantize workaround is dropped.

Bug 2 (the RHT-incompatible L=1 value kernels corrupting the masked decode
path) is still unmerged upstream, so carry it as an oMLX monkey-patch
(_fix_masked_decode_rht: disable those kernels -> correct einsum fallback).
With the masked path now correct, route B>1 continuous-batching decode
through decode_attention instead of the dequantize+SDPA workaround — no
per-step batch dequantize, and it resolves the batch-4-bit SSD fresh-vs-hit
divergence ([False] -> [True] at 4-bit; verified).

uv.lock is gitignored; regenerate it ('uv lock') and run the full suite in a
controlled env before release. Tests updated for the new routing; 29 TQ tests
pass on HEAD.

* fix(turboquant): correct B>1 make_mask for left-padded batches

BatchTurboQuantKVCache.make_mask hand-rolled a causal term that compared each
request's sequence length (offset) against the column index, then ANDed the
left_padding term — which masked out the valid left-padded tokens. Left-padded
requests in a ragged batch attended to ~nothing and decoded garbage, making
batch mode worse than single (fp16's BatchKVCache was unaffected).

Delegate to mlx-lm's create_causal_mask(N, offset=phys, left_padding=...),
exactly like BatchKVCache, so the masks are identical. After the fix:
- ragged-batch token-match to single-seq: 25% -> 71% (== same-length batch);
- teacher-forced top-1 agreement single-vs-batch (left-padded member): 12/12,
  i.e. batch is computationally equivalent to single; residual greedy token
  divergence is cascade noise, not quality loss.

This is an oMLX-only bug (separate from the mlx-vlm RHT PR).

* test(turboquant): group TQ tests under a 'turboquant' marker

Register a 'turboquant' pytest marker and apply it to the three TQ test
files (test_turboquant.py, test_turboquant_batch_memory.py,
test_turboquant_ssd.py) so the whole suite runs with 'pytest -m turboquant'.
The model-loading files are also marked 'slow' (deselected by default).

* test(turboquant): split semicolon statements in mask test (E702)

* chore(turboquant): drop RHT monkey-patch; pin merged mlx-vlm (Bug 2 upstream)

Blaizzy/mlx-vlm#1244 (the RHT masked-decode 'not use_rht' guard) is merged.
Bump the pin fea81522 -> 6f60ee4 (includes the merge) and delete the interim
_fix_masked_decode_rht monkey-patch — B>1 masked decode now relies on the
upstream fix. Removed the docs/upstream PR artifacts (PR is merged). Verified:
masked decode 1.2% with no patch; 26 TQ tests pass; single/batch coherent.
Blaizzy added a commit that referenced this pull request Jun 1, 2026
…nder RHT) (#1252)

The single-token value kernels (_metal_mse_weighted_sum and friends) return
the weighted value sum in the codec's rotated space and undo it with a
hard-coded matmul(weighted_rot, rotation). That inverse is only correct for
the dense-rotation codec, so #1244 disabled these kernels whenever the codec
uses the Randomized Hadamard Transform (use_rht) — every RHT decode fell back
to the slower einsum path.

Pass the codec's RHT signs into the wrappers and apply the matching inverse
(_rht_inverse when signs are set, else the dense matmul) via a small
_value_rotate_inverse helper, mirroring _TurboQuantMSECodec._rotate_inverse
and the already-correct fused-decode path. The kernel computes weighted_rot
correctly regardless of rotation type; only the post-kernel inverse needed
fixing. The not-self.use_rht guards are dropped so RHT decode takes the
kernel again.

Prod-mode fused-decode call sites pass no signs (default None) and keep the
exact dense matmul — byte-for-byte unchanged.

Verified: RHT weighted_sum / weighted_sum_stats_from_scores match the einsum
fallback and the dequantize ground truth to <1e-4 / <1e-3 across
dims {64,128,256} x bits {2,3,4,8} x repeats {1,4}; the L=1 value
reconstruction runs 3.5-13.7x faster than the einsum fallback (the gap grows
with context length).

Co-authored-by: Prince Canuma <prince.gdt@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants