diff --git a/src/kernels/arithmetic/simd.rs b/src/kernels/arithmetic/simd.rs index 1deb89e..ca49e86 100644 --- a/src/kernels/arithmetic/simd.rs +++ b/src/kernels/arithmetic/simd.rs @@ -32,7 +32,7 @@ use num_traits::{One, PrimInt, ToPrimitive, WrappingAdd, WrappingMul, WrappingSu use crate::enums::operators::ArithmeticOperator; use crate::kernels::bitmask::simd::all_true_mask_simd; -use crate::utils::simd_mask; +use crate::utils::{simd_mask, write_simd_mask_bits}; /// SIMD integer arithmetic kernel for dense arrays (no nulls). /// Vectorised operations with scalar fallback for power operations and array tails. @@ -187,15 +187,7 @@ pub fn int_masked_body_simd( }; r.copy_to_slice(&mut out[i..i + LANES]); // Write the out_mask based on the op - let valid_bits = valid.to_bitmask(); - for l in 0..LANES { - let idx = i + l; - if idx < n { - unsafe { - out_mask.set_unchecked(idx, ((valid_bits >> l) & 1) == 1); - } - } - } + write_simd_mask_bits(out_mask, i, valid); i += LANES; } // Scalar tail @@ -320,13 +312,7 @@ pub fn int_masked_body_simd( } _ => m_src, }; - let mbits = final_mask.to_bitmask(); - for l in 0..LANES { - let idx = i + l; - if idx < n { - unsafe { out_mask.set_unchecked(idx, ((mbits >> l) & 1) == 1) }; - } - } + write_simd_mask_bits(out_mask, i, final_mask); i += LANES; } @@ -387,17 +373,19 @@ pub fn float_masked_body_f32_simd( type M = ::Mask; let n = lhs.len(); - let mut i = 0; let dense = all_true_mask_simd::(mask); + if dense { + float_dense_body_f32_simd::(op, lhs, rhs, out); + out_mask.fill(true); + return; + } + + let mut i = 0; while i + LANES <= n { let a = Simd::::from_slice(&lhs[i..i + LANES]); let b = Simd::::from_slice(&rhs[i..i + LANES]); - let m: Mask = if dense { - Mask::splat(true) - } else { - simd_mask::(mask, i, n) - }; + let m: Mask = simd_mask::(mask, i, n); let res = match op { ArithmeticOperator::Add => a + b, @@ -412,19 +400,13 @@ pub fn float_masked_body_f32_simd( let selected = m.select(res, Simd::::splat(0.0)); selected.copy_to_slice(&mut out[i..i + LANES]); - let mbits = m.to_bitmask(); - for l in 0..LANES { - let idx = i + l; - if idx < n { - unsafe { out_mask.set_unchecked(idx, ((mbits >> l) & 1) == 1) }; - } - } + write_simd_mask_bits(out_mask, i, m); i += LANES; } - // Tail often caused by `n % LANES =! 0`; uses scalar fallback + // Tail often caused by `n % LANES != 0`; uses scalar fallback for j in i..n { - let valid = dense || unsafe { mask.get_unchecked(j) }; + let valid = unsafe { mask.get_unchecked(j) }; if valid { out[j] = match op { ArithmeticOperator::Add => lhs[j] + rhs[j], @@ -486,13 +468,7 @@ pub fn float_masked_body_f64_simd( let selected = m.select(res, Simd::::splat(0.0)); selected.copy_to_slice(&mut out[i..i + LANES]); - let mbits = m.to_bitmask(); - for l in 0..LANES { - let idx = i + l; - if idx < n { - unsafe { out_mask.set_unchecked(idx, ((mbits >> l) & 1) == 1) }; - } - } + write_simd_mask_bits(out_mask, i, m); i += LANES; } @@ -635,13 +611,7 @@ pub fn fma_masked_body_f32_simd( let selected = m.select(res, Simd::::splat(0.0)); selected.copy_to_slice(&mut out[i..i + LANES]); - let mbits = m.to_bitmask(); - for l in 0..LANES { - let idx = i + l; - if idx < n { - unsafe { out_mask.set_unchecked(idx, ((mbits >> l) & 1) == 1) }; - } - } + write_simd_mask_bits(out_mask, i, m); i += LANES; } @@ -694,13 +664,7 @@ pub fn fma_masked_body_f64_simd( let selected = m.select(res, Simd::::splat(0.0)); selected.copy_to_slice(&mut out[i..i + LANES]); - let mbits = m.to_bitmask(); - for l in 0..LANES { - let idx = i + l; - if idx < n { - unsafe { out_mask.set_unchecked(idx, ((mbits >> l) & 1) == 1) }; - } - } + write_simd_mask_bits(out_mask, i, m); i += LANES; } diff --git a/src/kernels/bitmask/simd.rs b/src/kernels/bitmask/simd.rs index 899a0e6..8f0b067 100644 --- a/src/kernels/bitmask/simd.rs +++ b/src/kernels/bitmask/simd.rs @@ -48,6 +48,7 @@ use core::simd::Simd; use crate::{Bitmask, BitmaskVT}; +use crate::kernels::arithmetic::simd::{W8, W16, W32, W64}; use crate::enums::operators::{LogicalOperator, UnaryOperator}; use crate::kernels::bitmask::{ @@ -313,38 +314,46 @@ where let (rhs_mask, rhs_off, rlen) = rhs; debug_assert_eq!(len, rlen, "in_mask: window length mismatch"); - // Scan rhs to see which values are present (true, false) - let mut has_true = false; - let mut has_false = false; - for i in 0..len { - let v = unsafe { rhs_mask.get_unchecked(rhs_off + i) }; - if v { - has_true = true; - } else { - has_false = true; - } - if has_true && has_false { - break; + if len == 0 { + return Bitmask::new_set_all(0, false); + } + + // Check which boolean values are present in rhs using word-level ops. + // Trailing bits in the last word must be masked off to avoid false positives. + let n_words = (len + 63) / 64; + let trailing = len & 63; + let mut any_set = 0u64; + let mut any_unset = 0u64; + unsafe { + let rp = rhs_mask.bits.as_ptr().cast::().add(rhs_off / 64); + for k in 0..n_words { + let mut w = *rp.add(k); + if k == n_words - 1 && trailing != 0 { + let valid_mask = (1u64 << trailing) - 1; + w &= valid_mask; + any_set |= w; + any_unset |= (!w) & valid_mask; + } else { + any_set |= w; + any_unset |= !w; + } + if any_set != 0 && any_unset != 0 { + break; + } } } + let has_true = any_set != 0; + let has_false = any_unset != 0; match (has_true, has_false) { - (true, true) => { - // Set contains both: every bit is in the set - Bitmask::new_set_all(len, true) - } - (true, false) => { - // Only 'true' in rhs: output bit is set iff lhs bit is true - lhs_mask.slice_clone(lhs_off, len) - } - (false, true) => { - // Only 'false' in rhs: output bit is set iff lhs bit is false - not_mask_simd::((lhs_mask, lhs_off, len)) - } - (false, false) => { - // Set is empty: all bits false - Bitmask::new_set_all(len, false) - } + // Set contains both values: every bit is a member + (true, true) => Bitmask::new_set_all(len, true), + // Only true in rhs: output bit is set iff lhs bit is true + (true, false) => lhs_mask.slice_clone(lhs_off, len), + // Only false in rhs: output bit is set iff lhs bit is false + (false, true) => not_mask_simd::((lhs_mask, lhs_off, len)), + // Empty set: no bits are members + (false, false) => Bitmask::new_set_all(len, false), } } @@ -385,49 +394,34 @@ where ao, bo ); } - let a_words = ao / 64; - let b_words = bo / 64; let n_words = (len + 63) / 64; let mut out = Bitmask::new_set_all(len, false); - { - use core::simd::Simd; - let simd_chunks = n_words / LANES; - let tail_words = n_words % LANES; - for chunk in 0..simd_chunks { - let mut arr_a = [0u64; LANES]; - let mut arr_b = [0u64; LANES]; - for lane in 0..LANES { - arr_a[lane] = unsafe { am.word_unchecked(a_words + chunk * LANES + lane) }; - arr_b[lane] = unsafe { bm.word_unchecked(b_words + chunk * LANES + lane) }; - } - let sa = Simd::::from_array(arr_a); - let sb = Simd::::from_array(arr_b); - let eq = !(sa ^ sb); - for lane in 0..LANES { - unsafe { - out.set_word_unchecked(chunk * LANES + lane, eq[lane]); - } + unsafe { + let ap = am.bits.as_ptr().cast::().add(ao / 64); + let bp = bm.bits.as_ptr().cast::().add(bo / 64); + let dp = out.bits.as_mut_ptr().cast::(); + let aw = std::slice::from_raw_parts(ap, n_words); + let bw = std::slice::from_raw_parts(bp, n_words); + + #[cfg(feature = "simd")] + { + let mut i = 0; + while i + LANES <= n_words { + let sa = Simd::::from_slice(&aw[i..i + LANES]); + let sb = Simd::::from_slice(&bw[i..i + LANES]); + let eq = !(sa ^ sb); + std::ptr::copy_nonoverlapping(eq.as_array().as_ptr(), dp.add(i), LANES); + i += LANES; } - } - let base = simd_chunks * LANES; - for k in 0..tail_words { - let wa = unsafe { am.word_unchecked(a_words + base + k) }; - let wb = unsafe { bm.word_unchecked(b_words + base + k) }; - let eq = !(wa ^ wb); - unsafe { - out.set_word_unchecked(base + k, eq); + for k in i..n_words { + *dp.add(k) = !(aw[k] ^ bw[k]); } } - } - #[cfg(not(feature = "simd"))] - { - for k in 0..n_words { - let wa = unsafe { am.word_unchecked(a_words + k) }; - let wb = unsafe { bm.word_unchecked(b_words + k) }; - let eq = !(wa ^ wb); - unsafe { - out.set_word_unchecked(k, eq); + #[cfg(not(feature = "simd"))] + { + for k in 0..n_words { + *dp.add(k) = !(aw[k] ^ bw[k]); } } } @@ -479,11 +473,10 @@ where !all_eq_mask_simd::(a, b) } -/// Vectorised equality test across entire bitmask windows with early termination optimisation. +/// Vectorised equality test across entire bitmask windows with early termination. /// /// Performs bulk equality comparison between two bitmask windows using SIMD comparison operations. -/// The implementation processes multiple words simultaneously and uses early termination to avoid -/// unnecessary work when differences are detected. +/// Processes multiple words simultaneously and terminates early when differences are detected. /// /// # Type Parameters /// - `LANES`: Number of u64 lanes to process simultaneously for vectorised comparison @@ -499,77 +492,80 @@ pub fn all_eq_mask_simd(a: BitmaskVT<'_>, b: BitmaskVT<'_>) where { let (am, ao, len) = a; + let (bm, bo, blen) = b; + debug_assert_eq!(len, blen, "BitWindow length mismatch in all_eq_mask"); - // Mask < 64 bits early exit - if len < 64 { - for i in 0..len { - if a.0.get(a.1 + i) != unsafe { b.0.get_unchecked(b.1 + i) } { - return false; - } - } + if len == 0 { return true; } - let (bm, bo, blen) = b; - debug_assert_eq!(len, blen, "BitWindow length mismatch in all_eq_mask"); + // Short masks: single word comparison with trailing bit mask + if len < 64 { + let wa = unsafe { am.word_unchecked(ao / 64) }; + let wb = unsafe { bm.word_unchecked(bo / 64) }; + let valid_mask = (1u64 << len) - 1; + return (wa & valid_mask) == (wb & valid_mask); + } + if ao % 64 != 0 || bo % 64 != 0 { panic!( "all_eq_mask_simd: offsets must be 64-bit aligned (got a: {}, b: {})", ao, bo ); } - let a_words = ao / 64; - let b_words = bo / 64; let n_words = (len + 63) / 64; let trailing = len & 63; - use core::simd::Simd; - use std::simd::prelude::SimdPartialEq; - - let simd_chunks = n_words / LANES; - let tail_words = n_words % LANES; - - for chunk in 0..simd_chunks { - let mut arr_a = [0u64; LANES]; - let mut arr_b = [0u64; LANES]; - for lane in 0..LANES { - arr_a[lane] = unsafe { am.word_unchecked(a_words + chunk * LANES + lane) }; - arr_b[lane] = unsafe { bm.word_unchecked(b_words + chunk * LANES + lane) }; - } - let sa = Simd::::from_array(arr_a); - let sb = Simd::::from_array(arr_b); - let eq_mask = sa.simd_eq(sb); - if !eq_mask.all() { - return false; + unsafe { + let aw = std::slice::from_raw_parts(am.bits.as_ptr().cast::().add(ao / 64), n_words); + let bw = std::slice::from_raw_parts(bm.bits.as_ptr().cast::().add(bo / 64), n_words); + + #[cfg(feature = "simd")] + { + use std::simd::prelude::SimdPartialEq; + let mut i = 0; + while i + LANES <= n_words { + let sa = Simd::::from_slice(&aw[i..i + LANES]); + let sb = Simd::::from_slice(&bw[i..i + LANES]); + if !sa.simd_eq(sb).all() { + return false; + } + i += LANES; + } + for k in i..n_words { + if k == n_words - 1 && trailing != 0 { + let mask = (1u64 << trailing) - 1; + if (aw[k] & mask) != (bw[k] & mask) { + return false; + } + } else if aw[k] != bw[k] { + return false; + } + } } - } - - let base = simd_chunks * LANES; - for k in 0..tail_words { - let idx = base + k; - let wa = unsafe { am.word_unchecked(a_words + idx) }; - let wb = unsafe { bm.word_unchecked(b_words + idx) }; - // For the last (possibly partial) word, mask slack bits - if idx == n_words - 1 && trailing != 0 { - let mask = (1u64 << trailing) - 1; - if (wa & mask) != (wb & mask) { - return false; + #[cfg(not(feature = "simd"))] + { + for k in 0..n_words { + if k == n_words - 1 && trailing != 0 { + let mask = (1u64 << trailing) - 1; + if (aw[k] & mask) != (bw[k] & mask) { + return false; + } + } else if aw[k] != bw[k] { + return false; + } } - } else if wa != wb { - return false; } } true } -/// Vectorised population count (number of set bits) with SIMD reduction for optimal performance. +/// Vectorised population count with SIMD reduction. /// -/// Computes the total number of set bits in a bitmask window using SIMD population count instructions -/// followed by horizontal reduction. This implementation provides significant performance improvements -/// for large bitmasks through parallel processing of multiple words. +/// Counts set bits in a bitmask window using SIMD popcount with horizontal reduction. /// /// # Type Parameters -/// - `LANES`: Number of u64 lanes to process simultaneously for vectorised popcount operations +/// - `LANES`: Number of u64 lanes to process simultaneously /// /// # Parameters /// - `m`: Bitmask window as `(mask, offset, length)` tuple @@ -585,34 +581,39 @@ where let word_start = offset / 64; let mut acc = 0usize; - { - use core::simd::Simd; - use std::simd::prelude::SimdUint; - - let simd_chunks = n_words / LANES; - let tail_words = n_words % LANES; + unsafe { + let words = std::slice::from_raw_parts( + mask.bits.as_ptr().cast::().add(word_start), + n_words, + ); - for chunk in 0..simd_chunks { - let mut arr = [0u64; LANES]; - for lane in 0..LANES { - arr[lane] = unsafe { mask.word_unchecked(word_start + chunk * LANES + lane) }; + #[cfg(feature = "simd")] + { + use std::simd::prelude::SimdUint; + let mut i = 0; + while i + LANES <= n_words { + let v = Simd::::from_slice(&words[i..i + LANES]); + acc += v.count_ones().reduce_sum() as usize; + i += LANES; + } + for k in i..n_words { + if k == n_words - 1 && len % 64 != 0 { + let slack_mask = (1u64 << (len % 64)) - 1; + acc += (words[k] & slack_mask).count_ones() as usize; + } else { + acc += words[k].count_ones() as usize; + } } - let v = Simd::::from_array(arr); - let counts = v.count_ones(); - acc += counts.reduce_sum() as usize; } - - // Tail scalar loop for any remaining words - let base = simd_chunks * LANES; - for k in 0..tail_words { - let word = unsafe { mask.word_unchecked(word_start + base + k) }; - // Mask off slack bits in final word if needed - if base + k == n_words - 1 && len % 64 != 0 { - let valid = len % 64; - let slack_mask = (1u64 << valid) - 1; - acc += (word & slack_mask).count_ones() as usize; - } else { - acc += word.count_ones() as usize; + #[cfg(not(feature = "simd"))] + { + for k in 0..n_words { + if k == n_words - 1 && len % 64 != 0 { + let slack_mask = (1u64 << (len % 64)) - 1; + acc += (words[k] & slack_mask).count_ones() as usize; + } else { + acc += words[k].count_ones() as usize; + } } } } @@ -624,14 +625,15 @@ where pub fn all_true_mask_simd(mask: &Bitmask) -> bool where { - if mask.len < 64 { - for i in 0..mask.len { - if !unsafe { mask.get_unchecked(i) } { - return false; - } - } + if mask.len == 0 { return true; } + // Short masks: single word comparison + if mask.len < 64 { + let w = unsafe { mask.word_unchecked(0) }; + let valid_mask = (1u64 << mask.len) - 1; + return (w & valid_mask) == valid_mask; + } let n_bits = mask.len; let n_words = (n_bits + 63) / 64; let words: &[u64] = @@ -670,14 +672,15 @@ where pub fn all_false_mask_simd(mask: &Bitmask) -> bool where { - if mask.len < 64 { - for i in 0..mask.len { - if unsafe { mask.get_unchecked(i) } { - return false; - } - } + if mask.len == 0 { return true; } + // Short masks: single word comparison + if mask.len < 64 { + let w = unsafe { mask.word_unchecked(0) }; + let valid_mask = (1u64 << mask.len) - 1; + return (w & valid_mask) == 0; + } let n_bits = mask.len; let n_words = (n_bits + 63) / 64; let words: &[u64] = @@ -709,6 +712,59 @@ where true } + +/// Generates a SIMD equality mask function for a given element type and lane count. +/// Processes LANES elements per iteration, with a scalar tail for the remainder. +macro_rules! impl_simd_eq_mask { + ($fn_name:ident, $t:ty, $lanes:expr) => { + pub fn $fn_name(data: &[$t], field_mask: $t, target: $t) -> Bitmask { + use vec64::Vec64; + use std::simd::cmp::SimdPartialEq; + let n = data.len(); + let n_bytes = (n + 7) / 8; + let mut bytes = Vec64::::with_capacity(n_bytes); + bytes.resize(n_bytes, 0); + + let mask_vec = Simd::<$t, $lanes>::splat(field_mask); + let target_vec = Simd::<$t, $lanes>::splat(target); + + let chunks = n / $lanes; + for i in 0..chunks { + let d = Simd::<$t, $lanes>::from_slice(&data[i * $lanes..]); + let masked = d & mask_vec; + let cmp = masked.simd_eq(target_vec); + let bits = cmp.to_bitmask() as u64; + let bit_pos = i * $lanes; + let byte_idx = bit_pos / 8; + let bit_shift = bit_pos % 8; + // Write result bits. For LANES >= 8 bit_shift is always 0 + // since LANES is a power of 2. For LANES < 8, sub-byte + // results OR into position within the byte. + let shifted = bits << bit_shift; + for b in 0..(($lanes + 7) / 8) { + bytes[byte_idx + b] |= (shifted >> (b * 8)) as u8; + } + } + + // Scalar tail + let start = chunks * $lanes; + for j in start..n { + if (data[j] & field_mask) == target { + bytes[j / 8] |= 1 << (j % 8); + } + } + + Bitmask::new(bytes, n) + } + }; +} + +impl_simd_eq_mask!(simd_eq_mask_u8, u8, W8); +impl_simd_eq_mask!(simd_eq_mask_u16, u16, W16); +impl_simd_eq_mask!(simd_eq_mask_u32, u32, W32); +impl_simd_eq_mask!(simd_eq_mask_u64, u64, W64); + + #[cfg(test)] mod tests { use crate::{Bitmask, BitmaskVT}; diff --git a/src/utils.rs b/src/utils.rs index e5784cb..d22e417 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -211,12 +211,61 @@ pub fn simd_mask( ) -> Mask where { - let mut bits = [false; N]; - for l in 0..N { - let idx = offset + l; - bits[l] = idx < len && unsafe { mask.get_unchecked(idx) }; + // Extract the packed bits covering this SIMD chunk from the bitmask. + // The bitmask is LSB-ordered which matches Mask::from_bitmask convention. + let word_idx = offset / 64; + let bit_shift = offset % 64; + let raw = unsafe { mask.word_unchecked(word_idx) } >> bit_shift; + + // If the chunk straddles a word boundary, pull in bits from the next word + let raw = if bit_shift > 0 && word_idx + 1 < (mask.len + 63) / 64 { + raw | (unsafe { mask.word_unchecked(word_idx + 1) } << (64 - bit_shift)) + } else { + raw + }; + + // Zero out lanes that are beyond the array length + let remaining = if offset < len { len - offset } else { 0 }; + let raw = if remaining < N && remaining < 64 { + raw & ((1u64 << remaining) - 1) + } else { + raw + }; + + Mask::from_bitmask(raw) +} + +/// Writes a SIMD mask's packed bits directly into the output bitmask at the given offset. +/// This is the write-side complement to `simd_mask`, avoiding per-lane `set_unchecked` calls. +#[inline(always)] +pub fn write_simd_mask_bits( + out_mask: &mut Bitmask, + offset: usize, + m: Mask, +) +where +{ + let mbits = m.to_bitmask(); + let word_idx = offset / 64; + let bit_shift = offset % 64; + + unsafe { + // Read-modify-write the target word + let existing = out_mask.word_unchecked(word_idx); + // Clear the N bits at bit_shift, then OR in the new bits + let lane_mask = if N >= 64 { !0u64 } else { (1u64 << N) - 1 }; + let cleared = existing & !(lane_mask << bit_shift); + out_mask.set_word_unchecked(word_idx, cleared | (mbits << bit_shift)); + + // If the chunk straddles a word boundary, write the overflow to the next word + if bit_shift > 0 && bit_shift + N > 64 { + let overflow_bits = N - (64 - bit_shift); + let next_existing = out_mask.word_unchecked(word_idx + 1); + let overflow_mask = (1u64 << overflow_bits) - 1; + let cleared_next = next_existing & !overflow_mask; + out_mask.set_word_unchecked(word_idx + 1, cleared_next | (mbits >> (64 - bit_shift))); + } } - Mask::from_array(bits) } /// Checks the mask capacity is large enough