Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions crates/rmlx-models/src/prefill_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ pub fn module_key_for_class(arch_class: &str) -> &'static str {
"LagunaForCausalLM" => "laguna",
"Qwen3_5MoeForConditionalGeneration" | "Qwen3_5ForConditionalGeneration" => "qwen3_5_moe",
"BitNetForCausalLM" => "bitnet",
// Qwen3VLMoe (single-shot prefill, no `prefill_chunk_for` call) and any
// unsupported class: no dedicated prefill-chunk row, fall through to
"Qwen3VLMoeForConditionalGeneration" => "qwen3_vl_moe",
// Any unsupported class: no dedicated prefill-chunk row, fall through to
// FALLBACK (safe, conservative).
_ => "",
}
Expand All @@ -131,6 +131,11 @@ fn arch_default(arch: &str) -> Option<usize> {
// available via `RMLX_PREFILL_CHUNK_QWEN3_5_MOE=256` for users who
// want to test it.
"qwen3_5_moe" => Some(64),
// qwen3_vl_moe: plain GQA MoE (no GDN linear attention), so it tolerates
// the same large chunk as gemma4. Native image tiling produces thousands
// of soft tokens; a single-shot forward over the full prompt trips the
// Metal ~10s GPU watchdog, so the image prefill is chunked at 512.
"qwen3_vl_moe" => Some(512),
"gemma3" => Some(256),
// gemma4 default 512: p0b-ttft bench measured -30% cold TTFT at 8K
// and -12% at 32K vs chunk=64 with no Metal watchdog at max-ctx 64K.
Expand Down
10 changes: 6 additions & 4 deletions crates/rmlx-models/src/prefill_chunk_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ fn defaults_match_recommendations() {
}
assert_eq!(arch_default("qwen3"), Some(256));
assert_eq!(arch_default("qwen3_5_moe"), Some(64));
assert_eq!(arch_default("qwen3_vl_moe"), Some(512));
assert_eq!(arch_default("gemma3"), Some(256));
assert_eq!(arch_default("gemma4"), Some(512));
assert_eq!(arch_default("qwen2"), Some(256));
Expand Down Expand Up @@ -47,13 +48,14 @@ fn module_key_for_class_maps_supported_classes() {
"qwen3_5_moe"
);
assert_eq!(module_key_for_class("BitNetForCausalLM"), "bitnet");

// Unknown / single-shot-prefill classes → "" → FALLBACK chunk, never the
// oversized gemma4 default.
// Qwen3-VL-MoE chunks its image prefill (native tiling → thousands of soft
// tokens would trip the Metal watchdog in one forward).
assert_eq!(
module_key_for_class("Qwen3VLMoeForConditionalGeneration"),
""
"qwen3_vl_moe"
);

// Unknown classes → "" → FALLBACK chunk, never the oversized gemma4 default.
assert_eq!(module_key_for_class("JinaEmbeddingsV4Model"), "");
assert_eq!(module_key_for_class("TotallyUnknownArch"), "");
}
Expand Down
86 changes: 73 additions & 13 deletions crates/rmlx-models/src/qwen3_vl_moe/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use rmlx_mlx::{Array, Device, Dtype};

