Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions crates/rmlx-models/src/gemma4/decoder_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,21 @@ impl DecoderLayer {
offset: i32,
cache: Option<&mut KvCache>,
kv_is_rotating: bool,
bidi_overlay: Option<&Array>,
device: Device,
) -> Result<(Array, Option<(Array, Array)>)> {
// Attention sub-layer.
let residual = x.try_clone()?;
let h = self.input_norm.forward(x, device)?;
let (h, new_kv) =
self.attn
.forward(&h, shared_kv, offset, cache, kv_is_rotating, device)?;
let (h, new_kv) = self.attn.forward(
&h,
shared_kv,
offset,
cache,
kv_is_rotating,
bidi_overlay,
device,
)?;
let h = self.post_attn_norm.forward(&h, device)?;
let h = add(&residual, &h, device)?;

Expand Down
78 changes: 77 additions & 1 deletion crates/rmlx-models/src/gemma4/layers/mask_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
//! the mask key dim from `k_seq` to `k_seq + 1` and making
//! `guard_invariant_producer_offset_matches_k_seq` RED.

use rmlx_mlx::Device;
use rmlx_mlx::{Array, Device, Dtype};

use super::{build_attn_mask, consumer_effective_offset, producer_effective_offset, LayerType};

Expand Down Expand Up @@ -53,6 +53,7 @@ fn full_attn_verify_block_mask_matches_producer_k_len() {
producer_offset + seq,
false, // attn_is_rotating
window,
None, // bidi_overlay: tests cover causal/SWA mask sizing, not the image overlay
Device::Cpu,
)
.unwrap();
Expand Down Expand Up @@ -100,6 +101,7 @@ fn sliding_attn_verify_block_mask_matches_capped_k_len() {
effective_offset + seq,
true, // attn_is_rotating
window,
None, // bidi_overlay: tests cover causal/SWA mask sizing, not the image overlay
Device::Cpu,
)
.unwrap();
Expand Down Expand Up @@ -162,6 +164,7 @@ fn guard_invariant_producer_offset_matches_k_seq() {
effective_offset + seq,
false,
window,
None, // bidi_overlay: tests cover causal/SWA mask sizing, not the image overlay
Device::Cpu,
)
.unwrap();
Expand Down Expand Up @@ -234,6 +237,7 @@ fn guard_invariant_consumer_mask_matches_shared_k_len() {
total_kv_len,
false, // attn_is_rotating
window,
None, // bidi_overlay: tests cover causal/SWA mask sizing, not the image overlay
Device::Cpu,
)
.unwrap();
Expand Down Expand Up @@ -285,6 +289,7 @@ fn guard_invariant_regressed_consumer_offset_inflates_mask() {
base_offset + seq,
false,
window,
None, // bidi_overlay: tests cover causal/SWA mask sizing, not the image overlay
Device::Cpu,
)
.unwrap();
Expand Down Expand Up @@ -332,6 +337,7 @@ fn guard_invariant_regressed_base_offset_inflates_mask() {
base_offset_desynced + seq,
false,
window,
None, // bidi_overlay: tests cover causal/SWA mask sizing, not the image overlay
Device::Cpu,
)
.unwrap();
Expand All @@ -343,3 +349,73 @@ fn guard_invariant_regressed_base_offset_inflates_mask() {
"base_offset-desynced mask is one key too long — the #32 guard trigger"
);
}

