From 905c75d59ae6dccaf986d31ddacff4c35237acff Mon Sep 17 00:00:00 2001 From: Mahmoud Mohajer Date: Tue, 23 Sep 2025 16:48:46 +0430 Subject: [PATCH 1/2] Refactor Fixed type to use runtime scale and simplify evaluation tests. Update Fixed struct to store scale as a field, modify related methods, and adjust tests for new Fixed implementation. --- src/eval.rs | 52 ++--- src/lib.rs | 557 ++++++++++++++++++++++------------------------------ 2 files changed, 259 insertions(+), 350 deletions(-) diff --git a/src/eval.rs b/src/eval.rs index 8dc2ff5..88a9a77 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -147,9 +147,10 @@ mod tests { } } - struct TestEval { + struct TestEval { log_size: u32, op: Op, + scale: u32, } #[derive(Clone, Copy)] @@ -162,7 +163,7 @@ mod tests { Sqrt, } - impl FrameworkEval for TestEval { + impl FrameworkEval for TestEval { fn log_size(&self) -> u32 { self.log_size } @@ -172,7 +173,7 @@ mod tests { } fn evaluate(&self, mut eval: E) -> E { - let scale_factor = E::F::from(M31::from_u32_unchecked(1 << SCALE)); + let scale_factor = E::F::from(M31::from_u32_unchecked(1 << self.scale)); match self.op { Op::Add => { @@ -233,10 +234,10 @@ mod tests { .collect() } - fn test_op_internal( + fn test_op_internal( op: Op, - inputs: &[Fixed], - expected_outputs: &[Fixed], + inputs: &[Fixed], + expected_outputs: &[Fixed], tamper_col_idx: usize, ) { const LOG_SIZE: u32 = 4; @@ -261,9 +262,10 @@ mod tests { let trace_polys = trace.map_cols(|c| c.interpolate()); - let component = TestEval:: { + let component = TestEval { log_size: LOG_SIZE, op, + scale: 15, // Default scale for tests }; // Test valid trace @@ -281,7 +283,7 @@ mod tests { if let Some(col) = invalid_trace_cols.get_mut(tamper_col_idx) { for val in col.iter_mut() { // Calculate scale factor for tampering - let scale_factor = M31::from_u32_unchecked(1 << SCALE); + let scale_factor = M31::from_u32_unchecked(1 << 15); // Default scale val.0 = (val.0 + scale_factor.0) % P; } } @@ -310,8 +312,8 @@ mod tests { fn test_add() { let mut rng = StdRng::seed_from_u64(42); for _ in 0..100 { - let a = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); - let b = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + let a = Fixed::from_f64((rng.gen::() - 0.5) * 200.0, 15); + let b = Fixed::from_f64((rng.gen::() - 0.5) * 200.0, 15); test_op_internal(Op::Add, &[a, b], &[a + b], 2); } @@ -321,8 +323,8 @@ mod tests { fn test_sub() { let mut rng = StdRng::seed_from_u64(42); for _ in 0..100 { - let a = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); - let b = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + let a = Fixed::from_f64((rng.gen::() - 0.5) * 200.0, 15); + let b = Fixed::from_f64((rng.gen::() - 0.5) * 200.0, 15); test_op_internal(Op::Sub, &[a, b], &[a - b], 2); } @@ -334,8 +336,8 @@ mod tests { // Test regular multiplication cases for _ in 0..100 { - let a = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); - let b = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + let a = Fixed::from_f64((rng.gen::() - 0.5) * 200.0, 15); + let b = Fixed::from_f64((rng.gen::() - 0.5) * 200.0, 15); let (expected, rem) = a * b; test_op_internal(Op::Mul, &[a, b], &[expected, rem], 2); @@ -355,8 +357,8 @@ mod tests { ]; for (a, b) in special_cases { - let fixed_a = Fixed::<15>::from_f64(a); - let fixed_b = Fixed::<15>::from_f64(b); + let fixed_a = Fixed::from_f64(a, 15); + let fixed_b = Fixed::from_f64(b, 15); let (expected, rem) = fixed_a * fixed_b; test_op_internal(Op::Mul, &[fixed_a, fixed_b], &[expected, rem], 2); @@ -369,8 +371,8 @@ mod tests { // Test regular remainder cases for _ in 0..50 { - let a = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); - let b = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + let a = Fixed::from_f64((rng.gen::() - 0.5) * 200.0, 15); + let b = Fixed::from_f64((rng.gen::() - 0.5) * 200.0, 15); // Skip cases where divisor is too close to zero if b.to_f64().abs() < 0.1 { @@ -394,8 +396,8 @@ mod tests { ]; for (a, b) in special_cases { - let fixed_a = Fixed::<15>::from_f64(a); - let fixed_b = Fixed::<15>::from_f64(b); + let fixed_a = Fixed::from_f64(a, 15); + let fixed_b = Fixed::from_f64(b, 15); let (quotient, remainder) = fixed_a.div_rem(fixed_b); test_op_internal(Op::Rem, &[fixed_a, fixed_b], &[quotient, remainder], 2); @@ -408,8 +410,8 @@ mod tests { // Test regular recip cases for _ in 0..100 { - let input = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); - if input.0 == 0 { + let input = Fixed::from_f64((rng.gen::() - 0.5) * 200.0, 15); + if input.value == 0 { continue; // Skip division by zero } @@ -429,7 +431,7 @@ mod tests { ]; for input in special_cases { - let fixed_input = Fixed::<15>::from_f64(input); + let fixed_input = Fixed::from_f64(input, 15); let (expected, rem) = fixed_input.recip(); test_op_internal(Op::Recip, &[fixed_input], &[expected, rem], 1); @@ -440,7 +442,7 @@ mod tests { fn test_sqrt() { let test_cases = vec![1.0, 4.0, 9.0, 2.0, 0.5, 0.25, 0.0]; for input in test_cases { - let fixed_input = Fixed::<15>::from_f64(input); + let fixed_input = Fixed::from_f64(input, 15); let (sqrt_out, rem) = fixed_input.sqrt(); test_op_internal(Op::Sqrt, &[fixed_input], &[sqrt_out, rem], 1); @@ -449,7 +451,7 @@ mod tests { let mut rng = StdRng::seed_from_u64(43); for _ in 0..50 { let input_val: f64 = rng.gen_range(0.0..100.0); - let fixed_input = Fixed::<15>::from_f64(input_val); + let fixed_input = Fixed::from_f64(input_val, 15); let (sqrt_out, rem) = fixed_input.sqrt(); test_op_internal(Op::Sqrt, &[fixed_input], &[sqrt_out, rem], 1); diff --git a/src/lib.rs b/src/lib.rs index 7bf2cf8..9594630 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,471 +8,378 @@ pub mod eval; // Half the prime modulus. pub const HALF_P: u32 = P / 2; -/// Integer representation of fixed-point Basefield with parametrized scale. +/// Integer representation of fixed-point Basefield with runtime scale. #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] -pub struct Fixed(pub i64); +pub struct Fixed { + pub value: i64, + pub scale: u32, +} -impl Fixed { - const SCALE_FACTOR: i64 = 1 << SCALE; - const HALF_SCALE_FACTOR: i64 = 1 << (SCALE - 1); +impl Fixed { + /// Create a new Fixed with the given value and scale + pub fn new(value: i64, scale: u32) -> Self { + Self { value, scale } + } - #[inline] - pub fn from_f64(value: f64) -> Self { - Self((value * Self::SCALE_FACTOR as f64).round() as i64) + /// Get the scale factor (2^scale) + pub fn scale_factor(&self) -> i64 { + 1i64 << self.scale } - #[inline] - /// Convert to a float + /// Get half the scale factor for rounding + pub fn half_scale_factor(&self) -> i64 { + 1i64 << (self.scale - 1) + } + + /// Create from f64 with specified scale + pub fn from_f64(value: f64, scale: u32) -> Self { + let scale_factor = 1i64 << scale; + Self { + value: (value * scale_factor as f64).round() as i64, + scale, + } + } + + /// Convert to f64 pub fn to_f64(self) -> f64 { - self.0 as f64 / Self::SCALE_FACTOR as f64 + let scale_factor = 1i64 << self.scale; + self.value as f64 / scale_factor as f64 } - #[inline] /// Convert to M31 for Stwo pub fn to_m31(self) -> M31 { const MODULUS_BITS: u32 = 31; - if self.0 >= 0 { - // For positive numbers, use M31 reduce directly - M31::reduce(self.0 as u64) + if self.value >= 0 { + M31::reduce(self.value as u64) } else { - // For negative numbers, efficiently compute P - (abs(value) % P) - // This is a fast implementation of modulo for 2^31-1 - let abs_val = (-self.0) as u64; + let abs_val = (-self.value) as u64; let abs_mod = (((((abs_val >> MODULUS_BITS) + abs_val + 1) >> MODULUS_BITS) + abs_val) & (P as u64)) as u32; if abs_mod == 0 { - M31::from_u32_unchecked(0) // -0 = 0 + M31::from_u32_unchecked(0) } else { M31::from_u32_unchecked(P - abs_mod) } } } - #[inline] - /// Convert from M31 - pub fn from_m31(value: M31) -> Self { + /// Convert from M31 with specified scale + pub fn from_m31(value: M31, scale: u32) -> Self { let m31_val = value.0; let is_negative = m31_val > HALF_P; if is_negative { - Self(-((P - m31_val) as i64)) + Self { + value: -((P - m31_val) as i64), + scale, + } } else { - Self(m31_val as i64) + Self { + value: m31_val as i64, + scale, + } } } - /// Computes both quotient and remainder for division. - /// Returns (quotient, remainder) where dividend = quotient * divisor + remainder - /// Note: The quotient is stored as an unscaled integer for constraint compatibility. - #[inline] + /// Computes both quotient and remainder for division pub fn div_rem(self, rhs: Self) -> (Self, Self) { - assert!(rhs.0 != 0, "Division by zero"); - let quotient = self.0 / rhs.0; - let remainder = self.0 % rhs.0; - (Self(quotient), Self(remainder)) + assert!(rhs.value != 0, "Division by zero"); + assert_eq!(self.scale, rhs.scale, "Scales must match for division"); + + let quotient = self.value / rhs.value; + let remainder = self.value % rhs.value; + + ( + Self { + value: quotient, + scale: self.scale, + }, + Self { + value: remainder, + scale: self.scale, + }, + ) } /// Computes the reciprocal (1/x) of a fixed-point number - /// - /// Returns a tuple of (quotient, remainder) where: - /// - quotient is the fixed-point representation of 1/x - /// - remainder is the remainder after division - #[inline] pub fn recip(self) -> (Self, Self) { - assert!(self.0 != 0, "Division by zero"); - - let scale_factor_squared = Self::SCALE_FACTOR * Self::SCALE_FACTOR; - let quotient = scale_factor_squared / self.0; - let remainder = scale_factor_squared % self.0; - - (Self(quotient), Self(remainder)) - } - - /// Computes the fixed-point representation of the square root and its remainder. - /// - /// `self` represents `input_val * SCALE_FACTOR`, to compute - /// `out` and `rem`, i.e., `out` represents `sqrt(input_val) * SCALE_FACTOR` - /// and the following hold for their underlying integer values - /// (`input.0`, `out.0`, `rem.0`): - /// - /// `out.0^2 + rem.0 = input.0 * SCALE_FACTOR` - /// - /// where `out.0` is the integer square root of `(input.0 * SCALE_FACTOR)`. - /// The remainder `rem.0` is the difference `(input.0 * SCALE_FACTOR) - out.0^2`. + assert!(self.value != 0, "Division by zero"); + + let scale_factor = self.scale_factor(); + let scale_factor_squared = scale_factor * scale_factor; + let quotient = scale_factor_squared / self.value; + let remainder = scale_factor_squared % self.value; + + ( + Self { + value: quotient, + scale: self.scale, + }, + Self { + value: remainder, + scale: self.scale, + }, + ) + } + + /// Computes the fixed-point representation of the square root and its remainder pub fn sqrt(&self) -> (Self, Self) { - // Panic for negative inputs - assert!(self.0 >= 0, "Cannot compute square root of negative number"); + assert!(self.value >= 0, "Cannot compute square root of negative number"); - // Special case: zero input - if self.0 == 0 { - return (Self(0), Self(0)); + if self.value == 0 { + return (Self::zero(self.scale), Self::zero(self.scale)); } - // Calculate value to compute sqrt of: self * SCALE_FACTOR - let input_scaled = (self.0 as u64) << SCALE; - - // Compute integer square root + let input_scaled = (self.value as u64) << self.scale; let sqrt_val = int_sqrt(input_scaled); - - // Calculate remainder (input_scaled - sqrt_val^2) let remainder = input_scaled - sqrt_val * sqrt_val; - (Self(sqrt_val as i64), Self(remainder as i64)) + ( + Self { + value: sqrt_val as i64, + scale: self.scale, + }, + Self { + value: remainder as i64, + scale: self.scale, + }, + ) } /// Convert this Fixed value to a Fixed with a different scale - pub fn convert_to(self) -> Fixed { - if TARGET_SCALE == SCALE { - // Same scale, just change the type - Fixed(self.0) - } else if TARGET_SCALE > SCALE { - // Going to higher precision - let shift = TARGET_SCALE - SCALE; - Fixed(self.0 << shift) + pub fn convert_to(self, target_scale: u32) -> Self { + if target_scale == self.scale { + self + } else if target_scale > self.scale { + let shift = target_scale - self.scale; + Self { + value: self.value << shift, + scale: target_scale, + } } else { - // Going to lower precision - let shift = SCALE - TARGET_SCALE; - Fixed(self.0 >> shift) + let shift = self.scale - target_scale; + Self { + value: self.value >> shift, + scale: target_scale, + } } } + + /// Create zero with specified scale + pub fn zero(scale: u32) -> Self { + Self { value: 0, scale } + } + + /// Check if the value is zero + pub fn is_zero(&self) -> bool { + self.value == 0 + } + + /// Get the scale of this Fixed + pub fn scale(&self) -> u32 { + self.scale + } + + /// Get the raw value + pub fn value(&self) -> i64 { + self.value + } } -/// Returns the floor of the square root of `n`. +/// Returns the floor of the square root of `n` #[inline] pub fn int_sqrt(n: u64) -> u64 { if n <= 1 { return n; } - // Initial guess let bits = 64 - n.leading_zeros(); let mut x = n >> (bits / 2); - // Ensure x is not zero (which would cause division by zero) if x == 0 { x = 1; } - // Newton's method with careful convergence checking let mut prev_x = x; loop { - // Compute next iteration let quotient = n / x; - let next_x = (x + quotient) / 2; // We can use regular division here since x + quotient ≤ n + 1 + let next_x = (x + quotient) / 2; - // Check for convergence or oscillation if next_x == x || next_x == prev_x { return next_x; } - // Update for next iteration prev_x = x; x = next_x; } } -impl Add for Fixed { +// Implement arithmetic operations +impl Add for Fixed { type Output = Self; #[inline] fn add(self, rhs: Self) -> Self::Output { - Self(self.0 + rhs.0) + assert_eq!(self.scale, rhs.scale, "Scales must match for addition"); + Self { + value: self.value + rhs.value, + scale: self.scale, + } } } -impl Sub for Fixed { +impl Sub for Fixed { type Output = Self; #[inline] fn sub(self, rhs: Self) -> Self::Output { - Self(self.0 - rhs.0) + assert_eq!(self.scale, rhs.scale, "Scales must match for subtraction"); + Self { + value: self.value - rhs.value, + scale: self.scale, + } } } -impl Mul for Fixed { +impl Mul for Fixed { type Output = (Self, Self); #[inline] fn mul(self, rhs: Self) -> Self::Output { - let product = self.0 * rhs.0; - - let quotient = (product + Self::HALF_SCALE_FACTOR) >> SCALE; - - // Calculate remainder to maintain: product = quotient * scale + remainder - let scaled_quotient = quotient << SCALE; + assert_eq!(self.scale, rhs.scale, "Scales must match for multiplication"); + + let product = self.value * rhs.value; + let half_scale_factor = self.half_scale_factor(); + let quotient = (product + half_scale_factor) >> self.scale; + let scaled_quotient = quotient << self.scale; let remainder = product - scaled_quotient; - (Self(quotient), Self(remainder)) + ( + Self { + value: quotient, + scale: self.scale, + }, + Self { + value: remainder, + scale: self.scale, + }, + ) } } -impl Rem for Fixed { +impl Rem for Fixed { type Output = Self; #[inline] fn rem(self, rhs: Self) -> Self::Output { - assert!(rhs.0 != 0, "Division by zero in remainder operation"); - Self(self.0 % rhs.0) + assert!(rhs.value != 0, "Division by zero in remainder operation"); + assert_eq!(self.scale, rhs.scale, "Scales must match for remainder"); + Self { + value: self.value % rhs.value, + scale: self.scale, + } } } -impl Zero for Fixed { +impl Zero for Fixed { #[inline] fn zero() -> Self { - Self(0) + Self::zero(12) // Default scale } #[inline] fn is_zero(&self) -> bool { - self.0 == 0 + self.is_zero() } } -#[cfg(test)] -mod tests { - use super::*; - use rand::rngs::StdRng; - use rand::{Rng, SeedableRng}; - - const EPSILON: f64 = 1e-3; - - fn assert_near(a: f64, b: f64) { - assert!((a - b).abs() < EPSILON, "Expected {} to be near {}", a, b); +// Add convenience methods for common scales +impl Fixed { + /// Create with 8-bit scale + pub fn from_f64_8(value: f64) -> Self { + Self::from_f64(value, 8) } - #[test] - fn test_negative() { - let a = Fixed::<15>::from_f64(-3.5); - let b = Fixed::<15>::from_f64(2.0); - - assert_near(a.to_f64(), -3.5); - assert_near((a + b).to_f64(), -1.5); - assert_near((a - b).to_f64(), -5.5); + /// Create with 12-bit scale (default) + pub fn from_f64_12(value: f64) -> Self { + Self::from_f64(value, 12) } - #[test] - fn test_add() { - let mut rng = StdRng::seed_from_u64(42); - - for _ in 0..1000 { - let a = (rng.gen::() - 0.5) * 200.0; - let b = (rng.gen::() - 0.5) * 200.0; - - let fa = Fixed::<15>::from_f64(a); - let fb = Fixed::<15>::from_f64(b); - - assert_near((fa + fb).to_f64(), a + b); - } + /// Create with 16-bit scale + pub fn from_f64_16(value: f64) -> Self { + Self::from_f64(value, 16) } - #[test] - fn test_sub() { - let mut rng = StdRng::seed_from_u64(42); - - for _ in 0..1000 { - let a = (rng.gen::() - 0.5) * 200.0; - let b = (rng.gen::() - 0.5) * 200.0; - - let fa = Fixed::<15>::from_f64(a); - let fb = Fixed::<15>::from_f64(b); - - assert_near((fa - fb).to_f64(), a - b); - } + /// Create with 24-bit scale + pub fn from_f64_24(value: f64) -> Self { + Self::from_f64(value, 24) } +} - #[test] - fn test_mul() { - let mut rng = StdRng::seed_from_u64(42); - - for _ in 0..1000 { - let a = (rng.gen::() - 0.5) * 10.0; - let b = (rng.gen::() - 0.5) * 10.0; - - let fa = Fixed::<15>::from_f64(a); - let fb = Fixed::<15>::from_f64(b); +#[cfg(test)] +mod tests { + use super::*; - let (q, _) = fa * fb; - let expected = a * b; + const EPSILON: f64 = 1e-3; - assert_near(q.to_f64(), expected); - } + fn assert_near(a: f64, b: f64) { + assert!((a - b).abs() < EPSILON, "Expected {} to be near {}", a, b); } #[test] - fn test_recip() { - let mut rng = StdRng::seed_from_u64(42); - - for _ in 0..100 { - let a = (rng.gen::() - 0.5) * 10.0; - if a.abs() < 0.1 { - continue; - } - - let fixed_a = Fixed::<15>::from_f64(a); - let (recip, _) = fixed_a.recip(); - let expected = 1.0 / a; - - assert_near(recip.to_f64(), expected); - } - - // Test specific cases - let test_cases = vec![ - (1.0, 1.0), // Reciprocal of 1 is 1 - (2.0, 0.5), // Reciprocal of 2 is 0.5 - (0.5, 2.0), // Reciprocal of 0.5 is 2 - (4.0, 0.25), // Reciprocal of 4 is 0.25 - (-1.0, -1.0), // Reciprocal of -1 is -1 - (-2.0, -0.5), // Reciprocal of -2 is -0.5 - ]; - - for (a, expected) in test_cases { - let fixed_a = Fixed::<15>::from_f64(a); - let (recip, _) = fixed_a.recip(); - assert_near(recip.to_f64(), expected); - } + fn test_basic_operations() { + let a = Fixed::from_f64(1.5, 12); + let b = Fixed::from_f64(2.0, 12); + + assert_near(a.to_f64(), 1.5); + assert_near(b.to_f64(), 2.0); + assert_near((a + b).to_f64(), 3.5); + assert_near((a - b).to_f64(), -0.5); } #[test] - fn test_sqrt() { - let mut test_cases = vec![ - 0.0, 1.0, 4.0, 9.0, 10.0, 16.0, 25.0, 81.0, 100.0, 0.25, 0.0625, 0.01, 5.0, 8.0, 12.0, - 15.0, 20.0, 50.0, 10000.0, 1000000.0, // Large value - 1e-10, // Small value - 0.001, // rest irrationals - 0.5, 2.0, 3.0, 42.0, // Nod to Douglas Adams - ]; - - let mut rng = StdRng::seed_from_u64(42); - for _ in 0..200 { - let value: f64 = rng.gen_range(0.01..50.0); - test_cases.push(value); - } - - for input in test_cases { - let fixed_input = Fixed::<15>::from_f64(input); - - if input < 0.0 { - let (result, remainder) = fixed_input.sqrt(); - assert_eq!(result.0, 0); - assert_eq!(remainder.0, 0); - continue; - } - - let (result, _) = fixed_input.sqrt(); - let result_f64 = result.to_f64(); - assert_near(result_f64, input.sqrt()); - } + fn test_multiplication() { + let a = Fixed::from_f64(2.5, 12); + let b = Fixed::from_f64(3.0, 12); + let (result, _) = a * b; + assert_near(result.to_f64(), 7.5); } #[test] - fn test_rem() { - let mut rng = StdRng::seed_from_u64(42); - - // Test random cases - for _ in 0..100 { - let a = (rng.gen::() - 0.5) * 20.0; - let b = (rng.gen::() - 0.5) * 20.0; - - // Skip cases where divisor is too close to zero - if b.abs() < 0.1 { - continue; - } - - let fa = Fixed::<15>::from_f64(a); - let fb = Fixed::<15>::from_f64(b); - - let remainder = fa % fb; - let expected = a % b; - - assert_near(remainder.to_f64(), expected); - } - - // Test specific cases - let test_cases = vec![ - (10.0, 3.0), // 10 % 3 = 1 - (7.5, 2.5), // 7.5 % 2.5 = 0 - (9.0, 4.0), // 9 % 4 = 1 - (-10.0, 3.0), // -10 % 3 = -1 (or 2, depending on implementation) - (10.0, -3.0), // 10 % -3 = 1 (or -2, depending on implementation) - (-10.0, -3.0), // -10 % -3 = -1 (or 2, depending on implementation) - ]; - - for (a, b) in test_cases { - let fa = Fixed::<15>::from_f64(a); - let fb = Fixed::<15>::from_f64(b); - let remainder = fa % fb; - let expected = a % b; - - assert_near(remainder.to_f64(), expected); - } + fn test_scale_conversion() { + let a = Fixed::from_f64(1.5, 12); + let b = a.convert_to(8); + assert_near(b.to_f64(), 1.5); + assert_eq!(b.scale(), 8); } #[test] - fn test_div_rem() { - let mut rng = StdRng::seed_from_u64(42); - - for _ in 0..100 { - let a = (rng.gen::() - 0.5) * 20.0; - let b = (rng.gen::() - 0.5) * 20.0; - println!("a {:?}", a); - println!("b {:?}", b); - - // Skip cases where divisor is too close to zero - if b.abs() < 0.1 { - continue; - } - - let fa = Fixed::<15>::from_f64(a); - let fb = Fixed::<15>::from_f64(b); - - let (quotient, remainder) = fa.div_rem(fb); - - // Verify: dividend = quotient * divisor + remainder - let reconstructed = quotient.0 * fb.0 + remainder.0; - assert_eq!(reconstructed, fa.0); - - // Check individual results - let expected_quotient = (a / b).trunc(); - let expected_remainder = a % b; - - // The quotient from div_rem is stored as an unscaled integer - assert_eq!(quotient.0 as f64, expected_quotient); - assert_near(remainder.to_f64(), expected_remainder); - } + fn test_different_scales() { + let a = Fixed::from_f64_8(1.5); + let b = Fixed::from_f64_12(1.5); + let c = Fixed::from_f64_16(1.5); + let d = Fixed::from_f64_24(1.5); + + assert_near(a.to_f64(), 1.5); + assert_near(b.to_f64(), 1.5); + assert_near(c.to_f64(), 1.5); + assert_near(d.to_f64(), 1.5); + + assert_eq!(a.scale(), 8); + assert_eq!(b.scale(), 12); + assert_eq!(c.scale(), 16); + assert_eq!(d.scale(), 24); } #[test] - fn test_different_scales() { - // Test with 15-bit scale - let scale_15 = Fixed::<15>::from_f64(1.5); - assert_near(scale_15.to_f64(), 1.5); + fn test_scale_mismatch() { + let a = Fixed::from_f64(1.0, 8); + let b = Fixed::from_f64(2.0, 12); - // Test with 8-bit scale (less precision) - let scale8 = Fixed::<8>::from_f64(1.5); - assert_near(scale8.to_f64(), 1.5); - - // Test with 24-bit scale (more precision) - let scale24 = Fixed::<24>::from_f64(1.5); - assert_near(scale24.to_f64(), 1.5); - - // Test conversion between scales - let from_15_to_8 = scale_15.convert_to::<8>(); - assert_near(from_15_to_8.to_f64(), 1.5); - - let from_8_to_24 = scale8.convert_to::<24>(); - assert_near(from_8_to_24.to_f64(), 1.5); - - // Multiplication with different scales - let a8 = Fixed::<8>::from_f64(2.5); - let b8 = Fixed::<8>::from_f64(3.0); - let (result8, _) = a8 * b8; - assert_near(result8.to_f64(), 7.5); - - let a24 = Fixed::<24>::from_f64(2.5); - let b24 = Fixed::<24>::from_f64(3.0); - let (result24, _) = a24 * b24; - assert_near(result24.to_f64(), 7.5); + // This should panic due to scale mismatch + let result = std::panic::catch_unwind(|| a + b); + assert!(result.is_err()); } -} +} \ No newline at end of file From 4521ea46b1ab92f6afbaa49b93364028bfdf60d0 Mon Sep 17 00:00:00 2001 From: Mahmoud Mohajer Date: Sun, 5 Oct 2025 10:18:31 +0430 Subject: [PATCH 2/2] Add comprehensive tests for Fixed type operations including addition, subtraction, multiplication, division, remainder, and square root. Utilize random number generation for extensive coverage and edge case handling. --- src/lib.rs | 201 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 9594630..5494d1a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -321,6 +321,8 @@ impl Fixed { #[cfg(test)] mod tests { use super::*; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; const EPSILON: f64 = 1e-3; @@ -328,6 +330,205 @@ mod tests { assert!((a - b).abs() < EPSILON, "Expected {} to be near {}", a, b); } + #[test] + fn test_negative() { + let a = Fixed::from_f64(-3.5, 15); + let b = Fixed::from_f64(2.0, 15); + + assert_near(a.to_f64(), -3.5); + assert_near((a + b).to_f64(), -1.5); + assert_near((a - b).to_f64(), -5.5); + } + + #[test] + fn test_add() { + let mut rng = StdRng::seed_from_u64(42); + + for _ in 0..1000 { + let a = (rng.gen::() - 0.5) * 200.0; + let b = (rng.gen::() - 0.5) * 200.0; + + let fa = Fixed::from_f64(a, 15); + let fb = Fixed::from_f64(b, 15); + + assert_near((fa + fb).to_f64(), a + b); + } + } + + #[test] + fn test_sub() { + let mut rng = StdRng::seed_from_u64(42); + + for _ in 0..1000 { + let a = (rng.gen::() - 0.5) * 200.0; + let b = (rng.gen::() - 0.5) * 200.0; + + let fa = Fixed::from_f64(a, 15); + let fb = Fixed::from_f64(b, 15); + + assert_near((fa - fb).to_f64(), a - b); + } + } + + #[test] + fn test_mul() { + let mut rng = StdRng::seed_from_u64(42); + + for _ in 0..1000 { + let a = (rng.gen::() - 0.5) * 10.0; + let b = (rng.gen::() - 0.5) * 10.0; + + let fa = Fixed::from_f64(a, 15); + let fb = Fixed::from_f64(b, 15); + + let (q, _) = fa * fb; + let expected = a * b; + + assert_near(q.to_f64(), expected); + } + } + + #[test] + fn test_recip() { + let mut rng = StdRng::seed_from_u64(42); + + for _ in 0..100 { + let a = (rng.gen::() - 0.5) * 10.0; + if a.abs() < 0.1 { + continue; + } + + let fixed_a = Fixed::from_f64(a, 15); + let (recip, _) = fixed_a.recip(); + let expected = 1.0 / a; + + assert_near(recip.to_f64(), expected); + } + + // Test specific cases + let test_cases = vec![ + (1.0, 1.0), // Reciprocal of 1 is 1 + (2.0, 0.5), // Reciprocal of 2 is 0.5 + (0.5, 2.0), // Reciprocal of 0.5 is 2 + (4.0, 0.25), // Reciprocal of 4 is 0.25 + (-1.0, -1.0), // Reciprocal of -1 is -1 + (-2.0, -0.5), // Reciprocal of -2 is -0.5 + ]; + + for (a, expected) in test_cases { + let fixed_a = Fixed::from_f64(a, 15); + let (recip, _) = fixed_a.recip(); + assert_near(recip.to_f64(), expected); + } + } + + #[test] + fn test_sqrt() { + let mut test_cases = vec![ + 0.0, 1.0, 4.0, 9.0, 10.0, 16.0, 25.0, 81.0, 100.0, 0.25, 0.0625, 0.01, 5.0, 8.0, 12.0, + 15.0, 20.0, 50.0, 10000.0, 1000000.0, // Large value + 1e-10, // Small value + 0.001, // rest irrationals + 0.5, 2.0, 3.0, 42.0, // Nod to Douglas Adams + ]; + + let mut rng = StdRng::seed_from_u64(42); + for _ in 0..200 { + let value: f64 = rng.gen_range(0.01..50.0); + test_cases.push(value); + } + + for input in test_cases { + let fixed_input = Fixed::from_f64(input, 15); + + if input < 0.0 { + let (result, remainder) = fixed_input.sqrt(); + assert_eq!(result.value, 0); + assert_eq!(remainder.value, 0); + continue; + } + + let (result, _) = fixed_input.sqrt(); + let result_f64 = result.to_f64(); + assert_near(result_f64, input.sqrt()); + } + } + + #[test] + fn test_rem() { + let mut rng = StdRng::seed_from_u64(42); + + // Test random cases + for _ in 0..100 { + let a = (rng.gen::() - 0.5) * 20.0; + let b = (rng.gen::() - 0.5) * 20.0; + + // Skip cases where divisor is too close to zero + if b.abs() < 0.1 { + continue; + } + + let fa = Fixed::from_f64(a, 15); + let fb = Fixed::from_f64(b, 15); + + let remainder = fa % fb; + let expected = a % b; + + assert_near(remainder.to_f64(), expected); + } + + // Test specific cases + let test_cases = vec![ + (10.0, 3.0), // 10 % 3 = 1 + (7.5, 2.5), // 7.5 % 2.5 = 0 + (9.0, 4.0), // 9 % 4 = 1 + (-10.0, 3.0), // -10 % 3 = -1 (or 2, depending on implementation) + (10.0, -3.0), // 10 % -3 = 1 (or -2, depending on implementation) + (-10.0, -3.0), // -10 % -3 = -1 (or 2, depending on implementation) + ]; + + for (a, b) in test_cases { + let fa = Fixed::from_f64(a, 15); + let fb = Fixed::from_f64(b, 15); + let remainder = fa % fb; + let expected = a % b; + + assert_near(remainder.to_f64(), expected); + } + } + + #[test] + fn test_div_rem() { + let mut rng = StdRng::seed_from_u64(42); + + for _ in 0..100 { + let a = (rng.gen::() - 0.5) * 20.0; + let b = (rng.gen::() - 0.5) * 20.0; + + // Skip cases where divisor is too close to zero + if b.abs() < 0.1 { + continue; + } + + let fa = Fixed::from_f64(a, 15); + let fb = Fixed::from_f64(b, 15); + + let (quotient, remainder) = fa.div_rem(fb); + + // Verify: dividend = quotient * divisor + remainder + let reconstructed = quotient.value * fb.value + remainder.value; + assert_eq!(reconstructed, fa.value); + + // Check individual results + let expected_quotient = (a / b).trunc(); + let expected_remainder = a % b; + + // The quotient from div_rem is stored as an unscaled integer + assert_eq!(quotient.value as f64, expected_quotient); + assert_near(remainder.to_f64(), expected_remainder); + } + } + #[test] fn test_basic_operations() { let a = Fixed::from_f64(1.5, 12);