use crate::constraint::ConstraintEngine;
use crate::decode_loop::ProbeStep;
use crate::kv_cache::kv_max_seq_and_ceiling;
use crate::prompt_cache::{chained_block_hashes_seeded, Consumed, ReusePolicy, FNV_OFFSET};
use crate::sampler::{apply_mask_argmax, sample_token_array, Pcg32, PenaltyConfig, SamplerConfig};
use rmlx_kv_quant::{KvCache, KvQuant};
Expand Down Expand Up @@ -128,9 +129,9 @@ fn make_step(token_id: u32, tokenizer: &tokenizers::Tokenizer) -> ProbeStep {
/// requests. Pass 1 for single-slot; pass N for multi-slot prefix matching.
/// Recommended: 4. Only the Exact-hit path is active under
/// [`ReusePolicy::ExactOnly`] (identical-prompt repeat skips re-prefill
/// entirely, same contract as Qwen2 / Qwen3 dense). `max_ctx_override` is unused
/// here: the text decoder prefills the whole prompt in one shot and grows its KV
/// ring lazily, so there is no pre-sized ceiling to clamp.
/// entirely, same contract as Qwen2 / Qwen3 dense). `max_ctx_override` sizes the
/// KV ring ceiling so a long prompt (up to the effective `--max-ctx`) grows to
/// fit and an over-cap prompt is rejected cleanly; see [`generate_image`].
#[allow(clippy::too_many_arguments)]
pub fn generate_greedy(
model: &Qwen3VlMoeText,
Expand All @@ -139,7 +140,7 @@ pub fn generate_greedy(
n_tokens: usize,
device: Device,
kv_quant: KvQuant,
_max_ctx_override: Option<i32>,
max_ctx_override: Option<i32>,
prompt_cache_slots: usize,
eos_ids: &[u32],
step_fn: &mut dyn FnMut(&ProbeStep) -> Option<u32>,
Expand Down Expand Up @@ -229,13 +230,37 @@ pub fn generate_greedy(
return Ok(steps);
}

// Path B (Miss): one-shot prefill from scratch.
// Path B (Miss): chunked prefill from scratch. Size the KV ring from the
// effective `--max-ctx` (lazy start + growth ceiling) so a prompt longer
// than the lazy KV_MAX_SEQ_DEFAULT=4096 start grows to fit instead of
// overflowing the fixed decode buffer; an over-cap prompt is rejected
// cleanly with KvCeilingExceeded (→ context_overflow). The shared
// chunked_prefill helper brackets enter_prefill()/exit_prefill() and flushes
// each chunk's command buffer under the ~10s Metal GPU watchdog — a
// single-shot forward over a multi-thousand-token prompt trips it. Mirrors
// the other arch text paths (qwen3_5_moe / gemma4).
let (initial_max_seq, max_seq_ceiling) =
kv_max_seq_and_ceiling(max_ctx_override, model.cfg.max_position_embeddings as i32);
let mut kv: Vec<KvCache> = (0..n_layers)
.enumerate()
.map(|(i, _)| KvCache::with_quant(kv_quant).with_layer_idx(i))
.map(|i| {
KvCache::with_quant_max_seq(kv_quant, initial_max_seq)
.with_max_seq_ceiling(max_seq_ceiling)
.with_layer_idx(i)
})
.collect();

let logits = model.forward_seq_with_cache(prompt_ids, Some(&mut kv), device)?;
let prefill_chunk = crate::prefill_chunk::prefill_chunk_for("qwen3_vl_moe");
let Some(logits) = crate::decode_loop::chunked_prefill(
&mut kv,
prompt_ids,
prefill_chunk,
device,
"Qwen3VLMoeForConditionalGeneration",
|chunk, kv| model.forward_seq_with_cache(chunk, Some(kv), device),
)?
else {
return Ok(steps);
};
let first = pick_token(
&logits,
vocab,
Expand Down Expand Up @@ -410,6 +435,7 @@ pub fn generate_image(
n_tokens: usize,
device: Device,
kv_quant: KvQuant,
max_ctx_override: Option<i32>,
eos_ids: &[u32],
step_fn: &mut dyn FnMut(&ProbeStep) -> Option<u32>,
mut constraint: Option<&mut dyn ConstraintEngine>,
Expand Down Expand Up @@ -441,26 +467,56 @@ pub fn generate_image(
&visual_positions,
device,
)?;
// Materialize the scatter-merged embeds before the chunked prefill so the
// full-sequence scatter is its own command buffer, not folded into the first
// prefill chunk.
Array::eval(&inputs_embeds)?;

// 2. 3D M-RoPE positions for the augmented sequence.
let pos = get_rope_index(&ids_i64, image_grids, image_token_id, spatial_merge_size)?;

// 3. prefill (deepstack injected after layers 0..len(deepstack)).
//
// The augmented image prompt is long (thousands of image soft tokens for
// native Qwen3-VL tiling — e.g. a 2560×2560 image → ~6400 soft tokens),
// far above the lazy KV_MAX_SEQ_DEFAULT=4096 start. Size the KV ring from
// the effective `--max-ctx`: `initial_max_seq` is the lazy start and
// `max_seq_ceiling` caps lazy growth and rejects an over-cap prompt with a
// clean `KvCeilingExceeded` (→ context_overflow) instead of a cryptic
// `slice_update` broadcast. Bracketing the chunked forward with
// enter_prefill()/exit_prefill() routes the prefill through the lazy-grow
// raw buffer (mirrors the Gemma4 image path) so it grows to fit.
let (initial_max_seq, max_seq_ceiling) =
kv_max_seq_and_ceiling(max_ctx_override, model.cfg.max_position_embeddings as i32);
let n_layers = model.cfg.num_hidden_layers;
let mut kv: Vec<KvCache> = (0..n_layers)
.enumerate()
.map(|(i, _)| KvCache::with_quant(kv_quant).with_layer_idx(i))
.map(|i| {
KvCache::with_quant_max_seq(kv_quant, initial_max_seq)
.with_max_seq_ceiling(max_seq_ceiling)
.with_layer_idx(i)
})
.collect();
let logits = model.forward_embeds(
for c in &mut kv {
c.enter_prefill();
}
// Chunk the image prefill so a long augmented prompt (thousands of image
// soft tokens) does not run a single multi-thousand-token forward in one
// Metal command buffer (the ~10s GPU watchdog). The chunk size comes from
// the per-arch prefill-chunk table (512 for this plain-GQA MoE arch).
let prefill_chunk = crate::prefill_chunk::prefill_chunk_for("qwen3_vl_moe");
let logits = model.forward_embeds_chunked(
&inputs_embeds,
seq,
&pos,
0,
&vision.deepstack_embeds,
&visual_positions,
Some(&mut kv),
prefill_chunk,
&mut kv,
device,
)?;
for c in &mut kv {
c.exit_prefill(device)?;
}

let mut steps = Vec::with_capacity(n_tokens);
let first = pick_token(
Expand Down Expand Up @@ -524,3 +580,7 @@ pub fn generate_image(
}
Ok(steps)
}

#[cfg(test)]
#[path = "generate_tests.rs"]
mod generate_tests;
144 changes: 144 additions & 0 deletions crates/rmlx-models/src/qwen3_vl_moe/generate_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
//! KV-cache sizing for the Qwen3-VL-MoE generate paths.
//!
//! Native Qwen3-VL image tiling produces thousands of image soft tokens (a
//! 2560×2560 image → ~6400 soft tokens), so the augmented prompt routinely
//! exceeds the lazy `KV_MAX_SEQ_DEFAULT = 4096` ring start. The bug these tests
//! pin: the image (and text) generate path built its per-layer caches with the
//! bare 4096 default and never bracketed the prefill, so a >4096-token prompt
//! overflowed the fixed decode buffer with
//! `slice_update: [broadcast_shapes] Shapes (1,4,6776,128) and (1,4,4096,128)`.
//!
//! The fix sizes the ring from the effective `--max-ctx` via
//! [`kv_max_seq_and_ceiling`] and brackets the one-shot prefill with
//! `enter_prefill()` / `exit_prefill()` so the lazy-grow path is used. These
//! tests reproduce that exact cache-construction + prefill pattern at the
//! cache level (CPU, `KvQuant::None`, the model's real head_dim=128 /
//! n_kv_heads=4 shape) without loading the 30B model:
//!
//! * a 6776-token prefill under a 16384 ceiling grows to fit and completes, and
//! * an over-cap prefill is rejected with a clean `KvCeilingExceeded`
//! (→ HTTP `context_overflow`), not a `slice_update` broadcast panic.

use crate::kv_cache::kv_max_seq_and_ceiling;
use rmlx_kv_quant::{KvCache, KvQuant};
use rmlx_mlx::{Array, Device, Dtype};

// Real Qwen3-VL-30B-A3B text-decoder KV shape (from the serve log:
// n_kv_heads=4, head_dim=128). max_position_embeddings is large, so a 16384
// `--max-ctx` ceiling resolves to exactly 16384.
const KV_H: i32 = 4;
const HEAD_DIM: i32 = 128;
const QWEN3VL_MPE: i32 = 262_144;

#[allow(
clippy::expect_used,
reason = "test helper: .expect() surfaces an Array::from_bytes failure as the test message"
)]
fn f32_arr(data: &[f32], shape: &[i32]) -> Array {
// SAFETY: Apple-Silicon-only build (CLAUDE.md hard rule 1); f32 is 4-byte LE
// on this target. `data` is borrowed read-only and copied into MLX before
// the borrow ends.
#[allow(
unsafe_code,
reason = "zero-copy byte view of an f32 slice for Array::from_bytes; copied before the borrow ends"
)]
let bytes = unsafe { std::slice::from_raw_parts(data.as_ptr().cast::<u8>(), data.len() * 4) };
Array::from_bytes(bytes, shape, Dtype::F32).expect("Array::from_bytes")
}

/// Build the per-layer cache exactly as `generate_image` / `generate_greedy`
/// do, then run a single-shot prefill chunk of `seq` tokens.
fn prefill_one_shot(max_ctx_override: Option<i32>, seq: i32) -> rmlx_core::Result<KvCache> {
let device = Device::Cpu;
let (initial_max_seq, max_seq_ceiling) = kv_max_seq_and_ceiling(max_ctx_override, QWEN3VL_MPE);
let mut cache = KvCache::with_quant_max_seq(KvQuant::None, initial_max_seq)
.with_max_seq_ceiling(max_seq_ceiling)
.with_layer_idx(0);

cache.enter_prefill();
let shape = [1_i32, KV_H, seq, HEAD_DIM];
let n: usize = shape.iter().map(|&d| d as usize).product();
let k = f32_arr(&vec![0.1_f32; n], &shape);
let v = f32_arr(&vec![0.2_f32; n], &shape);
cache.update(&k, &v, device)?;
cache.exit_prefill(device)?;
Ok(cache)
}

/// A 6776-token image prompt under `--max-ctx 16384` completes: the ring grows
/// past the 4096 default to fit instead of overflowing the fixed decode buffer.
/// This is the headline issue case (large image → ~6400 soft tokens + text).
#[test]
#[allow(
clippy::expect_used,
reason = "test asserts the prefill succeeds; .expect() surfaces the failure as the test message"
)]
fn image_prompt_over_4096_fits_under_max_ctx() {
// 6776 = the issue's augmented length (6715 image soft tokens + text).
let cache = prefill_one_shot(Some(16_384), 6776)
.expect("6776-token prefill must fit under a 16384 ceiling");
assert_eq!(
cache.offset(),
6776,
"offset tracks the full prefilled prompt length",
);
}

