diff --git a/src/eval.rs b/src/eval.rs index acb73a5..824fc96 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -45,6 +45,19 @@ pub trait EvalFixedPoint: EvalAtRow { self.add_constraint(remainder + aux - (divisor - Self::F::one())); } + /// Evaluates remainder constraints for fixed-point numbers. + /// Constrains: dividend = quotient * divisor + remainder + /// This is essentially the same as eval_fixed_div_rem but semantically focused on remainder. + fn eval_fixed_rem( + &mut self, + dividend: Self::F, + divisor: Self::F, + quotient: Self::F, + remainder: Self::F, + ) { + self.eval_fixed_div_rem(dividend, divisor, quotient, remainder); + } + /// Evaluates reciprocal constraints for fixed-point numbers. /// Constrains: scale_factor * scale_factor = value * reciprocal + remainder fn eval_fixed_recip( @@ -123,6 +136,7 @@ mod tests { Add, Sub, Mul, + Rem, Recip, Sqrt, } @@ -159,6 +173,13 @@ mod tests { let rem = eval.next_trace_mask(); eval.eval_fixed_mul(lhs, rhs, scale_factor, out, rem) } + Op::Rem => { + let dividend = eval.next_trace_mask(); + let divisor = eval.next_trace_mask(); + let quotient = eval.next_trace_mask(); + let remainder = eval.next_trace_mask(); + eval.eval_fixed_rem(dividend, divisor, quotient, remainder) + } Op::Recip => { let input = eval.next_trace_mask(); let out = eval.next_trace_mask(); @@ -321,6 +342,45 @@ mod tests { } } + #[test] + fn test_rem() { + let mut rng = StdRng::seed_from_u64(42); + + // Test regular remainder cases + for _ in 0..50 { + let a = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + let b = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + + // Skip cases where divisor is too close to zero + if b.to_f64().abs() < 0.1 { + continue; + } + + let (quotient, remainder) = a.div_rem(b); + + test_op_internal(Op::Rem, &[a, b], &[quotient, remainder], 2); + } + + // Test special cases + let special_cases = vec![ + (10.0, 3.0), // 10 % 3 = 1, quotient = 3 + (7.5, 2.5), // 7.5 % 2.5 = 0, quotient = 3 + (9.0, 4.0), // 9 % 4 = 1, quotient = 2 + (8.0, 3.0), // 8 % 3 = 2, quotient = 2 + (15.0, 4.0), // 15 % 4 = 3, quotient = 3 + (20.0, 6.0), // 20 % 6 = 2, quotient = 3 + (1.5, 0.5), // 1.5 % 0.5 = 0, quotient = 3 + ]; + + for (a, b) in special_cases { + let fixed_a = Fixed::<15>::from_f64(a); + let fixed_b = Fixed::<15>::from_f64(b); + let (quotient, remainder) = fixed_a.div_rem(fixed_b); + + test_op_internal(Op::Rem, &[fixed_a, fixed_b], &[quotient, remainder], 2); + } + } + #[test] fn test_recip() { let mut rng = StdRng::seed_from_u64(42); diff --git a/src/lib.rs b/src/lib.rs index 58f11da..57fdb45 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ use num_traits::Zero; use serde::{Deserialize, Serialize}; -use std::ops::{Add, Mul, Sub}; +use std::ops::{Add, Mul, Rem, Sub}; use stwo_prover::core::fields::m31::{M31, P}; pub mod eval; @@ -63,6 +63,17 @@ impl Fixed { } } + /// Computes both quotient and remainder for division. + /// Returns (quotient, remainder) where dividend = quotient * divisor + remainder + /// Note: The quotient is stored as an unscaled integer for constraint compatibility. + #[inline] + pub fn div_rem(self, rhs: Self) -> (Self, Self) { + assert!(rhs.0 != 0, "Division by zero"); + let quotient = self.0 / rhs.0; + let remainder = self.0 % rhs.0; + (Self(quotient), Self(remainder)) + } + /// Computes the reciprocal (1/x) of a fixed-point number /// /// Returns a tuple of (quotient, remainder) where: @@ -108,10 +119,7 @@ impl Fixed { // Calculate remainder (input_scaled - sqrt_val^2) let remainder = input_scaled - sqrt_val * sqrt_val; - ( - Self(sqrt_val as i64), - Self(remainder as i64), - ) + (Self(sqrt_val as i64), Self(remainder as i64)) } /// Convert this Fixed value to a Fixed with a different scale @@ -200,6 +208,16 @@ impl Mul for Fixed { } } +impl Rem for Fixed { + type Output = Self; + + #[inline] + fn rem(self, rhs: Self) -> Self::Output { + assert!(rhs.0 != 0, "Division by zero in remainder operation"); + Self(self.0 % rhs.0) + } +} + impl Zero for Fixed { #[inline] fn zero() -> Self { @@ -348,6 +366,83 @@ mod tests { } } + #[test] + fn test_rem() { + let mut rng = StdRng::seed_from_u64(42); + + // Test random cases + for _ in 0..100 { + let a = (rng.gen::() - 0.5) * 20.0; + let b = (rng.gen::() - 0.5) * 20.0; + + // Skip cases where divisor is too close to zero + if b.abs() < 0.1 { + continue; + } + + let fa = Fixed::<15>::from_f64(a); + let fb = Fixed::<15>::from_f64(b); + + let remainder = fa % fb; + let expected = a % b; + + assert_near(remainder.to_f64(), expected); + } + + // Test specific cases + let test_cases = vec![ + (10.0, 3.0), // 10 % 3 = 1 + (7.5, 2.5), // 7.5 % 2.5 = 0 + (9.0, 4.0), // 9 % 4 = 1 + (-10.0, 3.0), // -10 % 3 = -1 (or 2, depending on implementation) + (10.0, -3.0), // 10 % -3 = 1 (or -2, depending on implementation) + (-10.0, -3.0), // -10 % -3 = -1 (or 2, depending on implementation) + ]; + + for (a, b) in test_cases { + let fa = Fixed::<15>::from_f64(a); + let fb = Fixed::<15>::from_f64(b); + let remainder = fa % fb; + let expected = a % b; + + assert_near(remainder.to_f64(), expected); + } + } + + #[test] + fn test_div_rem() { + let mut rng = StdRng::seed_from_u64(42); + + for _ in 0..100 { + let a = (rng.gen::() - 0.5) * 20.0; + let b = (rng.gen::() - 0.5) * 20.0; + println!("a {:?}", a); + println!("b {:?}", b); + + // Skip cases where divisor is too close to zero + if b.abs() < 0.1 { + continue; + } + + let fa = Fixed::<15>::from_f64(a); + let fb = Fixed::<15>::from_f64(b); + + let (quotient, remainder) = fa.div_rem(fb); + + // Verify: dividend = quotient * divisor + remainder + let reconstructed = quotient.0 * fb.0 + remainder.0; + assert_eq!(reconstructed, fa.0); + + // Check individual results + let expected_quotient = (a / b).trunc(); + let expected_remainder = a % b; + + // The quotient from div_rem is stored as an unscaled integer + assert_eq!(quotient.0 as f64, expected_quotient); + assert_near(remainder.to_f64(), expected_remainder); + } + } + #[test] fn test_different_scales() { // Test with 15-bit scale