diff --git a/crates/rmlx-models/src/gemma4/decoder_layer.rs b/crates/rmlx-models/src/gemma4/decoder_layer.rs index e0b3b58..f8b6f0f 100644 --- a/crates/rmlx-models/src/gemma4/decoder_layer.rs +++ b/crates/rmlx-models/src/gemma4/decoder_layer.rs @@ -56,14 +56,21 @@ impl DecoderLayer { offset: i32, cache: Option<&mut KvCache>, kv_is_rotating: bool, + bidi_overlay: Option<&Array>, device: Device, ) -> Result<(Array, Option<(Array, Array)>)> { // Attention sub-layer. let residual = x.try_clone()?; let h = self.input_norm.forward(x, device)?; - let (h, new_kv) = - self.attn - .forward(&h, shared_kv, offset, cache, kv_is_rotating, device)?; + let (h, new_kv) = self.attn.forward( + &h, + shared_kv, + offset, + cache, + kv_is_rotating, + bidi_overlay, + device, + )?; let h = self.post_attn_norm.forward(&h, device)?; let h = add(&residual, &h, device)?; diff --git a/crates/rmlx-models/src/gemma4/layers/mask_tests.rs b/crates/rmlx-models/src/gemma4/layers/mask_tests.rs index 6b85d3c..649dcef 100644 --- a/crates/rmlx-models/src/gemma4/layers/mask_tests.rs +++ b/crates/rmlx-models/src/gemma4/layers/mask_tests.rs @@ -23,7 +23,7 @@ //! the mask key dim from `k_seq` to `k_seq + 1` and making //! `guard_invariant_producer_offset_matches_k_seq` RED. -use rmlx_mlx::Device; +use rmlx_mlx::{Array, Device, Dtype}; use super::{build_attn_mask, consumer_effective_offset, producer_effective_offset, LayerType}; @@ -53,6 +53,7 @@ fn full_attn_verify_block_mask_matches_producer_k_len() { producer_offset + seq, false, // attn_is_rotating window, + None, // bidi_overlay: tests cover causal/SWA mask sizing, not the image overlay Device::Cpu, ) .unwrap(); @@ -100,6 +101,7 @@ fn sliding_attn_verify_block_mask_matches_capped_k_len() { effective_offset + seq, true, // attn_is_rotating window, + None, // bidi_overlay: tests cover causal/SWA mask sizing, not the image overlay Device::Cpu, ) .unwrap(); @@ -162,6 +164,7 @@ fn guard_invariant_producer_offset_matches_k_seq() { effective_offset + seq, false, window, + None, // bidi_overlay: tests cover causal/SWA mask sizing, not the image overlay Device::Cpu, ) .unwrap(); @@ -234,6 +237,7 @@ fn guard_invariant_consumer_mask_matches_shared_k_len() { total_kv_len, false, // attn_is_rotating window, + None, // bidi_overlay: tests cover causal/SWA mask sizing, not the image overlay Device::Cpu, ) .unwrap(); @@ -285,6 +289,7 @@ fn guard_invariant_regressed_consumer_offset_inflates_mask() { base_offset + seq, false, window, + None, // bidi_overlay: tests cover causal/SWA mask sizing, not the image overlay Device::Cpu, ) .unwrap(); @@ -332,6 +337,7 @@ fn guard_invariant_regressed_base_offset_inflates_mask() { base_offset_desynced + seq, false, window, + None, // bidi_overlay: tests cover causal/SWA mask sizing, not the image overlay Device::Cpu, ) .unwrap(); @@ -343,3 +349,73 @@ fn guard_invariant_regressed_base_offset_inflates_mask() { "base_offset-desynced mask is one key too long — the #32 guard trigger" ); } + +/// Inside a sliding-window layer, a vision (image) bidi overlay must override the +/// window: an intra-image-block pair that the sliding window would block (the two +/// positions are farther apart than `window`) ends up allowed (`0.0`) in the +/// combined mask. Pins "image-bidi overrides the sliding window inside the block" +/// against a regression that would re-apply the window cap to image soft tokens. +#[test] +#[allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::indexing_slicing, + clippy::float_cmp, + reason = "test asserts on a known-good mask cell; unwrap/expect failures are the assertion; mask cells are exact 0.0 / -1e30" +)] +fn swa_prefill_image_bidi_overrides_window() { + // Single-shot prefill (offset == 0), seq == 4, the whole sequence one image + // block. Window = 2, so the SWA mask alone blocks query 3 / key 0 (distance + // 3 >= window). The bidi overlay opens every (i, j) in the block to 0.0, so + // the element-wise `maximum` combine must leave (3, 0) allowed. + let seq = 4i32; + let window = 2usize; + let n = seq as usize; + + // Overlay: 0.0 everywhere (one image block spanning the full sequence), + // matching the additive convention build_vision_bidi_overlay emits. + let overlay = Array::from_f32_slice(&vec![0.0_f32; n * n], &[1, 1, seq, seq]) + .unwrap() + .astype(Dtype::Bf16, Device::Cpu) + .unwrap(); + + let (mask, mode) = build_attn_mask( + LayerType::SlidingAttention, + seq, + 0, // effective_offset == 0 (single-shot prefill) + seq, + false, // attn_is_rotating + window, + Some(&overlay), + Device::Cpu, + ) + .unwrap(); + + assert_eq!(mode, "array", "SWA prefill uses array mode"); + let mask = mask.expect("array mode must carry a mask array"); + let shape = mask.shape(); + assert_eq!(shape, &[1, 1, seq, seq], "combined mask is [1,1,seq,seq]"); + + let grid: Vec = mask + .astype(Dtype::F32, Device::Cpu) + .unwrap() + .to_bytes() + .unwrap() + .chunks_exact(4) + .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect(); + + // Out-of-window intra-block cell (query 3, key 0): distance 3 >= window 2, so + // the SWA mask alone would block it (-1e30). The overlay opens it to 0.0. + assert_eq!( + grid[3 * n], + 0.0, + "image-bidi overlay must override the sliding window inside the block" + ); + // Future-direction intra-block cell (query 0, key 3) is also opened by the + // bidi overlay even though the causal/SWA mask blocks the future. + assert_eq!( + grid[3], 0.0, + "bidi overlay opens the in-future intra-block pair too" + ); +} diff --git a/crates/rmlx-models/src/gemma4/layers/mod.rs b/crates/rmlx-models/src/gemma4/layers/mod.rs index 6d5c79b..81d8ff6 100644 --- a/crates/rmlx-models/src/gemma4/layers/mod.rs +++ b/crates/rmlx-models/src/gemma4/layers/mod.rs @@ -147,6 +147,7 @@ impl Attention { clippy::indexing_slicing, reason = "bounds established by construction: buffer sized at init, loop indices bounded by slice length, or layer index validated before call" )] + #[allow(clippy::too_many_arguments)] pub(super) fn forward( &self, x: &Array, @@ -154,6 +155,7 @@ impl Attention { offset: i32, cache: Option<&mut KvCache>, kv_is_rotating: bool, + bidi_overlay: Option<&Array>, device: Device, ) -> Result<(Array, Option<(Array, Array)>)> { let shape = x.shape(); // [batch, seq, hidden] @@ -325,6 +327,7 @@ impl Attention { total_kv_len_pre, attn_is_rotating, self.sliding_window, + bidi_overlay, device, )?; // Cache-holding layer: route through the shared-KV variant of @@ -381,6 +384,7 @@ impl Attention { total_kv_len_pre, attn_is_rotating, self.sliding_window, + bidi_overlay, device, )?; let k_new = k.try_clone()?; @@ -464,6 +468,7 @@ impl Attention { total_kv_len, attn_is_rotating, self.sliding_window, + bidi_overlay, device, )?; // Guard (issue #32 part 2): the array-mode consumer mask's key dim @@ -552,6 +557,7 @@ fn build_attn_mask( total_kv_len: i32, attn_is_rotating: bool, sliding_window: usize, + bidi_overlay: Option<&Array>, device: Device, ) -> Result<(Option, &'static str)> { match layer_type { @@ -572,16 +578,26 @@ fn build_attn_mask( } else { // SWA prefill — banded-causal mask sized by the // capped effective offset. - let mask = Some(crate::layers::build_swa_prefill_mask( + let mask = crate::layers::build_swa_prefill_mask( effective_offset, seq, sliding_window, device, - )?); - Ok((mask, "array")) + )?; + let mask = combine_bidi_overlay(mask, bidi_overlay, device)?; + Ok((Some(mask), "array")) } } LayerType::FullAttention => { + // Bidirectional vision blocks require an explicit array mask even at + // offset==0 (where the default would be the cheap "causal" mode), so + // the intra-image-block openings can be merged in. + if bidi_overlay.is_some() && seq > 1 { + let causal = + crate::layers::build_chunked_prefill_mask(effective_offset, seq, device)?; + let mask = combine_bidi_overlay(causal, bidi_overlay, device)?; + return Ok((Some(mask), "array")); + } let mode = crate::layers::pick_attn_mask_mode(effective_offset, seq); let mask = if mode == "array" { Some(crate::layers::build_chunked_prefill_mask( @@ -597,6 +613,44 @@ fn build_attn_mask( } } +/// Merge a bidirectional-attention overlay into a prefill mask. +/// +/// Both masks use additive convention: `0.0` = attend allowed, large-negative = +/// blocked. A key/query pair is allowed if **either** the causal/SWA rule allows +/// it **or** both positions lie in the same vision (image) block — so the merge +/// is element-wise `maximum`. The overlay is `[1, 1, seq, seq]` and must match +/// the prefill mask's `[1, 1, seq, offset+seq]` only when `offset == 0` (the +/// single-shot image prefill path); the caller guarantees that. +fn combine_bidi_overlay( + causal: Array, + bidi_overlay: Option<&Array>, + device: Device, +) -> Result { + match bidi_overlay { + None => Ok(causal), + Some(overlay) => { + // Shapes agree only on the offset==0 single-shot image prefill, which + // is the only path that supplies an overlay. If the key dim differs + // (chunked prefill with a prior offset), skip the merge defensively + // and keep the causal mask. + if causal.shape() == overlay.shape() { + rmlx_mlx::maximum(&causal, overlay, device) + } else { + // Unreachable under the single-shot prefill invariant. If it + // ever fires the image block silently reverts to causal + // attention (the unified-vision colour-corruption failure mode), + // so make the violation loud instead of degrading silently. + tracing::warn!( + causal = ?causal.shape(), + overlay = ?overlay.shape(), + "gemma4 bidi overlay shape mismatch — image block fell back to causal" + ); + Ok(causal) + } + } + } +} + /// Expand K/V from [B, kv_heads, S, D] to [B, q_heads, S, D] by repeating. #[allow( clippy::indexing_slicing, diff --git a/crates/rmlx-models/src/gemma4/mod.rs b/crates/rmlx-models/src/gemma4/mod.rs index 1e3f191..895ace8 100644 --- a/crates/rmlx-models/src/gemma4/mod.rs +++ b/crates/rmlx-models/src/gemma4/mod.rs @@ -63,5 +63,5 @@ pub use vision::unified::{ }; pub use vision::{ build_inputs_embeds, load_multimodal_embedder, load_vision_tower, MultimodalEmbedder, - VisionModel, IMAGE_TOKEN_ID, + VisionModel, BOI_TOKEN_ID, EOI_TOKEN_ID, IMAGE_TOKEN_ID, }; diff --git a/crates/rmlx-models/src/gemma4/model.rs b/crates/rmlx-models/src/gemma4/model.rs index 00f4d62..7b3d189 100644 --- a/crates/rmlx-models/src/gemma4/model.rs +++ b/crates/rmlx-models/src/gemma4/model.rs @@ -54,6 +54,82 @@ fn cache_base_offset(caches: Option<&[KvCache]>) -> i32 { .map_or(0, |c| c.offset()) } +/// `` / `` marker token ids, as `i32` for the +/// id-scan loop below. The canonical source of truth is `super::vision` +/// (`u32`); cast here at the i32 boundary so there is one definition. +const BOI_TOKEN_ID: i32 = super::vision::BOI_TOKEN_ID as i32; +const EOI_TOKEN_ID: i32 = super::vision::EOI_TOKEN_ID as i32; + +/// Build the bidirectional-attention overlay for vision (image) soft-token +/// blocks during a single-shot prefill. +/// +/// Gemma 4 conditions image soft tokens with **bidirectional** attention within +/// each image block: every soft token of an image attends to every other soft +/// token of the same image, not just the causal prefix. (Text tokens stay +/// causal.) This is essential for the encoder-free unified embedder, whose raw +/// projected patches carry no pre-integrated spatial/colour context — read +/// causally they are misinterpreted (e.g. solid colours misnamed). The SigLIP +/// tower path already integrates the image in its ViT, so the overlay is a +/// faithful no-harm addition there. +/// +/// The overlay is `[1, 1, seq, seq]` additive (`0.0` = allowed, large-negative = +/// blocked): cell `(i, j)` is `0.0` when positions `i` and `j` lie strictly +/// between the same matching `` / `` pair, else +/// large-negative. The mask builder merges it (element-wise `maximum`) with each +/// layer's causal/SWA prefill mask, so an intra-block pair is allowed even when +/// it is "in the future" of the causal mask. +/// +/// Returns `None` when `seq == 1` (decode), when reading the ids fails, or when +/// the sequence contains no complete image block. +#[allow( + clippy::indexing_slicing, + reason = "chunks_exact(4) guarantees 4-byte slabs; data index i*n+j bounded by the n×n allocation" +)] +fn build_vision_bidi_overlay(ids_arr: &Array, seq: i32, device: Device) -> Option { + if seq <= 1 { + return None; + } + let raw = ids_arr.to_bytes().ok()?; + let ids: Vec = raw + .chunks_exact(4) + .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect(); + if ids.len() != seq as usize { + return None; + } + + // Collect (start, end) exclusive ranges of soft tokens inside each + // .. pair. + let mut blocks: Vec<(usize, usize)> = Vec::new(); + let mut open: Option = None; + for (i, &t) in ids.iter().enumerate() { + if t == BOI_TOKEN_ID { + open = Some(i + 1); + } else if t == EOI_TOKEN_ID { + if let Some(start) = open.take() { + if i > start { + blocks.push((start, i)); + } + } + } + } + if blocks.is_empty() { + return None; + } + + let n = seq as usize; + let mut data = vec![-1e30_f32; n * n]; + for (start, end) in blocks { + for i in start..end { + for j in start..end { + data[i * n + j] = 0.0; + } + } + } + let m = Array::from_f32_slice(&data, &[1, 1, seq, seq]).ok()?; + m.astype(rmlx_mlx::Dtype::Bf16, device).ok() +} + // --------------------------------------------------------------------------- // Full model // --------------------------------------------------------------------------- @@ -139,6 +215,7 @@ impl Gemma4Text { let (new_h, new_kv) = layer.forward( &h, shared_kv, per_layer, 0, // base_offset = 0 (fresh forward, no cache) None, false, // no cache → no rotating cache + None, // text/speculative path: causal only (no image bidi overlay) device, )?; h = new_h; @@ -265,7 +342,10 @@ impl Gemma4Text { scalar_f32((self.cfg.hidden_size as f32).sqrt()).astype(h_raw.dtype(), device)?; let h = multiply(&h_raw, &embed_scale, device)?; let h = h.reshape(&[1, seq, self.cfg.hidden_size as i32], device)?; - self.forward_h(h, ids_arr, seq, caches, device) + // Pure-text path: no image markers, so the vision bidi overlay is never + // needed. Passing `has_image = false` skips the device→host id sync that + // `build_vision_bidi_overlay` would otherwise force on every prefill. + self.forward_h(h, ids_arr, seq, caches, false, device) } /// forward pass from precomputed, already-scaled `inputs_embeds`. @@ -286,7 +366,10 @@ impl Gemma4Text { caches: Option<&mut [KvCache]>, device: Device, ) -> Result { - self.forward_h(embeds, ids_arr, seq, caches, device) + // Image path: the embeds carry scattered vision features, so the bidi + // overlay over image soft-token blocks is relevant. `has_image = true` + // gates the (single-shot prefill only) overlay build. + self.forward_h(embeds, ids_arr, seq, caches, true, device) } /// Shared decoder trunk + LM head over a precomputed scaled hidden state. @@ -304,11 +387,25 @@ impl Gemma4Text { ids_arr: &Array, seq: i32, caches: Option<&mut [KvCache]>, + has_image: bool, device: Device, ) -> Result { // Current sequence offset (0 when no cache). let base_offset = cache_base_offset(caches.as_deref()); + // Bidirectional attention overlay for image soft-token blocks. Only the + // single-shot prefill (base_offset == 0, full sequence in this call) can + // open the intra-image block; for chunked prefill / decode the overlay + // is None and attention stays causal. `has_image` is false on the + // pure-text path, so text prefill never forces a device→host id sync + // (`build_vision_bidi_overlay` reads the ids back to scan for markers). + let bidi_overlay = if has_image && base_offset == 0 { + build_vision_bidi_overlay(ids_arr, seq, device) + } else { + None + }; + let bidi_overlay = bidi_overlay.as_ref(); + // Per-position per-layer inputs. // ids_arr is [seq], h is [1, seq, hidden] — both span the full call. let per_layer_inputs = self.compute_per_layer_inputs(ids_arr, &h, device)?; @@ -350,6 +447,7 @@ impl Gemma4Text { base_offset, None, kv_is_rotating[layer_idx], + bidi_overlay, device, )?; h = new_h; @@ -382,6 +480,7 @@ impl Gemma4Text { base_offset, cache, kv_is_rotating[layer_idx], + bidi_overlay, device, )?; h = new_h; @@ -486,6 +585,7 @@ impl Gemma4Text { base_offset, None, kv_is_rotating[layer_idx], + None, device, )?; h = new_h; @@ -517,6 +617,7 @@ impl Gemma4Text { base_offset, cache, kv_is_rotating[layer_idx], + None, device, )?; h = new_h; @@ -626,6 +727,7 @@ impl Gemma4Text { base_offset, None, kv_is_rotating[layer_idx], + None, device, )?; h = new_h; @@ -657,6 +759,7 @@ impl Gemma4Text { base_offset, cache, kv_is_rotating[layer_idx], + None, device, )?; h = new_h; @@ -762,6 +865,7 @@ impl Gemma4Text { base_offset, cache, kv_is_rotating[layer_idx], + None, device, )?; h = new_h; @@ -937,3 +1041,19 @@ impl Gemma4Text { pub(super) fn apply_softcap(logits: &Array, cap: f32, device: Device) -> Result { softcap_fused(logits, cap, device) } + +/// Test-only re-export of [`build_vision_bidi_overlay`] so the sibling +/// `model_tests` module can assert the overlay allow/block pattern without a +/// full model forward pass. +#[cfg(test)] +pub(super) fn build_vision_bidi_overlay_for_test( + ids_arr: &Array, + seq: i32, + device: Device, +) -> Option { + build_vision_bidi_overlay(ids_arr, seq, device) +} + +#[cfg(test)] +#[path = "model_tests.rs"] +mod model_tests; diff --git a/crates/rmlx-models/src/gemma4/model_tests.rs b/crates/rmlx-models/src/gemma4/model_tests.rs new file mode 100644 index 0000000..61b961e --- /dev/null +++ b/crates/rmlx-models/src/gemma4/model_tests.rs @@ -0,0 +1,125 @@ +//! Model-free guards for the Gemma4 vision bidirectional-attention overlay. +//! +//! Gemma 4 conditions each image's soft tokens with bidirectional attention +//! (every soft token of an image attends to every other soft token of that +//! image). The encoder-free unified embedder produces raw projected patches +//! with no pre-integrated context, so reading them causally mis-conditions the +//! decoder (chromatic colours misnamed, spatial layout hallucinated). These +//! tests pin the overlay shape and the exact allow/block pattern so that +//! regression to a causal-only image block is caught without a model. + +use rmlx_mlx::{Array, Device, Dtype}; + +use super::build_vision_bidi_overlay_for_test as build_vision_bidi_overlay; + +const BOI: i32 = 255_999; +const EOI: i32 = 258_882; +const IMG: i32 = 258_880; + +/// Read the `[1,1,seq,seq]` overlay to a host `f32` grid for assertions. +#[allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::indexing_slicing, + reason = "test asserts the overlay materialises; chunks_exact(4) yields 4-byte slabs" +)] +fn overlay_grid(ids: &[i32]) -> Vec { + let seq = ids.len() as i32; + let arr = Array::from_i32_slice(ids, &[seq]).unwrap(); + let m = build_vision_bidi_overlay(&arr, seq, Device::Cpu).expect("overlay present"); + let raw = m + .astype(Dtype::F32, Device::Cpu) + .unwrap() + .to_bytes() + .unwrap(); + raw.chunks_exact(4) + .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect() +} + +/// Image soft tokens inside one ` .. ` block attend bidirectionally to +/// each other; everything else (text↔text, text↔image, the markers themselves) +/// is blocked by the overlay (the causal mask supplies those allowances). +#[test] +#[allow( + clippy::indexing_slicing, + clippy::float_cmp, + reason = "fixed-size test grid indexed by construction; overlay cells are exact 0.0 / -1e30" +)] +fn overlay_opens_image_block_bidirectionally() { + // [BOS, BOI, IMG, IMG, IMG, EOI, text] + let ids = [2, BOI, IMG, IMG, IMG, EOI, 100]; + let n = ids.len(); + let grid = overlay_grid(&ids); + + // Soft-token positions are indices 2,3,4 (strictly between BOI@1 and EOI@5). + let soft = [2usize, 3, 4]; + for &i in &soft { + for &j in &soft { + // Bidirectional: allowed (0.0) for every ordered pair, including j>i. + assert_eq!(grid[i * n + j], 0.0, "soft ({i},{j}) must be open"); + } + } + // A future-looking soft pair (j>i) is the discriminating case: under a + // causal-only mask this would be blocked. + assert_eq!( + grid[2 * n + 4], + 0.0, + "soft token must see a later soft token" + ); + + // Text/markers stay closed in the overlay (causal mask handles them). + for i in 0..n { + for j in 0..n { + let both_soft = soft.contains(&i) && soft.contains(&j); + if !both_soft { + assert!( + grid[i * n + j] < -1.0e9, + "non-soft pair ({i},{j}) must be blocked by the overlay" + ); + } + } + } +} + +/// Two image blocks each open only within themselves — soft tokens of image A +/// never attend to soft tokens of image B. +#[test] +#[allow( + clippy::indexing_slicing, + clippy::float_cmp, + reason = "fixed-size test grid indexed by construction; overlay cells are exact 0.0 / -1e30" +)] +fn overlay_does_not_cross_image_blocks() { + // [BOS, BOI, IMG, IMG, EOI, BOI, IMG, IMG, EOI] + let ids = [2, BOI, IMG, IMG, EOI, BOI, IMG, IMG, EOI]; + let n = ids.len(); + let grid = overlay_grid(&ids); + + let block_a = [2usize, 3]; + let block_b = [6usize, 7]; + for &i in &block_a { + for &j in &block_a { + assert_eq!(grid[i * n + j], 0.0); + } + for &j in &block_b { + assert!(grid[i * n + j] < -1.0e9, "block A must not see block B"); + assert!(grid[j * n + i] < -1.0e9, "block B must not see block A"); + } + } +} + +/// No image block (pure text) and decode (seq == 1) produce no overlay, so the +/// attention path stays fully causal. +#[test] +#[allow(clippy::unwrap_used, reason = "test fixture arrays always materialise")] +fn overlay_absent_without_image_block() { + let text_only = [2, 100, 200, 300]; + let seq = text_only.len() as i32; + let arr = Array::from_i32_slice(&text_only, &[seq]).unwrap(); + assert!(build_vision_bidi_overlay(&arr, seq, Device::Cpu).is_none()); + + let one = [2]; + let arr1 = Array::from_i32_slice(&one, &[1]).unwrap(); + assert!(build_vision_bidi_overlay(&arr1, 1, Device::Cpu).is_none()); +} diff --git a/crates/rmlx-models/src/gemma4/vision/mod.rs b/crates/rmlx-models/src/gemma4/vision/mod.rs index 296a98d..0b83c62 100644 --- a/crates/rmlx-models/src/gemma4/vision/mod.rs +++ b/crates/rmlx-models/src/gemma4/vision/mod.rs @@ -777,6 +777,14 @@ pub fn load_multimodal_embedder( /// must equal the vision tower's `num_soft_tokens` for the scatter to align. pub const IMAGE_TOKEN_ID: u32 = 258880; +/// Gemma4 `` (begin-of-image) marker token id. Wraps each +/// image's soft-token run on the left. +pub const BOI_TOKEN_ID: u32 = 255_999; + +/// Gemma4 `` marker token id. Wraps each image's soft-token run +/// on the right. +pub const EOI_TOKEN_ID: u32 = 258_882; + /// Build the merged `inputs_embeds` for a Gemma4 image prompt. /// /// Faithful host port of mlx-vlm `gemma4.py::Model.get_input_embeddings`: diff --git a/crates/rmlx-models/src/gemma4/vision/unified.rs b/crates/rmlx-models/src/gemma4/vision/unified.rs index 28fcc11..6175dc2 100644 --- a/crates/rmlx-models/src/gemma4/vision/unified.rs +++ b/crates/rmlx-models/src/gemma4/vision/unified.rs @@ -296,63 +296,7 @@ impl UnifiedVisionEmbedder { reason = "bounds established by construction: all indices derived from the host-computed patch grid" )] fn patchify_and_merge(&self, pv: &Gemma4PixelValues) -> Result<(Vec, Vec, Vec)> { - let p = self.cfg.patch_size; // 16 - let k = self.cfg.pooling_kernel_size; // 3 - let h = pv.height; - let w = pv.width; - if !h.is_multiple_of(p * k) || !w.is_multiple_of(p * k) { - return Err(Error::Model(format!( - "gemma4_unified vision: image {h}x{w} not divisible by model_patch_size {}", - p * k - ))); - } - let p_h = h / p; // teacher rows - let p_w = w / p; // teacher cols - let m_h = p_h / k; // model rows - let m_w = p_w / k; // model cols - let n_model = m_h * m_w; - let model_patch = p * k; // 48 - let patch_dim = model_patch * model_patch * 3; // 6912 - let n_pixels = h * w; - - // Build merged patches directly in the reference target layout. The - // upstream `patches_merge` reshapes the k×k kernel group to - // `(length, k, k, p, p, 3)` then permutes to `(length, k, p, k, p, 3)` - // and flattens — i.e. the 6912-vector interior is ordered - // **`[ky, ry, kx, rx, ch]`**. That makes the model patch a *contiguous* - // (k*p)×(k*p) image: full row = `ky*p + ry`, full col = `kx*p + rx`. - // (Ordering `[ky, kx, ry, rx, ch]` would tile 3×3 blocks instead and - // scramble fine detail — OCR fails, color survives.) - let mut merged = vec![0.0_f32; n_model * patch_dim]; - let mut x_idx = vec![0i32; n_model]; - let mut y_idx = vec![0i32; n_model]; - for my in 0..m_h { - for mx in 0..m_w { - let model_i = my * m_w + mx; - // model-patch position = (min teacher_x // k, min teacher_y // k) - // = (mx, my) since teacher cols/rows in a kernel are contiguous. - x_idx[model_i] = mx as i32; - y_idx[model_i] = my as i32; - let dst = model_i * patch_dim; - for ky in 0..k { - for ry in 0..p { - for kx in 0..k { - for rx in 0..p { - let y = (my * k + ky) * p + ry; // my*48 + ky*16 + ry - let x = (mx * k + kx) * p + rx; // mx*48 + kx*16 + rx - for ch in 0..3 { - let src = ch * n_pixels + y * w + x; - // interior index: [ky, ry, kx, rx, ch] over dims [k, p, k, p, 3] - let off = ((((ky * p + ry) * k + kx) * p + rx) * 3) + ch; - merged[dst + off] = pv.pixel_values[src]; - } - } - } - } - } - } - } - Ok((merged, x_idx, y_idx)) + patchify_and_merge_impl(&self.cfg, pv) } /// Gather `pos_embedding[x, 0, :] + pos_embedding[y, 1, :]` per model patch. @@ -381,6 +325,80 @@ impl UnifiedVisionEmbedder { } } +/// Host patchify (16px teacher patches) + `patches_merge` (k×k -> model patch) +/// core, factored out of [`UnifiedVisionEmbedder::patchify_and_merge`] so the +/// channel/value layout is covered by a model-free numerical test. +/// +/// Faithful to `convert_image_to_patches` (teacher-patch interior `[ry, rx, ch]`) +/// and `patches_merge` (model-patch interior `[ky, ry, kx, rx, ch]` over dims +/// `[k, p, k, p, 3]`; position = top-left teacher position // k). Returns the +/// flat `[n_model * patch_dim]` f32 plus per-model-patch `(x, y)` positions. +#[allow( + clippy::indexing_slicing, + reason = "bounds established by construction: all indices derived from the host-computed patch grid" +)] +fn patchify_and_merge_impl( + cfg: &UnifiedVisionConfig, + pv: &Gemma4PixelValues, +) -> Result<(Vec, Vec, Vec)> { + let p = cfg.patch_size; // 16 + let k = cfg.pooling_kernel_size; // 3 + let h = pv.height; + let w = pv.width; + if !h.is_multiple_of(p * k) || !w.is_multiple_of(p * k) { + return Err(Error::Model(format!( + "gemma4_unified vision: image {h}x{w} not divisible by model_patch_size {}", + p * k + ))); + } + let p_h = h / p; // teacher rows + let p_w = w / p; // teacher cols + let m_h = p_h / k; // model rows + let m_w = p_w / k; // model cols + let n_model = m_h * m_w; + let model_patch = p * k; // 48 + let patch_dim = model_patch * model_patch * 3; // 6912 + let n_pixels = h * w; + + // Build merged patches directly in the reference target layout. The upstream + // `patches_merge` reshapes the k×k kernel group to `(length, k, k, p, p, 3)` + // then permutes to `(length, k, p, k, p, 3)` and flattens — i.e. the + // 6912-vector interior is ordered **`[ky, ry, kx, rx, ch]`**, making the + // model patch a *contiguous* (k*p)×(k*p) image: full row = `ky*p + ry`, full + // col = `kx*p + rx`. (Ordering `[ky, kx, ry, rx, ch]` would tile 3×3 blocks + // instead and scramble fine detail — OCR fails, colour survives.) + let mut merged = vec![0.0_f32; n_model * patch_dim]; + let mut x_idx = vec![0i32; n_model]; + let mut y_idx = vec![0i32; n_model]; + for my in 0..m_h { + for mx in 0..m_w { + let model_i = my * m_w + mx; + // model-patch position = (min teacher_x // k, min teacher_y // k) + // = (mx, my) since teacher cols/rows in a kernel are contiguous. + x_idx[model_i] = mx as i32; + y_idx[model_i] = my as i32; + let dst = model_i * patch_dim; + for ky in 0..k { + for ry in 0..p { + for kx in 0..k { + for rx in 0..p { + let y = (my * k + ky) * p + ry; // my*48 + ky*16 + ry + let x = (mx * k + kx) * p + rx; // mx*48 + kx*16 + rx + for ch in 0..3 { + let src = ch * n_pixels + y * w + x; + // interior index: [ky, ry, kx, rx, ch] over dims [k, p, k, p, 3] + let off = ((((ky * p + ry) * k + kx) * p + rx) * 3) + ch; + merged[dst + off] = pv.pixel_values[src]; + } + } + } + } + } + } + } + Ok((merged, x_idx, y_idx)) +} + // --------------------------------------------------------------------------- // Loader // --------------------------------------------------------------------------- @@ -419,11 +437,14 @@ pub fn load_unified_vision_embedder( .any(|(_, h)| h.safetensors().is_ok_and(|st| st.tensor(name).is_ok())) }; + // patch_ln1 / patch_ln2 / pos_norm are PyTorch `nn.LayerNorm` constructed + // with the default eps (1e-5) in the reference embedder — NOT the model's + // `rms_norm_eps` (1e-6, which only governs the `embed_vision` RMSNorm). let layer_norm = |prefix: &str| -> Result { Ok(LayerNorm { weight: load_f32(&shards, &format!("{prefix}.weight"))?, bias: load_f32(&shards, &format!("{prefix}.bias"))?, - eps: cfg.rms_norm_eps, + eps: 1e-5, }) }; diff --git a/crates/rmlx-models/src/gemma4/vision/unified_tests.rs b/crates/rmlx-models/src/gemma4/vision/unified_tests.rs index 8feca42..0864653 100644 --- a/crates/rmlx-models/src/gemma4/vision/unified_tests.rs +++ b/crates/rmlx-models/src/gemma4/vision/unified_tests.rs @@ -96,3 +96,140 @@ fn is_unified_arch_false_for_missing_dir() { let p = Path::new("/nonexistent/gemma4-unified-test-dir"); assert!(!is_unified_arch(p)); } + +/// Build a solid-colour CHW `[1, 3, H, W]` pixel buffer (the shared +/// preprocessor output: rescaled `[0,1]`, channels-first, RGB). +#[allow( + clippy::indexing_slicing, + reason = "buffer sized 3*n at allocation; c*n+i bounded by c<3, i Gemma4PixelValues { + let n = h * w; + let mut pixel_values = vec![0.0_f32; 3 * n]; + for (c, &v) in rgb.iter().enumerate() { + for i in 0..n { + pixel_values[c * n + i] = v; + } + } + Gemma4PixelValues { + pixel_values, + height: h, + width: w, + num_soft_tokens: 0, + } +} + +/// Model-free numerical guard for the unified patchify front-end — the test that +/// would have caught a channel-order / value-scaling defect (the first suspected +/// cause of the unified-vision colour corruption). For a solid RGB input every +/// merged-patch slot must equal the source +/// channel value exactly, and the three channels must remain in RGB order +/// (interior index `% 3` selects R/G/B). A pure-green input must therefore yield +/// nonzero values only in the green-derived (`off % 3 == 1`) slots. +#[test] +#[allow( + clippy::indexing_slicing, + clippy::float_cmp, + clippy::expect_used, + reason = "fixed-size solid-colour buffer indexed by construction; rescaled pixel values are exact 0.0/1.0" +)] +fn patchify_preserves_channel_values_and_order() { + let cfg = cfg_12b(); + // One 48x48 model patch (k*p = 48): smallest valid image. + let (h, w) = (48usize, 48usize); + + for (name, rgb) in [ + ("red", [1.0, 0.0, 0.0]), + ("green", [0.0, 1.0, 0.0]), + ("blue", [0.0, 0.0, 1.0]), + ("white", [1.0, 1.0, 1.0]), + ("yellow", [1.0, 1.0, 0.0]), + ] { + let pv = solid_chw(h, w, rgb); + let (merged, x_idx, y_idx) = + patchify_and_merge_impl(&cfg, &pv).expect("patchify must succeed"); + assert_eq!(x_idx, vec![0]); + assert_eq!(y_idx, vec![0]); + let patch_dim = cfg.patch_dim(); + assert_eq!(merged.len(), patch_dim); + + // Every interior slot equals the source channel value (no scaling, + // inversion, or channel swap); channel = interior index % 3. + for (off, &val) in merged.iter().enumerate() { + let ch = off % 3; + assert!( + (val - rgb[ch]).abs() < 1e-6, + "{name}: slot {off} (ch {ch}) = {val}, expected {}", + rgb[ch] + ); + } + // Channel-order sanity: a channel that is 0 in the source must be 0 in + // every derived slot (e.g. pure green leaves R- and B-slots at 0). + for (ch, &src_val) in rgb.iter().enumerate() { + if src_val == 0.0 { + let any_nonzero = merged.iter().skip(ch).step_by(3).any(|&v| v != 0.0); + assert!( + !any_nonzero, + "{name}: channel {ch} is 0 in source but nonzero after patchify" + ); + } + } + } +} + +/// Channel routing within a single 48x48 patch must be spatially faithful: a +/// half-red / half-blue image (left columns red, right columns blue) lands red +/// in the left half of the contiguous model-patch image and blue in the right +/// half — proving the `[ky, ry, kx, rx, ch]` interior layout reconstructs the +/// original pixel grid (not a scrambled 3×3 tiling). +#[test] +#[allow( + clippy::indexing_slicing, + clippy::float_cmp, + clippy::expect_used, + reason = "fixed-size half-red/half-blue buffer indexed by construction; pixel values are exact 0.0/1.0" +)] +fn patchify_interior_layout_is_contiguous_image() { + let cfg = cfg_12b(); + let (h, w) = (48usize, 48usize); + let n = h * w; + let mut pixel_values = vec![0.0_f32; 3 * n]; + // CHW: R channel = 1 on left half, B channel = 1 on right half. + for y in 0..h { + for x in 0..w { + let i = y * w + x; + if x < w / 2 { + pixel_values[i] = 1.0; // R + } else { + pixel_values[2 * n + i] = 1.0; // B + } + } + } + let pv = Gemma4PixelValues { + pixel_values, + height: h, + width: w, + num_soft_tokens: 0, + }; + let (merged, _, _) = patchify_and_merge_impl(&cfg, &pv).expect("patchify must succeed"); + + // The merged 6912-vector is a contiguous 48x48x3 image: index + // (row*48 + col)*3 + ch. Left columns must be red, right columns blue. + let side = cfg.model_patch_size; // 48 + for row in 0..side { + for col in 0..side { + let base = (row * side + col) * 3; + let (r, g, b) = (merged[base], merged[base + 1], merged[base + 2]); + assert_eq!(g, 0.0, "no green expected at ({row},{col})"); + if col < side / 2 { + assert_eq!((r, b), (1.0, 0.0), "left half must be red at ({row},{col})"); + } else { + assert_eq!( + (r, b), + (0.0, 1.0), + "right half must be blue at ({row},{col})" + ); + } + } + } +} diff --git a/crates/rmlx-server/src/engine/image.rs b/crates/rmlx-server/src/engine/image.rs index 8892489..67de39c 100644 --- a/crates/rmlx-server/src/engine/image.rs +++ b/crates/rmlx-server/src/engine/image.rs @@ -9,10 +9,11 @@ use rmlx_core::Error; // ── image-prompt construction ───────────────────────────────────────── -/// Gemma4 `<|image>` begin-of-image marker token id. -pub(crate) const GEMMA4_BOI_TOKEN_ID: u32 = 255_999; -/// Gemma4 `` end-of-image marker token id. -pub(crate) const GEMMA4_EOI_TOKEN_ID: u32 = 258_882; +/// Gemma4 begin-/end-of-image marker token ids. Re-exported from the single +/// source of truth in `rmlx_models::gemma4` so the server and the model agree +/// on these correctness-critical ids. +pub(crate) use rmlx_models::gemma4::BOI_TOKEN_ID as GEMMA4_BOI_TOKEN_ID; +pub(crate) use rmlx_models::gemma4::EOI_TOKEN_ID as GEMMA4_EOI_TOKEN_ID; /// bundle of the multimodal components needed to turn image bytes /// into scattered `inputs_embeds`. Loaded once per model. One variant per diff --git a/docs/MODELS.md b/docs/MODELS.md index fe76853..ca57906 100644 --- a/docs/MODELS.md +++ b/docs/MODELS.md @@ -785,10 +785,29 @@ Per-image pipeline (faithful port of HF `gemma4_unified` `patch_ln1/ln2/pos_norm` are true **LayerNorm** (mean-subtraction, weight+bias), not RMSNorm — verified against the snapshot's `.weight`+`.bias` tensors and the -upstream class. Color, spatial layout (4-quadrant, left/right/top/bottom), and -object counting are exact on the real 12B; fine-grained OCR is weaker than the -e4b SigLIP tower — an architectural property of the encoder-free 35M projection -(it lacks the semantic richness of a full vision encoder), not a port defect. +upstream class. They use the PyTorch `nn.LayerNorm` default `eps = 1e-5` (not the +model's `rms_norm_eps = 1e-6`, which governs only the `embed_vision` RMSNorm). + +**Bidirectional vision attention (required).** Gemma 4 conditions each image's +soft tokens with **bidirectional** attention: every soft token of an image +attends to every other soft token of the *same* image, not just the causal +prefix (text stays causal). The text decoder builds a per-prefill overlay +(`build_vision_bidi_overlay`, keyed off the ``/`` +markers) that opens the intra-image block and merges it (element-wise `maximum`) +with each layer's causal/SWA mask. This is **load-bearing** for the encoder-free +path: the raw projected patches carry no pre-integrated context, so reading them +causally mis-conditions the decoder (chromatic colours misnamed, spatial layout +hallucinated). The SigLIP tower path (e4b/26b/31b) already integrates the image +in its ViT, so the overlay is a no-harm addition there (e4b vision unchanged). + +With bidirectional attention, chromatic colour, spatial layout (4-quadrant, +left/right/top/bottom, borders), and object counting are correct on the real 12B. +Two inherent limits of the encoder-free projection remain (faithful to HF, not a +port defect): fine-grained OCR is weaker than the e4b SigLIP tower; and +**achromatic** inputs (pure white / gray / black) are indistinguishable — for any +`(c, c, c)` pixel, `patch_ln1` normalises away the absolute level, so white, gray +and black map to one embedding (the model reads them as dark/black). Brightness +discrimination needs the SigLIP tower. ### Unified (encoder-free) audio — `Gemma4UnifiedForConditionalGeneration` (12B)