diff --git a/.gitignore b/.gitignore index d8b7fac75..d38805d0a 100644 --- a/.gitignore +++ b/.gitignore @@ -118,4 +118,6 @@ omlx/admin/tailwindcss-* docs/native_app_architecture.md # UV lockfile -uv.lock \ No newline at end of file +uv.lock +# generated TurboQuant memory report (machine-specific) +tq_batch_memory.md diff --git a/omlx/patches/turboquant_attention.py b/omlx/patches/turboquant_attention.py index 31c7dddff..5599e2855 100644 --- a/omlx/patches/turboquant_attention.py +++ b/omlx/patches/turboquant_attention.py @@ -51,6 +51,11 @@ def patched_sdpa( if isinstance(real_cache, (_TQCache, BatchTurboQuantKVCache)): if queries.shape[-2] == 1: + # Decode (B=1 and B>1). Continuous-batching decode passes a + # per-request left-padding array mask; the masked decode_attention + # path runs the quantized kernels directly (no full-batch + # dequantize per step). The RHT masked-decode fix landed upstream + # in mlx-vlm (Blaizzy/mlx-vlm#1244, in the pinned commit). return real_cache.decode_attention( queries, keys_state=keys, diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 7005a5c58..2e41ba6d1 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -1659,15 +1659,36 @@ def _on_prompt_progress(self, updates: list[tuple[int, int, int]]) -> None: # External prefill (composition pattern — replaces _process_prompts) # ------------------------------------------------------------------ - def _apply_turboquant_kv_empty(self, prompt_cache: list[Any]) -> None: - """Replace KVCache with empty TurboQuantKVCache before prefill. + def _turboquant_eligible(self, prompt_cache: list[Any]) -> bool: + """True if every cache layer can be safely TurboQuant-converted for + continuous batching. + + Only plain KVCache (and CacheList of KVCache, for VLM) implement the + merge/filter/extract/extend batch protocol that the monkey-patched + TurboQuantKVCache.merge relies on inside BatchGenerator. Chunked- and + rotating-attention caches (Llama-4, sliding-window) need + maybe_trim_front / rotating semantics that BatchTurboQuantKVCache does + not provide, so those models stay fp16 — no crash, no TurboQuant. + """ + from mlx_lm.models.cache import CacheList, KVCache + + def _ok(c: Any) -> bool: + if isinstance(c, KVCache): + return True + if isinstance(c, CacheList): + return all(_ok(inner) for inner in c.caches) + return False - NOTE: Not currently called -- see #771. Kept for future use when - TurboQuantKVCache implements merge()/maybe_trim_front(). + return bool(prompt_cache) and all(_ok(c) for c in prompt_cache) + + def _apply_turboquant_kv_empty(self, prompt_cache: list[Any]) -> None: + """Replace empty KVCache layers with empty TurboQuantKVCache. Tokens are quantized on the fly during update_and_fetch, avoiding the peak memory spike from storing full-precision KV then converting. - Skips the last KVCache layer if turboquant_skip_last is set. + Used only when there is no prefill history to preserve (the single + last token is quantized during insert()'s prompt step). Skips the + last KVCache layer if turboquant_skip_last is set. """ from mlx_lm.models.cache import CacheList, KVCache from mlx_vlm.turboquant import TurboQuantKVCache @@ -1701,13 +1722,13 @@ def _apply_turboquant_kv_empty(self, prompt_cache: list[Any]) -> None: ) def _apply_turboquant_kv_convert(self, prompt_cache: list[Any]) -> None: - """Convert existing KVCache data to TurboQuantKVCache via from_cache(). - - NOTE: Not currently called -- see #771. Kept for future use when - TurboQuantKVCache implements merge()/maybe_trim_front(). + """Convert populated KVCache data to TurboQuantKVCache via from_cache(). - Used when an existing cache is provided (e.g. from SSD prefix cache). - Uses from_cache() to quantize the existing KV data. + Called AFTER fp16 prefill completes (or on an SSD-restored fp16 + cache): the completed full-precision KV is quantized once, so prefill + hidden states stay exact and quantization error only enters at + decode-time reads. This is the key difference from #717/#771, which + quantized on the fly during prefill and corrupted hidden states. """ from mlx_lm.models.cache import CacheList, KVCache from mlx_vlm.turboquant import TurboQuantKVCache @@ -1770,13 +1791,17 @@ def _do_external_prefill( """ n_tokens = len(tokens) if n_tokens <= 1: - # Nothing to prefill, return cache + tokens as-is + # Nothing to prefill, return cache + tokens as-is. cache = existing_cache or make_prompt_cache(self.model) - # NOTE: Do NOT apply TurboQuant here. TurboQuantKVCache does not - # support merge(), which is called by _merge_caches() inside - # BatchGenerator when insert() creates a PromptProcessingBatch. - # TurboQuant conversion must happen inside BatchGenerator after - # the batch cache is created, not on individual per-request caches. + # TurboQuant: a TQ cache here makes _merge_caches() build a + # BatchTurboQuantKVCache (via the monkey-patched merge), so the + # one decode token quantizes against TQ history. An empty fresh + # cache gets empty TQ layers; a restored cache preserves its data. + if self._turboquant_kv_bits is not None and self._turboquant_eligible(cache): + if existing_cache is None: + self._apply_turboquant_kv_empty(cache) + else: + self._apply_turboquant_kv_convert(cache) return cache, tokens # Create or reuse cache @@ -1785,14 +1810,10 @@ def _do_external_prefill( else: prompt_cache = make_prompt_cache(self.model) - # NOTE: TurboQuant conversion is NOT applied during external prefill. - # TurboQuantKVCache does not support merge() or maybe_trim_front(), - # so passing it to insert() would fail in _merge_caches() or cause - # AttributeError in chunked-attention models (e.g. Llama-4-Scout). - # Additionally, on-the-fly quantization during prefill causes - # precision loss that corrupts hidden states across layers (#771). - # Prefill runs with standard KVCache; TurboQuant quantization - # happens inside BatchGenerator during the decode phase. + # TurboQuant runs in fp16 during the prefill loop below and is + # quantized once at the end (see the _apply_turboquant_kv_convert call + # before the return). Chunked/rotating models are gated out by + # _turboquant_eligible and stay fp16. # Clear stale mRoPE position state for text-only requests. if vlm_embeds is None and hasattr(self.model, "clear_vlm_position_state"): @@ -2043,6 +2064,15 @@ def _do_external_prefill( self.model._language_model._rope_deltas = _saved_rope_deltas request._prefill_saved_rope_deltas = None + # Quantize the completed fp16 KV cache to TurboQuant for decode. + # Done here (after the prefill loop, after boundary snapshots are + # captured fp16) so prefill hidden states stay exact and the paged-SSD + # format is unchanged. _merge_caches() then builds a + # BatchTurboQuantKVCache when this request is inserted. Gated to dense + # KVCache models — chunked/rotating caches stay fp16. + if self._turboquant_kv_bits is not None and self._turboquant_eligible(prompt_cache): + self._apply_turboquant_kv_convert(prompt_cache) + return prompt_cache, last_token # ------------------------------------------------------------------ @@ -5454,8 +5484,9 @@ def _sparse_progress(processed: int, total: int) -> None: if request.sampling_params.seed is not None: mx.random.seed(request.sampling_params.seed) - # NOTE: TurboQuant KV conversion is not applied during prefill. - # See _do_external_prefill() comment for rationale (#771). + # TurboQuant KV is quantized at the end of _do_external_prefill + # (fp16 prefill → quantize once); _merge_caches() turns the per + # request TQ cache into a BatchTurboQuantKVCache on insert. # VLM MTP routing: if a gemma4_assistant drafter is attached, run # an extra last-token forward to capture hidden + shared_kv_states, diff --git a/omlx/turboquant_kv.py b/omlx/turboquant_kv.py index 4c41c946e..84d09ae10 100644 --- a/omlx/turboquant_kv.py +++ b/omlx/turboquant_kv.py @@ -245,20 +245,15 @@ def make_mask( return create_attention_mask(N, offset, return_array, window_size) if isinstance(offset, mx.array) and offset.size == 1: return create_attention_mask(N, offset.item(), return_array, window_size) - # B>1: batched causal mask - max_offset = offset.max().item() - total = max_offset + N - rinds = mx.arange(total)[None, None, :] - linds = mx.arange(N)[None, None, :, None] - off = offset[:, None, None, None] - linds = linds + off - mask = linds >= rinds - if window_size is not None: - mask = mask & (linds < rinds + window_size) - if self.left_padding is not None: - lp = self.left_padding[:, None, None, None] - mask = mask & (rinds >= lp) - return mask + # B>1: delegate to mlx-lm's create_causal_mask with the physical column + # count + per-request left_padding, exactly like BatchKVCache. The old + # hand-rolled term compared each request's sequence length (offset) + # against the column index, which masked out valid left-padded tokens — + # so left-padded requests attended to ~nothing and decoded garbage. + phys = offset.max().item() + return create_causal_mask( + N, offset=phys, window_size=window_size, left_padding=self.left_padding + ) # prefill_attention and dequantize inherited from TurboQuantKVCache diff --git a/pyproject.toml b/pyproject.toml index b9578d743..bfe422976 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,8 +72,12 @@ dependencies = [ "jsonschema>=4.0.0", # Harmony format parser for gpt-oss models "openai-harmony", - # mlx-vlm from commit (f96138e) - Gemma4 MTP server batching (PR #1166), speculative utils refactor (PR #1169), Qwen native MTP drafter, MiniCPM-V 4.6 - "mlx-vlm @ git+https://github.com/Blaizzy/mlx-vlm@f96138eef1f5ce7fb5d97f8dd41a664a195b5659", + # mlx-vlm from commit (6f60ee4, main) - includes both TurboQuant decode fixes: + # "Bug 1" (fused single-token-quantize, fixed earlier on main) and "Bug 2" (the + # RHT masked-decode kernel, Blaizzy/mlx-vlm#1244, now merged) — so oMLX no longer + # needs its interim monkey-patch. (Also provides: Gemma4 MTP server batching PR + # #1166, speculative utils refactor PR #1169, Qwen native MTP drafter, MiniCPM-V 4.6.) + "mlx-vlm @ git+https://github.com/Blaizzy/mlx-vlm@6f60ee4458d85b636e2e6c09c33d32fc360d5e62", "Pillow>=9.0.0", # dflash-mlx v0.1.7 (1ba6713) — bstnxbt repo. Qwen thinking/GDN exactness fix, GQA SDPA reshape, DDTree + CopySpec decode path, prefix cache identity hardening, fp16 draft on old Apple chips "dflash-mlx @ git+https://github.com/bstnxbt/dflash-mlx@1ba671372b289c025b435c1a13aabb4bfb80b183", diff --git a/pytest.ini b/pytest.ini index 09c92df03..20175d510 100644 --- a/pytest.ini +++ b/pytest.ini @@ -14,6 +14,7 @@ asyncio_mode = auto markers = slow: marks tests as slow (require model loading, deselect with '-m "not slow"') integration: marks tests as integration tests (require running server) + turboquant: marks TurboQuant KV cache tests (run the suite with '-m turboquant') # Default options addopts = diff --git a/tests/test_turboquant.py b/tests/test_turboquant.py index 2d5258b9a..d0062b638 100644 --- a/tests/test_turboquant.py +++ b/tests/test_turboquant.py @@ -14,6 +14,8 @@ from omlx.turboquant_kv import BatchTurboQuantKVCache, _rebuild_codecs, _infer_head_dim +pytestmark = pytest.mark.turboquant + def _sample_unit_vectors(count: int, dim: int) -> mx.array: vectors = mx.random.normal((count, dim)) @@ -152,6 +154,32 @@ def test_batch_tq_continuous_batching_extend(): # offset is now mx.array after extend +def test_batch_make_mask_matches_fp16_left_padding(): + """Regression: B>1 make_mask must match mlx-lm's BatchKVCache for left-padded + batches. The old hand-rolled causal term compared each request's sequence + length against the column index and masked out valid left-padded tokens, so + left-padded requests attended to ~nothing and decoded garbage (batch worse + than single). It now delegates to create_causal_mask like BatchKVCache. + """ + from mlx_lm.models.cache import BatchKVCache + + lp = [0, 4, 2] + K = mx.random.normal((3, 2, 8, 16)) + V = mx.random.normal((3, 2, 8, 16)) + bk = BatchKVCache(lp) + bk.update_and_fetch(K, V) + bt = BatchTurboQuantKVCache(lp, bits=8.0) + bt.update_and_fetch(K, V) + + ref = bk.make_mask(1, return_array=True) # decode-step mask + got = bt.make_mask(1, return_array=True) + assert mx.array_equal(ref, got).item(), ( + "B>1 make_mask diverges from BatchKVCache for left-padding " + f"(member masks: BK={ref[:,0,0,:].sum(-1).tolist()} " + f"TQ={got[:,0,0,:].sum(-1).tolist()})" + ) + + def test_batch_tq_filter(): batch = BatchTurboQuantKVCache([0, 0, 0], bits=4.0) keys = mx.random.normal((3, 2, 8, 32)) @@ -314,3 +342,163 @@ def test_ssd_type_map_completeness(): "TurboQuantSplitState": TurboQuantSplitState, } assert set(_type_map.keys()) == expected_types + + +# --------------------------------------------------------------------------- +# Batched TurboQuant wiring (Phase 1): eligibility gate + post-prefill +# conversion path (from_cache -> merge -> BatchTurboQuantKVCache) +# --------------------------------------------------------------------------- + + +def test_turboquant_eligible_gate(): + """Only dense KVCache (and CacheList of KVCache) is batch-convertible. + + Chunked/rotating/quantized caches must gate OFF so chunked-attention + models (Llama-4) and sliding-window models stay fp16 instead of crashing + in _merge_caches() — the #771 SIGABRT class. + """ + from types import SimpleNamespace + + from mlx_lm.models.cache import ( + CacheList, + ChunkedKVCache, + KVCache, + QuantizedKVCache, + RotatingKVCache, + ) + + from omlx.scheduler import Scheduler + + # _turboquant_eligible is pure (ignores self); call the unbound method + # with a throwaway self so we don't construct a full Scheduler. + def elig(cache): + return Scheduler._turboquant_eligible(SimpleNamespace(), cache) + + assert elig([KVCache(), KVCache()]) is True + assert elig([]) is False + assert elig([KVCache(), ChunkedKVCache(8192)]) is False + assert elig([KVCache(), RotatingKVCache(32)]) is False + assert elig([QuantizedKVCache()]) is False + assert elig([CacheList(KVCache(), KVCache())]) is True + assert elig([CacheList(KVCache(), RotatingKVCache(32))]) is False + + +def test_from_cache_merge_builds_working_batch(): + """Mirror the scheduler path: fp16 prefill -> from_cache (post-prefill + quantize) -> _merge_caches builds a BatchTurboQuantKVCache that decodes. + + Importing omlx.scheduler installs the TurboQuantKVCache.merge monkey-patch + that _merge_caches() relies on, so caches[0].merge([...]) is what the + BatchGenerator actually calls at insert() time. + """ + import omlx.scheduler # noqa: F401 (applies the merge monkey-patch) + + per_request = [] + for length in (8, 4): # two requests of different prefill lengths + kv = KVCache() + kv.update_and_fetch( + mx.random.normal((1, 2, length, 32)), + mx.random.normal((1, 2, length, 32)), + ) + per_request.append(TurboQuantKVCache.from_cache(kv, bits=4.0)) + mx.eval(*[c.keys for c in per_request]) + + # Exactly what mlx-lm _merge_caches() does for one layer. + batch = per_request[0].merge(per_request) + assert isinstance(batch, BatchTurboQuantKVCache) + assert batch.left_padding.tolist() == [0, 4] # request 1 left-padded + assert batch.offset.tolist() == [8, 4] # per-request valid lengths + + # A decode step + the real attention path the model uses: update_and_fetch + # returns correctly-sliced state proxies (NOT the full reserved buffer), + # and decode_attention runs over the batched left-padding mask. + ks, vs = batch.update_and_fetch( + mx.random.normal((2, 2, 1, 32)), + mx.random.normal((2, 2, 1, 32)), + ) + assert batch.offset.tolist() == [9, 5] # both requests advanced by 1 + out = batch.decode_attention( + mx.random.normal((2, 2, 1, 32)), + keys_state=ks, + values_state=vs, + scale=32**-0.5, + mask=batch.make_mask(1, return_array=True), + ) + mx.eval(out) + assert out.shape == (2, 2, 1, 32) # (B, n_q_heads, 1, D) + + +def test_decode_single_token_quantize_is_accurate(): + """Regression: the decode step appends ONE token via update_and_fetch. + + An earlier mlx-vlm fused single-token quantize kernel (used only for + keys.shape[-2] == 1) was broken — ~140% reconstruction error at every bit + depth — which garbled generation once TurboQuant decode engaged. It is fixed + on the pinned mlx-vlm (main). This test fails loudly if that regresses. + """ + from omlx.patches.turboquant_attention import apply_turboquant_attention_patch + + apply_turboquant_attention_patch() + + ctx_k = mx.random.normal((1, 8, 40, 64)) * 0.1 + ctx_v = mx.random.normal((1, 8, 40, 64)) * 0.1 + new_k = mx.random.normal((1, 8, 1, 64)) * 0.1 + new_v = mx.random.normal((1, 8, 1, 64)) * 0.1 + + tq = TurboQuantKVCache(bits=8.0) + tq.update_and_fetch(ctx_k, ctx_v) + tq.update_and_fetch(new_k, new_v) # the decode-step append (T=1) + dk, _ = tq.dequantize() + + rel_err = ( + mx.mean(mx.abs(dk[:, :, 40:41, :] - new_k)).item() + / mx.mean(mx.abs(new_k)).item() + ) + # 8-bit TurboQuant is near-lossless; broken kernel gives >100%. + assert rel_err < 0.05, f"decode-token quantize error {rel_err:.1%} (kernel bug?)" + + +def test_batch_masked_decode_is_accurate(): + """Regression: B>1 continuous-batching decode passes an array mask. + + The L=1 value kernels formerly corrupted the masked decode_attention path + under RHT (~140% error); the `not use_rht` guard is now fixed upstream in the + pinned mlx-vlm (Blaizzy/mlx-vlm#1244). This verifies the patched + scaled_dot_product_attention produces correct masked decode output for a B>1 + array mask — matching the dequantize+SDPA reference over the same states. + """ + from mlx_lm.models import base as mlx_base + + from omlx.patches.turboquant_attention import apply_turboquant_attention_patch + + apply_turboquant_attention_patch() + + # B=2 ragged batch (different prefill lengths) -> needs an array mask. + singles = [] + for length in (12, 8): + fp = KVCache() + fp.update_and_fetch( + mx.random.normal((1, 4, length, 32)) * 0.1, + mx.random.normal((1, 4, length, 32)) * 0.1, + ) + singles.append(TurboQuantKVCache.from_cache(fp, bits=8.0)) + batch = BatchTurboQuantKVCache.merge(singles) + + q = mx.random.normal((2, 16, 1, 32)) * 0.1 # B=2, 16 q-heads / 4 kv-heads + ks, vs = batch.update_and_fetch( + mx.random.normal((2, 4, 1, 32)) * 0.1, + mx.random.normal((2, 4, 1, 32)) * 0.1, + ) + dk, dv = batch.dequantize(ks, vs) + t_len = dk.shape[2] + mask = mx.ones((2, 1, 1, t_len), dtype=mx.bool_) + + out = mlx_base.scaled_dot_product_attention(q, ks, vs, batch, scale=32**-0.5, mask=mask) + ref = mx.fast.scaled_dot_product_attention( + q, dk.astype(q.dtype), dv.astype(q.dtype), scale=32**-0.5, mask=mask + ) + mx.eval(out, ref) + rel = mx.mean(mx.abs(out - ref)).item() / mx.mean(mx.abs(ref)).item() + # 8-bit quantized masked decode vs dequantize+SDPA over the same states. + # Broken RHT kernels give ~140%; the fix brings it into quantization noise. + assert rel < 0.05, f"B>1 masked decode inaccurate (err {rel:.1%}) — RHT fix missing from pinned mlx-vlm?" diff --git a/tests/test_turboquant_batch_memory.py b/tests/test_turboquant_batch_memory.py new file mode 100644 index 000000000..0acabdc60 --- /dev/null +++ b/tests/test_turboquant_batch_memory.py @@ -0,0 +1,229 @@ +"""Phase 2: batched TurboQuant accuracy + memory/occupancy vs single-seq. + +Three comparisons on a real model: + - occupancy: KV-cache bytes/token, TQ vs fp16, single vs batch (+ pad waste), + measured at a controlled length so over-allocation slack cancels; + long-context savings projected from per-token bytes. + - accuracy : concurrent B>1 TQ vs single-seq TQ (token match) + coherence. + - peak mem : live peak during decode, TQ vs fp16, single vs batch (with the + caveat that at short context the model weights dominate). + +Skips when the model is not cached. Run directly to write the report: + python tests/test_turboquant_batch_memory.py +""" +import importlib.util +from pathlib import Path + +import mlx.core as mx +import pytest +from mlx_lm.models.cache import KVCache, make_prompt_cache +from mlx_vlm.turboquant import TurboQuantKVCache + +MODEL_REPO = "mlx-community/Llama-3.2-1B-Instruct-4bit" +TQ_BITS = 4.0 +MAX_TOKENS = 32 +OCC_LEN = 512 # multiple of TurboQuant cache_step (256) → no over-alloc slack + + +def _model_path(): + try: + from huggingface_hub import snapshot_download + + return snapshot_download(MODEL_REPO, local_files_only=True) + except Exception: + return None + + +pytestmark = [ + pytest.mark.turboquant, + pytest.mark.slow, + pytest.mark.skipif(_model_path() is None, reason=f"{MODEL_REPO} not cached"), +] + + +def _helpers(): + spec = importlib.util.spec_from_file_location( + "itest", str(Path(__file__).parent / "integration" / "test_full_integration.py") + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def _prompts(tokenizer): + msgs = [ + "Name three primary colors.", + "What is the capital of Japan?", + "Write one sentence about the ocean.", + "List two kinds of fruit.", + ] + return [ + list(tokenizer.apply_chat_template( + [{"role": "user", "content": m}], add_generation_prompt=True)) + for m in msgs + ] + + +def _convert_to_tq(cache, bits, skip_last=True): + """Mirror Scheduler._apply_turboquant_kv_convert (dense KVCache only).""" + kv = [i for i, c in enumerate(cache) if isinstance(c, KVCache)] + last = kv[-1] if (skip_last and len(kv) > 1) else -1 + return [ + (c if (not isinstance(c, KVCache) or i == last) + else TurboQuantKVCache.from_cache(c, bits=bits)) + for i, c in enumerate(cache) + ] + + +def _occupancy_at(model, length, bits=None): + """KV bytes after feeding `length` tokens (fp16, or TQ-converted).""" + cache = make_prompt_cache(model) + model(mx.zeros((1, length), dtype=mx.int32), cache=cache) + mx.eval([c.state for c in cache]) + if bits is not None: + cache = _convert_to_tq(cache, bits) + mx.eval([c.state for c in cache if not isinstance(c, KVCache) or c.offset]) + return sum(c.nbytes for c in cache) + + +def _peak(fn): + mx.reset_peak_memory() + out = fn() + return out, mx.get_peak_memory() + + +def _gather(): + from mlx_lm import load + + helpers = _helpers() + model, tokenizer = load(_model_path()) + prompts = _prompts(tokenizer) + lens = [len(p) for p in prompts] + + # --- occupancy at a controlled length (over-alloc slack cancels) --- + occ_fp16 = _occupancy_at(model, OCC_LEN) + occ_tq = _occupancy_at(model, OCC_LEN, bits=TQ_BITS) + bpt_fp16 = occ_fp16 / OCC_LEN # bytes per token, fp16 + bpt_tq = occ_tq / OCC_LEN # bytes per token, TQ + # batch (B requests, left-padded to max len): analytical, no over-alloc noise + max_len = max(lens) + batch_bytes_tq = len(lens) * max_len * bpt_tq + batch_bytes_fp16 = len(lens) * max_len * bpt_fp16 # same lengths, fp16 + pad_waste = (len(lens) * max_len - sum(lens)) * bpt_tq + + # --- accuracy + live peak (through the real scheduler) --- + (single_tq, peak_single_tq) = _peak( + lambda: [helpers._generate_tokens(model, tokenizer, p, max_tokens=MAX_TOKENS, turboquant_bits=TQ_BITS)[0] for p in prompts]) + (_, peak_single_fp) = _peak( + lambda: [helpers._generate_tokens(model, tokenizer, p, max_tokens=MAX_TOKENS)[0] for p in prompts]) + (_, peak_batch_fp) = _peak( + lambda: helpers._generate_batch(model, tokenizer, prompts, mode="concurrent", max_tokens=MAX_TOKENS)) + (batch_tq_res, peak_batch_tq) = _peak( + lambda: helpers._generate_batch(model, tokenizer, prompts, mode="concurrent", max_tokens=MAX_TOKENS, turboquant_bits=TQ_BITS)) + + batch_tq = {rid: toks for rid, toks, _ in batch_tq_res} + matches = [] + for i in range(len(prompts)): + s, b = single_tq[i], batch_tq.get(f"batch-{i}", []) + n = min(len(s), len(b)) + matches.append(100.0 * sum(1 for k in range(n) if s[k] == b[k]) / n if n else 0.0) + + return dict( + lens=lens, occ_len=OCC_LEN, + occ_fp16=occ_fp16, occ_tq=occ_tq, bpt_fp16=bpt_fp16, bpt_tq=bpt_tq, + batch_bytes_tq=batch_bytes_tq, batch_bytes_fp16=batch_bytes_fp16, + pad_waste=pad_waste, max_len=max_len, + peak_single_fp=peak_single_fp, peak_single_tq=peak_single_tq, + peak_batch_fp=peak_batch_fp, peak_batch_tq=peak_batch_tq, + batch_tq=batch_tq, matches=matches, + ) + + +_M = None + + +def _metrics(): + global _M + if _M is None: + _M = _gather() + return _M + + +def test_batch_tq_coherent_and_tracks_single(): + m = _metrics() + for i in range(len(m["lens"])): + assert len(m["batch_tq"].get(f"batch-{i}", [])) >= 5, f"batch req {i} degenerate" + assert max(m["matches"]) >= 50.0, f"no request tracked single-seq: {m['matches']}" + + +def test_occupancy_tq_below_fp16(): + m = _metrics() + ratio = m["occ_tq"] / m["occ_fp16"] + assert ratio < 0.6, f"TQ occupancy ratio {ratio:.2f} not below fp16" + + +def test_batch_occupancy_beats_fp16_and_pad_nonnegative(): + m = _metrics() + # same lengths, so the batch saving equals the per-token ratio (<0.6) + assert m["batch_bytes_tq"] < 0.6 * m["batch_bytes_fp16"], "batch TQ not saving vs fp16" + assert m["pad_waste"] >= 0 + + +def test_peaks_recorded(): + m = _metrics() + for k in ("peak_single_fp", "peak_single_tq", "peak_batch_fp", "peak_batch_tq"): + assert m[k] > 0 + + +def _write_report(m, path="tq_batch_memory.md"): + gb, kb = 1024 ** 3, 1024 + nb = len(m["lens"]) + ratio = m["occ_tq"] / m["occ_fp16"] + # project savings at a long context where KV (not weights) dominates + proj_ctx = 8192 + proj_fp16 = nb * proj_ctx * m["bpt_fp16"] / gb + proj_tq = m["batch_bytes_tq"] / m["max_len"] * proj_ctx / gb + lines = [ + f"# Batched TurboQuant — memory/occupancy ({MODEL_REPO}, {TQ_BITS}-bit)\n", + f"Batch requests: {m['lens']} tokens; occupancy measured at {m['occ_len']} tokens.\n", + "## KV occupancy (storage)\n", + "| metric | value |", + "|---|---:|", + f"| fp16 bytes/token | {m['bpt_fp16']:,.0f} B |", + f"| TQ bytes/token | {m['bpt_tq']:,.0f} B |", + f"| TQ / fp16 ratio | {ratio:.3f}x |", + f"| batch(B={nb}) TQ bytes | {m['batch_bytes_tq']/kb:,.0f} KB |", + f"| batch(B={nb}) fp16 bytes | {m['batch_bytes_fp16']/kb:,.0f} KB |", + f"| batch TQ / fp16 (same lengths) | {m['batch_bytes_tq']/m['batch_bytes_fp16']:.3f}x |", + f"| left-padding waste | {m['pad_waste']/kb:,.1f} KB ({100*m['pad_waste']/m['batch_bytes_tq']:.0f}% of batch) |\n", + f"## Projected KV at {proj_ctx}-token context, B={nb} (where KV dominates)\n", + "| | total KV |", + "|---|---:|", + f"| fp16 | {proj_fp16:.2f} GB |", + f"| TQ | {proj_tq:.2f} GB |", + f"| saved | {proj_fp16 - proj_tq:.2f} GB ({100*(1-proj_tq/proj_fp16):.0f}%) |\n", + "## Peak memory, live decode (short prompts → weights dominate)\n", + "| scenario | peak |", + "|---|---:|", + f"| single-seq fp16 | {m['peak_single_fp']/gb:.3f} GB |", + f"| single-seq TQ | {m['peak_single_tq']/gb:.3f} GB |", + f"| batch fp16 | {m['peak_batch_fp']/gb:.3f} GB |", + f"| batch TQ | {m['peak_batch_tq']/gb:.3f} GB |", + "", + "_Note: at short context the 1B model weights (~0.7 GB) dominate peak;_", + "_TQ's win shows in the projected long-context KV above. B>1 decode now_", + "_runs the quantized kernels directly (no per-step batch dequantize)._\n", + "## Accuracy: batch vs single-seq TQ (token match)\n", + "| request | match % |", + "|---|---:|", + ] + for i, pct in enumerate(m["matches"]): + lines.append(f"| batch-{i} | {pct:.0f}% |") + Path(path).write_text("\n".join(lines) + "\n") + return path + + +if __name__ == "__main__": + p = _write_report(_metrics()) + print(f"wrote {p}\n") + print(Path(p).read_text()) diff --git a/tests/test_turboquant_ssd.py b/tests/test_turboquant_ssd.py new file mode 100644 index 000000000..990e89f57 --- /dev/null +++ b/tests/test_turboquant_ssd.py @@ -0,0 +1,129 @@ +"""Phase 3: TurboQuant + paged-SSD prefix cache (single + batch). + +Validates the SSD round-trip now that TurboQuant decode actually engages: +prefill boundary snapshots are stored fp16 and re-quantized deterministically +on a cache hit, so a hit reproduces the fresh run exactly — no double-quant +(TQ->fp16->TQ) drift. Covers both single-request and concurrent-batch decode. + +Skips when the model is not cached locally. +""" +import importlib.util +import shutil +import tempfile +from pathlib import Path + +import pytest + +MODEL_REPO = "mlx-community/Llama-3.2-1B-Instruct-4bit" +TQ_BITS = 4.0 +BLOCK = 256 + + +def _model_path(): + try: + from huggingface_hub import snapshot_download + + return snapshot_download(MODEL_REPO, local_files_only=True) + except Exception: + return None + + +pytestmark = [ + pytest.mark.turboquant, + pytest.mark.slow, + pytest.mark.skipif(_model_path() is None, reason=f"{MODEL_REPO} not cached"), +] + + +def _helpers(): + spec = importlib.util.spec_from_file_location( + "itest", str(Path(__file__).parent / "integration" / "test_full_integration.py") + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +_LOADED = None + + +def _load(): + global _LOADED + if _LOADED is None: + from mlx_lm import load + + helpers = _helpers() + model, tok = load(_model_path()) + # ~400-token prompt so a full 256-block is cached + text = "The history of computing spans many centuries of innovation. " * 40 + ids = list(tok.encode(text))[:400] + _LOADED = (helpers, model, tok, ids) + return _LOADED + + +def test_tq_ssd_single_hit_matches_fresh(): + helpers, model, tok, ids = _load() + tmp = tempfile.mkdtemp(prefix="ssd_tq_") + try: + fresh, c1 = helpers._generate_tokens( + model, tok, ids, max_tokens=16, + ssd_cache_dir=tmp, block_size=BLOCK, turboquant_bits=TQ_BITS) + cached, c2 = helpers._generate_tokens( + model, tok, ids, max_tokens=16, + ssd_cache_dir=tmp, block_size=BLOCK, turboquant_bits=TQ_BITS) + finally: + shutil.rmtree(tmp, ignore_errors=True) + + assert len(fresh) >= 5, "fresh TQ+SSD run produced no output" + assert c2 > 0, "second run did not hit the SSD cache" + # Deterministic re-quantization on restore -> identical to fresh. + assert fresh == cached, "TQ+SSD cache hit diverged from fresh (double-quant drift?)" + + +def _batch_fresh_vs_hit(helpers, model, tok, prompts, bits): + tmp = tempfile.mkdtemp(prefix="ssd_tq_batch_") + try: + fresh = {rid: t for rid, t, _ in helpers._generate_batch( + model, tok, prompts, mode="concurrent", max_tokens=16, + ssd_cache_dir=tmp, block_size=BLOCK, turboquant_bits=bits)} + hit = {rid: (t, c) for rid, t, c in helpers._generate_batch( + model, tok, prompts, mode="concurrent", max_tokens=16, + ssd_cache_dir=tmp, block_size=BLOCK, turboquant_bits=bits)} + finally: + shutil.rmtree(tmp, ignore_errors=True) + return fresh, hit + + +def test_tq_ssd_batch_roundtrip_exact_at_high_bits(): + """Structural SSD correctness: at near-lossless 8-bit, a batched cache hit + reproduces the fresh run exactly — proving the fp16-snapshot round-trip and + re-quantization introduce no drift in the B>1 path.""" + helpers, model, tok, ids = _load() + prefix = ids[:300] + prompts = [prefix + list(tok.encode(f" Topic {k}."))[:24] for k in range(3)] + fresh, hit = _batch_fresh_vs_hit(helpers, model, tok, prompts, bits=8.0) + for i in range(len(prompts)): + ft = fresh[f"batch-{i}"] + ht, hc = hit[f"batch-{i}"] + assert hc > 0, f"batch req {i} did not hit SSD cache" + assert ft == ht, f"8-bit batch req {i} hit diverged from fresh (round-trip drift)" + + +def test_tq_ssd_batch_coherent_at_low_bits(): + """At lossy 4-bit, batched fresh-vs-hit may diverge by a few tokens where + quantization tips a greedy near-tie (single-request stays exact; fp16 is + exact) — output must still be coherent with the cache hit working. This + residual divergence resolves when the upstream masked-decode kernel (Bug 2) + lets B>1 use the same fused path as B=1.""" + helpers, model, tok, ids = _load() + prefix = ids[:300] + prompts = [prefix + list(tok.encode(f" Topic {k}."))[:24] for k in range(3)] + fresh, hit = _batch_fresh_vs_hit(helpers, model, tok, prompts, bits=TQ_BITS) + for i in range(len(prompts)): + ft = fresh[f"batch-{i}"] + ht, hc = hit[f"batch-{i}"] + assert len(ht) >= 3, f"batch req {i} degenerate under TQ+SSD" + assert hc > 0, f"batch req {i} did not hit SSD cache" + n = min(len(ft), len(ht)) + match = sum(1 for k in range(n) if ft[k] == ht[k]) / n if n else 0.0 + assert match >= 0.5, f"batch req {i} hit overlap {match:.0%} too low (not just a near-tie)"