diff --git a/crates/rmlx-models/src/prefill_chunk.rs b/crates/rmlx-models/src/prefill_chunk.rs index d9ef15e..eacfdf6 100644 --- a/crates/rmlx-models/src/prefill_chunk.rs +++ b/crates/rmlx-models/src/prefill_chunk.rs @@ -112,8 +112,8 @@ pub fn module_key_for_class(arch_class: &str) -> &'static str { "LagunaForCausalLM" => "laguna", "Qwen3_5MoeForConditionalGeneration" | "Qwen3_5ForConditionalGeneration" => "qwen3_5_moe", "BitNetForCausalLM" => "bitnet", - // Qwen3VLMoe (single-shot prefill, no `prefill_chunk_for` call) and any - // unsupported class: no dedicated prefill-chunk row, fall through to + "Qwen3VLMoeForConditionalGeneration" => "qwen3_vl_moe", + // Any unsupported class: no dedicated prefill-chunk row, fall through to // FALLBACK (safe, conservative). _ => "", } @@ -131,6 +131,11 @@ fn arch_default(arch: &str) -> Option { // available via `RMLX_PREFILL_CHUNK_QWEN3_5_MOE=256` for users who // want to test it. "qwen3_5_moe" => Some(64), + // qwen3_vl_moe: plain GQA MoE (no GDN linear attention), so it tolerates + // the same large chunk as gemma4. Native image tiling produces thousands + // of soft tokens; a single-shot forward over the full prompt trips the + // Metal ~10s GPU watchdog, so the image prefill is chunked at 512. + "qwen3_vl_moe" => Some(512), "gemma3" => Some(256), // gemma4 default 512: p0b-ttft bench measured -30% cold TTFT at 8K // and -12% at 32K vs chunk=64 with no Metal watchdog at max-ctx 64K. diff --git a/crates/rmlx-models/src/prefill_chunk_tests.rs b/crates/rmlx-models/src/prefill_chunk_tests.rs index 03bf6ed..d9e1779 100644 --- a/crates/rmlx-models/src/prefill_chunk_tests.rs +++ b/crates/rmlx-models/src/prefill_chunk_tests.rs @@ -11,6 +11,7 @@ fn defaults_match_recommendations() { } assert_eq!(arch_default("qwen3"), Some(256)); assert_eq!(arch_default("qwen3_5_moe"), Some(64)); + assert_eq!(arch_default("qwen3_vl_moe"), Some(512)); assert_eq!(arch_default("gemma3"), Some(256)); assert_eq!(arch_default("gemma4"), Some(512)); assert_eq!(arch_default("qwen2"), Some(256)); @@ -47,13 +48,14 @@ fn module_key_for_class_maps_supported_classes() { "qwen3_5_moe" ); assert_eq!(module_key_for_class("BitNetForCausalLM"), "bitnet"); - - // Unknown / single-shot-prefill classes → "" → FALLBACK chunk, never the - // oversized gemma4 default. + // Qwen3-VL-MoE chunks its image prefill (native tiling → thousands of soft + // tokens would trip the Metal watchdog in one forward). assert_eq!( module_key_for_class("Qwen3VLMoeForConditionalGeneration"), - "" + "qwen3_vl_moe" ); + + // Unknown classes → "" → FALLBACK chunk, never the oversized gemma4 default. assert_eq!(module_key_for_class("JinaEmbeddingsV4Model"), ""); assert_eq!(module_key_for_class("TotallyUnknownArch"), ""); } diff --git a/crates/rmlx-models/src/qwen3_vl_moe/generate.rs b/crates/rmlx-models/src/qwen3_vl_moe/generate.rs index e2da354..247113f 100644 --- a/crates/rmlx-models/src/qwen3_vl_moe/generate.rs +++ b/crates/rmlx-models/src/qwen3_vl_moe/generate.rs @@ -32,6 +32,7 @@ use rmlx_mlx::{Array, Device, Dtype}; use crate::constraint::ConstraintEngine; use crate::decode_loop::ProbeStep; +use crate::kv_cache::kv_max_seq_and_ceiling; use crate::prompt_cache::{chained_block_hashes_seeded, Consumed, ReusePolicy, FNV_OFFSET}; use crate::sampler::{apply_mask_argmax, sample_token_array, Pcg32, PenaltyConfig, SamplerConfig}; use rmlx_kv_quant::{KvCache, KvQuant}; @@ -128,9 +129,9 @@ fn make_step(token_id: u32, tokenizer: &tokenizers::Tokenizer) -> ProbeStep { /// requests. Pass 1 for single-slot; pass N for multi-slot prefix matching. /// Recommended: 4. Only the Exact-hit path is active under /// [`ReusePolicy::ExactOnly`] (identical-prompt repeat skips re-prefill -/// entirely, same contract as Qwen2 / Qwen3 dense). `max_ctx_override` is unused -/// here: the text decoder prefills the whole prompt in one shot and grows its KV -/// ring lazily, so there is no pre-sized ceiling to clamp. +/// entirely, same contract as Qwen2 / Qwen3 dense). `max_ctx_override` sizes the +/// KV ring ceiling so a long prompt (up to the effective `--max-ctx`) grows to +/// fit and an over-cap prompt is rejected cleanly; see [`generate_image`]. #[allow(clippy::too_many_arguments)] pub fn generate_greedy( model: &Qwen3VlMoeText, @@ -139,7 +140,7 @@ pub fn generate_greedy( n_tokens: usize, device: Device, kv_quant: KvQuant, - _max_ctx_override: Option, + max_ctx_override: Option, prompt_cache_slots: usize, eos_ids: &[u32], step_fn: &mut dyn FnMut(&ProbeStep) -> Option, @@ -229,13 +230,37 @@ pub fn generate_greedy( return Ok(steps); } - // Path B (Miss): one-shot prefill from scratch. + // Path B (Miss): chunked prefill from scratch. Size the KV ring from the + // effective `--max-ctx` (lazy start + growth ceiling) so a prompt longer + // than the lazy KV_MAX_SEQ_DEFAULT=4096 start grows to fit instead of + // overflowing the fixed decode buffer; an over-cap prompt is rejected + // cleanly with KvCeilingExceeded (→ context_overflow). The shared + // chunked_prefill helper brackets enter_prefill()/exit_prefill() and flushes + // each chunk's command buffer under the ~10s Metal GPU watchdog — a + // single-shot forward over a multi-thousand-token prompt trips it. Mirrors + // the other arch text paths (qwen3_5_moe / gemma4). + let (initial_max_seq, max_seq_ceiling) = + kv_max_seq_and_ceiling(max_ctx_override, model.cfg.max_position_embeddings as i32); let mut kv: Vec = (0..n_layers) - .enumerate() - .map(|(i, _)| KvCache::with_quant(kv_quant).with_layer_idx(i)) + .map(|i| { + KvCache::with_quant_max_seq(kv_quant, initial_max_seq) + .with_max_seq_ceiling(max_seq_ceiling) + .with_layer_idx(i) + }) .collect(); - let logits = model.forward_seq_with_cache(prompt_ids, Some(&mut kv), device)?; + let prefill_chunk = crate::prefill_chunk::prefill_chunk_for("qwen3_vl_moe"); + let Some(logits) = crate::decode_loop::chunked_prefill( + &mut kv, + prompt_ids, + prefill_chunk, + device, + "Qwen3VLMoeForConditionalGeneration", + |chunk, kv| model.forward_seq_with_cache(chunk, Some(kv), device), + )? + else { + return Ok(steps); + }; let first = pick_token( &logits, vocab, @@ -410,6 +435,7 @@ pub fn generate_image( n_tokens: usize, device: Device, kv_quant: KvQuant, + max_ctx_override: Option, eos_ids: &[u32], step_fn: &mut dyn FnMut(&ProbeStep) -> Option, mut constraint: Option<&mut dyn ConstraintEngine>, @@ -441,26 +467,56 @@ pub fn generate_image( &visual_positions, device, )?; + // Materialize the scatter-merged embeds before the chunked prefill so the + // full-sequence scatter is its own command buffer, not folded into the first + // prefill chunk. + Array::eval(&inputs_embeds)?; // 2. 3D M-RoPE positions for the augmented sequence. let pos = get_rope_index(&ids_i64, image_grids, image_token_id, spatial_merge_size)?; // 3. prefill (deepstack injected after layers 0..len(deepstack)). + // + // The augmented image prompt is long (thousands of image soft tokens for + // native Qwen3-VL tiling — e.g. a 2560×2560 image → ~6400 soft tokens), + // far above the lazy KV_MAX_SEQ_DEFAULT=4096 start. Size the KV ring from + // the effective `--max-ctx`: `initial_max_seq` is the lazy start and + // `max_seq_ceiling` caps lazy growth and rejects an over-cap prompt with a + // clean `KvCeilingExceeded` (→ context_overflow) instead of a cryptic + // `slice_update` broadcast. Bracketing the chunked forward with + // enter_prefill()/exit_prefill() routes the prefill through the lazy-grow + // raw buffer (mirrors the Gemma4 image path) so it grows to fit. + let (initial_max_seq, max_seq_ceiling) = + kv_max_seq_and_ceiling(max_ctx_override, model.cfg.max_position_embeddings as i32); let n_layers = model.cfg.num_hidden_layers; let mut kv: Vec = (0..n_layers) - .enumerate() - .map(|(i, _)| KvCache::with_quant(kv_quant).with_layer_idx(i)) + .map(|i| { + KvCache::with_quant_max_seq(kv_quant, initial_max_seq) + .with_max_seq_ceiling(max_seq_ceiling) + .with_layer_idx(i) + }) .collect(); - let logits = model.forward_embeds( + for c in &mut kv { + c.enter_prefill(); + } + // Chunk the image prefill so a long augmented prompt (thousands of image + // soft tokens) does not run a single multi-thousand-token forward in one + // Metal command buffer (the ~10s GPU watchdog). The chunk size comes from + // the per-arch prefill-chunk table (512 for this plain-GQA MoE arch). + let prefill_chunk = crate::prefill_chunk::prefill_chunk_for("qwen3_vl_moe"); + let logits = model.forward_embeds_chunked( &inputs_embeds, seq, &pos, - 0, &vision.deepstack_embeds, &visual_positions, - Some(&mut kv), + prefill_chunk, + &mut kv, device, )?; + for c in &mut kv { + c.exit_prefill(device)?; + } let mut steps = Vec::with_capacity(n_tokens); let first = pick_token( @@ -524,3 +580,7 @@ pub fn generate_image( } Ok(steps) } + +#[cfg(test)] +#[path = "generate_tests.rs"] +mod generate_tests; diff --git a/crates/rmlx-models/src/qwen3_vl_moe/generate_tests.rs b/crates/rmlx-models/src/qwen3_vl_moe/generate_tests.rs new file mode 100644 index 0000000..5344dc8 --- /dev/null +++ b/crates/rmlx-models/src/qwen3_vl_moe/generate_tests.rs @@ -0,0 +1,144 @@ +//! KV-cache sizing for the Qwen3-VL-MoE generate paths. +//! +//! Native Qwen3-VL image tiling produces thousands of image soft tokens (a +//! 2560×2560 image → ~6400 soft tokens), so the augmented prompt routinely +//! exceeds the lazy `KV_MAX_SEQ_DEFAULT = 4096` ring start. The bug these tests +//! pin: the image (and text) generate path built its per-layer caches with the +//! bare 4096 default and never bracketed the prefill, so a >4096-token prompt +//! overflowed the fixed decode buffer with +//! `slice_update: [broadcast_shapes] Shapes (1,4,6776,128) and (1,4,4096,128)`. +//! +//! The fix sizes the ring from the effective `--max-ctx` via +//! [`kv_max_seq_and_ceiling`] and brackets the one-shot prefill with +//! `enter_prefill()` / `exit_prefill()` so the lazy-grow path is used. These +//! tests reproduce that exact cache-construction + prefill pattern at the +//! cache level (CPU, `KvQuant::None`, the model's real head_dim=128 / +//! n_kv_heads=4 shape) without loading the 30B model: +//! +//! * a 6776-token prefill under a 16384 ceiling grows to fit and completes, and +//! * an over-cap prefill is rejected with a clean `KvCeilingExceeded` +//! (→ HTTP `context_overflow`), not a `slice_update` broadcast panic. + +use crate::kv_cache::kv_max_seq_and_ceiling; +use rmlx_kv_quant::{KvCache, KvQuant}; +use rmlx_mlx::{Array, Device, Dtype}; + +// Real Qwen3-VL-30B-A3B text-decoder KV shape (from the serve log: +// n_kv_heads=4, head_dim=128). max_position_embeddings is large, so a 16384 +// `--max-ctx` ceiling resolves to exactly 16384. +const KV_H: i32 = 4; +const HEAD_DIM: i32 = 128; +const QWEN3VL_MPE: i32 = 262_144; + +#[allow( + clippy::expect_used, + reason = "test helper: .expect() surfaces an Array::from_bytes failure as the test message" +)] +fn f32_arr(data: &[f32], shape: &[i32]) -> Array { + // SAFETY: Apple-Silicon-only build (CLAUDE.md hard rule 1); f32 is 4-byte LE + // on this target. `data` is borrowed read-only and copied into MLX before + // the borrow ends. + #[allow( + unsafe_code, + reason = "zero-copy byte view of an f32 slice for Array::from_bytes; copied before the borrow ends" + )] + let bytes = unsafe { std::slice::from_raw_parts(data.as_ptr().cast::(), data.len() * 4) }; + Array::from_bytes(bytes, shape, Dtype::F32).expect("Array::from_bytes") +} + +/// Build the per-layer cache exactly as `generate_image` / `generate_greedy` +/// do, then run a single-shot prefill chunk of `seq` tokens. +fn prefill_one_shot(max_ctx_override: Option, seq: i32) -> rmlx_core::Result { + let device = Device::Cpu; + let (initial_max_seq, max_seq_ceiling) = kv_max_seq_and_ceiling(max_ctx_override, QWEN3VL_MPE); + let mut cache = KvCache::with_quant_max_seq(KvQuant::None, initial_max_seq) + .with_max_seq_ceiling(max_seq_ceiling) + .with_layer_idx(0); + + cache.enter_prefill(); + let shape = [1_i32, KV_H, seq, HEAD_DIM]; + let n: usize = shape.iter().map(|&d| d as usize).product(); + let k = f32_arr(&vec![0.1_f32; n], &shape); + let v = f32_arr(&vec![0.2_f32; n], &shape); + cache.update(&k, &v, device)?; + cache.exit_prefill(device)?; + Ok(cache) +} + +/// A 6776-token image prompt under `--max-ctx 16384` completes: the ring grows +/// past the 4096 default to fit instead of overflowing the fixed decode buffer. +/// This is the headline issue case (large image → ~6400 soft tokens + text). +#[test] +#[allow( + clippy::expect_used, + reason = "test asserts the prefill succeeds; .expect() surfaces the failure as the test message" +)] +fn image_prompt_over_4096_fits_under_max_ctx() { + // 6776 = the issue's augmented length (6715 image soft tokens + text). + let cache = prefill_one_shot(Some(16_384), 6776) + .expect("6776-token prefill must fit under a 16384 ceiling"); + assert_eq!( + cache.offset(), + 6776, + "offset tracks the full prefilled prompt length", + ); +} + +/// Without a ceiling-sized cache (the pre-fix bare 4096 default) the same prompt +/// would overflow; with the ceiling resolved from `--max-ctx` it does not. Pin +/// that the resolved ceiling is the requested value, not the 4096 default. +#[test] +fn max_ctx_override_sizes_ceiling_not_default() { + let (initial, ceiling) = kv_max_seq_and_ceiling(Some(16_384), QWEN3VL_MPE); + assert_eq!( + ceiling, 16_384, + "ceiling honors --max-ctx, not the 4096 default" + ); + assert_eq!( + initial, 4096, + "ring still starts lazily at the 4096 default and grows up to the ceiling", + ); +} + +/// A prompt that exceeds the effective `--max-ctx` is rejected with a clean +/// `KvCeilingExceeded` (mapped to HTTP `context_overflow`), NOT the cryptic +/// `slice_update` broadcast panic the pre-fix path produced. +#[test] +#[allow( + clippy::panic, + reason = "test fails loudly if the over-cap prompt is unexpectedly accepted (KvCache has no Debug impl, so expect_err is unavailable)" +)] +fn over_cap_image_prompt_yields_context_overflow_not_broadcast_panic() { + // ceiling = 4096 (small --max-ctx); a 6776-token prompt is over-cap. + // KvCache has no Debug impl, so match the Result directly rather than + // unwrapping the Ok side via expect_err. + let Err(err) = prefill_one_shot(Some(4096), 6776) else { + panic!("an over-cap prompt must be rejected before allocation"); + }; + assert!( + matches!(err, rmlx_core::error::Error::KvCeilingExceeded { .. }), + "expected KvCeilingExceeded (→ context_overflow), got: {err}", + ); + let msg = err.to_string(); + assert!( + msg.contains("exceeds max-ctx ceiling"), + "error must name the ceiling overflow, not a broadcast shape: {msg}", + ); + assert!( + !msg.contains("broadcast"), + "must NOT surface a raw slice_update broadcast error: {msg}", + ); +} + +/// A small image (under 4096 soft tokens) still works — the common control case +/// (e.g. a 448×448 image → 196 soft tokens) must be unaffected by the fix. +#[test] +#[allow( + clippy::expect_used, + reason = "test asserts the prefill succeeds; .expect() surfaces the failure as the test message" +)] +fn small_image_prompt_under_default_still_works() { + let cache = + prefill_one_shot(Some(16_384), 213).expect("213-token small-image prefill must complete"); + assert_eq!(cache.offset(), 213); +} diff --git a/crates/rmlx-models/src/qwen3_vl_moe/model.rs b/crates/rmlx-models/src/qwen3_vl_moe/model.rs index 8db9aec..e91ec49 100644 --- a/crates/rmlx-models/src/qwen3_vl_moe/model.rs +++ b/crates/rmlx-models/src/qwen3_vl_moe/model.rs @@ -236,105 +236,157 @@ impl Qwen3VlMoeText { h.reshape(&[1, seq as i32, self.cfg.hidden_size as i32], device) } - /// Image-branch prefill: forward precomputed `inputs_embeds` `[1, seq, - /// hidden]` (text embeddings with the vision features already scattered at - /// the image-token positions) with explicit 3D M-RoPE `pos`, optionally + /// Chunked image-branch prefill: forward precomputed `inputs_embeds` `[1, + /// seq, hidden]` (text embeddings with the vision features already scattered + /// at the image-token positions) with explicit 3D M-RoPE `pos`, optionally /// injecting `deepstack_embeds[k]` additively at `visual_positions` after /// decoder layer `k` (mirrors `language.py::_deepstack_process`: the first /// `len(deepstack_embeds)` layers get an injection). /// - /// Returns logits for the last position `[1, 1, vocab]`. + /// The prompt is encoded in `prefill_chunk`-token slices so a long image + /// prompt (thousands of image soft tokens) does not run a single multi- + /// thousand-token forward in one Metal command buffer (the ~10s GPU + /// watchdog). Each chunk advances the per-layer [`KvCache`] offset; the + /// downstream chunked-prefill SDPA mask path engages automatically for + /// `base_offset > 0`. Deepstack visual injection is applied per-layer to the + /// subset of `visual_positions` that fall inside the current chunk. + /// + /// `pos` holds the 3D M-RoPE position id for every token; `visual_positions` + /// is the (contiguous) image-token run, aligned 1:1 with each + /// `deepstack_embeds[k]` row. Returns logits for the final position + /// `[1, 1, vocab]`. + /// + /// `prefill_chunk` must be ≥ 1; a value ≥ `seq` collapses to a single + /// forward. #[allow(clippy::too_many_arguments)] #[allow( clippy::indexing_slicing, reason = "bounds established by construction: buffer sized at init, loop indices bounded by slice length, or layer index validated before call" )] - pub(crate) fn forward_embeds( + pub(crate) fn forward_embeds_chunked( &self, inputs_embeds: &Array, seq: i32, pos: &RopeIndex3D, - base_offset: i32, deepstack_embeds: &[Array], visual_positions: &[usize], - kv_caches: Option<&mut [KvCache]>, + prefill_chunk: usize, + kv_caches: &mut [KvCache], device: Device, ) -> Result { - let (cos_v, sin_v) = build_interleaved_mrope_tables( - pos, - self.cfg.head_dim, - f64::from(self.cfg.rope_theta), - &self.cfg.mrope_section, - )?; - let hd = self.cfg.head_dim as i32; - let cos = f32_to_bf16_arr(&cos_v, &[1, 1, seq, hd], device)?; - let sin = f32_to_bf16_arr(&sin_v, &[1, 1, seq, hd], device)?; + let chunk = prefill_chunk.max(1) as i32; + let n_deep = deepstack_embeds.len(); + let hid = self.cfg.hidden_size as i32; + // Contiguous image-token run: [vis_lo, vis_hi). Empty when there are no + // visual tokens (pure-text augmented prompt — defensive). + let vis_lo = visual_positions.first().copied().unwrap_or(0); + let vis_hi = vis_lo + visual_positions.len(); - let mask_mode = pick_attn_mask_mode(base_offset, seq); - let shared_mask: Option = if mask_mode == "array" { - Some(build_chunked_prefill_mask(base_offset, seq, device)?) - } else { - None - }; + let mut last_logits: Option = None; + let mut start = 0_i32; + while start < seq { + let end = (start + chunk).min(seq); + let clen = end - start; - let mut h = inputs_embeds.try_clone()?; - let n_deep = deepstack_embeds.len(); - match kv_caches { - Some(kv) => { - for (i, layer) in self.layers.iter().enumerate() { - h = layer.forward( + // Slice this chunk's embeds and 3D positions. + let chunk_embeds = + inputs_embeds.slice(&[0, start, 0], &[1, end, hid], &[1, 1, 1], device)?; + let chunk_pos = RopeIndex3D { + t: pos.t[start as usize..end as usize].to_vec(), + h: pos.h[start as usize..end as usize].to_vec(), + w: pos.w[start as usize..end as usize].to_vec(), + }; + + // Visual positions inside this chunk, mapped to chunk-local indices, + // and the matching deepstack rows. The run is contiguous, so the + // intersection [lo, hi) is a single sub-run. + let lo = (start as usize).max(vis_lo); + let hi = (end as usize).min(vis_hi); + let chunk_visual_positions: Vec = if lo < hi { + (lo - start as usize..hi - start as usize).collect() + } else { + Vec::new() + }; + + let (cos_v, sin_v) = build_interleaved_mrope_tables( + &chunk_pos, + self.cfg.head_dim, + f64::from(self.cfg.rope_theta), + &self.cfg.mrope_section, + )?; + let hd = self.cfg.head_dim as i32; + let cos = f32_to_bf16_arr(&cos_v, &[1, 1, clen, hd], device)?; + let sin = f32_to_bf16_arr(&sin_v, &[1, 1, clen, hd], device)?; + + let base_offset = kv_caches[0].offset(); + tracing::debug!( + chunk_start = start, + chunk_end = end, + clen, + base_offset, + "qwen3_vl_moe image prefill chunk" + ); + let mask_mode = pick_attn_mask_mode(base_offset, clen); + let shared_mask: Option = if mask_mode == "array" { + Some(build_chunked_prefill_mask(base_offset, clen, device)?) + } else { + None + }; + + let mut h = chunk_embeds; + for (i, layer) in self.layers.iter().enumerate() { + h = layer.forward( + &h, + &cos, + &sin, + Some(&mut kv_caches[i]), + shared_mask.as_ref(), + mask_mode, + device, + )?; + if i < n_deep && lo < hi { + // Deepstack rows aligned 1:1 with visual_positions; slice the + // [lo - vis_lo, hi - vis_lo) rows for this chunk. + let row_lo = (lo - vis_lo) as i32; + let row_hi = (hi - vis_lo) as i32; + let ds = &deepstack_embeds[i]; + let ds_chunk = ds.slice(&[row_lo, 0], &[row_hi, hid], &[1, 1], device)?; + h = super::image::deepstack_inject( &h, - &cos, - &sin, - Some(&mut kv[i]), - shared_mask.as_ref(), - mask_mode, + &ds_chunk, + &chunk_visual_positions, device, )?; - if i < n_deep { - h = super::image::deepstack_inject( - &h, - &deepstack_embeds[i], - visual_positions, - device, - )?; - } } } - None => { - for (i, layer) in self.layers.iter().enumerate() { - h = layer.forward( - &h, - &cos, - &sin, - None, - shared_mask.as_ref(), - mask_mode, - device, - )?; - if i < n_deep { - h = super::image::deepstack_inject( - &h, - &deepstack_embeds[i], - visual_positions, - device, - )?; - } + + // Flush the per-chunk command buffer under the ~10s Metal watchdog. + // Eval the chunk hidden directly (forces this chunk's full forward, + // including the K/V writes) so a long image prompt does not fold all + // chunks into one buffer. Non-final chunks then skip the lm_head. + h.eval()?; + if end < seq { + for c in kv_caches.iter() { + c.eval_prefill_state()?; } + } else { + let h = self.final_norm.forward(&h, device)?; + let last = h.slice(&[0, clen - 1, 0], &[1, clen, hid], &[1, 1, 1], device)?; + let logits = match &self.lm_head { + Some(lin) => lin.forward(&last, device)?, + None => self.embed_as_linear(&last, device)?, + }; + last_logits = Some(logits); } - } - let h = self.final_norm.forward(&h, device)?; - let last = h.slice( - &[0, seq - 1, 0], - &[1, seq, self.cfg.hidden_size as i32], - &[1, 1, 1], - device, - )?; - match &self.lm_head { - Some(lin) => lin.forward(&last, device), - None => self.embed_as_linear(&last, device), + start = end; } + + last_logits.ok_or_else(|| { + rmlx_core::error::Error::Model( + "qwen3_vl_moe forward_embeds_chunked: empty prompt produced no logits".into(), + ) + }) } fn embed_as_linear(&self, x: &Array, device: Device) -> Result { diff --git a/crates/rmlx-models/src/qwen3_vl_moe/vision.rs b/crates/rmlx-models/src/qwen3_vl_moe/vision.rs index 9526eb8..0e1c32d 100644 --- a/crates/rmlx-models/src/qwen3_vl_moe/vision.rs +++ b/crates/rmlx-models/src/qwen3_vl_moe/vision.rs @@ -422,6 +422,14 @@ impl Qwen3VlMoeVision { let mut deepstack_embeds = Vec::with_capacity(self.cfg.deepstack_visual_indexes.len()); for (layer_num, blk) in self.blocks.iter().enumerate() { h = blk.forward(&h, &cos, &sin, mask.as_ref(), device)?; + // Flush each block's command buffer. The ViT runs full attention over + // every patch (native tiling → tens of thousands of patches for a + // large image, an O(num_patches^2) score matrix per block); letting + // all 27 blocks accumulate into a single lazy graph overruns the ~10s + // Metal GPU watchdog. Evaluating per block keeps each command buffer + // small. No effect on small images (the eval is cheap once the block + // is already computed). + h.eval()?; if let Some(ds_idx) = self .cfg .deepstack_visual_indexes @@ -429,11 +437,21 @@ impl Qwen3VlMoeVision { .position(|&x| x == layer_num) { let de = self.deepstack_mergers[ds_idx].forward(&h, device)?; + // Materialize each deepstack merger output now so it is not + // re-derived (pulling the full ViT graph) when injected into the + // first text-prefill chunk. + de.eval()?; deepstack_embeds.push(de); } } let image_embeds = self.merger.forward(&h, device)?; + // Materialize the merged image embeds before returning so the downstream + // scatter into the text sequence is its own command buffer, decoupled + // from the prefill — and so the whole ViT (a tens-of-thousands-of-patches + // full-attention graph) never folds into a single prefill buffer that + // overruns the Metal watchdog. + image_embeds.eval()?; Ok(VisionOutput { image_embeds, deepstack_embeds, diff --git a/crates/rmlx-server/src/engine/arch_generator.rs b/crates/rmlx-server/src/engine/arch_generator.rs index 40a2d65..3bcebb2 100644 --- a/crates/rmlx-server/src/engine/arch_generator.rs +++ b/crates/rmlx-server/src/engine/arch_generator.rs @@ -1072,6 +1072,11 @@ impl Generator for ArchGenerator { } else { None }, + // Effective `--max-ctx` (launch flag or per-request override) + // so the image-path KV ring is sized to fit a large + // multi-thousand-soft-token prompt and an over-cap prompt is + // rejected cleanly rather than overflowing the 4096 default. + max_ctx_override, &eos_ids, &tokenizer, &mut step_fn, diff --git a/crates/rmlx-server/src/engine/image.rs b/crates/rmlx-server/src/engine/image.rs index 742b627..14d7ff0 100644 --- a/crates/rmlx-server/src/engine/image.rs +++ b/crates/rmlx-server/src/engine/image.rs @@ -339,6 +339,7 @@ pub(crate) fn run_qwen3vl_image( n_tokens: usize, device: rmlx_mlx::Device, kv_quant_override: Option, + max_ctx_override: Option, eos_ids: &[u32], tokenizer: &tokenizers::Tokenizer, step_fn: &mut dyn FnMut(&rmlx_models::ProbeStep) -> Option, @@ -531,6 +532,7 @@ pub(crate) fn run_qwen3vl_image( n_tokens, device, kv_quant, + max_ctx_override, eos_ids, step_fn, constraint, diff --git a/docs/MODELS.md b/docs/MODELS.md index c601eb2..e191583 100644 --- a/docs/MODELS.md +++ b/docs/MODELS.md @@ -456,6 +456,26 @@ size, video token id) but not yet exercised in rMLX. `max_position_embeddings` from `text_config`. +Native Qwen3-VL image tiling produces thousands of image soft tokens (a +2560×2560 image → ~6400 soft tokens at `patch_size=16`, `merge_size=2`), so the +augmented image prompt routinely exceeds the lazy `KV_MAX_SEQ_DEFAULT = 4096` +ring start. Both the image and text generate paths size the KV ring from the +effective `--max-ctx` (via `kv_max_seq_and_ceiling`, same as the other arches): +the ring grows lazily up to that ceiling so a prompt up to `--max-ctx` fits, and +a prompt over the ceiling is rejected with a clean `context_overflow` rather +than a `slice_update` broadcast error. Serve a long prompt with `--max-ctx N` ≥ +(soft tokens + text length). + +Both prefill paths are **chunked** (per-arch `prefill_chunk` = 512, plain GQA +MoE — no GDN, so it tolerates the same chunk as Gemma4) and the vision tower +evaluates per ViT block, so a long prompt does not run a single multi-thousand- +token forward in one Metal command buffer. Note: the ViT runs **full attention +over every image patch** for a single image (reference-faithful — mlx-vlm splits +attention per image, which is one full block for one image). A very large image +(tens of thousands of patches → an O(num_patches²) attention) can still overrun +the ~10s Metal GPU watchdog on memory-constrained Apple Silicon; this is a +vision-tower scaling limit, independent of the KV-cache sizing above. + ### Special features - **3D M-RoPE** for spatial + temporal + text position encoding.