/// Without a ceiling-sized cache (the pre-fix bare 4096 default) the same prompt
/// would overflow; with the ceiling resolved from `--max-ctx` it does not. Pin
/// that the resolved ceiling is the requested value, not the 4096 default.
#[test]
fn max_ctx_override_sizes_ceiling_not_default() {
let (initial, ceiling) = kv_max_seq_and_ceiling(Some(16_384), QWEN3VL_MPE);
assert_eq!(
ceiling, 16_384,
"ceiling honors --max-ctx, not the 4096 default"
);
assert_eq!(
initial, 4096,
"ring still starts lazily at the 4096 default and grows up to the ceiling",
);
}

/// A prompt that exceeds the effective `--max-ctx` is rejected with a clean
/// `KvCeilingExceeded` (mapped to HTTP `context_overflow`), NOT the cryptic
/// `slice_update` broadcast panic the pre-fix path produced.
#[test]
#[allow(
clippy::panic,
reason = "test fails loudly if the over-cap prompt is unexpectedly accepted (KvCache has no Debug impl, so expect_err is unavailable)"
)]
fn over_cap_image_prompt_yields_context_overflow_not_broadcast_panic() {
// ceiling = 4096 (small --max-ctx); a 6776-token prompt is over-cap.
// KvCache has no Debug impl, so match the Result directly rather than
// unwrapping the Ok side via expect_err.
let Err(err) = prefill_one_shot(Some(4096), 6776) else {
panic!("an over-cap prompt must be rejected before allocation");
};
assert!(
matches!(err, rmlx_core::error::Error::KvCeilingExceeded { .. }),
"expected KvCeilingExceeded (→ context_overflow), got: {err}",
);
let msg = err.to_string();
assert!(
msg.contains("exceeds max-ctx ceiling"),
"error must name the ceiling overflow, not a broadcast shape: {msg}",
);
assert!(
!msg.contains("broadcast"),
"must NOT surface a raw slice_update broadcast error: {msg}",
);
}

/// A small image (under 4096 soft tokens) still works — the common control case
/// (e.g. a 448×448 image → 196 soft tokens) must be unaffected by the fix.
#[test]
#[allow(
clippy::expect_used,
reason = "test asserts the prefill succeeds; .expect() surfaces the failure as the test message"
)]
fn small_image_prompt_under_default_still_works() {
let cache =
prefill_one_shot(Some(16_384), 213).expect("213-token small-image prefill must complete");
assert_eq!(cache.offset(), 213);
}
Loading
Loading