Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,6 @@ omlx/admin/tailwindcss-*
docs/native_app_architecture.md

# UV lockfile
uv.lock
uv.lock
# generated TurboQuant memory report (machine-specific)
tq_batch_memory.md
63 changes: 63 additions & 0 deletions docs/upstream/mlx-vlm-turboquant-rht-decode-PR.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Upstream PR: fix TurboQuant L=1 value kernels under RHT (mlx-vlm)

**Target:** `Blaizzy/mlx-vlm` `main` (verified against `fea81522`; same on `f96138e` / `v0.5.0`).
**File:** `mlx_vlm/turboquant.py` — `_TurboQuantMSECodec`.

## Summary

`_TurboQuantMSECodec.weighted_sum` and `weighted_sum_stats_from_scores` call the
L=1 value-reconstruction Metal kernels (`_metal_mse_weighted_sum`,
`_metal_mse_weighted_sum_sum_from_scores`) **without** the `if not self.use_rht`
guard that the sibling `weighted_sum_from_scores` already has. Those kernels
finish with `matmul(weighted_rot, rotation)`, which is the inverse only for a
plain rotation. The codec defaults to `use_rht=True` (randomized Hadamard
transform), whose inverse is `_rht_inverse(.; signs)`. So under RHT the kernels
apply the wrong inverse transform and return essentially uncorrelated output.

## Impact

- Single-query **decode attention** through the slow/masked path is corrupt.
- Reproduction error is ~140% (of signal magnitude) at every bit depth (2–8),
i.e. not a precision issue — a wrong-transform issue.
- Latent because the common decode path uses the fused `_fused_mse_decode_kernel`
(mask is `None`/`"causal"`); the bug only shows when an array mask forces the
slow path — e.g. continuous-batching decode with per-request left-padding.

## Root cause

`weighted_sum_from_scores` is guarded:
```python
if not self.use_rht:
fast_output = _metal_mse_weighted_sum_from_scores(...)
...
```
but `weighted_sum` and `weighted_sum_stats_from_scores` are not. The non-fused
fallback paths (einsum + `self._rotate_inverse(...)`) are correct for both RHT
and plain rotation.

## Fix

Add `not self.use_rht and` to the two L=1 guards (see the patch). Under RHT this
takes the correct einsum/`_rotate_inverse` fallback; with a plain rotation
(`use_rht=False`) the kernels still run.

## Verification

```
# _TurboQuantMSECodec, 8-bit, single-query decode through the masked path
array-mask decode error: before = 140.0% after = 1.2%
```
End-to-end on `mlx-community/Llama-3.2-1B-Instruct-4bit`, continuous-batching
decode (B>1, left-padded) produces coherent output after the fix; before it is
garbage.

## Notes / suggested follow-up

- A proper fix in the kernels themselves (apply the RHT inverse instead of
`matmul(., rotation)`) would let RHT use the fast path; this PR takes the
conservative route (fall back to the correct math) matching the existing
`weighted_sum_from_scores` behavior.
- Related: the fused single-token quantize kernel (`_try_fused_kv_quantize` /
`_fused_kv_quantize_kernel`, the T=1 path) had an analogous decode-time
defect that was fixed on `main` (`fea81522`); this PR addresses the remaining
value-kernel/RHT case.
61 changes: 61 additions & 0 deletions docs/upstream/mlx-vlm-turboquant-rht-decode.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
From: oMLX maintainers
Subject: [PATCH] turboquant: guard L=1 value Metal kernels behind `not use_rht`

The L=1 value-reconstruction Metal kernels (`_metal_mse_weighted_sum`,
`_metal_mse_weighted_sum_sum_from_scores`) apply `matmul(weighted_rot, rotation)`
to undo the codec rotation. That is only correct when the codec uses a plain
rotation matrix. When `use_rht=True` (a randomized Hadamard transform, the
default for `_TurboQuantMSECodec`), the inverse is `_rht_inverse(.; signs)`, not
`matmul(.; rotation)`, so these kernels return garbage (~140% reconstruction
error at every bit depth).

`weighted_sum_from_scores` already guards its kernel with `if not self.use_rht`.
`weighted_sum` and `weighted_sum_stats_from_scores` do not — so the
single-query decode path (and any continuous-batching decode that takes the
masked branch) is corrupt under RHT. This was latent because the fused decode
fast path (`_fused_mse_decode_kernel`) is used for the common mask=None/causal
case; the bug only surfaces when an array mask forces the slow path.

Fix: add the same `not self.use_rht` guard to both methods, mirroring
`weighted_sum_from_scores`. Under RHT this falls back to the correct einsum +
`_rotate_inverse` path. Verified 140% -> ~1% reconstruction error.

