diff --git a/Cargo.toml b/Cargo.toml index 886ad44..27f353d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,7 +76,7 @@ c_ffi_tests = ['cc'] # Highly recommend keeping these off unless required # E.g., constrained or embedded environments, as they add combinatorial # weight to the binary and enum match arms -extended_categorical = ["extended_numeric_types"] +extended_categorical = [] # Adds UInt8, UInt16, Int8, Int16 types. # diff --git a/pyo3/Cargo.toml b/pyo3/Cargo.toml index 5eb9c15..d0762fb 100644 --- a/pyo3/Cargo.toml +++ b/pyo3/Cargo.toml @@ -22,7 +22,7 @@ name = "minarrow_pyo3" crate-type = ["cdylib", "rlib"] [dependencies] -minarrow = { version = "0.8.1", features = ["large_string"] } +minarrow = { version = "0.9.1", features = ["large_string"] } pyo3 = { version = "0.23" } thiserror = "2" diff --git a/src/enums/array.rs b/src/enums/array.rs index f3ec583..d77c006 100644 --- a/src/enums/array.rs +++ b/src/enums/array.rs @@ -1908,6 +1908,395 @@ impl Array { crate::traits::print::value_to_string(self, idx) } + /// Extract the element at `idx` as a `Scalar`, or `None` if out of bounds. + /// + /// Returns `Scalar::Null` for null elements. + #[cfg(feature = "scalar_type")] + pub fn get_scalar(&self, idx: usize) -> Option { + use crate::Scalar; + if idx >= self.len() { + return None; + } + let is_null = self.null_mask().is_some_and(|m| !m.get(idx)); + if is_null { + return Some(Scalar::Null); + } + match self { + Array::NumericArray(num) => match num { + #[cfg(feature = "extended_numeric_types")] + NumericArray::Int8(a) => Some(Scalar::Int8(a.data[idx])), + #[cfg(feature = "extended_numeric_types")] + NumericArray::Int16(a) => Some(Scalar::Int16(a.data[idx])), + NumericArray::Int32(a) => Some(Scalar::Int32(a.data[idx])), + NumericArray::Int64(a) => Some(Scalar::Int64(a.data[idx])), + #[cfg(feature = "extended_numeric_types")] + NumericArray::UInt8(a) => Some(Scalar::UInt8(a.data[idx])), + #[cfg(feature = "extended_numeric_types")] + NumericArray::UInt16(a) => Some(Scalar::UInt16(a.data[idx])), + NumericArray::UInt32(a) => Some(Scalar::UInt32(a.data[idx])), + NumericArray::UInt64(a) => Some(Scalar::UInt64(a.data[idx])), + NumericArray::Float32(a) => Some(Scalar::Float32(a.data[idx])), + NumericArray::Float64(a) => Some(Scalar::Float64(a.data[idx])), + NumericArray::Null => Some(Scalar::Null), + }, + Array::TextArray(text) => match text { + TextArray::String32(a) => Some(Scalar::String32(a.get_str(idx)?.to_owned())), + #[cfg(feature = "large_string")] + TextArray::String64(a) => Some(Scalar::String64(a.get_str(idx)?.to_owned())), + TextArray::Categorical32(a) => Some(Scalar::String32(a.get_str(idx)?.to_owned())), + #[cfg(feature = "extended_categorical")] + TextArray::Categorical8(a) => Some(Scalar::String32(a.get_str(idx)?.to_owned())), + #[cfg(feature = "extended_categorical")] + TextArray::Categorical16(a) => Some(Scalar::String32(a.get_str(idx)?.to_owned())), + #[cfg(feature = "extended_categorical")] + TextArray::Categorical64(a) => Some(Scalar::String32(a.get_str(idx)?.to_owned())), + TextArray::Null => Some(Scalar::Null), + }, + Array::BooleanArray(a) => Some(Scalar::Boolean(a.get(idx)?)), + #[cfg(feature = "datetime")] + Array::TemporalArray(temp) => match temp { + crate::TemporalArray::Datetime32(a) => Some(Scalar::Datetime32(a.data[idx])), + crate::TemporalArray::Datetime64(a) => Some(Scalar::Datetime64(a.data[idx])), + crate::TemporalArray::Null => Some(Scalar::Null), + }, + Array::Null => Some(Scalar::Null), + } + } + + /// Create an all-null array of the given ArrowType with `n_rows` elements. + /// + /// The data buffer is zero-filled and every element is masked as null. + /// For datetime types, set the time_unit on the returned array afterwards. + pub fn null_array(arrow_type: &ArrowType, n_rows: usize) -> Array { + let mask = Bitmask::with_capacity(n_rows); + match arrow_type { + ArrowType::Null => Array::Null, + ArrowType::Boolean => { + Array::BooleanArray(Arc::new(BooleanArray::from_vec(vec![false; n_rows], Some(mask)))) + } + #[cfg(feature = "extended_numeric_types")] + ArrowType::Int8 => { + Array::from_int8(IntegerArray::new(Vec64::from_slice(&vec![0i8; n_rows]), Some(mask))) + } + #[cfg(feature = "extended_numeric_types")] + ArrowType::Int16 => { + Array::from_int16(IntegerArray::new(Vec64::from_slice(&vec![0i16; n_rows]), Some(mask))) + } + ArrowType::Int32 => { + Array::from_int32(IntegerArray::new(Vec64::from_slice(&vec![0i32; n_rows]), Some(mask))) + } + ArrowType::Int64 => { + Array::from_int64(IntegerArray::new(Vec64::from_slice(&vec![0i64; n_rows]), Some(mask))) + } + #[cfg(feature = "extended_numeric_types")] + ArrowType::UInt8 => { + Array::NumericArray(NumericArray::UInt8(Arc::new(IntegerArray::new(Vec64::from_slice(&vec![0u8; n_rows]), Some(mask))))) + } + #[cfg(feature = "extended_numeric_types")] + ArrowType::UInt16 => { + Array::NumericArray(NumericArray::UInt16(Arc::new(IntegerArray::new(Vec64::from_slice(&vec![0u16; n_rows]), Some(mask))))) + } + ArrowType::UInt32 => { + Array::NumericArray(NumericArray::UInt32(Arc::new(IntegerArray::new(Vec64::from_slice(&vec![0u32; n_rows]), Some(mask))))) + } + ArrowType::UInt64 => { + Array::NumericArray(NumericArray::UInt64(Arc::new(IntegerArray::new(Vec64::from_slice(&vec![0u64; n_rows]), Some(mask))))) + } + ArrowType::Float32 => { + Array::NumericArray(NumericArray::Float32(Arc::new(FloatArray::new(Vec64::from_slice(&vec![0.0f32; n_rows]), Some(mask))))) + } + ArrowType::Float64 => { + Array::from_float64(FloatArray::new(Vec64::from_slice(&vec![0.0f64; n_rows]), Some(mask))) + } + ArrowType::String => { + let strs: Vec<&str> = vec![""; n_rows]; + let mut arr = StringArray::::from_slice(&strs); + arr.null_mask = Some(mask); + Array::from_string32(arr) + } + #[cfg(feature = "large_string")] + ArrowType::LargeString => { + let strs: Vec<&str> = vec![""; n_rows]; + let mut arr = StringArray::::from_slice(&strs); + arr.null_mask = Some(mask); + Array::TextArray(TextArray::String64(Arc::new(arr))) + } + ArrowType::Dictionary(cat_idx) => { + let strs: Vec<&str> = vec![""; n_rows]; + match cat_idx { + CategoricalIndexType::UInt32 => { + let mut arr = CategoricalArray::::from_vec(strs, None); + arr.null_mask = Some(mask); + Array::TextArray(TextArray::Categorical32(Arc::new(arr))) + } + #[cfg(feature = "extended_categorical")] + _ => Array::Null, + } + } + #[cfg(feature = "datetime")] + ArrowType::Date32 | ArrowType::Time32(_) | ArrowType::Duration32(_) => { + Array::TemporalArray(crate::TemporalArray::Datetime32(Arc::new( + crate::DatetimeArray::new(Vec64::from_slice(&vec![0i32; n_rows]), Some(mask), None), + ))) + } + #[cfg(feature = "datetime")] + ArrowType::Date64 + | ArrowType::Time64(_) + | ArrowType::Duration64(_) + | ArrowType::Timestamp(_, _) => { + Array::TemporalArray(crate::TemporalArray::Datetime64(Arc::new( + crate::DatetimeArray::new(Vec64::from_slice(&vec![0i64; n_rows]), Some(mask), None), + ))) + } + #[cfg(feature = "datetime")] + ArrowType::Interval(_) => Array::Null, + ArrowType::Utf8View => { + let strs: Vec<&str> = vec![""; n_rows]; + let mut arr = StringArray::::from_slice(&strs); + arr.null_mask = Some(mask); + Array::from_string32(arr) + } + } + } + + /// Build an array from a slice of Scalars. + /// + /// All scalars must be the same type. The type is inferred from the first + /// non-Null element. If all elements are Null, returns `Array::Null`. + #[cfg(feature = "scalar_type")] + pub fn from_scalars(scalars: &[crate::Scalar]) -> Array { + use crate::Scalar; + if scalars.is_empty() { + return Array::default(); + } + + // Find the first non-null to determine type + let template = scalars.iter().find(|s| !matches!(s, Scalar::Null)); + let Some(template) = template else { + return Array::Null; + }; + + match template { + Scalar::Float64(_) => { + let mut data = Vec64::::with_capacity(scalars.len()); + let mut mask = Bitmask::new_set_all(scalars.len(), true); + for (i, s) in scalars.iter().enumerate() { + match s { + Scalar::Float64(v) => data.push(*v), + Scalar::Null => { data.push(0.0); mask.set(i, false); } + _ => data.push(s.f64()), + } + } + let has_nulls = mask.count_zeros() > 0; + Array::from_float64(FloatArray::new(crate::Buffer::from_vec64(data), if has_nulls { Some(mask) } else { None })) + } + Scalar::Float32(_) => { + let mut data = Vec64::::with_capacity(scalars.len()); + let mut mask = Bitmask::new_set_all(scalars.len(), true); + for (i, s) in scalars.iter().enumerate() { + match s { + Scalar::Float32(v) => data.push(*v), + Scalar::Null => { data.push(0.0); mask.set(i, false); } + _ => data.push(s.f64() as f32), + } + } + let has_nulls = mask.count_zeros() > 0; + Array::NumericArray(NumericArray::Float32(Arc::new(FloatArray::new(crate::Buffer::from_vec64(data), if has_nulls { Some(mask) } else { None })))) + } + Scalar::Int32(_) => { + let mut data = Vec64::::with_capacity(scalars.len()); + let mut mask = Bitmask::new_set_all(scalars.len(), true); + for (i, s) in scalars.iter().enumerate() { + match s { + Scalar::Int32(v) => data.push(*v), + Scalar::Null => { data.push(0); mask.set(i, false); } + _ => data.push(s.f64() as i32), + } + } + let has_nulls = mask.count_zeros() > 0; + Array::from_int32(IntegerArray::new(crate::Buffer::from_vec64(data), if has_nulls { Some(mask) } else { None })) + } + Scalar::Int64(_) => { + let mut data = Vec64::::with_capacity(scalars.len()); + let mut mask = Bitmask::new_set_all(scalars.len(), true); + for (i, s) in scalars.iter().enumerate() { + match s { + Scalar::Int64(v) => data.push(*v), + Scalar::Null => { data.push(0); mask.set(i, false); } + _ => data.push(s.f64() as i64), + } + } + let has_nulls = mask.count_zeros() > 0; + Array::from_int64(IntegerArray::new(crate::Buffer::from_vec64(data), if has_nulls { Some(mask) } else { None })) + } + Scalar::UInt32(_) => { + let mut data = Vec64::::with_capacity(scalars.len()); + let mut mask = Bitmask::new_set_all(scalars.len(), true); + for (i, s) in scalars.iter().enumerate() { + match s { + Scalar::UInt32(v) => data.push(*v), + Scalar::Null => { data.push(0); mask.set(i, false); } + _ => data.push(s.f64() as u32), + } + } + let has_nulls = mask.count_zeros() > 0; + Array::NumericArray(NumericArray::UInt32(Arc::new(IntegerArray::new(crate::Buffer::from_vec64(data), if has_nulls { Some(mask) } else { None })))) + } + Scalar::UInt64(_) => { + let mut data = Vec64::::with_capacity(scalars.len()); + let mut mask = Bitmask::new_set_all(scalars.len(), true); + for (i, s) in scalars.iter().enumerate() { + match s { + Scalar::UInt64(v) => data.push(*v), + Scalar::Null => { data.push(0); mask.set(i, false); } + _ => data.push(s.f64() as u64), + } + } + let has_nulls = mask.count_zeros() > 0; + Array::NumericArray(NumericArray::UInt64(Arc::new(IntegerArray::new(crate::Buffer::from_vec64(data), if has_nulls { Some(mask) } else { None })))) + } + Scalar::Boolean(_) => { + let mut data = Vec::with_capacity(scalars.len()); + let mut mask = Bitmask::new_set_all(scalars.len(), true); + for (i, s) in scalars.iter().enumerate() { + match s { + Scalar::Boolean(v) => data.push(*v), + Scalar::Null => { data.push(false); mask.set(i, false); } + _ => data.push(false), + } + } + let has_nulls = mask.count_zeros() > 0; + Array::BooleanArray(Arc::new(BooleanArray::from_vec(data, if has_nulls { Some(mask) } else { None }))) + } + Scalar::String32(_) => { + let strs: Vec = scalars.iter().map(|s| match s { + Scalar::String32(v) => v.clone(), + #[cfg(feature = "large_string")] + Scalar::String64(v) => v.clone(), + Scalar::Null => String::new(), + _ => String::new(), + }).collect(); + let mut mask = Bitmask::new_set_all(scalars.len(), true); + for (i, s) in scalars.iter().enumerate() { + if matches!(s, Scalar::Null) { mask.set(i, false); } + } + let refs: Vec<&str> = strs.iter().map(|s| s.as_str()).collect(); + let mut arr = StringArray::::from_slice(&refs); + let has_nulls = mask.count_zeros() > 0; + if has_nulls { arr.null_mask = Some(mask); } + Array::from_string32(arr) + } + #[cfg(feature = "large_string")] + Scalar::String64(_) => { + let strs: Vec = scalars.iter().map(|s| match s { + Scalar::String64(v) | Scalar::String32(v) => v.clone(), + Scalar::Null => String::new(), + _ => String::new(), + }).collect(); + let mut mask = Bitmask::new_set_all(scalars.len(), true); + for (i, s) in scalars.iter().enumerate() { + if matches!(s, Scalar::Null) { mask.set(i, false); } + } + let refs: Vec<&str> = strs.iter().map(|s| s.as_str()).collect(); + let mut arr = StringArray::::from_slice(&refs); + let has_nulls = mask.count_zeros() > 0; + if has_nulls { arr.null_mask = Some(mask); } + Array::TextArray(TextArray::String64(Arc::new(arr))) + } + #[cfg(feature = "datetime")] + Scalar::Datetime32(_) => { + let mut data = Vec64::::with_capacity(scalars.len()); + let mut mask = Bitmask::new_set_all(scalars.len(), true); + for (i, s) in scalars.iter().enumerate() { + match s { + Scalar::Datetime32(v) => data.push(*v), + Scalar::Null => { data.push(0); mask.set(i, false); } + _ => data.push(0), + } + } + let has_nulls = mask.count_zeros() > 0; + Array::TemporalArray(crate::TemporalArray::Datetime32(Arc::new( + crate::DatetimeArray::new(crate::Buffer::from_vec64(data), if has_nulls { Some(mask) } else { None }, None), + ))) + } + #[cfg(feature = "datetime")] + Scalar::Datetime64(_) => { + let mut data = Vec64::::with_capacity(scalars.len()); + let mut mask = Bitmask::new_set_all(scalars.len(), true); + for (i, s) in scalars.iter().enumerate() { + match s { + Scalar::Datetime64(v) => data.push(*v), + Scalar::Null => { data.push(0); mask.set(i, false); } + _ => data.push(0), + } + } + let has_nulls = mask.count_zeros() > 0; + Array::TemporalArray(crate::TemporalArray::Datetime64(Arc::new( + crate::DatetimeArray::new(crate::Buffer::from_vec64(data), if has_nulls { Some(mask) } else { None }, None), + ))) + } + #[cfg(feature = "datetime")] + Scalar::Interval => Array::Null, + #[cfg(feature = "extended_numeric_types")] + Scalar::Int8(_) => { + let mut data = Vec64::::with_capacity(scalars.len()); + let mut mask = Bitmask::new_set_all(scalars.len(), true); + for (i, s) in scalars.iter().enumerate() { + match s { + Scalar::Int8(v) => data.push(*v), + Scalar::Null => { data.push(0); mask.set(i, false); } + _ => data.push(s.f64() as i8), + } + } + let has_nulls = mask.count_zeros() > 0; + Array::from_int8(IntegerArray::new(crate::Buffer::from_vec64(data), if has_nulls { Some(mask) } else { None })) + } + #[cfg(feature = "extended_numeric_types")] + Scalar::Int16(_) => { + let mut data = Vec64::::with_capacity(scalars.len()); + let mut mask = Bitmask::new_set_all(scalars.len(), true); + for (i, s) in scalars.iter().enumerate() { + match s { + Scalar::Int16(v) => data.push(*v), + Scalar::Null => { data.push(0); mask.set(i, false); } + _ => data.push(s.f64() as i16), + } + } + let has_nulls = mask.count_zeros() > 0; + Array::from_int16(IntegerArray::new(crate::Buffer::from_vec64(data), if has_nulls { Some(mask) } else { None })) + } + #[cfg(feature = "extended_numeric_types")] + Scalar::UInt8(_) => { + let mut data = Vec64::::with_capacity(scalars.len()); + let mut mask = Bitmask::new_set_all(scalars.len(), true); + for (i, s) in scalars.iter().enumerate() { + match s { + Scalar::UInt8(v) => data.push(*v), + Scalar::Null => { data.push(0); mask.set(i, false); } + _ => data.push(s.f64() as u8), + } + } + let has_nulls = mask.count_zeros() > 0; + Array::from_uint8(IntegerArray::new(crate::Buffer::from_vec64(data), if has_nulls { Some(mask) } else { None })) + } + #[cfg(feature = "extended_numeric_types")] + Scalar::UInt16(_) => { + let mut data = Vec64::::with_capacity(scalars.len()); + let mut mask = Bitmask::new_set_all(scalars.len(), true); + for (i, s) in scalars.iter().enumerate() { + match s { + Scalar::UInt16(v) => data.push(*v), + Scalar::Null => { data.push(0); mask.set(i, false); } + _ => data.push(s.f64() as u16), + } + } + let has_nulls = mask.count_zeros() > 0; + Array::from_uint16(IntegerArray::new(crate::Buffer::from_vec64(data), if has_nulls { Some(mask) } else { None })) + } + Scalar::Null => Array::Null, + } + } + /// Compare two elements within the same array by index. /// /// Uses total ordering for floats via `total_cmp()`. Nulls sort last: @@ -2080,11 +2469,11 @@ impl Array { TextArray::Categorical32(arr) => { Arc::make_mut(arr).set_null_mask(Some(mask)); } - #[cfg(all(feature = "extended_categorical", feature = "extended_numeric_types"))] + #[cfg(feature = "extended_categorical")] TextArray::Categorical8(arr) => { Arc::make_mut(arr).set_null_mask(Some(mask)); } - #[cfg(all(feature = "extended_categorical", feature = "extended_numeric_types"))] + #[cfg(feature = "extended_categorical")] TextArray::Categorical16(arr) => { Arc::make_mut(arr).set_null_mask(Some(mask)); } diff --git a/src/enums/operators.rs b/src/enums/operators.rs index 6681b83..6ac1b88 100644 --- a/src/enums/operators.rs +++ b/src/enums/operators.rs @@ -27,6 +27,13 @@ pub enum ArithmeticOperator { /// For integers, uses repeated multiplication. For floating-point, uses `pow()` function. /// Negative exponents on integers may yield zero due to truncation. Power, + /// Floor division (`lhs // rhs`) + /// + /// Rounds the quotient towards negative infinity. For unsigned integers this is + /// identical to truncation division. For signed integers, when the remainder is + /// non-zero and the operands have different signs, the result is one less than + /// truncation division. For floating-point, equivalent to `(lhs / rhs).floor()`. + FloorDiv, } /// Comparison operators for binary predicates. diff --git a/src/enums/value/conversions.rs b/src/enums/value/conversions.rs index 9163070..24aa202 100644 --- a/src/enums/value/conversions.rs +++ b/src/enums/value/conversions.rs @@ -874,6 +874,20 @@ impl TryFrom for NumericArrayV { }), } } + Value::Array(inner) => { + let arr = Arc::try_unwrap(inner).unwrap_or_else(|arc| (*arc).clone()); + match arr { + Array::NumericArray(num_arr) => { + let len = num_arr.len(); + Ok(NumericArrayV::new(num_arr, 0, len)) + } + _ => Err(MinarrowError::TypeError { + from: "Value", + to: "NumericArrayV", + message: Some("Array is not a NumericArray".to_owned()), + }), + } + } _ => Err(MinarrowError::TypeError { from: "Value", to: "NumericArrayV", diff --git a/src/ffi/arrow_c_ffi.rs b/src/ffi/arrow_c_ffi.rs index e80b2f5..ed9aa93 100644 --- a/src/ffi/arrow_c_ffi.rs +++ b/src/ffi/arrow_c_ffi.rs @@ -553,9 +553,9 @@ fn export_categorical_array_to_c( let mut field = schema.fields[0].clone(); field.dtype = match index_bits { - #[cfg(all(feature = "extended_categorical", feature = "extended_numeric_types"))] + #[cfg(feature = "extended_categorical")] 8 => ArrowType::Dictionary(crate::ffi::arrow_dtype::CategoricalIndexType::UInt8), - #[cfg(all(feature = "extended_categorical", feature = "extended_numeric_types"))] + #[cfg(feature = "extended_categorical")] 16 => ArrowType::Dictionary(crate::ffi::arrow_dtype::CategoricalIndexType::UInt16), 32 => ArrowType::Dictionary(crate::ffi::arrow_dtype::CategoricalIndexType::UInt32), #[cfg(feature = "extended_categorical")] @@ -1483,14 +1483,12 @@ unsafe fn import_categorical( // Build codes & wrap match index_type { - #[cfg(feature = "extended_numeric_types")] #[cfg(feature = "extended_categorical")] CategoricalIndexType::UInt8 => { let codes_buf = unsafe { build_codes::(codes_ptr, len, ownership) }; let arr = CategoricalArray::::new(codes_buf, dict_strings, null_mask); Arc::new(Array::TextArray(TextArray::Categorical8(Arc::new(arr)))) } - #[cfg(feature = "extended_numeric_types")] #[cfg(feature = "extended_categorical")] CategoricalIndexType::UInt16 => { let codes_buf = unsafe { build_codes::(codes_ptr, len, ownership) }; @@ -1502,7 +1500,6 @@ unsafe fn import_categorical( let arr = CategoricalArray::::new(codes_buf, dict_strings, null_mask); Arc::new(Array::TextArray(TextArray::Categorical32(Arc::new(arr)))) } - #[cfg(feature = "extended_numeric_types")] #[cfg(feature = "extended_categorical")] CategoricalIndexType::UInt64 => { let codes_buf = unsafe { build_codes::(codes_ptr, len, ownership) }; diff --git a/src/ffi/arrow_dtype.rs b/src/ffi/arrow_dtype.rs index 8818205..a0eb799 100644 --- a/src/ffi/arrow_dtype.rs +++ b/src/ffi/arrow_dtype.rs @@ -117,8 +117,7 @@ pub enum ArrowType { /// - Smaller widths reduce memory footprint for low-cardinality data. /// - Larger widths enable more distinct categories without overflow. /// - Variant availability depends on feature flags: -/// - `UInt8` and `UInt16` require both `extended_categorical` and `extended_numeric_types`. -/// - `UInt64` requires `extended_categorical`. +/// - `UInt8`, `UInt16`, and `UInt64` require `extended_categorical`. /// - `UInt32` is always available. /// /// ## Interoperability @@ -127,12 +126,12 @@ pub enum ArrowType { #[derive(PartialEq, Clone, Debug)] pub enum CategoricalIndexType { - #[cfg(all(feature = "extended_categorical", feature = "extended_numeric_types"))] + #[cfg(feature = "extended_categorical")] UInt8, - #[cfg(all(feature = "extended_categorical", feature = "extended_numeric_types"))] + #[cfg(feature = "extended_categorical")] UInt16, UInt32, - #[cfg(all(feature = "extended_categorical"))] + #[cfg(feature = "extended_categorical")] UInt64, } diff --git a/src/kernels/arithmetic/simd.rs b/src/kernels/arithmetic/simd.rs index b69b25e..1deb89e 100644 --- a/src/kernels/arithmetic/simd.rs +++ b/src/kernels/arithmetic/simd.rs @@ -63,7 +63,7 @@ pub fn int_dense_body_simd( ArithmeticOperator::Multiply => a * b, ArithmeticOperator::Divide => a / b, // Panics if divisor is zero ArithmeticOperator::Remainder => a % b, // Panics if divisor is zero - ArithmeticOperator::Power => { + ArithmeticOperator::Power | ArithmeticOperator::FloorDiv => { vectorisable = 0; break; } @@ -88,6 +88,15 @@ pub fn int_dense_body_simd( } acc } + ArithmeticOperator::FloorDiv => { + if rhs[idx] == T::zero() { + panic!("Floor division by zero") + } else { + let d = lhs[idx] / rhs[idx]; + let r = lhs[idx] % rhs[idx]; + if r != T::zero() && (lhs[idx] ^ rhs[idx]) < T::zero() { d - T::one() } else { d } + } + } }; } } @@ -159,6 +168,22 @@ pub fn int_masked_body_simd( let r = div_zero.select(Simd::splat(T::zero()), r); (r, valid) } + ArithmeticOperator::FloorDiv => { + let div_zero = b.simd_eq(Simd::splat(T::zero())); + let valid = !div_zero; + // Per-lane floor division with sign correction + let mut tmp = [T::zero(); LANES]; + for l in 0..LANES { + if b[l] == T::zero() { + tmp[l] = T::zero(); + } else { + let d = a[l] / b[l]; + let r = a[l] % b[l]; + tmp[l] = if r != T::zero() && (a[l] ^ b[l]) < T::zero() { d - T::one() } else { d }; + } + } + (Simd::::from_array(tmp), valid) + } }; r.copy_to_slice(&mut out[i..i + LANES]); // Write the out_mask based on the op @@ -217,6 +242,21 @@ pub fn int_masked_body_simd( } } } + ArithmeticOperator::FloorDiv => { + if rhs[idx] == T::zero() { + out[idx] = T::zero(); + unsafe { + out_mask.set_unchecked(idx, false); + } + } else { + let d = lhs[idx] / rhs[idx]; + let r = lhs[idx] % rhs[idx]; + out[idx] = if r != T::zero() && (lhs[idx] ^ rhs[idx]) < T::zero() { d - T::one() } else { d }; + unsafe { + out_mask.set_unchecked(idx, true); + } + } + } } } return; @@ -254,6 +294,18 @@ pub fn int_masked_body_simd( } Simd::::from_array(tmp) } + ArithmeticOperator::FloorDiv => { + // Per-lane floor division with sign correction + let mut tmp = [T::zero(); LANES]; + for l in 0..LANES { + if b[l] != T::zero() { + let d = a[l] / b[l]; + let r = a[l] % b[l]; + tmp[l] = if r != T::zero() && (a[l] ^ b[l]) < T::zero() { d - T::one() } else { d }; + } + } + Simd::::from_array(tmp) + } }; // apply source validity mask, write results @@ -262,8 +314,8 @@ pub fn int_masked_body_simd( // write out-mask bits: combine source mask with div-by-zero validity let final_mask = match op { - ArithmeticOperator::Divide | ArithmeticOperator::Remainder => { - // For div/rem: valid iff source is valid AND not dividing by zero + ArithmeticOperator::Divide | ArithmeticOperator::Remainder | ArithmeticOperator::FloorDiv => { + // Valid iff source is valid and not dividing by zero m_src & !div_zero } _ => m_src, @@ -301,6 +353,15 @@ pub fn int_masked_body_simd( } } ArithmeticOperator::Power => (lhs[j].pow(rhs[j].to_u32().unwrap_or(0)), true), + ArithmeticOperator::FloorDiv => { + if rhs[j] == T::zero() { + (T::zero(), false) + } else { + let d = lhs[j] / rhs[j]; + let r = lhs[j] % rhs[j]; + if r != T::zero() && (lhs[j] ^ rhs[j]) < T::zero() { (d - T::one(), true) } else { (d, true) } + } + } }; out[j] = result; unsafe { out_mask.set_unchecked(j, final_valid) }; @@ -345,6 +406,7 @@ pub fn float_masked_body_f32_simd( ArithmeticOperator::Divide => a / b, ArithmeticOperator::Remainder => a % b, ArithmeticOperator::Power => (b * a.ln()).exp(), + ArithmeticOperator::FloorDiv => (a / b).floor(), }; let selected = m.select(res, Simd::::splat(0.0)); @@ -371,6 +433,7 @@ pub fn float_masked_body_f32_simd( ArithmeticOperator::Divide => lhs[j] / rhs[j], ArithmeticOperator::Remainder => lhs[j] % rhs[j], ArithmeticOperator::Power => (rhs[j] * lhs[j].ln()).exp(), + ArithmeticOperator::FloorDiv => (lhs[j] / rhs[j]).floor(), }; unsafe { out_mask.set_unchecked(j, true) }; } else { @@ -417,6 +480,7 @@ pub fn float_masked_body_f64_simd( ArithmeticOperator::Divide => a / b, ArithmeticOperator::Remainder => a % b, ArithmeticOperator::Power => (b * a.ln()).exp(), + ArithmeticOperator::FloorDiv => (a / b).floor(), }; let selected = m.select(res, Simd::::splat(0.0)); @@ -443,6 +507,7 @@ pub fn float_masked_body_f64_simd( ArithmeticOperator::Divide => lhs[j] / rhs[j], ArithmeticOperator::Remainder => lhs[j] % rhs[j], ArithmeticOperator::Power => (rhs[j] * lhs[j].ln()).exp(), + ArithmeticOperator::FloorDiv => (lhs[j] / rhs[j]).floor(), }; unsafe { out_mask.set_unchecked(j, true) }; } else { @@ -474,6 +539,7 @@ pub fn float_dense_body_f32_simd( ArithmeticOperator::Divide => a / b, ArithmeticOperator::Remainder => a % b, ArithmeticOperator::Power => (b * a.ln()).exp(), + ArithmeticOperator::FloorDiv => (a / b).floor(), }; res.copy_to_slice(&mut out[i..i + LANES]); i += LANES; @@ -488,6 +554,7 @@ pub fn float_dense_body_f32_simd( ArithmeticOperator::Divide => lhs[j] / rhs[j], ArithmeticOperator::Remainder => lhs[j] % rhs[j], ArithmeticOperator::Power => (rhs[j] * lhs[j].ln()).exp(), + ArithmeticOperator::FloorDiv => (lhs[j] / rhs[j]).floor(), }; } } @@ -514,6 +581,7 @@ pub fn float_dense_body_f64_simd( ArithmeticOperator::Divide => a / b, ArithmeticOperator::Remainder => a % b, ArithmeticOperator::Power => (b * a.ln()).exp(), + ArithmeticOperator::FloorDiv => (a / b).floor(), }; res.copy_to_slice(&mut out[i..i + LANES]); i += LANES; @@ -528,6 +596,7 @@ pub fn float_dense_body_f64_simd( ArithmeticOperator::Divide => lhs[j] / rhs[j], ArithmeticOperator::Remainder => lhs[j] % rhs[j], ArithmeticOperator::Power => (rhs[j] * lhs[j].ln()).exp(), + ArithmeticOperator::FloorDiv => (lhs[j] / rhs[j]).floor(), }; } } diff --git a/src/kernels/arithmetic/std.rs b/src/kernels/arithmetic/std.rs index 4bd72ca..a6abea7 100644 --- a/src/kernels/arithmetic/std.rs +++ b/src/kernels/arithmetic/std.rs @@ -54,6 +54,16 @@ pub fn int_dense_body_std lhs[i].pow(rhs[i].to_u32().unwrap_or(0)), + ArithmeticOperator::FloorDiv => { + if rhs[i] == T::zero() { + panic!("Floor division by zero") + } else { + let d = lhs[i] / rhs[i]; + let r = lhs[i] % rhs[i]; + // If remainder is non-zero and signs differ, floor toward -inf + if r != T::zero() && (lhs[i] ^ rhs[i]) < T::zero() { d - T::one() } else { d } + } + } }; } } @@ -93,6 +103,15 @@ pub fn int_masked_body_std (lhs[i].pow(rhs[i].to_u32().unwrap_or(0)), true), + ArithmeticOperator::FloorDiv => { + if rhs[i] == T::zero() { + (T::zero(), false) + } else { + let d = lhs[i] / rhs[i]; + let r = lhs[i] % rhs[i]; + if r != T::zero() && (lhs[i] ^ rhs[i]) < T::zero() { (d - T::one(), true) } else { (d, true) } + } + } }; out[i] = result; unsafe { @@ -121,6 +140,7 @@ pub fn float_dense_body_std(op: ArithmeticOperator, lhs: &[T], rhs: &[ ArithmeticOperator::Divide => lhs[i] / rhs[i], ArithmeticOperator::Remainder => lhs[i] % rhs[i], ArithmeticOperator::Power => (rhs[i] * lhs[i].ln()).exp(), + ArithmeticOperator::FloorDiv => (lhs[i] / rhs[i]).floor(), }; } } @@ -148,6 +168,7 @@ pub fn float_masked_body_std( ArithmeticOperator::Divide => lhs[i] / rhs[i], ArithmeticOperator::Remainder => lhs[i] % rhs[i], ArithmeticOperator::Power => (rhs[i] * lhs[i].ln()).exp(), + ArithmeticOperator::FloorDiv => (lhs[i] / rhs[i]).floor(), }; unsafe { out_mask.set_unchecked(i, true); diff --git a/src/kernels/arithmetic/string.rs b/src/kernels/arithmetic/string.rs index c3b6988..1acf7bf 100644 --- a/src/kernels/arithmetic/string.rs +++ b/src/kernels/arithmetic/string.rs @@ -40,9 +40,6 @@ use crate::{Bitmask, Vec64}; use num_traits::ToPrimitive; use crate::enums::operators::ArithmeticOperator::{self}; -#[cfg(feature = "str_arithmetic")] -use crate::kernels::string::string_predicate_masks; - #[cfg(feature = "str_arithmetic")] use crate::utils::{ confirm_mask_capacity, estimate_categorical_cardinality, estimate_string_cardinality, @@ -673,8 +670,10 @@ where let lmask_ref = lmask_slice.as_ref(); let rmask_ref = rmask_slice.as_ref(); - // build per‐position validity - let (lmask, rmask, mut out_mask) = string_predicate_masks(lmask_ref, rmask_ref, llen); + // build per-position validity + let lmask = lmask_ref; + let rmask = rmask_ref; + let mut out_mask = Bitmask::new_set_all(llen, false); let _ = confirm_mask_capacity(llen, lmask)?; let _ = confirm_mask_capacity(llen, rmask)?; diff --git a/src/kernels/string.rs b/src/kernels/string.rs index 961123a..f1ff3cd 100644 --- a/src/kernels/string.rs +++ b/src/kernels/string.rs @@ -32,17 +32,15 @@ use crate::enums::error::KernelError; use crate::utils::confirm_mask_capacity; use std::marker::PhantomData; -/// Helper for predicate kernels: produce optional input masks and a fresh output mask -#[inline(always)] -pub fn string_predicate_masks<'a>( - lhs_mask: Option<&'a Bitmask>, - rhs_mask: Option<&'a Bitmask>, - len: usize, -) -> (Option<&'a Bitmask>, Option<&'a Bitmask>, Bitmask) { - let out = Bitmask::new_set_all(len, false); - (lhs_mask, rhs_mask, out) +/// Side for string padding operations. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PadSide { + Left, + Right, + Both, } + // Concatenation /// Concatenates corresponding string pairs from two string arrays element-wise. @@ -71,8 +69,9 @@ pub fn concat_str_str(lhs: StringAVT, rhs: StringAVT) -> Strin let (rarr, roff, rlen) = rhs; let len = llen.min(rlen); - let (lmask, rmask, mut out_mask) = - string_predicate_masks(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), len); + let lmask = larr.null_mask.as_ref(); + let rmask = rarr.null_mask.as_ref(); + let mut out_mask = Bitmask::new_set_all(len, false); let _ = confirm_mask_capacity(larr.len(), lmask); let _ = confirm_mask_capacity(rarr.len(), rmask); @@ -158,8 +157,9 @@ pub fn concat_dict_dict( let (rarr, roff, rlen) = rhs; let len = llen.min(rlen); - let (lmask, rmask, mut out_mask) = - string_predicate_masks(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), len); + let lmask = larr.null_mask.as_ref(); + let rmask = rarr.null_mask.as_ref(); + let mut out_mask = Bitmask::new_set_all(len, false); let _ = confirm_mask_capacity(larr.data.len(), lmask)?; let _ = confirm_mask_capacity(rarr.data.len(), rmask)?; @@ -245,8 +245,9 @@ pub fn concat_str_dict( let (rarr, roff, rlen) = rhs; let len = llen.min(rlen); - let (lmask, rmask, mut out_mask) = - string_predicate_masks(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), len); + let lmask = larr.null_mask.as_ref(); + let rmask = rarr.null_mask.as_ref(); + let mut out_mask = Bitmask::new_set_all(len, false); let _ = confirm_mask_capacity(larr.len(), lmask)?; let _ = confirm_mask_capacity(rarr.data.len(), rmask)?; @@ -453,8 +454,9 @@ macro_rules! str_cat_predicate { let (rarr, roff, rlen) = rhs; let len = llen.min(rlen); - let (lmask, rmask, mut out_mask) = - string_predicate_masks(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), len); + let lmask = larr.null_mask.as_ref(); + let rmask = rarr.null_mask.as_ref(); + let mut out_mask = Bitmask::new_set_all(len, false); let data = binary_str_pred_loop!( len, @@ -504,8 +506,9 @@ macro_rules! cat_cat_predicate { let (rarr, roff, rlen) = rhs; let len = llen.min(rlen); - let (lmask, rmask, mut out_mask) = - string_predicate_masks(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), len); + let lmask = larr.null_mask.as_ref(); + let rmask = rarr.null_mask.as_ref(); + let mut out_mask = Bitmask::new_set_all(len, false); let data = binary_str_pred_loop!( len, @@ -556,8 +559,9 @@ macro_rules! dict_str_predicate { let (rarr, roff, rlen) = rhs; let len = llen.min(rlen); - let (lmask, rmask, mut out_mask) = - string_predicate_masks(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), len); + let lmask = larr.null_mask.as_ref(); + let rmask = rarr.null_mask.as_ref(); + let mut out_mask = Bitmask::new_set_all(len, false); let _ = confirm_mask_capacity(larr.data.len(), lmask)?; let _ = confirm_mask_capacity(rarr.len(), rmask)?; @@ -666,8 +670,9 @@ pub fn regex_str_str<'a, T: Integer, U: Integer>( let (larr, loff, llen) = lhs; let (rarr, roff, rlen) = rhs; let len = llen.min(rlen); - let (lmask, rmask, mut out_mask) = - string_predicate_masks(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), len); + let lmask = larr.null_mask.as_ref(); + let rmask = rarr.null_mask.as_ref(); + let mut out_mask = Bitmask::new_set_all(len, false); let data = regex_match_loop!(len, lmask, rmask, out_mask, larr, loff, rarr, roff); Ok(BooleanArray { @@ -707,8 +712,9 @@ pub fn regex_dict_str<'a, U: Integer, T: Integer>( let (larr, loff, llen) = lhs; let (rarr, roff, rlen) = rhs; let len = llen.min(rlen); - let (lmask, rmask, mut out_mask) = - string_predicate_masks(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), len); + let lmask = larr.null_mask.as_ref(); + let rmask = rarr.null_mask.as_ref(); + let mut out_mask = Bitmask::new_set_all(len, false); let data = regex_match_loop!(len, lmask, rmask, out_mask, larr, loff, rarr, roff); Ok(BooleanArray { @@ -748,8 +754,9 @@ pub fn regex_str_dict<'a, T: Integer, U: Integer>( let (larr, loff, llen) = lhs; let (rarr, roff, rlen) = rhs; let len = llen.min(rlen); - let (lmask, rmask, mut out_mask) = - string_predicate_masks(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), len); + let lmask = larr.null_mask.as_ref(); + let rmask = rarr.null_mask.as_ref(); + let mut out_mask = Bitmask::new_set_all(len, false); let data = regex_match_loop!(len, lmask, rmask, out_mask, larr, loff, rarr, roff); Ok(BooleanArray { @@ -792,8 +799,9 @@ pub fn regex_dict_dict<'a, T: Integer>( let (larr, loff, llen) = lhs; let (rarr, roff, rlen) = rhs; let len = llen.min(rlen); - let (lmask, rmask, mut out_mask) = - string_predicate_masks(larr.null_mask.as_ref(), rarr.null_mask.as_ref(), len); + let lmask = larr.null_mask.as_ref(); + let rmask = rarr.null_mask.as_ref(); + let mut out_mask = Bitmask::new_set_all(len, false); let data = regex_match_loop!(len, lmask, rmask, out_mask, larr, loff, rarr, roff); Ok(BooleanArray { @@ -1065,205 +1073,1139 @@ pub fn count_distinct_string(window: StringAVT) -> usize { set.len() } -#[cfg(test)] -mod tests { - use crate::{CategoricalArray, StringArray, vec64}; +// Unary string transformations - use super::*; +/// Generates a unary string transform kernel for StringArray. +/// The transform closure receives `&str` and returns `String`. +macro_rules! unary_str_transform { + ($fn_name:ident, $doc:expr, $transform:expr) => { + #[doc = $doc] + pub fn $fn_name(input: StringAVT) -> Result, KernelError> { + let (arr, offset, len) = input; - // --- Helper constructors + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); - fn str_array(vals: &[&str]) -> StringArray { - StringArray::::from_slice(vals) - } + let input_bytes = arr.offsets[offset + len].to_usize() + - arr.offsets[offset].to_usize(); + let mut offsets = Vec64::::with_capacity(len + 1); + unsafe { offsets.set_len(len + 1); } + let mut data = Vec64::::with_capacity(input_bytes); - fn dict_array(vals: &[&str]) -> CategoricalArray { - let owned: Vec<&str> = vals.to_vec(); - CategoricalArray::::from_values(owned) - } + offsets[0] = T::zero(); + let mut cur = 0usize; - fn bm(bools: &[bool]) -> Bitmask { - Bitmask::from_bools(bools) - } + for i in 0..len { + let valid = mask_opt.as_ref() + .map_or(true, |m| unsafe { m.get_unchecked(i) }); + + if valid { + let s = unsafe { arr.get_str_unchecked(offset + i) }; + let t: String = ($transform)(s); + data.extend_from_slice(t.as_bytes()); + cur += t.len(); + } + offsets[i + 1] = T::from_usize(cur); + } - // --- Concat + Ok(StringArray { + offsets: offsets.into(), + data: data.into(), + null_mask: mask_opt, + }) + } + }; +} - #[test] - fn test_concat_str_str() { - let a = str_array::(&["foo", "bar", ""]); - let b = str_array::(&["baz", "qux", "quux"]); - let out = concat_str_str((&a, 0, a.len()), (&b, 0, b.len())); - assert_eq!(out.get(0), Some("foobaz")); - assert_eq!(out.get(1), Some("barqux")); - assert_eq!(out.get(2), Some("quux")); - assert!(out.null_mask.as_ref().unwrap().all_set()); - } +/// Generates a unary string transform kernel for CategoricalArray. +/// Transforms dictionary values once and remaps indices. +macro_rules! unary_dict_transform { + ($fn_name:ident, $doc:expr, $transform:expr) => { + #[doc = $doc] + pub fn $fn_name(input: CategoricalAVT) -> Result, KernelError> { + let (arr, offset, len) = input; + + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); + + // Transform each dictionary value once, build old->new index mapping + #[cfg(feature = "fast_hash")] + let mut seen: AHashMap = AHashMap::with_capacity(arr.unique_values.len()); + #[cfg(not(feature = "fast_hash"))] + let mut seen: std::collections::HashMap = + std::collections::HashMap::with_capacity(arr.unique_values.len()); + + let mut new_unique = Vec64::::new(); + let mut idx_map = Vec64::::with_capacity(arr.unique_values.len()); + + for old_val in arr.unique_values.iter() { + let t: String = ($transform)(old_val.as_str()); + let new_idx = match seen.get(&t) { + Some(&ix) => ix, + None => { + let ix = T::from_usize(new_unique.len()); + new_unique.push(t.clone()); + seen.insert(t, ix); + ix + } + }; + idx_map.push(new_idx); + } - #[test] - fn test_concat_str_str_chunk() { - let a = str_array::(&["XXX", "foo", "bar", ""]); - let b = str_array::(&["YYY", "baz", "qux", "quux"]); - // Window is [1..4) for both, i.e., ["foo", "bar", ""] - let out = concat_str_str((&a, 1, 3), (&b, 1, 3)); - assert_eq!(out.get(0), Some("foobaz")); - assert_eq!(out.get(1), Some("barqux")); - assert_eq!(out.get(2), Some("quux")); - assert!(out.null_mask.as_ref().unwrap().all_set()); - } + // Remap indices for the window + let mut data = Vec64::::with_capacity(len); + unsafe { data.set_len(len); } + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + data[i] = if valid { + idx_map[arr.data[offset + i].to_usize()] + } else { + T::zero() + }; + } - #[test] - fn test_concat_dict_dict() { - let a = dict_array::(&["x", "y"]); - let b = dict_array::(&["1", "2"]); - let out = concat_dict_dict((&a, 0, a.len()), (&b, 0, b.len())).unwrap(); - let s0 = out.get(0).unwrap(); - let s1 = out.get(1).unwrap(); - assert!(["x1", "y2"].contains(&s0)); - assert!(["x1", "y2"].contains(&s1)); - assert!(out.null_mask.as_ref().unwrap().all_set()); - } + Ok(CategoricalArray { + data: data.into(), + unique_values: new_unique, + null_mask: mask_opt, + }) + } + }; +} - #[test] - fn test_concat_dict_dict_chunk() { - let a = dict_array::(&["foo", "x", "y", "bar"]); - let b = dict_array::(&["A", "1", "2", "B"]); - let out = concat_dict_dict((&a, 1, 2), (&b, 1, 2)).unwrap(); - let s0 = out.get(0).unwrap(); - let s1 = out.get(1).unwrap(); - assert!(["x1", "y2"].contains(&s0)); - assert!(["x1", "y2"].contains(&s1)); - assert!(out.null_mask.as_ref().unwrap().all_set()); - } +unary_str_transform!(to_uppercase_str, + "Converts each string element to uppercase.", + |s: &str| s.to_uppercase() +); +unary_dict_transform!(to_uppercase_dict, + "Converts each categorical string element to uppercase.", + |s: &str| s.to_uppercase() +); + +unary_str_transform!(to_lowercase_str, + "Converts each string element to lowercase.", + |s: &str| s.to_lowercase() +); +unary_dict_transform!(to_lowercase_dict, + "Converts each categorical string element to lowercase.", + |s: &str| s.to_lowercase() +); + +unary_str_transform!(trim_str, + "Trims leading and trailing whitespace from each string element.", + |s: &str| s.trim().to_owned() +); +unary_dict_transform!(trim_dict, + "Trims leading and trailing whitespace from each categorical string element.", + |s: &str| s.trim().to_owned() +); + +unary_str_transform!(ltrim_str, + "Trims leading whitespace from each string element.", + |s: &str| s.trim_start().to_owned() +); +unary_dict_transform!(ltrim_dict, + "Trims leading whitespace from each categorical string element.", + |s: &str| s.trim_start().to_owned() +); + +unary_str_transform!(rtrim_str, + "Trims trailing whitespace from each string element.", + |s: &str| s.trim_end().to_owned() +); +unary_dict_transform!(rtrim_dict, + "Trims trailing whitespace from each categorical string element.", + |s: &str| s.trim_end().to_owned() +); + +unary_str_transform!(reverse_str, + "Reverses each string element by Unicode characters.", + |s: &str| s.chars().rev().collect::() +); +unary_dict_transform!(reverse_dict, + "Reverses each categorical string element by Unicode characters.", + |s: &str| s.chars().rev().collect::() +); + +// Byte length + +/// Computes the byte length of each string in a StringArray slice. +pub fn byte_length_str( + input: StringAVT, +) -> Result, KernelError> { + let (array, offset, len) = input; - #[test] - fn test_concat_str_dict() { - let a = str_array::(&["ab", "cd", ""]); - let b = dict_array::(&["xy", "zq", ""]); - let out = concat_str_dict((&a, 0, a.len()), (&b, 0, b.len())).unwrap(); - assert_eq!(out.get(0), Some("abxy")); - assert_eq!(out.get(1), Some("cdzq")); - assert_eq!(out.get(2), Some("")); - assert!(out.null_mask.as_ref().unwrap().all_set()); - } + let mask_opt = array.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); - #[test] - fn test_concat_str_dict_chunk() { - let a = str_array::(&["dummy", "ab", "cd", ""]); - let b = dict_array::(&["dummy", "xy", "zq", ""]); - let out = concat_str_dict((&a, 1, 3), (&b, 1, 3)).unwrap(); - assert_eq!(out.get(0), Some("abxy")); - assert_eq!(out.get(1), Some("cdzq")); - assert_eq!(out.get(2), Some("")); - assert!(out.null_mask.as_ref().unwrap().all_set()); + let mut data = Vec64::::with_capacity(len); + unsafe { data.set_len(len); } + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + if valid { + let start = array.offsets[offset + i].to_usize(); + let end = array.offsets[offset + i + 1].to_usize(); + data[i] = T::from_usize(end - start); + } else { + data[i] = T::zero(); + } } - #[test] - fn test_concat_dict_str() { - let a = dict_array::(&["hi", "ho"]); - let b = str_array::(&["yo", "no"]); - let out = concat_dict_str((&a, 0, a.len()), (&b, 0, b.len())).unwrap(); - assert_eq!(out.get(0), Some("yohi")); - assert_eq!(out.get(1), Some("noho")); - assert!(out.null_mask.as_ref().unwrap().all_set()); - } + Ok(IntegerArray { + data: data.into(), + null_mask: mask_opt, + }) +} - #[test] - fn test_concat_dict_str_chunk() { - let a = dict_array::(&["dummy", "hi", "ho", "zzz"]); - let b = str_array::(&["dummy", "yo", "no", "xxx"]); - let out = concat_dict_str((&a, 1, 2), (&b, 1, 2)).unwrap(); - assert_eq!(out.get(0), Some("yohi")); - assert_eq!(out.get(1), Some("noho")); - assert!(out.null_mask.as_ref().unwrap().all_set()); - } +/// Computes the byte length of each string in a CategoricalArray slice. +pub fn byte_length_dict( + input: CategoricalAVT, +) -> Result, KernelError> { + let (array, offset, len) = input; - // --- String predicates + let mask_opt = array.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); - #[test] - fn test_contains_str_str() { - let s = str_array::(&["abc", "def", "ghijk"]); - let p = str_array::(&["b", "x", "jk"]); - let out = contains_str_str((&s, 0, s.len()), (&p, 0, p.len())); - assert_eq!(out.get(0), Some(true)); - assert_eq!(out.get(1), Some(false)); - assert_eq!(out.get(2), Some(true)); + let mut data = Vec64::::with_capacity(len); + unsafe { data.set_len(len); } + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + data[i] = if valid { + T::from_usize(unsafe { array.get_str_unchecked(offset + i) }.len()) + } else { + T::zero() + }; } - #[test] - fn test_contains_str_str_chunk() { - let s = str_array::(&["dummy", "abc", "def", "ghijk"]); - let p = str_array::(&["dummy", "b", "x", "jk"]); - let out = contains_str_str((&s, 1, 3), (&p, 1, 3)); - assert_eq!(out.get(0), Some(true)); - assert_eq!(out.get(1), Some(false)); - assert_eq!(out.get(2), Some(true)); - } + Ok(IntegerArray { + data: data.into(), + null_mask: mask_opt, + }) +} - #[test] - fn test_starts_with_str_str() { - let s = str_array::(&["apricot", "banana", "apple"]); - let p = str_array::(&["ap", "ba", "a"]); - let out = starts_with_str_str((&s, 0, s.len()), (&p, 0, p.len())); - assert_eq!(out.get(0), Some(true)); - assert_eq!(out.get(1), Some(true)); - assert_eq!(out.get(2), Some(true)); - } +// Find and count - #[test] - fn test_starts_with_str_str_chunk() { - let s = str_array::(&["dummy", "apricot", "banana", "apple"]); - let p = str_array::(&["dummy", "ap", "ba", "a"]); - let out = starts_with_str_str((&s, 1, 3), (&p, 1, 3)); - assert_eq!(out.get(0), Some(true)); - assert_eq!(out.get(1), Some(true)); - assert_eq!(out.get(2), Some(true)); - } +/// Finds the first byte index of `needle` in each string element. Returns -1 if not found. +pub fn find_str( + input: StringAVT, + needle: &str, +) -> Result, KernelError> { + let (arr, offset, len) = input; - #[test] - fn test_ends_with_str_str() { - let s = str_array::(&["robot", "fast", "last"]); - let p = str_array::(&["ot", "st", "ast"]); - let out = ends_with_str_str((&s, 0, s.len()), (&p, 0, p.len())); - assert_eq!(out.get(0), Some(true)); - assert_eq!(out.get(1), Some(true)); - assert_eq!(out.get(2), Some(true)); - } + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); - #[test] - fn test_ends_with_str_str_chunk() { - let s = str_array::(&["dummy", "robot", "fast", "last"]); - let p = str_array::(&["dummy", "ot", "st", "ast"]); - let out = ends_with_str_str((&s, 1, 3), (&p, 1, 3)); - assert_eq!(out.get(0), Some(true)); - assert_eq!(out.get(1), Some(true)); - assert_eq!(out.get(2), Some(true)); + let mut data = Vec64::::with_capacity(len); + unsafe { data.set_len(len); } + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + data[i] = if valid { + let s = unsafe { arr.get_str_unchecked(offset + i) }; + s.find(needle).map_or(-1, |pos| pos as i32) + } else { + 0 + }; } - #[test] - fn test_contains_str_dict() { - let s = str_array::(&["abcde", "xyz", "qrstuv"]); - let p = dict_array::(&["c", "z", "tu"]); - let out = contains_str_dict((&s, 0, s.len()), (&p, 0, p.len())).unwrap(); - assert_eq!(out.get(0), Some(true)); - assert_eq!(out.get(1), Some(true)); - assert_eq!(out.get(2), Some(true)); - } + Ok(IntegerArray { + data: data.into(), + null_mask: mask_opt, + }) +} - #[test] - fn test_contains_str_dict_chunk() { - let s = str_array::(&["dummy", "abcde", "xyz", "qrstuv"]); - let p = dict_array::(&["dummy", "c", "z", "tu"]); - let out = contains_str_dict((&s, 1, 3), (&p, 1, 3)).unwrap(); - assert_eq!(out.get(0), Some(true)); - assert_eq!(out.get(1), Some(true)); - assert_eq!(out.get(2), Some(true)); - } +/// Finds the first byte index of `needle` in each categorical string element. Returns -1 if not found. +pub fn find_dict( + input: CategoricalAVT, + needle: &str, +) -> Result, KernelError> { + let (arr, offset, len) = input; - #[test] - fn test_contains_dict_dict() { - let s = dict_array::(&["cdef", "foo", "bar"]); - let p = dict_array::(&["cd", "oo", "baz"]); - let out = contains_dict_dict((&s, 0, s.len()), (&p, 0, p.len())).unwrap(); + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); + + let mut data = Vec64::::with_capacity(len); + unsafe { data.set_len(len); } + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + data[i] = if valid { + let s = unsafe { arr.get_str_unchecked(offset + i) }; + s.find(needle).map_or(-1, |pos| pos as i32) + } else { + 0 + }; + } + + Ok(IntegerArray { + data: data.into(), + null_mask: mask_opt, + }) +} + +/// Counts non-overlapping occurrences of `needle` in each string element. +pub fn count_match_str( + input: StringAVT, + needle: &str, +) -> Result, KernelError> { + let (arr, offset, len) = input; + + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); + + let mut data = Vec64::::with_capacity(len); + unsafe { data.set_len(len); } + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + data[i] = if valid { + let s = unsafe { arr.get_str_unchecked(offset + i) }; + s.matches(needle).count() as i32 + } else { + 0 + }; + } + + Ok(IntegerArray { + data: data.into(), + null_mask: mask_opt, + }) +} + +/// Counts non-overlapping occurrences of `needle` in each categorical string element. +pub fn count_match_dict( + input: CategoricalAVT, + needle: &str, +) -> Result, KernelError> { + let (arr, offset, len) = input; + + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); + + let mut data = Vec64::::with_capacity(len); + unsafe { data.set_len(len); } + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + data[i] = if valid { + let s = unsafe { arr.get_str_unchecked(offset + i) }; + s.matches(needle).count() as i32 + } else { + 0 + }; + } + + Ok(IntegerArray { + data: data.into(), + null_mask: mask_opt, + }) +} + +// Substring + +/// Extracts a character-based substring from each string element. +/// `start` is the 0-based character offset. `opt_len` limits the number of characters taken. +pub fn substring_str( + input: StringAVT, + start: usize, + opt_len: Option, +) -> Result, KernelError> { + let (arr, offset, len) = input; + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); + + let input_bytes = arr.offsets[offset + len].to_usize() + - arr.offsets[offset].to_usize(); + let mut offsets = Vec64::::with_capacity(len + 1); + unsafe { offsets.set_len(len + 1); } + let mut data = Vec64::::with_capacity(input_bytes); + + offsets[0] = T::zero(); + let mut cur = 0usize; + + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + if valid { + let s = unsafe { arr.get_str_unchecked(offset + i) }; + let chars = s.chars().skip(start); + let t: String = match opt_len { + Some(n) => chars.take(n).collect(), + None => chars.collect(), + }; + data.extend_from_slice(t.as_bytes()); + cur += t.len(); + } + offsets[i + 1] = T::from_usize(cur); + } + + Ok(StringArray { + offsets: offsets.into(), + data: data.into(), + null_mask: mask_opt, + }) +} + +/// Extracts a character-based substring from each categorical string element. +pub fn substring_dict( + input: CategoricalAVT, + start: usize, + opt_len: Option, +) -> Result, KernelError> { + let (arr, offset, len) = input; + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); + + #[cfg(feature = "fast_hash")] + let mut seen: AHashMap = AHashMap::with_capacity(arr.unique_values.len()); + #[cfg(not(feature = "fast_hash"))] + let mut seen: std::collections::HashMap = + std::collections::HashMap::with_capacity(arr.unique_values.len()); + + let mut new_unique = Vec64::::new(); + let mut idx_map = Vec64::::with_capacity(arr.unique_values.len()); + + for old_val in arr.unique_values.iter() { + let chars = old_val.chars().skip(start); + let t: String = match opt_len { + Some(n) => chars.take(n).collect(), + None => chars.collect(), + }; + let new_idx = match seen.get(&t) { + Some(&ix) => ix, + None => { + let ix = T::from_usize(new_unique.len()); + new_unique.push(t.clone()); + seen.insert(t, ix); + ix + } + }; + idx_map.push(new_idx); + } + + let mut data = Vec64::::with_capacity(len); + unsafe { data.set_len(len); } + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + data[i] = if valid { + idx_map[arr.data[offset + i].to_usize()] + } else { + T::zero() + }; + } + + Ok(CategoricalArray { + data: data.into(), + unique_values: new_unique, + null_mask: mask_opt, + }) +} + +// Replace + +/// Replaces all occurrences of `from` with `to` in each string element. +pub fn replace_str( + input: StringAVT, + from: &str, + to: &str, +) -> Result, KernelError> { + let (arr, offset, len) = input; + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); + + let input_bytes = arr.offsets[offset + len].to_usize() + - arr.offsets[offset].to_usize(); + let mut offsets = Vec64::::with_capacity(len + 1); + unsafe { offsets.set_len(len + 1); } + let mut data = Vec64::::with_capacity(input_bytes); + + offsets[0] = T::zero(); + let mut cur = 0usize; + + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + if valid { + let s = unsafe { arr.get_str_unchecked(offset + i) }; + let t = s.replace(from, to); + data.extend_from_slice(t.as_bytes()); + cur += t.len(); + } + offsets[i + 1] = T::from_usize(cur); + } + + Ok(StringArray { + offsets: offsets.into(), + data: data.into(), + null_mask: mask_opt, + }) +} + +/// Replaces all occurrences of `from` with `to` in each categorical string element. +pub fn replace_dict( + input: CategoricalAVT, + from: &str, + to: &str, +) -> Result, KernelError> { + let (arr, offset, len) = input; + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); + + #[cfg(feature = "fast_hash")] + let mut seen: AHashMap = AHashMap::with_capacity(arr.unique_values.len()); + #[cfg(not(feature = "fast_hash"))] + let mut seen: std::collections::HashMap = + std::collections::HashMap::with_capacity(arr.unique_values.len()); + + let mut new_unique = Vec64::::new(); + let mut idx_map = Vec64::::with_capacity(arr.unique_values.len()); + + for old_val in arr.unique_values.iter() { + let t = old_val.replace(from, to); + let new_idx = match seen.get(&t) { + Some(&ix) => ix, + None => { + let ix = T::from_usize(new_unique.len()); + new_unique.push(t.clone()); + seen.insert(t, ix); + ix + } + }; + idx_map.push(new_idx); + } + + let mut data = Vec64::::with_capacity(len); + unsafe { data.set_len(len); } + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + data[i] = if valid { + idx_map[arr.data[offset + i].to_usize()] + } else { + T::zero() + }; + } + + Ok(CategoricalArray { + data: data.into(), + unique_values: new_unique, + null_mask: mask_opt, + }) +} + +// Repeat + +/// Repeats each string element `n` times. +pub fn repeat_str( + input: StringAVT, + n: usize, +) -> Result, KernelError> { + let (arr, offset, len) = input; + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); + + let input_bytes = arr.offsets[offset + len].to_usize() + - arr.offsets[offset].to_usize(); + let mut offsets = Vec64::::with_capacity(len + 1); + unsafe { offsets.set_len(len + 1); } + let mut data = Vec64::::with_capacity(input_bytes * n); + + offsets[0] = T::zero(); + let mut cur = 0usize; + + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + if valid { + let s = unsafe { arr.get_str_unchecked(offset + i) }; + let bytes = s.as_bytes(); + for _ in 0..n { + data.extend_from_slice(bytes); + } + cur += bytes.len() * n; + } + offsets[i + 1] = T::from_usize(cur); + } + + Ok(StringArray { + offsets: offsets.into(), + data: data.into(), + null_mask: mask_opt, + }) +} + +/// Repeats each categorical string element `n` times. +pub fn repeat_dict( + input: CategoricalAVT, + n: usize, +) -> Result, KernelError> { + let (arr, offset, len) = input; + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); + + // Repeat is a 1:1 mapping on dictionary values, no merging possible + let mut new_unique = Vec64::::new(); + for old_val in arr.unique_values.iter() { + new_unique.push(old_val.repeat(n)); + } + + let mut data = Vec64::::with_capacity(len); + unsafe { data.set_len(len); } + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + data[i] = if valid { + arr.data[offset + i] + } else { + T::zero() + }; + } + + Ok(CategoricalArray { + data: data.into(), + unique_values: new_unique, + null_mask: mask_opt, + }) +} + +// Pad + +/// Pads each string element to `width` characters using `fill_char`. +pub fn pad_str( + input: StringAVT, + width: usize, + fill_char: char, + side: PadSide, +) -> Result, KernelError> { + let (arr, offset, len) = input; + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); + + let max_padded = width * fill_char.len_utf8(); + let mut offsets = Vec64::::with_capacity(len + 1); + unsafe { offsets.set_len(len + 1); } + let mut data = Vec64::::with_capacity(len * max_padded); + + offsets[0] = T::zero(); + let mut cur = 0usize; + + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + if valid { + let s = unsafe { arr.get_str_unchecked(offset + i) }; + let char_len = s.chars().count(); + if char_len >= width { + data.extend_from_slice(s.as_bytes()); + cur += s.len(); + } else { + let padding = width - char_len; + let pad_bytes = fill_char.len_utf8(); + let mut pad_buf = [0u8; 4]; + let pad_slice = fill_char.encode_utf8(&mut pad_buf); + match side { + PadSide::Left => { + for _ in 0..padding { data.extend_from_slice(pad_slice.as_bytes()); } + data.extend_from_slice(s.as_bytes()); + } + PadSide::Right => { + data.extend_from_slice(s.as_bytes()); + for _ in 0..padding { data.extend_from_slice(pad_slice.as_bytes()); } + } + PadSide::Both => { + let left = padding / 2; + let right = padding - left; + for _ in 0..left { data.extend_from_slice(pad_slice.as_bytes()); } + data.extend_from_slice(s.as_bytes()); + for _ in 0..right { data.extend_from_slice(pad_slice.as_bytes()); } + } + } + cur += s.len() + padding * pad_bytes; + } + } + offsets[i + 1] = T::from_usize(cur); + } + + Ok(StringArray { + offsets: offsets.into(), + data: data.into(), + null_mask: mask_opt, + }) +} + +/// Pads each categorical string element to `width` characters using `fill_char`. +pub fn pad_dict( + input: CategoricalAVT, + width: usize, + fill_char: char, + side: PadSide, +) -> Result, KernelError> { + let (arr, offset, len) = input; + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); + + let pad_one = |s: &str| -> String { + let char_len = s.chars().count(); + if char_len >= width { + return s.to_owned(); + } + let padding = width - char_len; + match side { + PadSide::Left => { + let mut out = String::with_capacity(s.len() + padding * fill_char.len_utf8()); + for _ in 0..padding { out.push(fill_char); } + out.push_str(s); + out + } + PadSide::Right => { + let mut out = String::with_capacity(s.len() + padding * fill_char.len_utf8()); + out.push_str(s); + for _ in 0..padding { out.push(fill_char); } + out + } + PadSide::Both => { + let left = padding / 2; + let right = padding - left; + let mut out = String::with_capacity(s.len() + padding * fill_char.len_utf8()); + for _ in 0..left { out.push(fill_char); } + out.push_str(s); + for _ in 0..right { out.push(fill_char); } + out + } + } + }; + + // Pad is 1:1 on dictionary values since padding a unique string always produces a unique result + let mut new_unique = Vec64::::new(); + for old_val in arr.unique_values.iter() { + new_unique.push(pad_one(old_val)); + } + + let mut data = Vec64::::with_capacity(len); + unsafe { data.set_len(len); } + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + data[i] = if valid { + arr.data[offset + i] + } else { + T::zero() + }; + } + + Ok(CategoricalArray { + data: data.into(), + unique_values: new_unique, + null_mask: mask_opt, + }) +} + +// Join (aggregation) + +/// Joins all non-null string elements with `delimiter`, returning a single string. +/// Returns `None` if all elements are null. +pub fn join_str(input: StringAVT, delimiter: &str) -> Option { + let (arr, offset, len) = input; + let mut parts: Vec<&str> = Vec::with_capacity(len); + for i in offset..offset + len { + let valid = arr.null_mask.as_ref().map_or(true, |b| unsafe { b.get_unchecked(i) }); + if valid { + parts.push(unsafe { arr.get_str_unchecked(i) }); + } + } + if parts.is_empty() { None } else { Some(parts.join(delimiter)) } +} + +/// Joins all non-null categorical string elements with `delimiter`, returning a single string. +/// Returns `None` if all elements are null. +pub fn join_dict(input: CategoricalAVT, delimiter: &str) -> Option { + let (arr, offset, len) = input; + let mut parts: Vec<&str> = Vec::with_capacity(len); + for i in offset..offset + len { + let valid = arr.null_mask.as_ref().map_or(true, |b| unsafe { b.get_unchecked(i) }); + if valid { + parts.push(unsafe { arr.get_str_unchecked(i) }); + } + } + if parts.is_empty() { None } else { Some(parts.join(delimiter)) } +} + +// Regex replace + +/// Replaces all regex matches in each string element with `replacement`. +#[cfg(feature = "regex")] +pub fn regex_replace_str( + input: StringAVT, + pattern: &str, + replacement: &str, +) -> Result, KernelError> { + let re = Regex::new(pattern).map_err(|_| { + KernelError::InvalidArguments("Invalid regex pattern".to_string()) + })?; + let (arr, offset, len) = input; + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); + + let input_bytes = arr.offsets[offset + len].to_usize() + - arr.offsets[offset].to_usize(); + let mut offsets = Vec64::::with_capacity(len + 1); + unsafe { offsets.set_len(len + 1); } + let mut data = Vec64::::with_capacity(input_bytes); + + offsets[0] = T::zero(); + let mut cur = 0usize; + + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + if valid { + let s = unsafe { arr.get_str_unchecked(offset + i) }; + let t = re.replace_all(s, replacement).into_owned(); + data.extend_from_slice(t.as_bytes()); + cur += t.len(); + } + offsets[i + 1] = T::from_usize(cur); + } + + Ok(StringArray { + offsets: offsets.into(), + data: data.into(), + null_mask: mask_opt, + }) +} + +/// Replaces all regex matches in each categorical string element with `replacement`. +#[cfg(feature = "regex")] +pub fn regex_replace_dict( + input: CategoricalAVT, + pattern: &str, + replacement: &str, +) -> Result, KernelError> { + let re = Regex::new(pattern).map_err(|_| { + KernelError::InvalidArguments("Invalid regex pattern".to_string()) + })?; + let (arr, offset, len) = input; + let mask_opt = arr.null_mask.as_ref().map(|orig| { + let mut m = Bitmask::new_set_all(len, true); + for i in 0..len { + unsafe { m.set_unchecked(i, orig.get_unchecked(offset + i)); } + } + m + }); + + #[cfg(feature = "fast_hash")] + let mut seen: AHashMap = AHashMap::with_capacity(arr.unique_values.len()); + #[cfg(not(feature = "fast_hash"))] + let mut seen: std::collections::HashMap = + std::collections::HashMap::with_capacity(arr.unique_values.len()); + + let mut new_unique = Vec64::::new(); + let mut idx_map = Vec64::::with_capacity(arr.unique_values.len()); + + for old_val in arr.unique_values.iter() { + let t = re.replace_all(old_val, replacement).into_owned(); + let new_idx = match seen.get(&t) { + Some(&ix) => ix, + None => { + let ix = T::from_usize(new_unique.len()); + new_unique.push(t.clone()); + seen.insert(t, ix); + ix + } + }; + idx_map.push(new_idx); + } + + let mut data = Vec64::::with_capacity(len); + unsafe { data.set_len(len); } + for i in 0..len { + let valid = mask_opt.as_ref().map_or(true, |m| unsafe { m.get_unchecked(i) }); + data[i] = if valid { + idx_map[arr.data[offset + i].to_usize()] + } else { + T::zero() + }; + } + + Ok(CategoricalArray { + data: data.into(), + unique_values: new_unique, + null_mask: mask_opt, + }) +} + +#[cfg(test)] +mod tests { + use crate::{CategoricalArray, StringArray, vec64}; + + use super::*; + + // --- Helper constructors + + fn str_array(vals: &[&str]) -> StringArray { + StringArray::::from_slice(vals) + } + + fn dict_array(vals: &[&str]) -> CategoricalArray { + let owned: Vec<&str> = vals.to_vec(); + CategoricalArray::::from_values(owned) + } + + fn bm(bools: &[bool]) -> Bitmask { + Bitmask::from_bools(bools) + } + + // --- Concat + + #[test] + fn test_concat_str_str() { + let a = str_array::(&["foo", "bar", ""]); + let b = str_array::(&["baz", "qux", "quux"]); + let out = concat_str_str((&a, 0, a.len()), (&b, 0, b.len())); + assert_eq!(out.get(0), Some("foobaz")); + assert_eq!(out.get(1), Some("barqux")); + assert_eq!(out.get(2), Some("quux")); + assert!(out.null_mask.as_ref().unwrap().all_set()); + } + + #[test] + fn test_concat_str_str_chunk() { + let a = str_array::(&["XXX", "foo", "bar", ""]); + let b = str_array::(&["YYY", "baz", "qux", "quux"]); + // Window is [1..4) for both, i.e., ["foo", "bar", ""] + let out = concat_str_str((&a, 1, 3), (&b, 1, 3)); + assert_eq!(out.get(0), Some("foobaz")); + assert_eq!(out.get(1), Some("barqux")); + assert_eq!(out.get(2), Some("quux")); + assert!(out.null_mask.as_ref().unwrap().all_set()); + } + + #[test] + fn test_concat_dict_dict() { + let a = dict_array::(&["x", "y"]); + let b = dict_array::(&["1", "2"]); + let out = concat_dict_dict((&a, 0, a.len()), (&b, 0, b.len())).unwrap(); + let s0 = out.get(0).unwrap(); + let s1 = out.get(1).unwrap(); + assert!(["x1", "y2"].contains(&s0)); + assert!(["x1", "y2"].contains(&s1)); + assert!(out.null_mask.as_ref().unwrap().all_set()); + } + + #[test] + fn test_concat_dict_dict_chunk() { + let a = dict_array::(&["foo", "x", "y", "bar"]); + let b = dict_array::(&["A", "1", "2", "B"]); + let out = concat_dict_dict((&a, 1, 2), (&b, 1, 2)).unwrap(); + let s0 = out.get(0).unwrap(); + let s1 = out.get(1).unwrap(); + assert!(["x1", "y2"].contains(&s0)); + assert!(["x1", "y2"].contains(&s1)); + assert!(out.null_mask.as_ref().unwrap().all_set()); + } + + #[test] + fn test_concat_str_dict() { + let a = str_array::(&["ab", "cd", ""]); + let b = dict_array::(&["xy", "zq", ""]); + let out = concat_str_dict((&a, 0, a.len()), (&b, 0, b.len())).unwrap(); + assert_eq!(out.get(0), Some("abxy")); + assert_eq!(out.get(1), Some("cdzq")); + assert_eq!(out.get(2), Some("")); + assert!(out.null_mask.as_ref().unwrap().all_set()); + } + + #[test] + fn test_concat_str_dict_chunk() { + let a = str_array::(&["dummy", "ab", "cd", ""]); + let b = dict_array::(&["dummy", "xy", "zq", ""]); + let out = concat_str_dict((&a, 1, 3), (&b, 1, 3)).unwrap(); + assert_eq!(out.get(0), Some("abxy")); + assert_eq!(out.get(1), Some("cdzq")); + assert_eq!(out.get(2), Some("")); + assert!(out.null_mask.as_ref().unwrap().all_set()); + } + + #[test] + fn test_concat_dict_str() { + let a = dict_array::(&["hi", "ho"]); + let b = str_array::(&["yo", "no"]); + let out = concat_dict_str((&a, 0, a.len()), (&b, 0, b.len())).unwrap(); + assert_eq!(out.get(0), Some("yohi")); + assert_eq!(out.get(1), Some("noho")); + assert!(out.null_mask.as_ref().unwrap().all_set()); + } + + #[test] + fn test_concat_dict_str_chunk() { + let a = dict_array::(&["dummy", "hi", "ho", "zzz"]); + let b = str_array::(&["dummy", "yo", "no", "xxx"]); + let out = concat_dict_str((&a, 1, 2), (&b, 1, 2)).unwrap(); + assert_eq!(out.get(0), Some("yohi")); + assert_eq!(out.get(1), Some("noho")); + assert!(out.null_mask.as_ref().unwrap().all_set()); + } + + // --- String predicates + + #[test] + fn test_contains_str_str() { + let s = str_array::(&["abc", "def", "ghijk"]); + let p = str_array::(&["b", "x", "jk"]); + let out = contains_str_str((&s, 0, s.len()), (&p, 0, p.len())); + assert_eq!(out.get(0), Some(true)); + assert_eq!(out.get(1), Some(false)); + assert_eq!(out.get(2), Some(true)); + } + + #[test] + fn test_contains_str_str_chunk() { + let s = str_array::(&["dummy", "abc", "def", "ghijk"]); + let p = str_array::(&["dummy", "b", "x", "jk"]); + let out = contains_str_str((&s, 1, 3), (&p, 1, 3)); + assert_eq!(out.get(0), Some(true)); + assert_eq!(out.get(1), Some(false)); + assert_eq!(out.get(2), Some(true)); + } + + #[test] + fn test_starts_with_str_str() { + let s = str_array::(&["apricot", "banana", "apple"]); + let p = str_array::(&["ap", "ba", "a"]); + let out = starts_with_str_str((&s, 0, s.len()), (&p, 0, p.len())); + assert_eq!(out.get(0), Some(true)); + assert_eq!(out.get(1), Some(true)); + assert_eq!(out.get(2), Some(true)); + } + + #[test] + fn test_starts_with_str_str_chunk() { + let s = str_array::(&["dummy", "apricot", "banana", "apple"]); + let p = str_array::(&["dummy", "ap", "ba", "a"]); + let out = starts_with_str_str((&s, 1, 3), (&p, 1, 3)); + assert_eq!(out.get(0), Some(true)); + assert_eq!(out.get(1), Some(true)); + assert_eq!(out.get(2), Some(true)); + } + + #[test] + fn test_ends_with_str_str() { + let s = str_array::(&["robot", "fast", "last"]); + let p = str_array::(&["ot", "st", "ast"]); + let out = ends_with_str_str((&s, 0, s.len()), (&p, 0, p.len())); + assert_eq!(out.get(0), Some(true)); + assert_eq!(out.get(1), Some(true)); + assert_eq!(out.get(2), Some(true)); + } + + #[test] + fn test_ends_with_str_str_chunk() { + let s = str_array::(&["dummy", "robot", "fast", "last"]); + let p = str_array::(&["dummy", "ot", "st", "ast"]); + let out = ends_with_str_str((&s, 1, 3), (&p, 1, 3)); + assert_eq!(out.get(0), Some(true)); + assert_eq!(out.get(1), Some(true)); + assert_eq!(out.get(2), Some(true)); + } + + #[test] + fn test_contains_str_dict() { + let s = str_array::(&["abcde", "xyz", "qrstuv"]); + let p = dict_array::(&["c", "z", "tu"]); + let out = contains_str_dict((&s, 0, s.len()), (&p, 0, p.len())).unwrap(); + assert_eq!(out.get(0), Some(true)); + assert_eq!(out.get(1), Some(true)); + assert_eq!(out.get(2), Some(true)); + } + + #[test] + fn test_contains_str_dict_chunk() { + let s = str_array::(&["dummy", "abcde", "xyz", "qrstuv"]); + let p = dict_array::(&["dummy", "c", "z", "tu"]); + let out = contains_str_dict((&s, 1, 3), (&p, 1, 3)).unwrap(); + assert_eq!(out.get(0), Some(true)); + assert_eq!(out.get(1), Some(true)); + assert_eq!(out.get(2), Some(true)); + } + + #[test] + fn test_contains_dict_dict() { + let s = dict_array::(&["cdef", "foo", "bar"]); + let p = dict_array::(&["cd", "oo", "baz"]); + let out = contains_dict_dict((&s, 0, s.len()), (&p, 0, p.len())).unwrap(); assert_eq!(out.get(0), Some(true)); assert_eq!(out.get(1), Some(true)); assert_eq!(out.get(2), Some(false)); @@ -1685,4 +2627,293 @@ mod tests { let result = max_categorical_array((&cat, 0, indices.len())); assert_eq!(result, Some("zebra".to_string())); // Only positions 0 and 2 valid: "zebra", "dog" -> "zebra" is larger } + + // --- Unary transforms + + #[test] + fn test_to_uppercase_str() { + let a = str_array::(&["hello", "World", "FOO"]); + let out = to_uppercase_str((&a, 0, a.len())).unwrap(); + assert_eq!(out.get(0), Some("HELLO")); + assert_eq!(out.get(1), Some("WORLD")); + assert_eq!(out.get(2), Some("FOO")); + } + + #[test] + fn test_to_uppercase_str_chunk() { + let a = str_array::(&["skip", "hello", "world", "skip"]); + let out = to_uppercase_str((&a, 1, 2)).unwrap(); + assert_eq!(out.get(0), Some("HELLO")); + assert_eq!(out.get(1), Some("WORLD")); + } + + #[test] + fn test_to_uppercase_dict() { + let a = dict_array::(&["hello", "world"]); + let out = to_uppercase_dict((&a, 0, a.data.len())).unwrap(); + assert_eq!(out.get_str(0), Some("HELLO")); + assert_eq!(out.get_str(1), Some("WORLD")); + } + + #[test] + fn test_to_lowercase_str() { + let a = str_array::(&["HELLO", "World"]); + let out = to_lowercase_str((&a, 0, a.len())).unwrap(); + assert_eq!(out.get(0), Some("hello")); + assert_eq!(out.get(1), Some("world")); + } + + #[test] + fn test_trim_str() { + let a = str_array::(&[" hello ", "world ", " foo"]); + let out = trim_str((&a, 0, a.len())).unwrap(); + assert_eq!(out.get(0), Some("hello")); + assert_eq!(out.get(1), Some("world")); + assert_eq!(out.get(2), Some("foo")); + } + + #[test] + fn test_ltrim_str() { + let a = str_array::(&[" hello", " world "]); + let out = ltrim_str((&a, 0, a.len())).unwrap(); + assert_eq!(out.get(0), Some("hello")); + assert_eq!(out.get(1), Some("world ")); + } + + #[test] + fn test_rtrim_str() { + let a = str_array::(&["hello ", " world "]); + let out = rtrim_str((&a, 0, a.len())).unwrap(); + assert_eq!(out.get(0), Some("hello")); + assert_eq!(out.get(1), Some(" world")); + } + + #[test] + fn test_reverse_str() { + let a = str_array::(&["abc", "hello"]); + let out = reverse_str((&a, 0, a.len())).unwrap(); + assert_eq!(out.get(0), Some("cba")); + assert_eq!(out.get(1), Some("olleh")); + } + + #[test] + fn test_reverse_dict() { + let a = dict_array::(&["abc", "xy"]); + let out = reverse_dict((&a, 0, a.data.len())).unwrap(); + assert_eq!(out.get_str(0), Some("cba")); + assert_eq!(out.get_str(1), Some("yx")); + } + + // --- Byte length + + #[test] + fn test_byte_length_str() { + let a = str_array::(&["hi", "café"]); + let out = byte_length_str((&a, 0, a.len())).unwrap(); + assert_eq!(out.data[0], 2u32); // "hi" = 2 bytes + assert_eq!(out.data[1], 5u32); // "café" = 5 bytes (é is 2 bytes in UTF-8) + } + + #[test] + fn test_byte_length_dict() { + let a = dict_array::(&["hi", "café"]); + let out = byte_length_dict((&a, 0, a.data.len())).unwrap(); + assert_eq!(out.data[0], 2u32); + assert_eq!(out.data[1], 5u32); + } + + // --- Find and count + + #[test] + fn test_find_str() { + let a = str_array::(&["hello world", "foo bar", "baz"]); + let out = find_str((&a, 0, a.len()), "world").unwrap(); + assert_eq!(out.data[0], 6); + assert_eq!(out.data[1], -1); + assert_eq!(out.data[2], -1); + } + + #[test] + fn test_find_dict() { + let a = dict_array::(&["hello", "world"]); + let out = find_dict((&a, 0, a.data.len()), "llo").unwrap(); + assert_eq!(out.data[0], 2); + assert_eq!(out.data[1], -1); + } + + #[test] + fn test_count_match_str() { + let a = str_array::(&["abcabc", "abc", "xyz"]); + let out = count_match_str((&a, 0, a.len()), "abc").unwrap(); + assert_eq!(out.data[0], 2); + assert_eq!(out.data[1], 1); + assert_eq!(out.data[2], 0); + } + + #[test] + fn test_count_match_dict() { + let a = dict_array::(&["aaa", "a"]); + let out = count_match_dict((&a, 0, a.data.len()), "a").unwrap(); + assert_eq!(out.data[0], 3); + assert_eq!(out.data[1], 1); + } + + // --- Substring + + #[test] + fn test_substring_str() { + let a = str_array::(&["hello world", "foo"]); + let out = substring_str((&a, 0, a.len()), 6, None).unwrap(); + assert_eq!(out.get(0), Some("world")); + assert_eq!(out.get(1), Some("")); + } + + #[test] + fn test_substring_str_with_len() { + let a = str_array::(&["hello world"]); + let out = substring_str((&a, 0, a.len()), 0, Some(5)).unwrap(); + assert_eq!(out.get(0), Some("hello")); + } + + #[test] + fn test_substring_dict() { + let a = dict_array::(&["hello", "world"]); + let out = substring_dict((&a, 0, a.data.len()), 1, Some(3)).unwrap(); + assert_eq!(out.get_str(0), Some("ell")); + assert_eq!(out.get_str(1), Some("orl")); + } + + // --- Replace + + #[test] + fn test_replace_str() { + let a = str_array::(&["hello world", "foo bar foo"]); + let out = replace_str((&a, 0, a.len()), "foo", "baz").unwrap(); + assert_eq!(out.get(0), Some("hello world")); + assert_eq!(out.get(1), Some("baz bar baz")); + } + + #[test] + fn test_replace_dict() { + let a = dict_array::(&["aXb", "cXd"]); + let out = replace_dict((&a, 0, a.data.len()), "X", "Y").unwrap(); + assert_eq!(out.get_str(0), Some("aYb")); + assert_eq!(out.get_str(1), Some("cYd")); + } + + // --- Repeat + + #[test] + fn test_repeat_str() { + let a = str_array::(&["ab", "x"]); + let out = repeat_str((&a, 0, a.len()), 3).unwrap(); + assert_eq!(out.get(0), Some("ababab")); + assert_eq!(out.get(1), Some("xxx")); + } + + #[test] + fn test_repeat_dict() { + let a = dict_array::(&["ha"]); + let out = repeat_dict((&a, 0, a.data.len()), 2).unwrap(); + assert_eq!(out.get_str(0), Some("haha")); + } + + // --- Pad + + #[test] + fn test_pad_str_left() { + let a = str_array::(&["hi", "hello"]); + let out = pad_str((&a, 0, a.len()), 5, ' ', PadSide::Left).unwrap(); + assert_eq!(out.get(0), Some(" hi")); + assert_eq!(out.get(1), Some("hello")); // already 5 chars + } + + #[test] + fn test_pad_str_right() { + let a = str_array::(&["hi"]); + let out = pad_str((&a, 0, a.len()), 5, '-', PadSide::Right).unwrap(); + assert_eq!(out.get(0), Some("hi---")); + } + + #[test] + fn test_pad_str_both() { + let a = str_array::(&["hi"]); + let out = pad_str((&a, 0, a.len()), 6, '*', PadSide::Both).unwrap(); + assert_eq!(out.get(0), Some("**hi**")); + } + + #[test] + fn test_pad_dict() { + let a = dict_array::(&["x"]); + let out = pad_dict((&a, 0, a.data.len()), 3, '0', PadSide::Left).unwrap(); + assert_eq!(out.get_str(0), Some("00x")); + } + + // --- Join + + #[test] + fn test_join_str() { + let a = str_array::(&["a", "b", "c"]); + assert_eq!(join_str((&a, 0, a.len()), ", "), Some("a, b, c".to_owned())); + } + + #[test] + fn test_join_str_chunk() { + let a = str_array::(&["skip", "a", "b", "skip"]); + assert_eq!(join_str((&a, 1, 2), "-"), Some("a-b".to_owned())); + } + + #[test] + fn test_join_dict() { + let a = dict_array::(&["x", "y"]); + assert_eq!(join_dict((&a, 0, a.data.len()), "+"), Some("x+y".to_owned())); + } + + // --- Regex replace (feature-gated) + + #[cfg(feature = "regex")] + #[test] + fn test_regex_replace_str() { + let a = str_array::(&["abc123def456", "no digits"]); + let out = regex_replace_str((&a, 0, a.len()), r"\d+", "N").unwrap(); + assert_eq!(out.get(0), Some("abcNdefN")); + assert_eq!(out.get(1), Some("no digits")); + } + + #[cfg(feature = "regex")] + #[test] + fn test_regex_replace_dict() { + let a = dict_array::(&["foo123"]); + let out = regex_replace_dict((&a, 0, a.data.len()), r"\d+", "").unwrap(); + assert_eq!(out.get_str(0), Some("foo")); + } + + // --- Null handling tests + + #[test] + fn test_to_uppercase_str_with_nulls() { + let mut a = str_array::(&["hello", "skip", "world"]); + a.null_mask = Some(bm(&[true, false, true])); + let out = to_uppercase_str((&a, 0, a.len())).unwrap(); + assert_eq!(out.get(0), Some("HELLO")); + assert!(!out.null_mask.as_ref().unwrap().get(1)); + assert_eq!(out.get(2), Some("WORLD")); + } + + #[test] + fn test_find_str_with_nulls() { + let mut a = str_array::(&["hello", "skip"]); + a.null_mask = Some(bm(&[true, false])); + let out = find_str((&a, 0, a.len()), "ello").unwrap(); + assert_eq!(out.data[0], 1); + assert_eq!(out.data[1], 0); + assert!(!out.null_mask.as_ref().unwrap().get(1)); + } + + #[test] + fn test_join_str_with_nulls() { + let mut a = str_array::(&["a", "SKIP", "b"]); + a.null_mask = Some(bm(&[true, false, true])); + assert_eq!(join_str((&a, 0, a.len()), ","), Some("a,b".to_owned())); + } } diff --git a/src/structs/views/array_view.rs b/src/structs/views/array_view.rs index b8eb7f3..255f337 100644 --- a/src/structs/views/array_view.rs +++ b/src/structs/views/array_view.rs @@ -209,6 +209,18 @@ impl ArrayV { } } + /// Returns the value at logical index `i` as a `Scalar`, respecting nulls. + /// + /// Delegates to `Array::get_scalar` with the view's offset applied. + #[cfg(feature = "scalar_type")] + #[inline] + pub fn get_scalar(&self, i: usize) -> Option { + if i >= self.len { + return None; + } + self.array.get_scalar(self.offset + i) + } + /// Returns a new window view into a sub-range of this view. #[inline] pub fn slice(&self, offset: usize, len: usize) -> Self {