Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ KV-quantization matrix any MLX server ships, including rotation-based KV familie
(TurboQuant, IsoQuant, PlanarQuant, RotorQuant, ParoQuant) that no other MLX
server offers.

> Status: **0.1.0** — first feature-complete native MLX backend. Apple Silicon
> only (Metal). See [What works](#what-works).
> Status: **0.2.2** — feature-complete native MLX backend: OpenAI- and
> Anthropic-compatible text, tool/function calling, streaming, image + audio
> input, embeddings, and a multi-model registry. Apple Silicon only (Metal).
> See [What works](#what-works).

## Why

Expand All @@ -21,22 +23,30 @@ server offers.
## What works

- **Text generation** — OpenAI-compatible `/v1/chat/completions` and
`/v1/completions`, plus an Anthropic-compatible surface. Temperature, top-k/p,
penalties, thinking-budget, constrained / schema-guided decoding.
- **Image input** — vision-capable models (Gemma 4 SigLIP tower, Qwen3-VL-MoE
deepstack) accept images via `image_url` content parts (data-URI, http, file
path, or base64).
- **Audio input** — audio transcription / translation endpoints for
audio-capable models.
`/v1/completions`, plus an Anthropic-compatible surface. Streaming (SSE),
temperature, top-k/p, penalties, thinking-budget, constrained / schema-guided
decoding.
- **Image input** — vision-capable models accept images via `image_url` content
parts (data-URI, http, file path, or base64): Gemma 4 SigLIP tower (e4b /
26b), the encoder-free Gemma 4 12B `gemma4_unified` any-to-any architecture,
jina-v4, and Qwen3-VL-MoE deepstack.
- **Audio input** — audio-capable models accept audio (Gemma 4 unified Conformer
tower) plus Whisper speech-to-text via the model-agnostic `rmlx transcribe`
CLI (txt / vtt / srt / json, long-form chunking).
- **Embeddings** — `/v1/embeddings`, including multimodal (text + image) jina-v4.
- **Tool / function calling** — OpenAI `tool_calls` and Anthropic `tool_use`,
multi-turn, multiple emit formats (Qwen XML, Hermes-JSON, Gemma).
- **Multi-model registry** — serve many models from one process with
load-on-demand / unload-on-idle, a bounded resident-model cap, and a shared
multimodal encoder-output cache (scoped per model).
- **Quantization** — affine 2–8 bit, mxfp4 / mxfp8, nvfp4, ParoQuant weights;
KV-cache quant incl. fp8, TurboQuant, RotorQuant, PlanarQuant, IsoQuant,
paged-KV, mixed / asymmetric K/V, and an SSD KV tier.
- **Speculative decoding** — MTP, DFlash, and Eagle3 drafters.
- **Prompt caching** — automatic prefix caching with block hashing.
- **Conversion** — `rmlx convert` re-quantizes / repacks MLX → MLX.

Conversion (`rmlx convert`, MLX → MLX re-quantize / layout repack) is a roadmap
target and not yet shipped.

Continuously smoke-tested end-to-end. The first four families carry committed
golden-token decode gates (temp=0, exact token-id match); embeddings and the
Expand Down
2 changes: 2 additions & 0 deletions crates/rmlx-models/src/gemma3/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ pub fn build_inputs_embeds(
input_ids: &[u32],
device: Device,
mm_cache: Option<&crate::multimodal_cache::MultimodalCache>,
model_sig: u64,
) -> Result<(Array, Array)> {
let hidden = model.cfg.hidden_size as i32;
let seq = input_ids.len();
Expand Down Expand Up @@ -627,6 +628,7 @@ pub fn build_inputs_embeds(
u16::try_from(pv.width).unwrap_or(u16::MAX),
3,
crate::multimodal_cache::MmDtype::F32,
model_sig,
);
let feats = crate::multimodal_cache::get_or_compute(mm_cache, key, || {
projector.forward(&vision.forward(pv, device)?, device)
Expand Down
2 changes: 1 addition & 1 deletion crates/rmlx-models/src/gemma4/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ fn build_inputs_embeds_aligns_and_shapes() {
);

let (embeds, masked) =
build_inputs_embeds(&model, &vision, &embedder, &[pv], &ids, device, None)
build_inputs_embeds(&model, &vision, &embedder, &[pv], &ids, device, None, 0)
.expect("build embeds");
embeds.eval().expect("eval embeds");

Expand Down
2 changes: 2 additions & 0 deletions crates/rmlx-models/src/gemma4/vision/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ pub fn build_inputs_embeds(
input_ids: &[u32],
device: Device,
mm_cache: Option<&crate::multimodal_cache::MultimodalCache>,
model_sig: u64,
) -> Result<(Array, Array)> {
let hidden = model.cfg.hidden_size as i32;
let seq = input_ids.len();
Expand Down Expand Up @@ -867,6 +868,7 @@ pub fn build_inputs_embeds(
u16::try_from(pv.width).unwrap_or(u16::MAX),
3,
crate::multimodal_cache::MmDtype::F32,
model_sig,
);
let feats = crate::multimodal_cache::get_or_compute(mm_cache, key, || {
embedder.forward(&vision.forward(pv, device)?, device)
Expand Down
2 changes: 2 additions & 0 deletions crates/rmlx-models/src/gemma4/vision/unified.rs
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ pub fn build_unified_inputs_embeds(
input_ids: &[u32],
device: Device,
mm_cache: Option<&crate::multimodal_cache::MultimodalCache>,
model_sig: u64,
) -> Result<(Array, Array)> {
let hidden = model.cfg.hidden_size as i32;
let seq = input_ids.len();
Expand Down Expand Up @@ -619,6 +620,7 @@ pub fn build_unified_inputs_embeds(
u16::try_from(pv.width).unwrap_or(u16::MAX),
3,
crate::multimodal_cache::MmDtype::F32,
model_sig,
);
let feats = crate::multimodal_cache::get_or_compute(mm_cache, key, || {
embedder.forward(pv, device)
Expand Down
10 changes: 8 additions & 2 deletions crates/rmlx-models/src/jina_v4/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ impl JinaV4 {
pixel_values: &PixelValues,
device: rmlx_mlx::Device,
mm_cache: Option<&crate::multimodal_cache::MultimodalCache>,
model_sig: u64,
) -> Result<(rmlx_mlx::Array, Vec<i64>)> {
let image_token_id = i64::from(self.config.image_token_id);
let merge = self.config.vision_config.spatial_merge_size;
Expand All @@ -212,6 +213,7 @@ impl JinaV4 {
u16::try_from(g.w).unwrap_or(u16::MAX),
3,
crate::multimodal_cache::MmDtype::F32,
model_sig,
);
let vision_feats = crate::multimodal_cache::get_or_compute(mm_cache, key, || {
self.vision.forward(pixel_values, device)
Expand Down Expand Up @@ -251,8 +253,10 @@ impl JinaV4 {
device: rmlx_mlx::Device,
truncate_dim: Option<usize>,
mm_cache: Option<&crate::multimodal_cache::MultimodalCache>,
model_sig: u64,
) -> Result<Vec<f32>> {
let (hidden, input_ids) = self.image_hidden(prompt_ids, pixel_values, device, mm_cache)?;
let (hidden, input_ids) =
self.image_hidden(prompt_ids, pixel_values, device, mm_cache, model_sig)?;
let (start, end) = image::vision_span(
&input_ids,
i64::from(self.config.vision_start_token_id),
Expand All @@ -279,8 +283,10 @@ impl JinaV4 {
pixel_values: &PixelValues,
device: rmlx_mlx::Device,
mm_cache: Option<&crate::multimodal_cache::MultimodalCache>,
model_sig: u64,
) -> Result<Vec<Vec<f32>>> {
let (hidden, _input_ids) = self.image_hidden(prompt_ids, pixel_values, device, mm_cache)?;
let (hidden, _input_ids) =
self.image_hidden(prompt_ids, pixel_values, device, mm_cache, model_sig)?;
let projected = self.projector.forward(&hidden, device)?;
pooling::multi_vector(&projected, device)
}
Expand Down
78 changes: 59 additions & 19 deletions crates/rmlx-models/src/multimodal_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@
//! modality. Wrap in `Arc<MultimodalCache>` for cheap clone.
//! - **Hash recipe** (lifted from
//! `dynamo/components/.../vllm/multimodal_utils/hash_utils.py:32-95`):
//! fixed 12-byte header `[version:u8=1, mode:u8 (0=image | 1=audio),
//! fixed 20-byte header `[version:u8=2, mode:u8 (0=image | 1=audio),
//! dtype:u8 (0=bf16 | 1=f32), channels:u8, dim1:u16 LE, dim2:u16 LE,
//! reserved:u32 LE]` + the canonical preprocessed pixel/PCM bytes.
//! The fixed header prevents `(H,W)` collision when the same pixel byte
//! stream is reshaped to a different geometry. For audio, `dim1`/`dim2` are
//! `0` and the `reserved` field carries the sample rate; `channels` lands
//! in `header[3]` so mono vs stereo PCM byte runs cannot alias; `n_samples`
//! is implied by the trailing byte length.
//! reserved:u32 LE, model_sig:u64 LE]` + the canonical preprocessed
//! pixel/PCM bytes. The fixed header prevents `(H,W)` collision when the
//! same pixel byte stream is reshaped to a different geometry. For audio,
//! `dim1`/`dim2` are `0` and the `reserved` field carries the sample rate;
//! `channels` lands in `header[3]` so mono vs stereo PCM byte runs cannot
//! alias; `n_samples` is implied by the trailing byte length.
//! - **`model_sig`** is a stable per-loaded-model identity (a u64 hash of the
//! model's registry id / canonical snapshot path). It scopes every cached
//! encoder output to the model that produced it, so a multi-model
//! `--registry` server never serves model A's vision-tower output (projected
//! to A's hidden size) to model B for the same image. Without it, two models
//! sharing one cache and the same image collide on the content hash → a soft
//! token embedded at the wrong hidden dim → shape-mismatch failure. Same
//! model + same input ⇒ same `model_sig` ⇒ the cache still hits.
//! - **Hasher**: `twox-hash` xxh3_64 (MIT, tiny, no transitive cost). Stored
//! as 8 bytes (the raw digest). If the digest is ever widened, bump the
//! `version` byte in the header along with the key array size — this type
Expand Down Expand Up @@ -68,7 +76,7 @@ use rmlx_mlx::Array;
use tracing::Level;
use twox_hash::xxhash3_64::Hasher as XxHash3_64Hasher;

/// Mode discriminator embedded in the 12-byte header.
/// Mode discriminator embedded in the 20-byte header.
#[allow(
clippy::exhaustive_enums,
reason = "fixed wire-format byte: adding a variant is a breaking change to the digest layout and requires bumping `version` in the header"
Expand All @@ -81,7 +89,7 @@ pub enum MmMode {
Audio = 1,
}

/// Dtype discriminator embedded in the 12-byte header. Kept independent of
/// Dtype discriminator embedded in the 20-byte header. Kept independent of
/// the MLX `Dtype` enum so the on-the-wire byte never breaks when MLX adds a
/// new element type.
#[allow(
Expand All @@ -99,7 +107,7 @@ pub enum MmDtype {
/// Content-hash key for the cache. 8 bytes (xxh3_64 digest).
///
/// If the digest is ever widened (e.g. blake3 256-bit), bump the `version`
/// byte in the 12-byte header and grow this array — the on-the-wire layout
/// byte in the 20-byte header and grow this array — the on-the-wire layout
/// is internal-only, no external ABI depends on the size.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct MmCacheKey([u8; 8]);
Expand All @@ -119,26 +127,39 @@ impl MmCacheKey {

/// Build an image key. `pixel_bytes` is the **post-preprocess** flat
/// pixel buffer (e.g. the same `&[f32]` the vision tower reads). The
/// 12-byte header prevents `(H,W)` reshape collisions.
/// 20-byte header prevents `(H,W)` reshape collisions. `model_sig` is the
/// stable per-loaded-model identity (see [`model_sig`]) — it scopes the
/// key to the producing model so encoder outputs are never shared across
/// models in a multi-model registry.
#[must_use]
pub fn image_key(
pixel_bytes: &[u8],
height: u16,
width: u16,
channels: u8,
dtype: MmDtype,
model_sig: u64,
) -> Self {
let header = build_header(MmMode::Image, dtype, channels, height, width, 0);
let header = build_header(MmMode::Image, dtype, channels, height, width, 0, model_sig);
Self(digest(&header, pixel_bytes))
}

/// Build an audio key. `pcm_bytes` is the **post-preprocess** PCM byte
/// stream (typically f32 mono). `sample_rate` lands in the `reserved`
/// field so two identical PCM byte runs at different sample rates do not
/// alias. `dim1`/`dim2` are unused for audio (set to 0).
/// alias. `dim1`/`dim2` are unused for audio (set to 0). `model_sig` is the
/// stable per-loaded-model identity (see [`model_sig`]) — it scopes the
/// key to the producing model so encoder outputs are never shared across
/// models in a multi-model registry.
#[must_use]
pub fn audio_key(pcm_bytes: &[u8], sample_rate: u32, dtype: MmDtype, channels: u8) -> Self {
let header = build_header(MmMode::Audio, dtype, channels, 0, 0, sample_rate);
pub fn audio_key(
pcm_bytes: &[u8],
sample_rate: u32,
dtype: MmDtype,
channels: u8,
model_sig: u64,
) -> Self {
let header = build_header(MmMode::Audio, dtype, channels, 0, 0, sample_rate, model_sig);
Self(digest(&header, pcm_bytes))
}

Expand All @@ -156,19 +177,21 @@ fn build_header(
dim1: u16,
dim2: u16,
reserved: u32,
) -> [u8; 12] {
let mut header = [0u8; 12];
header[0] = 1; // version
model_sig: u64,
) -> [u8; 20] {
let mut header = [0u8; 20];
header[0] = 2; // version — bumped from 1 when `model_sig` was folded in
header[1] = mode as u8;
header[2] = dtype as u8;
header[3] = channels;
header[4..6].copy_from_slice(&dim1.to_le_bytes());
header[6..8].copy_from_slice(&dim2.to_le_bytes());
header[8..12].copy_from_slice(&reserved.to_le_bytes());
header[12..20].copy_from_slice(&model_sig.to_le_bytes());
header
}

fn digest(header: &[u8; 12], payload: &[u8]) -> [u8; 8] {
fn digest(header: &[u8; 20], payload: &[u8]) -> [u8; 8] {
// Streaming xxh3_64 over (header || payload). Avoids the ~9.6 MiB temp
// allocation a one-shot path would need for a 896×896×3 f32 image.
// `alloc` feature on `twox-hash` is enabled at the workspace level for
Expand Down Expand Up @@ -648,6 +671,23 @@ pub fn pcm_f32_bytes(pcm: &[f32]) -> &[u8] {
pixel_f32_bytes(pcm)
}

/// Derive the stable per-loaded-model signature folded into every
/// [`MmCacheKey`]. `id` should be a stable identity for the loaded model —
/// its registry id or canonical snapshot path. The same id always yields the
/// same `u64`, so repeat same-model requests still hit the cache; two distinct
/// models (different ids) yield different signatures, so their encoder outputs
/// never collide on a shared cache.
///
/// `hidden_size` alone is deliberately NOT used: two different models can share
/// a hidden size yet produce different features. The id/path is the
/// collision-safe identity.
#[must_use]
pub fn model_sig(id: &str) -> u64 {
let mut h = XxHash3_64Hasher::new();
h.write(id.as_bytes());
h.finish()
}

#[cfg(test)]
#[path = "multimodal_cache_tests.rs"]
mod multimodal_cache_tests;
Loading
Loading