--- a/mlx_vlm/turboquant.py
+++ b/mlx_vlm/turboquant.py
@@ class _TurboQuantMSECodec:
def weighted_sum(self, weights: mx.array, state: TurboQuantMSEState) -> mx.array:
- if weights.shape[-2] == 1:
+ if not self.use_rht and weights.shape[-2] == 1:
fast_output = _metal_mse_weighted_sum(
weights,
state,
self.bits,
self.codebook,
self.rotation,
)
if fast_output is not None:
return fast_output
@@ class _TurboQuantMSECodec:
def weighted_sum_stats_from_scores(
self, scores: mx.array, state: TurboQuantMSEState
) -> tuple[mx.array, mx.array, mx.array]:
max_scores = mx.max(scores, axis=-1)
# Metal kernel fast path: only for single-query decode (L=1)
- if scores.ndim == 5 and scores.shape[-2] == 1:
+ if not self.use_rht and scores.ndim == 5 and scores.shape[-2] == 1:
max_scores_2d = max_scores.reshape(
max_scores.shape[0],
max_scores.shape[1],
max_scores.shape[2],
)
fast_output = _metal_mse_weighted_sum_sum_from_scores(
scores,
state,
self.bits,
self.codebook,
self.rotation,
max_scores_2d,
)
if fast_output is not None:
denom = mx.sum(mx.exp(scores - max_scores[..., None]), axis=-1)
return fast_output, denom, max_scores
51 changes: 51 additions & 0 deletions omlx/patches/turboquant_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,49 @@
logger = logging.getLogger(__name__)

_PATCHED = False
_RHT_DECODE_FIXED = False


def _fix_masked_decode_rht() -> None:
"""Work around mlx-vlm's L=1 value Metal kernels that ignore RHT (Bug 2).

`_TurboQuantMSECodec.weighted_sum` and `weighted_sum_stats_from_scores` call
the single-query (L=1) value kernels (`_metal_mse_weighted_sum`,
`_metal_mse_weighted_sum_sum_from_scores`) WITHOUT the `if not self.use_rht`
guard that the sibling `weighted_sum_from_scores` has. Those kernels undo the
codec rotation with `matmul(., rotation)`, but TurboQuant KV codecs use
`use_rht=True` (randomized Hadamard transform) whose inverse is
`_rht_inverse` — so they corrupt the masked decode path (~140% error). That
is what makes B>1 continuous-batching decode (which passes a per-request
left-padding array mask) produce garbage.

We disable those two kernels so the codec falls back to the correct einsum +
`_rotate_inverse` path. Since our KV codecs are always `use_rht=True`, this is
equivalent to the upstream `not self.use_rht` guard, at the cost of one matmul
instead of a fused kernel on the slow decode path — negligible.

Temporary until the upstream fix lands; see
docs/upstream/mlx-vlm-turboquant-rht-decode-PR.md. (Bug 1, the fused
single-token quantize kernel, is already fixed on the pinned mlx-vlm main.)
"""
global _RHT_DECODE_FIXED
if _RHT_DECODE_FIXED:
return
try:
import mlx_vlm.turboquant as _tq
except ImportError:
return

def _decline(*args, **kwargs):
return None

_tq._metal_mse_weighted_sum = _decline
_tq._metal_mse_weighted_sum_sum_from_scores = _decline
_RHT_DECODE_FIXED = True
logger.info(
"TurboQuant decode fix applied: disabled RHT-incompatible L=1 value "
"kernels (mlx-vlm Bug 2 workaround)"
)


