From 90d56f7e7a7e2b4082c25eac124af1d0289581cd Mon Sep 17 00:00:00 2001 From: Pushkinist <4850452+Pushkinist@users.noreply.github.com> Date: Thu, 18 Jun 2026 14:34:11 +0700 Subject: [PATCH 1/2] fix(mm-cache): key encoder-output cache on model identity to stop cross-model leak (#132) --- README.md | 30 +++-- crates/rmlx-models/src/gemma3/vision.rs | 2 + crates/rmlx-models/src/gemma4/tests.rs | 2 +- crates/rmlx-models/src/gemma4/vision/mod.rs | 2 + .../rmlx-models/src/gemma4/vision/unified.rs | 2 + crates/rmlx-models/src/jina_v4/mod.rs | 10 +- crates/rmlx-models/src/multimodal_cache.rs | 72 +++++++++--- .../rmlx-models/src/multimodal_cache_tests.rs | 103 ++++++++++++++---- crates/rmlx-server/src/embeddings.rs | 15 ++- .../rmlx-server/src/engine/arch_generator.rs | 8 ++ crates/rmlx-server/src/engine/image.rs | 9 +- docs/CLI.md | 2 +- 12 files changed, 202 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index 9e2cef4..011d6d6 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,10 @@ KV-quantization matrix any MLX server ships, including rotation-based KV familie (TurboQuant, IsoQuant, PlanarQuant, RotorQuant, ParoQuant) that no other MLX server offers. -> Status: **0.1.0** — first feature-complete native MLX backend. Apple Silicon -> only (Metal). See [What works](#what-works). +> Status: **0.2.2** — feature-complete native MLX backend: OpenAI- and +> Anthropic-compatible text, tool/function calling, streaming, image + audio +> input, embeddings, and a multi-model registry. Apple Silicon only (Metal). +> See [What works](#what-works). ## Why @@ -21,22 +23,30 @@ server offers. ## What works - **Text generation** — OpenAI-compatible `/v1/chat/completions` and - `/v1/completions`, plus an Anthropic-compatible surface. Temperature, top-k/p, - penalties, thinking-budget, constrained / schema-guided decoding. -- **Image input** — vision-capable models (Gemma 4 SigLIP tower, Qwen3-VL-MoE - deepstack) accept images via `image_url` content parts (data-URI, http, file - path, or base64). -- **Audio input** — audio transcription / translation endpoints for - audio-capable models. + `/v1/completions`, plus an Anthropic-compatible surface. Streaming (SSE), + temperature, top-k/p, penalties, thinking-budget, constrained / schema-guided + decoding. +- **Image input** — vision-capable models accept images via `image_url` content + parts (data-URI, http, file path, or base64): Gemma 4 SigLIP tower (e4b / + 26b), the encoder-free Gemma 4 12B `gemma4_unified` any-to-any architecture, + jina-v4, and Qwen3-VL-MoE deepstack. +- **Audio input** — audio-capable models accept audio (Gemma 4 unified Conformer + tower) plus Whisper speech-to-text via the model-agnostic `rmlx transcribe` + CLI (txt / vtt / srt / json, long-form chunking). - **Embeddings** — `/v1/embeddings`, including multimodal (text + image) jina-v4. - **Tool / function calling** — OpenAI `tool_calls` and Anthropic `tool_use`, multi-turn, multiple emit formats (Qwen XML, Hermes-JSON, Gemma). +- **Multi-model registry** — serve many models from one process with + load-on-demand / unload-on-idle, a bounded resident-model cap, and a shared + multimodal encoder-output cache (scoped per model). - **Quantization** — affine 2–8 bit, mxfp4 / mxfp8, nvfp4, ParoQuant weights; KV-cache quant incl. fp8, TurboQuant, RotorQuant, PlanarQuant, IsoQuant, paged-KV, mixed / asymmetric K/V, and an SSD KV tier. - **Speculative decoding** — MTP, DFlash, and Eagle3 drafters. - **Prompt caching** — automatic prefix caching with block hashing. -- **Conversion** — `rmlx convert` re-quantizes / repacks MLX → MLX. + +Conversion (`rmlx convert`, MLX → MLX re-quantize / layout repack) is a roadmap +target and not yet shipped. Continuously smoke-tested end-to-end. The first four families carry committed golden-token decode gates (temp=0, exact token-id match); embeddings and the diff --git a/crates/rmlx-models/src/gemma3/vision.rs b/crates/rmlx-models/src/gemma3/vision.rs index 97a4485..5270ce4 100644 --- a/crates/rmlx-models/src/gemma3/vision.rs +++ b/crates/rmlx-models/src/gemma3/vision.rs @@ -567,6 +567,7 @@ pub fn build_inputs_embeds( input_ids: &[u32], device: Device, mm_cache: Option<&crate::multimodal_cache::MultimodalCache>, + model_sig: u64, ) -> Result<(Array, Array)> { let hidden = model.cfg.hidden_size as i32; let seq = input_ids.len(); @@ -627,6 +628,7 @@ pub fn build_inputs_embeds( u16::try_from(pv.width).unwrap_or(u16::MAX), 3, crate::multimodal_cache::MmDtype::F32, + model_sig, ); let feats = crate::multimodal_cache::get_or_compute(mm_cache, key, || { projector.forward(&vision.forward(pv, device)?, device) diff --git a/crates/rmlx-models/src/gemma4/tests.rs b/crates/rmlx-models/src/gemma4/tests.rs index 73b2feb..2a9ea8b 100644 --- a/crates/rmlx-models/src/gemma4/tests.rs +++ b/crates/rmlx-models/src/gemma4/tests.rs @@ -657,7 +657,7 @@ fn build_inputs_embeds_aligns_and_shapes() { ); let (embeds, masked) = - build_inputs_embeds(&model, &vision, &embedder, &[pv], &ids, device, None) + build_inputs_embeds(&model, &vision, &embedder, &[pv], &ids, device, None, 0) .expect("build embeds"); embeds.eval().expect("eval embeds"); diff --git a/crates/rmlx-models/src/gemma4/vision/mod.rs b/crates/rmlx-models/src/gemma4/vision/mod.rs index 0b83c62..14ec65d 100644 --- a/crates/rmlx-models/src/gemma4/vision/mod.rs +++ b/crates/rmlx-models/src/gemma4/vision/mod.rs @@ -814,6 +814,7 @@ pub fn build_inputs_embeds( input_ids: &[u32], device: Device, mm_cache: Option<&crate::multimodal_cache::MultimodalCache>, + model_sig: u64, ) -> Result<(Array, Array)> { let hidden = model.cfg.hidden_size as i32; let seq = input_ids.len(); @@ -867,6 +868,7 @@ pub fn build_inputs_embeds( u16::try_from(pv.width).unwrap_or(u16::MAX), 3, crate::multimodal_cache::MmDtype::F32, + model_sig, ); let feats = crate::multimodal_cache::get_or_compute(mm_cache, key, || { embedder.forward(&vision.forward(pv, device)?, device) diff --git a/crates/rmlx-models/src/gemma4/vision/unified.rs b/crates/rmlx-models/src/gemma4/vision/unified.rs index 6175dc2..60d9f67 100644 --- a/crates/rmlx-models/src/gemma4/vision/unified.rs +++ b/crates/rmlx-models/src/gemma4/vision/unified.rs @@ -569,6 +569,7 @@ pub fn build_unified_inputs_embeds( input_ids: &[u32], device: Device, mm_cache: Option<&crate::multimodal_cache::MultimodalCache>, + model_sig: u64, ) -> Result<(Array, Array)> { let hidden = model.cfg.hidden_size as i32; let seq = input_ids.len(); @@ -619,6 +620,7 @@ pub fn build_unified_inputs_embeds( u16::try_from(pv.width).unwrap_or(u16::MAX), 3, crate::multimodal_cache::MmDtype::F32, + model_sig, ); let feats = crate::multimodal_cache::get_or_compute(mm_cache, key, || { embedder.forward(pv, device) diff --git a/crates/rmlx-models/src/jina_v4/mod.rs b/crates/rmlx-models/src/jina_v4/mod.rs index 1d96152..5defe04 100644 --- a/crates/rmlx-models/src/jina_v4/mod.rs +++ b/crates/rmlx-models/src/jina_v4/mod.rs @@ -191,6 +191,7 @@ impl JinaV4 { pixel_values: &PixelValues, device: rmlx_mlx::Device, mm_cache: Option<&crate::multimodal_cache::MultimodalCache>, + model_sig: u64, ) -> Result<(rmlx_mlx::Array, Vec)> { let image_token_id = i64::from(self.config.image_token_id); let merge = self.config.vision_config.spatial_merge_size; @@ -212,6 +213,7 @@ impl JinaV4 { u16::try_from(g.w).unwrap_or(u16::MAX), 3, crate::multimodal_cache::MmDtype::F32, + model_sig, ); let vision_feats = crate::multimodal_cache::get_or_compute(mm_cache, key, || { self.vision.forward(pixel_values, device) @@ -251,8 +253,10 @@ impl JinaV4 { device: rmlx_mlx::Device, truncate_dim: Option, mm_cache: Option<&crate::multimodal_cache::MultimodalCache>, + model_sig: u64, ) -> Result> { - let (hidden, input_ids) = self.image_hidden(prompt_ids, pixel_values, device, mm_cache)?; + let (hidden, input_ids) = + self.image_hidden(prompt_ids, pixel_values, device, mm_cache, model_sig)?; let (start, end) = image::vision_span( &input_ids, i64::from(self.config.vision_start_token_id), @@ -279,8 +283,10 @@ impl JinaV4 { pixel_values: &PixelValues, device: rmlx_mlx::Device, mm_cache: Option<&crate::multimodal_cache::MultimodalCache>, + model_sig: u64, ) -> Result>> { - let (hidden, _input_ids) = self.image_hidden(prompt_ids, pixel_values, device, mm_cache)?; + let (hidden, _input_ids) = + self.image_hidden(prompt_ids, pixel_values, device, mm_cache, model_sig)?; let projected = self.projector.forward(&hidden, device)?; pooling::multi_vector(&projected, device) } diff --git a/crates/rmlx-models/src/multimodal_cache.rs b/crates/rmlx-models/src/multimodal_cache.rs index 4bbc9bb..e948316 100644 --- a/crates/rmlx-models/src/multimodal_cache.rs +++ b/crates/rmlx-models/src/multimodal_cache.rs @@ -15,14 +15,22 @@ //! modality. Wrap in `Arc` for cheap clone. //! - **Hash recipe** (lifted from //! `dynamo/components/.../vllm/multimodal_utils/hash_utils.py:32-95`): -//! fixed 12-byte header `[version:u8=1, mode:u8 (0=image | 1=audio), +//! fixed 20-byte header `[version:u8=2, mode:u8 (0=image | 1=audio), //! dtype:u8 (0=bf16 | 1=f32), channels:u8, dim1:u16 LE, dim2:u16 LE, -//! reserved:u32 LE]` + the canonical preprocessed pixel/PCM bytes. -//! The fixed header prevents `(H,W)` collision when the same pixel byte -//! stream is reshaped to a different geometry. For audio, `dim1`/`dim2` are -//! `0` and the `reserved` field carries the sample rate; `channels` lands -//! in `header[3]` so mono vs stereo PCM byte runs cannot alias; `n_samples` -//! is implied by the trailing byte length. +//! reserved:u32 LE, model_sig:u64 LE]` + the canonical preprocessed +//! pixel/PCM bytes. The fixed header prevents `(H,W)` collision when the +//! same pixel byte stream is reshaped to a different geometry. For audio, +//! `dim1`/`dim2` are `0` and the `reserved` field carries the sample rate; +//! `channels` lands in `header[3]` so mono vs stereo PCM byte runs cannot +//! alias; `n_samples` is implied by the trailing byte length. +//! - **`model_sig`** is a stable per-loaded-model identity (a u64 hash of the +//! model's registry id / canonical snapshot path). It scopes every cached +//! encoder output to the model that produced it, so a multi-model +//! `--registry` server never serves model A's vision-tower output (projected +//! to A's hidden size) to model B for the same image. Without it, two models +//! sharing one cache and the same image collide on the content hash → a soft +//! token embedded at the wrong hidden dim → shape-mismatch failure. Same +//! model + same input ⇒ same `model_sig` ⇒ the cache still hits. //! - **Hasher**: `twox-hash` xxh3_64 (MIT, tiny, no transitive cost). Stored //! as 8 bytes (the raw digest). If the digest is ever widened, bump the //! `version` byte in the header along with the key array size — this type @@ -119,7 +127,10 @@ impl MmCacheKey { /// Build an image key. `pixel_bytes` is the **post-preprocess** flat /// pixel buffer (e.g. the same `&[f32]` the vision tower reads). The - /// 12-byte header prevents `(H,W)` reshape collisions. + /// 20-byte header prevents `(H,W)` reshape collisions. `model_sig` is the + /// stable per-loaded-model identity (see [`model_sig`]) — it scopes the + /// key to the producing model so encoder outputs are never shared across + /// models in a multi-model registry. #[must_use] pub fn image_key( pixel_bytes: &[u8], @@ -127,18 +138,28 @@ impl MmCacheKey { width: u16, channels: u8, dtype: MmDtype, + model_sig: u64, ) -> Self { - let header = build_header(MmMode::Image, dtype, channels, height, width, 0); + let header = build_header(MmMode::Image, dtype, channels, height, width, 0, model_sig); Self(digest(&header, pixel_bytes)) } /// Build an audio key. `pcm_bytes` is the **post-preprocess** PCM byte /// stream (typically f32 mono). `sample_rate` lands in the `reserved` /// field so two identical PCM byte runs at different sample rates do not - /// alias. `dim1`/`dim2` are unused for audio (set to 0). + /// alias. `dim1`/`dim2` are unused for audio (set to 0). `model_sig` is the + /// stable per-loaded-model identity (see [`model_sig`]) — it scopes the + /// key to the producing model so encoder outputs are never shared across + /// models in a multi-model registry. #[must_use] - pub fn audio_key(pcm_bytes: &[u8], sample_rate: u32, dtype: MmDtype, channels: u8) -> Self { - let header = build_header(MmMode::Audio, dtype, channels, 0, 0, sample_rate); + pub fn audio_key( + pcm_bytes: &[u8], + sample_rate: u32, + dtype: MmDtype, + channels: u8, + model_sig: u64, + ) -> Self { + let header = build_header(MmMode::Audio, dtype, channels, 0, 0, sample_rate, model_sig); Self(digest(&header, pcm_bytes)) } @@ -156,19 +177,21 @@ fn build_header( dim1: u16, dim2: u16, reserved: u32, -) -> [u8; 12] { - let mut header = [0u8; 12]; - header[0] = 1; // version + model_sig: u64, +) -> [u8; 20] { + let mut header = [0u8; 20]; + header[0] = 2; // version — bumped from 1 when `model_sig` was folded in header[1] = mode as u8; header[2] = dtype as u8; header[3] = channels; header[4..6].copy_from_slice(&dim1.to_le_bytes()); header[6..8].copy_from_slice(&dim2.to_le_bytes()); header[8..12].copy_from_slice(&reserved.to_le_bytes()); + header[12..20].copy_from_slice(&model_sig.to_le_bytes()); header } -fn digest(header: &[u8; 12], payload: &[u8]) -> [u8; 8] { +fn digest(header: &[u8; 20], payload: &[u8]) -> [u8; 8] { // Streaming xxh3_64 over (header || payload). Avoids the ~9.6 MiB temp // allocation a one-shot path would need for a 896×896×3 f32 image. // `alloc` feature on `twox-hash` is enabled at the workspace level for @@ -648,6 +671,23 @@ pub fn pcm_f32_bytes(pcm: &[f32]) -> &[u8] { pixel_f32_bytes(pcm) } +/// Derive the stable per-loaded-model signature folded into every +/// [`MmCacheKey`]. `id` should be a stable identity for the loaded model — +/// its registry id or canonical snapshot path. The same id always yields the +/// same `u64`, so repeat same-model requests still hit the cache; two distinct +/// models (different ids) yield different signatures, so their encoder outputs +/// never collide on a shared cache. +/// +/// `hidden_size` alone is deliberately NOT used: two different models can share +/// a hidden size yet produce different features. The id/path is the +/// collision-safe identity. +#[must_use] +pub fn model_sig(id: &str) -> u64 { + let mut h = XxHash3_64Hasher::new(); + h.write(id.as_bytes()); + h.finish() +} + #[cfg(test)] #[path = "multimodal_cache_tests.rs"] mod multimodal_cache_tests; diff --git a/crates/rmlx-models/src/multimodal_cache_tests.rs b/crates/rmlx-models/src/multimodal_cache_tests.rs index 7badc79..8e0f7b9 100644 --- a/crates/rmlx-models/src/multimodal_cache_tests.rs +++ b/crates/rmlx-models/src/multimodal_cache_tests.rs @@ -16,11 +16,16 @@ fn make_array(elems: i32) -> Array { Array::from_bytes(&data, &[elems], Dtype::F32).expect("Array::from_bytes") } +/// Fixed model signature used by tests that are not exercising the +/// per-model scoping itself. +const SIG_A: u64 = 0x1111_2222_3333_4444; +const SIG_B: u64 = 0x5555_6666_7777_8888; + #[test] fn image_key_same_pixels_same_hash() { let pixels = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - let k1 = MmCacheKey::image_key(&pixels, 64, 64, 3, MmDtype::F32); - let k2 = MmCacheKey::image_key(&pixels, 64, 64, 3, MmDtype::F32); + let k1 = MmCacheKey::image_key(&pixels, 64, 64, 3, MmDtype::F32, SIG_A); + let k2 = MmCacheKey::image_key(&pixels, 64, 64, 3, MmDtype::F32, SIG_A); assert_eq!(k1, k2); } @@ -28,8 +33,8 @@ fn image_key_same_pixels_same_hash() { fn image_key_different_dims_different_hash() { // Same byte stream, different declared (H, W). Must NOT collide. let pixels = vec![0u8; 64]; - let k1 = MmCacheKey::image_key(&pixels, 8, 8, 1, MmDtype::F32); - let k2 = MmCacheKey::image_key(&pixels, 4, 16, 1, MmDtype::F32); + let k1 = MmCacheKey::image_key(&pixels, 8, 8, 1, MmDtype::F32, SIG_A); + let k2 = MmCacheKey::image_key(&pixels, 4, 16, 1, MmDtype::F32, SIG_A); assert_ne!( k1, k2, "(H,W) reshape collision: header is supposed to disambiguate" @@ -39,32 +44,90 @@ fn image_key_different_dims_different_hash() { #[test] fn image_key_dtype_disambiguates() { let pixels = vec![7u8; 16]; - let kf = MmCacheKey::image_key(&pixels, 4, 4, 1, MmDtype::F32); - let kb = MmCacheKey::image_key(&pixels, 4, 4, 1, MmDtype::Bf16); + let kf = MmCacheKey::image_key(&pixels, 4, 4, 1, MmDtype::F32, SIG_A); + let kb = MmCacheKey::image_key(&pixels, 4, 4, 1, MmDtype::Bf16, SIG_A); assert_ne!(kf, kb); } +#[test] +fn image_key_model_sig_disambiguates() { + // Same pixels + geometry + dtype, different model. Must NOT collide — + // this is the cross-model leak the signature exists to close. + let pixels = vec![9u8; 48]; + let ka = MmCacheKey::image_key(&pixels, 16, 16, 3, MmDtype::F32, SIG_A); + let kb = MmCacheKey::image_key(&pixels, 16, 16, 3, MmDtype::F32, SIG_B); + assert_ne!( + ka, kb, + "distinct models must produce distinct keys for the same image" + ); +} + +#[test] +fn image_key_same_model_sig_same_hash() { + // Same model + same input ⇒ same key ⇒ cache still hits (no perf + // regression for repeat same-model requests). + let pixels = vec![9u8; 48]; + let k1 = MmCacheKey::image_key(&pixels, 16, 16, 3, MmDtype::F32, SIG_A); + let k2 = MmCacheKey::image_key(&pixels, 16, 16, 3, MmDtype::F32, SIG_A); + assert_eq!(k1, k2); +} + #[test] fn audio_key_sr_disambiguates() { let pcm = vec![0u8; 32]; - let k16 = MmCacheKey::audio_key(&pcm, 16_000, MmDtype::F32, 1); - let k48 = MmCacheKey::audio_key(&pcm, 48_000, MmDtype::F32, 1); + let k16 = MmCacheKey::audio_key(&pcm, 16_000, MmDtype::F32, 1, SIG_A); + let k48 = MmCacheKey::audio_key(&pcm, 48_000, MmDtype::F32, 1, SIG_A); assert_ne!(k16, k48, "sample-rate must change the digest"); } #[test] fn audio_key_same_inputs_same_hash() { let pcm = vec![3u8; 40]; - let a = MmCacheKey::audio_key(&pcm, 16_000, MmDtype::F32, 1); - let b = MmCacheKey::audio_key(&pcm, 16_000, MmDtype::F32, 1); + let a = MmCacheKey::audio_key(&pcm, 16_000, MmDtype::F32, 1, SIG_A); + let b = MmCacheKey::audio_key(&pcm, 16_000, MmDtype::F32, 1, SIG_A); assert_eq!(a, b); } +#[test] +fn audio_key_model_sig_disambiguates() { + // Same PCM + sample-rate + dtype + channels, different model. Must NOT + // collide — symmetric hardening for the (not-yet-wired) Whisper path. + let pcm = vec![5u8; 40]; + let ka = MmCacheKey::audio_key(&pcm, 16_000, MmDtype::F32, 1, SIG_A); + let kb = MmCacheKey::audio_key(&pcm, 16_000, MmDtype::F32, 1, SIG_B); + assert_ne!( + ka, kb, + "distinct models must produce distinct keys for the same audio" + ); +} + +#[test] +fn audio_key_same_model_sig_same_hash() { + let pcm = vec![5u8; 40]; + let a = MmCacheKey::audio_key(&pcm, 16_000, MmDtype::F32, 1, SIG_A); + let b = MmCacheKey::audio_key(&pcm, 16_000, MmDtype::F32, 1, SIG_A); + assert_eq!(a, b); +} + +#[test] +fn model_sig_stable_and_distinct() { + // The helper used at every call site: same id ⇒ same sig; different id ⇒ + // different sig (with overwhelming probability for distinct strings). + assert_eq!( + model_sig("mlx-community__gemma-4-e2b-it-mxfp8"), + model_sig("mlx-community__gemma-4-e2b-it-mxfp8") + ); + assert_ne!( + model_sig("mlx-community__gemma-4-e2b-it-mxfp8"), + model_sig("mlx-community__gemma-4-e4b-it-mxfp8") + ); +} + #[test] fn disabled_cache_is_noop() { let c = MultimodalCache::new(0); assert!(c.is_disabled()); - let key = MmCacheKey::image_key(b"abc", 1, 1, 1, MmDtype::F32); + let key = MmCacheKey::image_key(b"abc", 1, 1, 1, MmDtype::F32, SIG_A); let arr = make_array(4); let sz = array_byte_size(&arr).expect("array_byte_size"); c.put(key, arr, sz); @@ -78,7 +141,7 @@ fn disabled_cache_is_noop() { #[test] fn get_miss_then_hit_after_put() { let c = MultimodalCache::new(1 << 20); - let key = MmCacheKey::image_key(b"px", 1, 1, 1, MmDtype::F32); + let key = MmCacheKey::image_key(b"px", 1, 1, 1, MmDtype::F32, SIG_A); assert!(c.get(&key).is_none()); let arr = make_array(8); let sz = array_byte_size(&arr).expect("array_byte_size"); @@ -100,9 +163,9 @@ fn lru_evicts_to_budget() { let budget = arr_bytes * 2; let c = MultimodalCache::new(budget); - let k1 = MmCacheKey::image_key(b"k1", 1, 1, 1, MmDtype::F32); - let k2 = MmCacheKey::image_key(b"k2", 1, 1, 1, MmDtype::F32); - let k3 = MmCacheKey::image_key(b"k3", 1, 1, 1, MmDtype::F32); + let k1 = MmCacheKey::image_key(b"k1", 1, 1, 1, MmDtype::F32, SIG_A); + let k2 = MmCacheKey::image_key(b"k2", 1, 1, 1, MmDtype::F32, SIG_A); + let k3 = MmCacheKey::image_key(b"k3", 1, 1, 1, MmDtype::F32, SIG_A); c.put(k1, make_array(elems_per_entry), arr_bytes); c.put(k2, make_array(elems_per_entry), arr_bytes); // Touch k1 so k2 becomes LRU. @@ -123,7 +186,7 @@ fn put_oversize_is_noop() { let c = MultimodalCache::new(16); let arr = make_array(64); // 256 bytes let sz = array_byte_size(&arr).expect("array_byte_size"); - let key = MmCacheKey::image_key(b"big", 1, 1, 1, MmDtype::F32); + let key = MmCacheKey::image_key(b"big", 1, 1, 1, MmDtype::F32, SIG_A); c.put(key, arr, sz); assert!(c.get(&key).is_none(), "oversize entry must not be cached"); let s = c.stats(); @@ -134,7 +197,7 @@ fn put_oversize_is_noop() { #[test] fn stats_track_hits_misses() { let c = MultimodalCache::new(1 << 20); - let key = MmCacheKey::image_key(b"x", 2, 2, 1, MmDtype::F32); + let key = MmCacheKey::image_key(b"x", 2, 2, 1, MmDtype::F32, SIG_A); // miss x3 for _ in 0..3 { let _ = c.get(&key); @@ -154,7 +217,7 @@ fn stats_track_hits_misses() { #[test] fn clear_drops_entries_keeps_counters() { let c = MultimodalCache::new(1 << 20); - let key = MmCacheKey::image_key(b"y", 2, 2, 1, MmDtype::F32); + let key = MmCacheKey::image_key(b"y", 2, 2, 1, MmDtype::F32, SIG_A); let _ = c.get(&key); let arr = make_array(4); let sz = array_byte_size(&arr).expect("array_byte_size"); @@ -175,7 +238,7 @@ fn concurrent_get_put_safe() { handles.push(std::thread::spawn(move || { for i in 0..32 { let buf = [t as u8, i as u8, 0, 0]; - let key = MmCacheKey::image_key(&buf, 1, 1, 1, MmDtype::F32); + let key = MmCacheKey::image_key(&buf, 1, 1, 1, MmDtype::F32, SIG_A); if c.get(&key).is_none() { let arr = make_array(2); let sz = array_byte_size(&arr).expect("array_byte_size"); @@ -201,6 +264,6 @@ fn array_byte_size_matches_dtype_and_shape() { #[test] fn key_short_hex_is_8_chars() { - let k = MmCacheKey::image_key(b"abc", 1, 1, 1, MmDtype::F32); + let k = MmCacheKey::image_key(b"abc", 1, 1, 1, MmDtype::F32, SIG_A); assert_eq!(k.short_hex().len(), 8); } diff --git a/crates/rmlx-server/src/embeddings.rs b/crates/rmlx-server/src/embeddings.rs index 2a28e92..8c359b4 100644 --- a/crates/rmlx-server/src/embeddings.rs +++ b/crates/rmlx-server/src/embeddings.rs @@ -653,12 +653,16 @@ fn compute_embeddings( prompt_ids, pixel_values, } => { + // Scope every mm-cache entry to this loaded model so a shared + // (multi-model) cache never serves another model's vision output + // for the same image. + let model_sig = rmlx_models::multimodal_cache::model_sig(model_id); data.reserve(pixel_values.len()); for (index, pv) in pixel_values.iter().enumerate() { if return_multivector { let mv = holder .model - .embed_image_multi(&prompt_ids, pv, device, mm_cache) + .embed_image_multi(&prompt_ids, pv, device, mm_cache, model_sig) .map_err(|e| { EmbedError::Compute(format!("embed_image_multi failed: {e}")) })?; @@ -666,7 +670,14 @@ fn compute_embeddings( } else { let v = holder .model - .embed_image_single(&prompt_ids, pv, device, truncate_dim, mm_cache) + .embed_image_single( + &prompt_ids, + pv, + device, + truncate_dim, + mm_cache, + model_sig, + ) .map_err(classify_single_err)?; push_single(&mut data, index, v); } diff --git a/crates/rmlx-server/src/engine/arch_generator.rs b/crates/rmlx-server/src/engine/arch_generator.rs index a29298e..40a2d65 100644 --- a/crates/rmlx-server/src/engine/arch_generator.rs +++ b/crates/rmlx-server/src/engine/arch_generator.rs @@ -743,6 +743,12 @@ impl Generator for ArchGenerator { // kv_quant_override and max_ctx_override are set at server-startup time via // from_snapshot and passed here; None means use arch default. let model_id_for_log = self.model_id.clone(); + // Per-loaded-model signature folded into every multimodal-cache key so + // a shared (multi-model `--registry`) encoder-output cache never serves + // one model's vision/audio features to another for the same input. The + // model id is the stable identity; same id ⇒ same sig ⇒ cache still + // hits for repeat same-model requests. + let model_sig = rmlx_models::multimodal_cache::model_sig(&self.model_id); tokio::task::spawn_blocking(move || { // Acquire the serialisation lock. let _guard = { @@ -964,6 +970,7 @@ impl Generator for ArchGenerator { &prompt_tokens, device, mm_cache.as_deref(), + model_sig, ) { Ok(triple) => Some(triple), Err(e) => { @@ -1074,6 +1081,7 @@ impl Generator for ArchGenerator { &penalty_cfg, &mut token_history, mm_cache.as_deref(), + model_sig, ) } else { match multimodal_inputs { diff --git a/crates/rmlx-server/src/engine/image.rs b/crates/rmlx-server/src/engine/image.rs index 67de39c..0d55549 100644 --- a/crates/rmlx-server/src/engine/image.rs +++ b/crates/rmlx-server/src/engine/image.rs @@ -68,6 +68,7 @@ pub(crate) fn build_image_prompt( prompt_tokens: &[u32], device: rmlx_mlx::Device, mm_cache: Option<&rmlx_models::multimodal_cache::MultimodalCache>, + model_sig: u64, ) -> rmlx_core::Result<(Vec, rmlx_mlx::Array, rmlx_mlx::Array)> { // Splice per-image token blocks (`` + N×image-token + ``) in after // the prompt's leading BOS token so the image conditions the whole turn @@ -138,7 +139,7 @@ pub(crate) fn build_image_prompt( ); let (embeds, masked_ids) = rmlx_models::gemma4::build_inputs_embeds( - model, vision, embedder, &pixels, &aug_ids, device, mm_cache, + model, vision, embedder, &pixels, &aug_ids, device, mm_cache, model_sig, )?; Ok((aug_ids, embeds, masked_ids)) } @@ -199,7 +200,7 @@ pub(crate) fn build_image_prompt( let pv_only: Vec = pixels.into_iter().map(|(pv, _)| pv).collect(); let (embeds, masked_ids) = rmlx_models::gemma4::build_unified_inputs_embeds( - model, embedder, &pv_only, &aug_ids, device, mm_cache, + model, embedder, &pv_only, &aug_ids, device, mm_cache, model_sig, )?; Ok((aug_ids, embeds, masked_ids)) } @@ -252,7 +253,7 @@ pub(crate) fn build_image_prompt( ); let (embeds, ids) = rmlx_models::gemma3::build_inputs_embeds( - model, vision, projector, &pixels, &aug_ids, device, mm_cache, + model, vision, projector, &pixels, &aug_ids, device, mm_cache, model_sig, )?; // Gemma3 has no masked-ids concept; pass the plain ids array through // the same (Vec, Array, Array) shape so the caller is arch-agnostic. @@ -304,6 +305,7 @@ pub(crate) fn run_qwen3vl_image( penalty_cfg: &rmlx_models::PenaltyConfig, token_history: &mut Vec, mm_cache: Option<&rmlx_models::multimodal_cache::MultimodalCache>, + model_sig: u64, ) -> rmlx_core::Result> { // Registering a thread-local GPU stream + CommandEncoder once per thread entry point. // tokio blocking-pool threads start with no GPU stream context; MLX's array @@ -370,6 +372,7 @@ pub(crate) fn run_qwen3vl_image( u16::try_from(gw).unwrap_or(u16::MAX), 3, rmlx_models::multimodal_cache::MmDtype::F32, + model_sig, ); if let Some(mut arrays) = cache.get_many(&key) { // First array is the merged image_embeds, the rest are deepstack. diff --git a/docs/CLI.md b/docs/CLI.md index a2f5e39..c7685f3 100644 --- a/docs/CLI.md +++ b/docs/CLI.md @@ -95,7 +95,7 @@ mutually exclusive. | `--whisper-tokenizer-path` | path | — | Path to a directory containing `tokenizer.json` (e.g. `openai/whisper-large-v3`). Required for audio endpoints; the mlx-community Whisper snapshot does not ship tokenizer files. Env: `RMLX_WHISPER_TOKENIZER_PATH`. | | `--tts-model-path` | path | — | Path to a Qwen3-TTS model snapshot directory. Required for `POST /v1/audio/speech`. Codec decoder not yet implemented; returns 501 until then. Env: `RMLX_TTS_MODEL_PATH`. | | `--tts-tokenizer-path` | path | — | Path to the Qwen3-TTS speech tokenizer snapshot directory. Used alongside `--tts-model-path`. Env: `RMLX_TTS_TOKENIZER_PATH`. | -| `--mm-cache-bytes` | usize | `536870912` (512 MiB) | Byte budget for the multimodal encoder-output cache. Vision-tower (and Whisper-encoder) outputs are cached keyed on the post-preprocess pixel/PCM content hash so repeated calls with identical inputs skip the encoder. `0` disables the cache. Env: `RMLX_MM_CACHE_BYTES`. | +| `--mm-cache-bytes` | usize | `536870912` (512 MiB) | Byte budget for the multimodal encoder-output cache. Vision-tower (and Whisper-encoder) outputs are cached keyed on the post-preprocess pixel/PCM content hash **plus the producing model's identity** so repeated calls with identical inputs skip the encoder. The model-identity component means a shared cache in multi-model `--registry` mode never serves one model's encoder output to another for the same image/audio (cached outputs are projected to a model's hidden size and must not cross models). `0` disables the cache. Env: `RMLX_MM_CACHE_BYTES`. | | `--kv-ssd-cache-gb` | f64 | 0.0 | SSD prompt-cache tier budget in GiB per namespace. `0` = tier off (RAM-only). Blocks land in `/cache/kv//`. | | `--project` | string | (model id) | SSD prompt-cache namespace name. Requires `--kv-ssd-cache-gb > 0`. | | `--kv-ssd-global-gb` | f64 | 0.0 | Global SSD pool ceiling across all namespaces in GiB. `0` = no global cap. Effective per-namespace ceiling is `min(--kv-ssd-cache-gb, --kv-ssd-global-gb)` when global > 0. | From d846206c27000ab170d8afdeac846e75d2a43643 Mon Sep 17 00:00:00 2001 From: Pushkinist <4850452+Pushkinist@users.noreply.github.com> Date: Thu, 18 Jun 2026 14:37:44 +0700 Subject: [PATCH 2/2] fix(mm-cache): update stale 12-byte header docstrings to 20 bytes --- crates/rmlx-models/src/multimodal_cache.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/rmlx-models/src/multimodal_cache.rs b/crates/rmlx-models/src/multimodal_cache.rs index e948316..54ae4b9 100644 --- a/crates/rmlx-models/src/multimodal_cache.rs +++ b/crates/rmlx-models/src/multimodal_cache.rs @@ -76,7 +76,7 @@ use rmlx_mlx::Array; use tracing::Level; use twox_hash::xxhash3_64::Hasher as XxHash3_64Hasher; -/// Mode discriminator embedded in the 12-byte header. +/// Mode discriminator embedded in the 20-byte header. #[allow( clippy::exhaustive_enums, reason = "fixed wire-format byte: adding a variant is a breaking change to the digest layout and requires bumping `version` in the header" @@ -89,7 +89,7 @@ pub enum MmMode { Audio = 1, } -/// Dtype discriminator embedded in the 12-byte header. Kept independent of +/// Dtype discriminator embedded in the 20-byte header. Kept independent of /// the MLX `Dtype` enum so the on-the-wire byte never breaks when MLX adds a /// new element type. #[allow( @@ -107,7 +107,7 @@ pub enum MmDtype { /// Content-hash key for the cache. 8 bytes (xxh3_64 digest). /// /// If the digest is ever widened (e.g. blake3 256-bit), bump the `version` -/// byte in the 12-byte header and grow this array — the on-the-wire layout +/// byte in the 20-byte header and grow this array — the on-the-wire layout /// is internal-only, no external ABI depends on the size. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct MmCacheKey([u8; 8]);