Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,6 @@ omlx/admin/tailwindcss-*
docs/native_app_architecture.md

# UV lockfile
uv.lock
uv.lock
# generated TurboQuant memory report (machine-specific)
tq_batch_memory.md
5 changes: 5 additions & 0 deletions omlx/patches/turboquant_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def patched_sdpa(

if isinstance(real_cache, (_TQCache, BatchTurboQuantKVCache)):
if queries.shape[-2] == 1:
# Decode (B=1 and B>1). Continuous-batching decode passes a
# per-request left-padding array mask; the masked decode_attention
# path runs the quantized kernels directly (no full-batch
# dequantize per step). The RHT masked-decode fix landed upstream
# in mlx-vlm (Blaizzy/mlx-vlm#1244, in the pinned commit).
return real_cache.decode_attention(
queries,
keys_state=keys,
Expand Down
85 changes: 58 additions & 27 deletions omlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1659,15 +1659,36 @@ def _on_prompt_progress(self, updates: list[tuple[int, int, int]]) -> None:
# External prefill (composition pattern — replaces _process_prompts)
# ------------------------------------------------------------------

def _apply_turboquant_kv_empty(self, prompt_cache: list[Any]) -> None:
"""Replace KVCache with empty TurboQuantKVCache before prefill.
def _turboquant_eligible(self, prompt_cache: list[Any]) -> bool:
"""True if every cache layer can be safely TurboQuant-converted for
continuous batching.

Only plain KVCache (and CacheList of KVCache, for VLM) implement the
merge/filter/extract/extend batch protocol that the monkey-patched
TurboQuantKVCache.merge relies on inside BatchGenerator. Chunked- and
rotating-attention caches (Llama-4, sliding-window) need
maybe_trim_front / rotating semantics that BatchTurboQuantKVCache does
not provide, so those models stay fp16 — no crash, no TurboQuant.
"""
from mlx_lm.models.cache import CacheList, KVCache

def _ok(c: Any) -> bool:
if isinstance(c, KVCache):
return True
if isinstance(c, CacheList):
return all(_ok(inner) for inner in c.caches)
return False

NOTE: Not currently called -- see #771. Kept for future use when
TurboQuantKVCache implements merge()/maybe_trim_front().
return bool(prompt_cache) and all(_ok(c) for c in prompt_cache)

def _apply_turboquant_kv_empty(self, prompt_cache: list[Any]) -> None:
"""Replace empty KVCache layers with empty TurboQuantKVCache.

Tokens are quantized on the fly during update_and_fetch, avoiding
the peak memory spike from storing full-precision KV then converting.
Skips the last KVCache layer if turboquant_skip_last is set.
Used only when there is no prefill history to preserve (the single
last token is quantized during insert()'s prompt step). Skips the
last KVCache layer if turboquant_skip_last is set.
"""
from mlx_lm.models.cache import CacheList, KVCache
from mlx_vlm.turboquant import TurboQuantKVCache
Expand Down Expand Up @@ -1701,13 +1722,13 @@ def _apply_turboquant_kv_empty(self, prompt_cache: list[Any]) -> None:
)

def _apply_turboquant_kv_convert(self, prompt_cache: list[Any]) -> None:
"""Convert existing KVCache data to TurboQuantKVCache via from_cache().

NOTE: Not currently called -- see #771. Kept for future use when
TurboQuantKVCache implements merge()/maybe_trim_front().
"""Convert populated KVCache data to TurboQuantKVCache via from_cache().

Used when an existing cache is provided (e.g. from SSD prefix cache).
Uses from_cache() to quantize the existing KV data.
Called AFTER fp16 prefill completes (or on an SSD-restored fp16
cache): the completed full-precision KV is quantized once, so prefill
hidden states stay exact and quantization error only enters at
decode-time reads. This is the key difference from #717/#771, which
quantized on the fly during prefill and corrupted hidden states.
"""
from mlx_lm.models.cache import CacheList, KVCache
from mlx_vlm.turboquant import TurboQuantKVCache
Expand Down Expand Up @@ -1770,13 +1791,17 @@ def _do_external_prefill(
"""
n_tokens = len(tokens)
if n_tokens <= 1:
# Nothing to prefill, return cache + tokens as-is
# Nothing to prefill, return cache + tokens as-is.
cache = existing_cache or make_prompt_cache(self.model)
# NOTE: Do NOT apply TurboQuant here. TurboQuantKVCache does not
# support merge(), which is called by _merge_caches() inside
# BatchGenerator when insert() creates a PromptProcessingBatch.
# TurboQuant conversion must happen inside BatchGenerator after
# the batch cache is created, not on individual per-request caches.
# TurboQuant: a TQ cache here makes _merge_caches() build a
# BatchTurboQuantKVCache (via the monkey-patched merge), so the
# one decode token quantizes against TQ history. An empty fresh
# cache gets empty TQ layers; a restored cache preserves its data.
if self._turboquant_kv_bits is not None and self._turboquant_eligible(cache):
if existing_cache is None:
self._apply_turboquant_kv_empty(cache)
else:
self._apply_turboquant_kv_convert(cache)
return cache, tokens

# Create or reuse cache
Expand All @@ -1785,14 +1810,10 @@ def _do_external_prefill(
else:
prompt_cache = make_prompt_cache(self.model)

# NOTE: TurboQuant conversion is NOT applied during external prefill.
# TurboQuantKVCache does not support merge() or maybe_trim_front(),
# so passing it to insert() would fail in _merge_caches() or cause
# AttributeError in chunked-attention models (e.g. Llama-4-Scout).
# Additionally, on-the-fly quantization during prefill causes
# precision loss that corrupts hidden states across layers (#771).
# Prefill runs with standard KVCache; TurboQuant quantization
# happens inside BatchGenerator during the decode phase.
# TurboQuant runs in fp16 during the prefill loop below and is
# quantized once at the end (see the _apply_turboquant_kv_convert call
# before the return). Chunked/rotating models are gated out by
# _turboquant_eligible and stay fp16.

# Clear stale mRoPE position state for text-only requests.
if vlm_embeds is None and hasattr(self.model, "clear_vlm_position_state"):
Expand Down Expand Up @@ -2043,6 +2064,15 @@ def _do_external_prefill(
self.model._language_model._rope_deltas = _saved_rope_deltas
request._prefill_saved_rope_deltas = None

# Quantize the completed fp16 KV cache to TurboQuant for decode.
# Done here (after the prefill loop, after boundary snapshots are
# captured fp16) so prefill hidden states stay exact and the paged-SSD
# format is unchanged. _merge_caches() then builds a
# BatchTurboQuantKVCache when this request is inserted. Gated to dense
# KVCache models — chunked/rotating caches stay fp16.
if self._turboquant_kv_bits is not None and self._turboquant_eligible(prompt_cache):
self._apply_turboquant_kv_convert(prompt_cache)

return prompt_cache, last_token

# ------------------------------------------------------------------
Expand Down Expand Up @@ -5454,8 +5484,9 @@ def _sparse_progress(processed: int, total: int) -> None:
if request.sampling_params.seed is not None:
mx.random.seed(request.sampling_params.seed)

# NOTE: TurboQuant KV conversion is not applied during prefill.
# See _do_external_prefill() comment for rationale (#771).
# TurboQuant KV is quantized at the end of _do_external_prefill
# (fp16 prefill → quantize once); _merge_caches() turns the per
# request TQ cache into a BatchTurboQuantKVCache on insert.

# VLM MTP routing: if a gemma4_assistant drafter is attached, run
# an extra last-token forward to capture hidden + shared_kv_states,
Expand Down
23 changes: 9 additions & 14 deletions omlx/turboquant_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,20 +245,15 @@ def make_mask(
return create_attention_mask(N, offset, return_array, window_size)
if isinstance(offset, mx.array) and offset.size == 1:
return create_attention_mask(N, offset.item(), return_array, window_size)
# B>1: batched causal mask
max_offset = offset.max().item()
total = max_offset + N
rinds = mx.arange(total)[None, None, :]
linds = mx.arange(N)[None, None, :, None]
off = offset[:, None, None, None]
linds = linds + off
mask = linds >= rinds
if window_size is not None:
mask = mask & (linds < rinds + window_size)
if self.left_padding is not None:
lp = self.left_padding[:, None, None, None]
mask = mask & (rinds >= lp)
return mask
# B>1: delegate to mlx-lm's create_causal_mask with the physical column
# count + per-request left_padding, exactly like BatchKVCache. The old
# hand-rolled term compared each request's sequence length (offset)
# against the column index, which masked out valid left-padded tokens —
# so left-padded requests attended to ~nothing and decoded garbage.
phys = offset.max().item()
return create_causal_mask(
N, offset=phys, window_size=window_size, left_padding=self.left_padding
)

# prefill_attention and dequantize inherited from TurboQuantKVCache

Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,12 @@ dependencies = [
"jsonschema>=4.0.0",
# Harmony format parser for gpt-oss models
"openai-harmony",
# mlx-vlm from commit (f96138e) - Gemma4 MTP server batching (PR #1166), speculative utils refactor (PR #1169), Qwen native MTP drafter, MiniCPM-V 4.6
"mlx-vlm @ git+https://github.com/Blaizzy/mlx-vlm@f96138eef1f5ce7fb5d97f8dd41a664a195b5659",
# mlx-vlm from commit (6f60ee4, main) - includes both TurboQuant decode fixes:
# "Bug 1" (fused single-token-quantize, fixed earlier on main) and "Bug 2" (the
# RHT masked-decode kernel, Blaizzy/mlx-vlm#1244, now merged) — so oMLX no longer
# needs its interim monkey-patch. (Also provides: Gemma4 MTP server batching PR
# #1166, speculative utils refactor PR #1169, Qwen native MTP drafter, MiniCPM-V 4.6.)
"mlx-vlm @ git+https://github.com/Blaizzy/mlx-vlm@6f60ee4458d85b636e2e6c09c33d32fc360d5e62",
"Pillow>=9.0.0",
# dflash-mlx v0.1.7 (1ba6713) — bstnxbt repo. Qwen thinking/GDN exactness fix, GQA SDPA reshape, DDTree + CopySpec decode path, prefix cache identity hardening, fp16 draft on old Apple chips
"dflash-mlx @ git+https://github.com/bstnxbt/dflash-mlx@1ba671372b289c025b435c1a13aabb4bfb80b183",
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ asyncio_mode = auto
markers =
slow: marks tests as slow (require model loading, deselect with '-m "not slow"')
integration: marks tests as integration tests (require running server)
turboquant: marks TurboQuant KV cache tests (run the suite with '-m turboquant')

# Default options
addopts =
Expand Down
Loading