/// Inside a sliding-window layer, a vision (image) bidi overlay must override the
/// window: an intra-image-block pair that the sliding window would block (the two
/// positions are farther apart than `window`) ends up allowed (`0.0`) in the
/// combined mask. Pins "image-bidi overrides the sliding window inside the block"
/// against a regression that would re-apply the window cap to image soft tokens.
#[test]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing,
clippy::float_cmp,
reason = "test asserts on a known-good mask cell; unwrap/expect failures are the assertion; mask cells are exact 0.0 / -1e30"
)]
fn swa_prefill_image_bidi_overrides_window() {
// Single-shot prefill (offset == 0), seq == 4, the whole sequence one image
// block. Window = 2, so the SWA mask alone blocks query 3 / key 0 (distance
// 3 >= window). The bidi overlay opens every (i, j) in the block to 0.0, so
// the element-wise `maximum` combine must leave (3, 0) allowed.
let seq = 4i32;
let window = 2usize;
let n = seq as usize;

// Overlay: 0.0 everywhere (one image block spanning the full sequence),
// matching the additive convention build_vision_bidi_overlay emits.
let overlay = Array::from_f32_slice(&vec![0.0_f32; n * n], &[1, 1, seq, seq])
.unwrap()
.astype(Dtype::Bf16, Device::Cpu)
.unwrap();

let (mask, mode) = build_attn_mask(
LayerType::SlidingAttention,
seq,
0, // effective_offset == 0 (single-shot prefill)
seq,
false, // attn_is_rotating
window,
Some(&overlay),
Device::Cpu,
)
.unwrap();

assert_eq!(mode, "array", "SWA prefill uses array mode");
let mask = mask.expect("array mode must carry a mask array");
let shape = mask.shape();
assert_eq!(shape, &[1, 1, seq, seq], "combined mask is [1,1,seq,seq]");

let grid: Vec<f32> = mask
.astype(Dtype::F32, Device::Cpu)
.unwrap()
.to_bytes()
.unwrap()
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();

// Out-of-window intra-block cell (query 3, key 0): distance 3 >= window 2, so
// the SWA mask alone would block it (-1e30). The overlay opens it to 0.0.
assert_eq!(
grid[3 * n],
0.0,
"image-bidi overlay must override the sliding window inside the block"
);
// Future-direction intra-block cell (query 0, key 3) is also opened by the
// bidi overlay even though the causal/SWA mask blocks the future.
assert_eq!(
grid[3], 0.0,
"bidi overlay opens the in-future intra-block pair too"
);
}
60 changes: 57 additions & 3 deletions crates/rmlx-models/src/gemma4/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,15 @@ impl Attention {
clippy::indexing_slicing,
reason = "bounds established by construction: buffer sized at init, loop indices bounded by slice length, or layer index validated before call"
)]
#[allow(clippy::too_many_arguments)]
pub(super) fn forward(
&self,
x: &Array,
shared_kv: Option<(&Array, &Array)>,
offset: i32,
cache: Option<&mut KvCache>,
kv_is_rotating: bool,
bidi_overlay: Option<&Array>,
device: Device,
) -> Result<(Array, Option<(Array, Array)>)> {
let shape = x.shape(); // [batch, seq, hidden]
Expand Down Expand Up @@ -325,6 +327,7 @@ impl Attention {
total_kv_len_pre,
attn_is_rotating,
self.sliding_window,
bidi_overlay,
device,
)?;
// Cache-holding layer: route through the shared-KV variant of
Expand Down Expand Up @@ -381,6 +384,7 @@ impl Attention {
total_kv_len_pre,
attn_is_rotating,
self.sliding_window,
bidi_overlay,
device,
)?;
let k_new = k.try_clone()?;
Expand Down Expand Up @@ -464,6 +468,7 @@ impl Attention {
total_kv_len,
attn_is_rotating,
self.sliding_window,
bidi_overlay,
device,
)?;
// Guard (issue #32 part 2): the array-mode consumer mask's key dim
Expand Down Expand Up @@ -552,6 +557,7 @@ fn build_attn_mask(
total_kv_len: i32,
attn_is_rotating: bool,
sliding_window: usize,
bidi_overlay: Option<&Array>,
device: Device,
) -> Result<(Option<Array>, &'static str)> {
match layer_type {
Expand All @@ -572,16 +578,26 @@ fn build_attn_mask(
} else {
// SWA prefill — banded-causal mask sized by the
// capped effective offset.
let mask = Some(crate::layers::build_swa_prefill_mask(
let mask = crate::layers::build_swa_prefill_mask(
effective_offset,
seq,
sliding_window,
device,
)?);
Ok((mask, "array"))
)?;
let mask = combine_bidi_overlay(mask, bidi_overlay, device)?;
Ok((Some(mask), "array"))
}
}
LayerType::FullAttention => {
// Bidirectional vision blocks require an explicit array mask even at
// offset==0 (where the default would be the cheap "causal" mode), so
// the intra-image-block openings can be merged in.
if bidi_overlay.is_some() && seq > 1 {
let causal =
crate::layers::build_chunked_prefill_mask(effective_offset, seq, device)?;
let mask = combine_bidi_overlay(causal, bidi_overlay, device)?;
return Ok((Some(mask), "array"));
}
let mode = crate::layers::pick_attn_mask_mode(effective_offset, seq);
let mask = if mode == "array" {
Some(crate::layers::build_chunked_prefill_mask(
Expand All @@ -597,6 +613,44 @@ fn build_attn_mask(
}
}

/// Merge a bidirectional-attention overlay into a prefill mask.
///
/// Both masks use additive convention: `0.0` = attend allowed, large-negative =
/// blocked. A key/query pair is allowed if **either** the causal/SWA rule allows
/// it **or** both positions lie in the same vision (image) block — so the merge
/// is element-wise `maximum`. The overlay is `[1, 1, seq, seq]` and must match
/// the prefill mask's `[1, 1, seq, offset+seq]` only when `offset == 0` (the
/// single-shot image prefill path); the caller guarantees that.
fn combine_bidi_overlay(
causal: Array,
bidi_overlay: Option<&Array>,
device: Device,
) -> Result<Array> {
match bidi_overlay {
None => Ok(causal),
Some(overlay) => {
// Shapes agree only on the offset==0 single-shot image prefill, which
// is the only path that supplies an overlay. If the key dim differs
// (chunked prefill with a prior offset), skip the merge defensively
// and keep the causal mask.
if causal.shape() == overlay.shape() {
rmlx_mlx::maximum(&causal, overlay, device)
} else {
// Unreachable under the single-shot prefill invariant. If it
// ever fires the image block silently reverts to causal
// attention (the unified-vision colour-corruption failure mode),
// so make the violation loud instead of degrading silently.
tracing::warn!(
causal = ?causal.shape(),
overlay = ?overlay.shape(),
"gemma4 bidi overlay shape mismatch — image block fell back to causal"
);
Ok(causal)
}
}
}
}

/// Expand K/V from [B, kv_heads, S, D] to [B, q_heads, S, D] by repeating.
#[allow(
clippy::indexing_slicing,
Expand Down
2 changes: 1 addition & 1 deletion crates/rmlx-models/src/gemma4/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@ pub use vision::unified::{
};
pub use vision::{
build_inputs_embeds, load_multimodal_embedder, load_vision_tower, MultimodalEmbedder,
VisionModel, IMAGE_TOKEN_ID,
VisionModel, BOI_TOKEN_ID, EOI_TOKEN_ID, IMAGE_TOKEN_ID,
};
Loading
Loading