def apply_turboquant_attention_patch() -> bool:
Expand Down Expand Up @@ -51,6 +94,10 @@ def patched_sdpa(

if isinstance(real_cache, (_TQCache, BatchTurboQuantKVCache)):
if queries.shape[-2] == 1:
# Decode (B=1 and B>1). With the masked decode path corrected by
# _fix_masked_decode_rht(), continuous-batching decode using a
# per-request left-padding array mask runs the quantized kernels
# directly — no full-batch dequantize per step.
return real_cache.decode_attention(
queries,
keys_state=keys,
Expand Down Expand Up @@ -99,6 +146,10 @@ def patched_sdpa(
except ImportError:
pass

# Without this, B>1 (masked) decode is corrupt and TurboQuant batching
# produces garbage (see _fix_masked_decode_rht).
_fix_masked_decode_rht()

_PATCHED = True
logger.info("TurboQuant attention patch applied")
return True
85 changes: 58 additions & 27 deletions omlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1659,15 +1659,36 @@ def _on_prompt_progress(self, updates: list[tuple[int, int, int]]) -> None:
# External prefill (composition pattern — replaces _process_prompts)
# ------------------------------------------------------------------

def _apply_turboquant_kv_empty(self, prompt_cache: list[Any]) -> None:
"""Replace KVCache with empty TurboQuantKVCache before prefill.
def _turboquant_eligible(self, prompt_cache: list[Any]) -> bool:
"""True if every cache layer can be safely TurboQuant-converted for
continuous batching.

Only plain KVCache (and CacheList of KVCache, for VLM) implement the
merge/filter/extract/extend batch protocol that the monkey-patched
TurboQuantKVCache.merge relies on inside BatchGenerator. Chunked- and
rotating-attention caches (Llama-4, sliding-window) need
maybe_trim_front / rotating semantics that BatchTurboQuantKVCache does
not provide, so those models stay fp16 — no crash, no TurboQuant.
"""
from mlx_lm.models.cache import CacheList, KVCache

def _ok(c: Any) -> bool:
if isinstance(c, KVCache):
return True
if isinstance(c, CacheList):
return all(_ok(inner) for inner in c.caches)
return False

NOTE: Not currently called -- see #771. Kept for future use when
TurboQuantKVCache implements merge()/maybe_trim_front().
return bool(prompt_cache) and all(_ok(c) for c in prompt_cache)

def _apply_turboquant_kv_empty(self, prompt_cache: list[Any]) -> None:
"""Replace empty KVCache layers with empty TurboQuantKVCache.

Tokens are quantized on the fly during update_and_fetch, avoiding
the peak memory spike from storing full-precision KV then converting.
Skips the last KVCache layer if turboquant_skip_last is set.
Used only when there is no prefill history to preserve (the single
last token is quantized during insert()'s prompt step). Skips the
last KVCache layer if turboquant_skip_last is set.
"""
from mlx_lm.models.cache import CacheList, KVCache
from mlx_vlm.turboquant import TurboQuantKVCache
Expand Down Expand Up @@ -1701,13 +1722,13 @@ def _apply_turboquant_kv_empty(self, prompt_cache: list[Any]) -> None:
)

def _apply_turboquant_kv_convert(self, prompt_cache: list[Any]) -> None:
"""Convert existing KVCache data to TurboQuantKVCache via from_cache().

NOTE: Not currently called -- see #771. Kept for future use when
TurboQuantKVCache implements merge()/maybe_trim_front().
"""Convert populated KVCache data to TurboQuantKVCache via from_cache().

Used when an existing cache is provided (e.g. from SSD prefix cache).
Uses from_cache() to quantize the existing KV data.
Called AFTER fp16 prefill completes (or on an SSD-restored fp16
cache): the completed full-precision KV is quantized once, so prefill
hidden states stay exact and quantization error only enters at
decode-time reads. This is the key difference from #717/#771, which
quantized on the fly during prefill and corrupted hidden states.
"""
from mlx_lm.models.cache import CacheList, KVCache
from mlx_vlm.turboquant import TurboQuantKVCache
Expand Down Expand Up @@ -1770,13 +1791,17 @@ def _do_external_prefill(
"""
n_tokens = len(tokens)
if n_tokens <= 1:
# Nothing to prefill, return cache + tokens as-is
# Nothing to prefill, return cache + tokens as-is.
cache = existing_cache or make_prompt_cache(self.model)
# NOTE: Do NOT apply TurboQuant here. TurboQuantKVCache does not
# support merge(), which is called by _merge_caches() inside
# BatchGenerator when insert() creates a PromptProcessingBatch.
# TurboQuant conversion must happen inside BatchGenerator after
# the batch cache is created, not on individual per-request caches.
# TurboQuant: a TQ cache here makes _merge_caches() build a
# BatchTurboQuantKVCache (via the monkey-patched merge), so the
# one decode token quantizes against TQ history. An empty fresh
# cache gets empty TQ layers; a restored cache preserves its data.
if self._turboquant_kv_bits is not None and self._turboquant_eligible(cache):
if existing_cache is None:
self._apply_turboquant_kv_empty(cache)
else:
self._apply_turboquant_kv_convert(cache)
return cache, tokens

# Create or reuse cache
Expand All @@ -1785,14 +1810,10 @@ def _do_external_prefill(
else:
prompt_cache = make_prompt_cache(self.model)

# NOTE: TurboQuant conversion is NOT applied during external prefill.
# TurboQuantKVCache does not support merge() or maybe_trim_front(),
# so passing it to insert() would fail in _merge_caches() or cause
# AttributeError in chunked-attention models (e.g. Llama-4-Scout).
# Additionally, on-the-fly quantization during prefill causes
# precision loss that corrupts hidden states across layers (#771).
# Prefill runs with standard KVCache; TurboQuant quantization
# happens inside BatchGenerator during the decode phase.
# TurboQuant runs in fp16 during the prefill loop below and is
# quantized once at the end (see the _apply_turboquant_kv_convert call
# before the return). Chunked/rotating models are gated out by
# _turboquant_eligible and stay fp16.

# Clear stale mRoPE position state for text-only requests.
if vlm_embeds is None and hasattr(self.model, "clear_vlm_position_state"):
Expand Down Expand Up @@ -2043,6 +2064,15 @@ def _do_external_prefill(
self.model._language_model._rope_deltas = _saved_rope_deltas
request._prefill_saved_rope_deltas = None

# Quantize the completed fp16 KV cache to TurboQuant for decode.
# Done here (after the prefill loop, after boundary snapshots are
# captured fp16) so prefill hidden states stay exact and the paged-SSD
# format is unchanged. _merge_caches() then builds a
# BatchTurboQuantKVCache when this request is inserted. Gated to dense
# KVCache models — chunked/rotating caches stay fp16.
if self._turboquant_kv_bits is not None and self._turboquant_eligible(prompt_cache):
self._apply_turboquant_kv_convert(prompt_cache)

return prompt_cache, last_token

# ------------------------------------------------------------------
Expand Down Expand Up @@ -5454,8 +5484,9 @@ def _sparse_progress(processed: int, total: int) -> None:
if request.sampling_params.seed is not None:
mx.random.seed(request.sampling_params.seed)

# NOTE: TurboQuant KV conversion is not applied during prefill.
# See _do_external_prefill() comment for rationale (#771).
# TurboQuant KV is quantized at the end of _do_external_prefill
# (fp16 prefill → quantize once); _merge_caches() turns the per
# request TQ cache into a BatchTurboQuantKVCache on insert.

# VLM MTP routing: if a gemma4_assistant drafter is attached, run
# an extra last-token forward to capture hidden + shared_kv_states,
Expand Down
23 changes: 9 additions & 14 deletions omlx/turboquant_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,20 +245,15 @@ def make_mask(
return create_attention_mask(N, offset, return_array, window_size)
if isinstance(offset, mx.array) and offset.size == 1:
return create_attention_mask(N, offset.item(), return_array, window_size)
# B>1: batched causal mask
max_offset = offset.max().item()
total = max_offset + N
rinds = mx.arange(total)[None, None, :]
linds = mx.arange(N)[None, None, :, None]
off = offset[:, None, None, None]
linds = linds + off
mask = linds >= rinds
if window_size is not None:
mask = mask & (linds < rinds + window_size)
if self.left_padding is not None:
lp = self.left_padding[:, None, None, None]
mask = mask & (rinds >= lp)
return mask
# B>1: delegate to mlx-lm's create_causal_mask with the physical column
# count + per-request left_padding, exactly like BatchKVCache. The old
# hand-rolled term compared each request's sequence length (offset)
# against the column index, which masked out valid left-padded tokens —
# so left-padded requests attended to ~nothing and decoded garbage.
phys = offset.max().item()
return create_causal_mask(
N, offset=phys, window_size=window_size, left_padding=self.left_padding
)

# prefill_attention and dequantize inherited from TurboQuantKVCache

Expand Down
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,13 @@ dependencies = [
"jsonschema>=4.0.0",
# Harmony format parser for gpt-oss models
"openai-harmony",
# mlx-vlm from commit (f96138e) - Gemma4 MTP server batching (PR #1166), speculative utils refactor (PR #1169), Qwen native MTP drafter, MiniCPM-V 4.6
"mlx-vlm @ git+https://github.com/Blaizzy/mlx-vlm@f96138eef1f5ce7fb5d97f8dd41a664a195b5659",
# mlx-vlm from commit (fea81522, main) - forwarded from f96138e to pick up the
# fused single-token-quantize decode fix (TurboQuant "Bug 1"). The masked-decode
# RHT fix ("Bug 2") is still unmerged upstream and is carried as an oMLX monkey-patch
# until the PR lands — see docs/upstream/mlx-vlm-turboquant-rht-decode-PR.md and
# omlx/patches/turboquant_attention.py. (f96138e provided: Gemma4 MTP server batching
# PR #1166, speculative utils refactor PR #1169, Qwen native MTP drafter, MiniCPM-V 4.6.)
"mlx-vlm @ git+https://github.com/Blaizzy/mlx-vlm@fea81522ec5d7f420cd033fe6dafe08a5d807aab",
"Pillow>=9.0.0",
# dflash-mlx v0.1.7 (1ba6713) — bstnxbt repo. Qwen thinking/GDN exactness fix, GQA SDPA reshape, DDTree + CopySpec decode path, prefix cache identity hardening, fp16 draft on old Apple chips
"dflash-mlx @ git+https://github.com/bstnxbt/dflash-mlx@1ba671372b289c025b435c1a13aabb4bfb80b183",
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ asyncio_mode = auto
markers =
slow: marks tests as slow (require model loading, deselect with '-m "not slow"')
integration: marks tests as integration tests (require running server)
turboquant: marks TurboQuant KV cache tests (run the suite with '-m turboquant')

# Default options
addopts =
Expand Down
Loading