From 9f7ccae0cfe592036dde4164bc32819a8f0ecdb1 Mon Sep 17 00:00:00 2001 From: Mahmoud Mohajer Date: Sun, 16 Mar 2025 23:26:08 +0430 Subject: [PATCH 1/9] implement remainder operation for Fixed type and add corresponding tests --- src/lib.rs | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 69d1583..2d7280f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ use num_traits::Zero; -use std::ops::{Add, Div, Mul, Sub}; +use std::ops::{Add, Div, Mul, Sub, Rem}; use stwo_prover::core::fields::m31::{M31, P}; pub mod eval; @@ -124,6 +124,15 @@ impl Div for Fixed { } } +impl Rem for Fixed { + type Output = Self; + + #[inline] + fn rem(self, rhs: Self) -> Self::Output { + Self(self.0 % rhs.0) + } +} + impl Zero for Fixed { #[inline] fn zero() -> Self { @@ -244,4 +253,29 @@ mod tests { assert_near(result, expected); } } + + #[test] + fn test_rem() { + let test_cases = vec![ + (5.0, 2.0, 1.0), + (-5.0, 2.0, -1.0), + (5.0, -2.0, 1.0), + (-5.0, -2.0, -1.0), + (7.5, 2.5, 0.0), + (3.2, 1.5, 0.2), + ]; + + for (a, b, expected) in test_cases { + let fa = Fixed::from_f64(a); + let fb = Fixed::from_f64(b); + let result = (fa % fb).to_f64(); + assert_near(result, expected); + } + } + + #[test] + fn test_zero() { + assert!(Fixed::zero().is_zero()); + assert!(!Fixed::from_f64(1.0).is_zero()); + } } From 632bea5cdad2878e08731fb1f0a51f989bebed6c Mon Sep 17 00:00:00 2001 From: Mahmoud Mohajer Date: Mon, 17 Mar 2025 11:47:33 +0430 Subject: [PATCH 2/9] refactor remainder operation to use rem method for Fixed type --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 2d7280f..54b9921 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -268,7 +268,7 @@ mod tests { for (a, b, expected) in test_cases { let fa = Fixed::from_f64(a); let fb = Fixed::from_f64(b); - let result = (fa % fb).to_f64(); + let result = (fa.rem(fb)).to_f64(); assert_near(result, expected); } } From 48f6649b3967d8b4d06d1e6a4c6191182ba9332f Mon Sep 17 00:00:00 2001 From: Mahmoud Mohajer Date: Mon, 17 Mar 2025 11:52:53 +0430 Subject: [PATCH 3/9] simplify remainder operation in tests by using % operator for Fixed type --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 54b9921..2d7280f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -268,7 +268,7 @@ mod tests { for (a, b, expected) in test_cases { let fa = Fixed::from_f64(a); let fb = Fixed::from_f64(b); - let result = (fa.rem(fb)).to_f64(); + let result = (fa % fb).to_f64(); assert_near(result, expected); } } From 9ab433c23d8f7d1ef7757051fb76e09d10a2f7d5 Mon Sep 17 00:00:00 2001 From: Mahmoud Mohajer Date: Mon, 17 Mar 2025 17:05:34 +0430 Subject: [PATCH 4/9] enhance tests for remainder operation in Fixed type with additional cases with random values --- src/lib.rs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 2d7280f..5a89cf7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -265,9 +265,23 @@ mod tests { (3.2, 1.5, 0.2), ]; - for (a, b, expected) in test_cases { + for (a, b, expected) in test_cases.clone() { + let fa = Fixed::from_f64(a); + let fb = Fixed::from_f64(b); + let result = (fa % fb).to_f64(); + assert_near(result, expected); + } + + let mut rng = StdRng::seed_from_u64(42); + + for _ in 0..1000 { + let a = (rng.gen::() - 0.5) * 100.0; + let b = (rng.gen::() - 0.5) * 100.0; + let expected = a % b; + let fa = Fixed::from_f64(a); let fb = Fixed::from_f64(b); + let result = (fa % fb).to_f64(); assert_near(result, expected); } From bce918b885a49722e0cab4fea53eeeeb87ea9b71 Mon Sep 17 00:00:00 2001 From: Mahmoud Mohajer Date: Mon, 14 Apr 2025 06:35:36 +0430 Subject: [PATCH 5/9] refactor: remove unused Div implementation for Fixed type --- src/lib.rs | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b55808c..71b2253 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -181,15 +181,6 @@ impl Mul for Fixed { } } -impl Div for Fixed { - type Output = Self; - - #[inline] - fn div(self, rhs: Self) -> Self::Output { - Self((self.0 << DEFAULT_SCALE) / rhs.0) - } -} - impl Rem for Fixed { type Output = Self; From 2d1504215499b0672d35c46636aeac60066cc7b9 Mon Sep 17 00:00:00 2001 From: Mahmoud Mohajer Date: Tue, 15 Apr 2025 14:41:49 +0430 Subject: [PATCH 6/9] feat: add eval_fixed_rem method for fixed-point remainder constraint --- src/eval.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/eval.rs b/src/eval.rs index 8a7693a..fce05f0 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -45,6 +45,14 @@ pub trait EvalFixedPoint: EvalAtRow { self.add_constraint(remainder + aux - (divisor - Self::F::one())); } + /// Evaluates the remainder constraint for fixed-point division. + /// This method enforces that remainder is in the range [0, divisor). + fn eval_fixed_rem(&mut self, divisor: Self::F, remainder: Self::F) { + // Create an auxiliary variable to check that remainder < divisor. + let aux = self.add_intermediate(divisor.clone() - Self::F::one() - remainder.clone()); + self.add_constraint(remainder + aux - (divisor - Self::F::one())); + } + /// Evaluates reciprocal constraints for fixed-point numbers. /// Constrains: scale * scale = value * reciprocal + remainder fn eval_fixed_recip( From afac24988cba7ff86677dffaa366d462d025d58f Mon Sep 17 00:00:00 2001 From: Mahmoud Mohajer Date: Tue, 15 Apr 2025 14:44:30 +0430 Subject: [PATCH 7/9] refactor: remove unused Div import from std::ops --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 71b2253..ae48143 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ use num_traits::Zero; -use std::ops::{Add, Div, Mul, Sub, Rem}; +use std::ops::{Add, Mul, Sub, Rem}; use stwo_prover::core::fields::m31::{M31, P}; pub mod eval; From 7a9ef4f8372feeef3ae506715eb6b73f91e6bc7f Mon Sep 17 00:00:00 2001 From: Mahmoud Mohajer Date: Mon, 21 Apr 2025 15:28:48 +0430 Subject: [PATCH 8/9] feat: implement eval_fixed_rem method for fixed-point division remainder evaluation --- src/eval.rs | 57 +++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/src/eval.rs b/src/eval.rs index fce05f0..fc2a466 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -45,14 +45,6 @@ pub trait EvalFixedPoint: EvalAtRow { self.add_constraint(remainder + aux - (divisor - Self::F::one())); } - /// Evaluates the remainder constraint for fixed-point division. - /// This method enforces that remainder is in the range [0, divisor). - fn eval_fixed_rem(&mut self, divisor: Self::F, remainder: Self::F) { - // Create an auxiliary variable to check that remainder < divisor. - let aux = self.add_intermediate(divisor.clone() - Self::F::one() - remainder.clone()); - self.add_constraint(remainder + aux - (divisor - Self::F::one())); - } - /// Evaluates reciprocal constraints for fixed-point numbers. /// Constrains: scale * scale = value * reciprocal + remainder fn eval_fixed_recip( @@ -128,6 +120,7 @@ mod tests { Mul, Recip, Sqrt, + Rem, } impl FrameworkEval for TestEval { @@ -172,6 +165,13 @@ mod tests { let rem = eval.next_trace_mask(); eval.eval_fixed_sqrt(input, out, rem) } + Op::Rem => { + let dividend = eval.next_trace_mask(); + let divisor = eval.next_trace_mask(); + let quot = eval.next_trace_mask(); + let rem = eval.next_trace_mask(); + eval.eval_fixed_div_rem(dividend, divisor, quot, rem); + } } eval } @@ -367,4 +367,45 @@ mod tests { test_op(Op::Sqrt, vec![fixed_input], vec![sqrt_out, rem], 1); } } + + #[test] + fn test_eval_rem() { + let mut rng = StdRng::seed_from_u64(42); + + // Test regular remainder cases + for _ in 0..100 { + let a = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); + let b = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); + if b.to_f64() > 0.0 { + let raw_q = a.0.div_euclid(b.0); + let raw_r = a.0.rem_euclid(b.0); + let q = Fixed(raw_q); + let r = Fixed(raw_r); + test_op(Op::Rem, vec![a, b], vec![q, r], 3); + } + } + + // Test special cases + let special_cases = vec![ + (10.0, 3.0), + (-10.0, 3.0), + (10.0, -3.0), + (-10.0, -3.0), + (0.0, 3.0), + (10.0, 1.0), + (10.0, 0.5), + ]; + + for (a0, b0) in special_cases { + let fixed_a = Fixed::from_f64(a0); + let fixed_b = Fixed::from_f64(b0); + if fixed_b.to_f64() > 0.0 { + let raw_q = fixed_a.0.div_euclid(fixed_b.0); + let raw_r = fixed_a.0.rem_euclid(fixed_b.0); + let q = Fixed(raw_q); + let r = Fixed(raw_r); + test_op(Op::Rem, vec![fixed_a, fixed_b], vec![q, r], 3); + } + } + } } From f1cfa1bed60d2128da3bfe39cd5c2d030c59b173 Mon Sep 17 00:00:00 2001 From: raphaelDkhn <113879115+raphaelDkhn@users.noreply.github.com> Date: Thu, 29 May 2025 19:31:09 +0200 Subject: [PATCH 9/9] implement div_rem + merge master --- Cargo.lock | 3 +- Cargo.toml | 4 +- rust-toolchain.toml | 2 +- src/eval.rs | 220 +++++++++++++++++++++++++------------------- src/lib.rs | 216 +++++++++++++++++++++++++++++++------------ 5 files changed, 286 insertions(+), 159 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c1380e5..00c5ba5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -629,7 +629,7 @@ dependencies = [ [[package]] name = "stwo-prover" version = "0.1.1" -source = "git+https://github.com/starkware-libs/stwo.git?rev=8902113#8902113710a45672d068f847cb7aa72913265731" +source = "git+https://github.com/starkware-libs/stwo?rev=045963c#045963c3814e605e18d3edafadc79f52de8f21bb" dependencies = [ "blake2", "blake3", @@ -640,6 +640,7 @@ dependencies = [ "itertools 0.12.1", "num-traits", "rand", + "rayon", "serde", "starknet-crypto", "starknet-ff", diff --git a/Cargo.toml b/Cargo.toml index 8e0ea17..f874aae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,9 @@ lazy_static = "1.5.0" num-traits = "0.2.19" once_cell = "1.20.3" rayon = "1.10.0" -stwo-prover = { git = "https://github.com/starkware-libs/stwo.git", rev = "8902113" } +stwo-prover = { git = "https://github.com/starkware-libs/stwo", rev = "045963c", features = [ + "parallel", +], default-features = false } serde = { version = "1.0.217", features = ["derive"] } [dev-dependencies] diff --git a/rust-toolchain.toml b/rust-toolchain.toml index b4709af..9aa8998 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2025-01-02" \ No newline at end of file +channel = "nightly-2025-04-06" \ No newline at end of file diff --git a/src/eval.rs b/src/eval.rs index fc2a466..6ec2ac3 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -1,4 +1,4 @@ -use crate::{HALF_P, SCALE_FACTOR}; +use crate::HALF_P; use num_traits::One; use stwo_prover::{constraint_framework::EvalAtRow, core::fields::m31::M31}; @@ -19,12 +19,12 @@ pub trait EvalFixedPoint: EvalAtRow { &mut self, a: Self::F, b: Self::F, - scale: Self::F, + scale_factor: Self::F, quotient: Self::F, remainder: Self::F, ) { let product = self.add_intermediate(a * b); - self.eval_fixed_div_rem(product, scale, quotient, remainder); + self.eval_fixed_div_rem(product, scale_factor, quotient, remainder); } /// Evaluates constraints for signed division with remainder. @@ -45,29 +45,49 @@ pub trait EvalFixedPoint: EvalAtRow { self.add_constraint(remainder + aux - (divisor - Self::F::one())); } + /// Evaluates remainder constraints for fixed-point numbers. + /// Constrains: dividend = quotient * divisor + remainder + /// This is essentially the same as eval_fixed_div_rem but semantically focused on remainder. + fn eval_fixed_rem( + &mut self, + dividend: Self::F, + divisor: Self::F, + quotient: Self::F, + remainder: Self::F, + ) { + self.eval_fixed_div_rem(dividend, divisor, quotient, remainder); + } + /// Evaluates reciprocal constraints for fixed-point numbers. - /// Constrains: scale * scale = value * reciprocal + remainder + /// Constrains: scale_factor * scale_factor = value * reciprocal + remainder fn eval_fixed_recip( &mut self, value: Self::F, - scale: Self::F, + scale_factor: Self::F, reciprocal: Self::F, remainder: Self::F, ) { - let scale_squared = self.add_intermediate(scale.clone() * scale); + let scale_squared = self.add_intermediate(scale_factor.clone() * scale_factor); self.eval_fixed_div_rem(scale_squared, value, reciprocal, remainder); } /// Evaluates constraints for square root operations. /// Adds constraints to verify that: /// 1. The input is non-negative - /// 2. out^2 + rem = input * SCALE_FACTOR + /// 2. out^2 + rem = input * scale_factor /// /// # Parameters /// - `input`: The trace column value representing the scaled input. /// - `out`: The trace column value of the scaled square root. /// - `rem`: The trace column value of the remainder. - fn eval_fixed_sqrt(&mut self, input: Self::F, out: Self::F, rem: Self::F) { + /// - `scale_factor`: The scale_factor factor to use for fixed-point representation. + fn eval_fixed_sqrt( + &mut self, + input: Self::F, + out: Self::F, + rem: Self::F, + scale_factor: Self::F, + ) { // Constraint to ensure input is non-negative // For field elements, we check if input is in the range [0, HALF_P) // We need an auxiliary variable to ensure 0 <= input < HALF_P @@ -75,8 +95,8 @@ pub trait EvalFixedPoint: EvalAtRow { self.add_intermediate(Self::F::from(M31(HALF_P)) - Self::F::one() - input.clone()); self.add_constraint(input.clone() + aux - (Self::F::from(M31(HALF_P)) - Self::F::one())); - // Enforce the constraint: out^2 + rem = input * SCALE_FACTOR - self.add_constraint((out.clone() * out) + rem.clone() - (input * SCALE_FACTOR)); + // Enforce the constraint: out^2 + rem = input * scale_factor + self.add_constraint((out.clone() * out) + rem.clone() - (input * scale_factor)); } } @@ -85,7 +105,6 @@ impl EvalFixedPoint for T {} #[cfg(test)] mod tests { - use num_traits::Zero; use rand::{rngs::StdRng, Rng, SeedableRng}; use stwo_prover::{ @@ -93,7 +112,7 @@ mod tests { core::{ backend::{simd::SimdBackend, Col, Column}, fields::{ - m31::{BaseField, P}, + m31::{BaseField, M31, P}, qm31::SecureField, }, pcs::TreeVec, @@ -104,11 +123,10 @@ mod tests { }, }; - use crate::{Fixed, SCALE_FACTOR}; - use super::*; + use crate::Fixed; - struct TestEval { + struct TestEval { log_size: u32, op: Op, } @@ -118,12 +136,12 @@ mod tests { Add, Sub, Mul, + Rem, Recip, Sqrt, - Rem, } - impl FrameworkEval for TestEval { + impl FrameworkEval for TestEval { fn log_size(&self) -> u32 { self.log_size } @@ -133,6 +151,8 @@ mod tests { } fn evaluate(&self, mut eval: E) -> E { + let scale_factor = E::F::from(M31::from_u32_unchecked(1 << SCALE)); + match self.op { Op::Add => { let lhs = eval.next_trace_mask(); @@ -151,26 +171,26 @@ mod tests { let rhs = eval.next_trace_mask(); let out = eval.next_trace_mask(); let rem = eval.next_trace_mask(); - eval.eval_fixed_mul(lhs, rhs, SCALE_FACTOR.into(), out, rem) + eval.eval_fixed_mul(lhs, rhs, scale_factor, out, rem) + } + Op::Rem => { + let dividend = eval.next_trace_mask(); + let divisor = eval.next_trace_mask(); + let quotient = eval.next_trace_mask(); + let remainder = eval.next_trace_mask(); + eval.eval_fixed_rem(dividend, divisor, quotient, remainder) } Op::Recip => { let input = eval.next_trace_mask(); let out = eval.next_trace_mask(); let rem = eval.next_trace_mask(); - eval.eval_fixed_recip(input, SCALE_FACTOR.into(), out, rem) + eval.eval_fixed_recip(input, scale_factor, out, rem) } Op::Sqrt => { let input = eval.next_trace_mask(); let out = eval.next_trace_mask(); let rem = eval.next_trace_mask(); - eval.eval_fixed_sqrt(input, out, rem) - } - Op::Rem => { - let dividend = eval.next_trace_mask(); - let divisor = eval.next_trace_mask(); - let quot = eval.next_trace_mask(); - let rem = eval.next_trace_mask(); - eval.eval_fixed_div_rem(dividend, divisor, quot, rem); + eval.eval_fixed_sqrt(input, out, rem, scale_factor) } } eval @@ -192,11 +212,11 @@ mod tests { .collect() } - fn test_op( + fn test_op_internal( op: Op, - inputs: Vec, - expected_outputs: Vec, - tamper_col_idx: usize, // The column to tamper + inputs: &[Fixed], + expected_outputs: &[Fixed], + tamper_col_idx: usize, ) { const LOG_SIZE: u32 = 4; let domain = CanonicCoset::new(LOG_SIZE); @@ -220,13 +240,13 @@ mod tests { let trace_polys = trace.map_cols(|c| c.interpolate()); - let component = TestEval { + let component = TestEval:: { log_size: LOG_SIZE, op, }; // Test valid trace - constraint_framework::assert_constraints( + constraint_framework::assert_constraints_on_polys( &trace_polys, domain, |eval| { @@ -239,7 +259,9 @@ mod tests { let mut invalid_trace_cols = trace_cols; if let Some(col) = invalid_trace_cols.get_mut(tamper_col_idx) { for val in col.iter_mut() { - val.0 = (val.0 + SCALE_FACTOR.0) % P; + // Calculate scale factor for tampering + let scale_factor = M31::from_u32_unchecked(1 << SCALE); + val.0 = (val.0 + scale_factor.0) % P; } } @@ -251,7 +273,7 @@ mod tests { // This should panic for invalid trace let result = std::panic::catch_unwind(|| { - constraint_framework::assert_constraints( + constraint_framework::assert_constraints_on_polys( &invalid_trace_polys, domain, |eval| { @@ -267,10 +289,10 @@ mod tests { fn test_add() { let mut rng = StdRng::seed_from_u64(42); for _ in 0..100 { - let a = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); - let b = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); + let a = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + let b = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); - test_op(Op::Add, vec![a, b], vec![a + b], 2); + test_op_internal(Op::Add, &[a, b], &[a + b], 2); } } @@ -278,9 +300,10 @@ mod tests { fn test_sub() { let mut rng = StdRng::seed_from_u64(42); for _ in 0..100 { - let a = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); - let b = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); - test_op(Op::Sub, vec![a, b], vec![a - b], 2); + let a = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + let b = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + + test_op_internal(Op::Sub, &[a, b], &[a - b], 2); } } @@ -290,11 +313,11 @@ mod tests { // Test regular multiplication cases for _ in 0..100 { - let a = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); - let b = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); + let a = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + let b = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); let (expected, rem) = a * b; - test_op(Op::Mul, vec![a, b], vec![expected, rem], 2); + test_op_internal(Op::Mul, &[a, b], &[expected, rem], 2); } // Test special cases @@ -311,11 +334,50 @@ mod tests { ]; for (a, b) in special_cases { - let fixed_a = Fixed::from_f64(a); - let fixed_b = Fixed::from_f64(b); + let fixed_a = Fixed::<15>::from_f64(a); + let fixed_b = Fixed::<15>::from_f64(b); let (expected, rem) = fixed_a * fixed_b; - test_op(Op::Mul, vec![fixed_a, fixed_b], vec![expected, rem], 2); + test_op_internal(Op::Mul, &[fixed_a, fixed_b], &[expected, rem], 2); + } + } + + #[test] + fn test_rem() { + let mut rng = StdRng::seed_from_u64(42); + + // Test regular remainder cases + for _ in 0..50 { + let a = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + let b = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + + // Skip cases where divisor is too close to zero + if b.to_f64().abs() < 0.1 { + continue; + } + + let (quotient, remainder) = a.div_rem(b); + + test_op_internal(Op::Rem, &[a, b], &[quotient, remainder], 2); + } + + // Test special cases + let special_cases = vec![ + (10.0, 3.0), // 10 % 3 = 1, quotient = 3 + (7.5, 2.5), // 7.5 % 2.5 = 0, quotient = 3 + (9.0, 4.0), // 9 % 4 = 1, quotient = 2 + (8.0, 3.0), // 8 % 3 = 2, quotient = 2 + (15.0, 4.0), // 15 % 4 = 3, quotient = 3 + (20.0, 6.0), // 20 % 6 = 2, quotient = 3 + (1.5, 0.5), // 1.5 % 0.5 = 0, quotient = 3 + ]; + + for (a, b) in special_cases { + let fixed_a = Fixed::<15>::from_f64(a); + let fixed_b = Fixed::<15>::from_f64(b); + let (quotient, remainder) = fixed_a.div_rem(fixed_b); + + test_op_internal(Op::Rem, &[fixed_a, fixed_b], &[quotient, remainder], 2); } } @@ -323,12 +385,16 @@ mod tests { fn test_recip() { let mut rng = StdRng::seed_from_u64(42); - // Test regular multiplication cases + // Test regular recip cases for _ in 0..100 { - let input = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); + let input = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + if input.0 == 0 { + continue; // Skip division by zero + } + let (expected, rem) = input.recip(); - test_op(Op::Recip, vec![input], vec![expected, rem], 2); + test_op_internal(Op::Recip, &[input], &[expected, rem], 1); } // Test special cases @@ -342,70 +408,30 @@ mod tests { ]; for input in special_cases { - let fixed_input = Fixed::from_f64(input); + let fixed_input = Fixed::<15>::from_f64(input); let (expected, rem) = fixed_input.recip(); - test_op(Op::Recip, vec![fixed_input], vec![expected, rem], 1); + test_op_internal(Op::Recip, &[fixed_input], &[expected, rem], 1); } } #[test] - fn test_eval_sqrt() { + fn test_sqrt() { let test_cases = vec![1.0, 4.0, 9.0, 2.0, 0.5, 0.25, 0.0]; for input in test_cases { - let fixed_input = Fixed::from_f64(input); + let fixed_input = Fixed::<15>::from_f64(input); let (sqrt_out, rem) = fixed_input.sqrt(); - test_op(Op::Sqrt, vec![fixed_input], vec![sqrt_out, rem], 1); + test_op_internal(Op::Sqrt, &[fixed_input], &[sqrt_out, rem], 1); } let mut rng = StdRng::seed_from_u64(43); for _ in 0..50 { let input_val: f64 = rng.gen_range(0.0..100.0); - let fixed_input = Fixed::from_f64(input_val); + let fixed_input = Fixed::<15>::from_f64(input_val); let (sqrt_out, rem) = fixed_input.sqrt(); - test_op(Op::Sqrt, vec![fixed_input], vec![sqrt_out, rem], 1); - } - } - - #[test] - fn test_eval_rem() { - let mut rng = StdRng::seed_from_u64(42); - - // Test regular remainder cases - for _ in 0..100 { - let a = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); - let b = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); - if b.to_f64() > 0.0 { - let raw_q = a.0.div_euclid(b.0); - let raw_r = a.0.rem_euclid(b.0); - let q = Fixed(raw_q); - let r = Fixed(raw_r); - test_op(Op::Rem, vec![a, b], vec![q, r], 3); - } - } - // Test special cases - let special_cases = vec![ - (10.0, 3.0), - (-10.0, 3.0), - (10.0, -3.0), - (-10.0, -3.0), - (0.0, 3.0), - (10.0, 1.0), - (10.0, 0.5), - ]; - - for (a0, b0) in special_cases { - let fixed_a = Fixed::from_f64(a0); - let fixed_b = Fixed::from_f64(b0); - if fixed_b.to_f64() > 0.0 { - let raw_q = fixed_a.0.div_euclid(fixed_b.0); - let raw_r = fixed_a.0.rem_euclid(fixed_b.0); - let q = Fixed(raw_q); - let r = Fixed(raw_r); - test_op(Op::Rem, vec![fixed_a, fixed_b], vec![q, r], 3); - } + test_op_internal(Op::Sqrt, &[fixed_input], &[sqrt_out, rem], 1); } } } diff --git a/src/lib.rs b/src/lib.rs index 67ca2bf..57fdb45 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,25 +1,20 @@ use num_traits::Zero; use serde::{Deserialize, Serialize}; -use std::ops::{Add, Mul, Sub, Rem}; +use std::ops::{Add, Mul, Rem, Sub}; use stwo_prover::core::fields::m31::{M31, P}; pub mod eval; -// Number of bits used for decimal precision. -pub const DEFAULT_SCALE: u32 = 12; -// Scale factor = 2^DEFAULT_SCALE, used for fixed-point arithmetic. -pub const SCALE_FACTOR: M31 = M31::from_u32_unchecked(1 << DEFAULT_SCALE); // Half the prime modulus. pub const HALF_P: u32 = P / 2; -// Mask for remainder in fixed-point operations (2^DEFAULT_SCALE - 1) -const REMAINDER_MASK: i64 = (1 << DEFAULT_SCALE) - 1; -/// Integer representation of fixed-point Basefield. +/// Integer representation of fixed-point Basefield with parametrized scale. #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] -pub struct Fixed(pub i64); +pub struct Fixed(pub i64); -impl Fixed { - const SCALE_FACTOR: i64 = 1 << DEFAULT_SCALE; +impl Fixed { + const SCALE_FACTOR: i64 = 1 << SCALE; + const HALF_SCALE_FACTOR: i64 = 1 << (SCALE - 1); #[inline] pub fn from_f64(value: f64) -> Self { @@ -68,6 +63,17 @@ impl Fixed { } } + /// Computes both quotient and remainder for division. + /// Returns (quotient, remainder) where dividend = quotient * divisor + remainder + /// Note: The quotient is stored as an unscaled integer for constraint compatibility. + #[inline] + pub fn div_rem(self, rhs: Self) -> (Self, Self) { + assert!(rhs.0 != 0, "Division by zero"); + let quotient = self.0 / rhs.0; + let remainder = self.0 % rhs.0; + (Self(quotient), Self(remainder)) + } + /// Computes the reciprocal (1/x) of a fixed-point number /// /// Returns a tuple of (quotient, remainder) where: @@ -77,9 +83,9 @@ impl Fixed { pub fn recip(self) -> (Self, Self) { assert!(self.0 != 0, "Division by zero"); - let scale_squared = Self::SCALE_FACTOR * Self::SCALE_FACTOR; - let quotient = scale_squared / self.0; - let remainder = scale_squared % self.0; + let scale_factor_squared = Self::SCALE_FACTOR * Self::SCALE_FACTOR; + let quotient = scale_factor_squared / self.0; + let remainder = scale_factor_squared % self.0; (Self(quotient), Self(remainder)) } @@ -105,7 +111,7 @@ impl Fixed { } // Calculate value to compute sqrt of: self * SCALE_FACTOR - let input_scaled = (self.0 as u64) << DEFAULT_SCALE; + let input_scaled = (self.0 as u64) << SCALE; // Compute integer square root let sqrt_val = int_sqrt(input_scaled); @@ -115,6 +121,22 @@ impl Fixed { (Self(sqrt_val as i64), Self(remainder as i64)) } + + /// Convert this Fixed value to a Fixed with a different scale + pub fn convert_to(self) -> Fixed { + if TARGET_SCALE == SCALE { + // Same scale, just change the type + Fixed(self.0) + } else if TARGET_SCALE > SCALE { + // Going to higher precision + let shift = TARGET_SCALE - SCALE; + Fixed(self.0 << shift) + } else { + // Going to lower precision + let shift = SCALE - TARGET_SCALE; + Fixed(self.0 >> shift) + } + } } /// Returns the floor of the square root of `n`. @@ -151,7 +173,7 @@ pub fn int_sqrt(n: u64) -> u64 { } } -impl Add for Fixed { +impl Add for Fixed { type Output = Self; #[inline] @@ -160,7 +182,7 @@ impl Add for Fixed { } } -impl Sub for Fixed { +impl Sub for Fixed { type Output = Self; #[inline] @@ -169,29 +191,34 @@ impl Sub for Fixed { } } -impl Mul for Fixed { +impl Mul for Fixed { type Output = (Self, Self); #[inline] fn mul(self, rhs: Self) -> Self::Output { let product = self.0 * rhs.0; - ( - Self(product >> DEFAULT_SCALE), - Self(product & REMAINDER_MASK), - ) + + let quotient = (product + Self::HALF_SCALE_FACTOR) >> SCALE; + + // Calculate remainder to maintain: product = quotient * scale + remainder + let scaled_quotient = quotient << SCALE; + let remainder = product - scaled_quotient; + + (Self(quotient), Self(remainder)) } } -impl Rem for Fixed { +impl Rem for Fixed { type Output = Self; #[inline] fn rem(self, rhs: Self) -> Self::Output { + assert!(rhs.0 != 0, "Division by zero in remainder operation"); Self(self.0 % rhs.0) } } -impl Zero for Fixed { +impl Zero for Fixed { #[inline] fn zero() -> Self { Self(0) @@ -209,7 +236,7 @@ mod tests { use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; - const EPSILON: f64 = 1e-2; + const EPSILON: f64 = 1e-3; fn assert_near(a: f64, b: f64) { assert!((a - b).abs() < EPSILON, "Expected {} to be near {}", a, b); @@ -217,11 +244,11 @@ mod tests { #[test] fn test_negative() { - let a = Fixed::from_f64(-3.5); - let b = Fixed::from_f64(2.0); + let a = Fixed::<15>::from_f64(-3.5); + let b = Fixed::<15>::from_f64(2.0); assert_near(a.to_f64(), -3.5); - assert_near((a + b.clone()).to_f64(), -1.5); + assert_near((a + b).to_f64(), -1.5); assert_near((a - b).to_f64(), -5.5); } @@ -233,8 +260,8 @@ mod tests { let a = (rng.gen::() - 0.5) * 200.0; let b = (rng.gen::() - 0.5) * 200.0; - let fa = Fixed::from_f64(a); - let fb = Fixed::from_f64(b); + let fa = Fixed::<15>::from_f64(a); + let fb = Fixed::<15>::from_f64(b); assert_near((fa + fb).to_f64(), a + b); } @@ -248,8 +275,8 @@ mod tests { let a = (rng.gen::() - 0.5) * 200.0; let b = (rng.gen::() - 0.5) * 200.0; - let fa = Fixed::from_f64(a); - let fb = Fixed::from_f64(b); + let fa = Fixed::<15>::from_f64(a); + let fb = Fixed::<15>::from_f64(b); assert_near((fa - fb).to_f64(), a - b); } @@ -263,8 +290,8 @@ mod tests { let a = (rng.gen::() - 0.5) * 10.0; let b = (rng.gen::() - 0.5) * 10.0; - let fa = Fixed::from_f64(a); - let fb = Fixed::from_f64(b); + let fa = Fixed::<15>::from_f64(a); + let fb = Fixed::<15>::from_f64(b); let (q, _) = fa * fb; let expected = a * b; @@ -283,7 +310,7 @@ mod tests { continue; } - let fixed_a = Fixed::from_f64(a); + let fixed_a = Fixed::<15>::from_f64(a); let (recip, _) = fixed_a.recip(); let expected = 1.0 / a; @@ -301,7 +328,7 @@ mod tests { ]; for (a, expected) in test_cases { - let fixed_a = Fixed::from_f64(a); + let fixed_a = Fixed::<15>::from_f64(a); let (recip, _) = fixed_a.recip(); assert_near(recip.to_f64(), expected); } @@ -324,7 +351,7 @@ mod tests { } for input in test_cases { - let fixed_input = Fixed::from_f64(input); + let fixed_input = Fixed::<15>::from_f64(input); if input < 0.0 { let (result, remainder) = fixed_input.sqrt(); @@ -341,40 +368,111 @@ mod tests { #[test] fn test_rem() { + let mut rng = StdRng::seed_from_u64(42); + + // Test random cases + for _ in 0..100 { + let a = (rng.gen::() - 0.5) * 20.0; + let b = (rng.gen::() - 0.5) * 20.0; + + // Skip cases where divisor is too close to zero + if b.abs() < 0.1 { + continue; + } + + let fa = Fixed::<15>::from_f64(a); + let fb = Fixed::<15>::from_f64(b); + + let remainder = fa % fb; + let expected = a % b; + + assert_near(remainder.to_f64(), expected); + } + + // Test specific cases let test_cases = vec![ - (5.0, 2.0, 1.0), - (-5.0, 2.0, -1.0), - (5.0, -2.0, 1.0), - (-5.0, -2.0, -1.0), - (7.5, 2.5, 0.0), - (3.2, 1.5, 0.2), + (10.0, 3.0), // 10 % 3 = 1 + (7.5, 2.5), // 7.5 % 2.5 = 0 + (9.0, 4.0), // 9 % 4 = 1 + (-10.0, 3.0), // -10 % 3 = -1 (or 2, depending on implementation) + (10.0, -3.0), // 10 % -3 = 1 (or -2, depending on implementation) + (-10.0, -3.0), // -10 % -3 = -1 (or 2, depending on implementation) ]; - for (a, b, expected) in test_cases.clone() { - let fa = Fixed::from_f64(a); - let fb = Fixed::from_f64(b); - let result = (fa % fb).to_f64(); - assert_near(result, expected); + for (a, b) in test_cases { + let fa = Fixed::<15>::from_f64(a); + let fb = Fixed::<15>::from_f64(b); + let remainder = fa % fb; + let expected = a % b; + + assert_near(remainder.to_f64(), expected); } + } + #[test] + fn test_div_rem() { let mut rng = StdRng::seed_from_u64(42); - for _ in 0..1000 { - let a = (rng.gen::() - 0.5) * 100.0; - let b = (rng.gen::() - 0.5) * 100.0; - let expected = a % b; + for _ in 0..100 { + let a = (rng.gen::() - 0.5) * 20.0; + let b = (rng.gen::() - 0.5) * 20.0; + println!("a {:?}", a); + println!("b {:?}", b); + + // Skip cases where divisor is too close to zero + if b.abs() < 0.1 { + continue; + } + + let fa = Fixed::<15>::from_f64(a); + let fb = Fixed::<15>::from_f64(b); + + let (quotient, remainder) = fa.div_rem(fb); + + // Verify: dividend = quotient * divisor + remainder + let reconstructed = quotient.0 * fb.0 + remainder.0; + assert_eq!(reconstructed, fa.0); - let fa = Fixed::from_f64(a); - let fb = Fixed::from_f64(b); + // Check individual results + let expected_quotient = (a / b).trunc(); + let expected_remainder = a % b; - let result = (fa % fb).to_f64(); - assert_near(result, expected); + // The quotient from div_rem is stored as an unscaled integer + assert_eq!(quotient.0 as f64, expected_quotient); + assert_near(remainder.to_f64(), expected_remainder); } } #[test] - fn test_zero() { - assert!(Fixed::zero().is_zero()); - assert!(!Fixed::from_f64(1.0).is_zero()); + fn test_different_scales() { + // Test with 15-bit scale + let scale_15 = Fixed::<15>::from_f64(1.5); + assert_near(scale_15.to_f64(), 1.5); + + // Test with 8-bit scale (less precision) + let scale8 = Fixed::<8>::from_f64(1.5); + assert_near(scale8.to_f64(), 1.5); + + // Test with 24-bit scale (more precision) + let scale24 = Fixed::<24>::from_f64(1.5); + assert_near(scale24.to_f64(), 1.5); + + // Test conversion between scales + let from_15_to_8 = scale_15.convert_to::<8>(); + assert_near(from_15_to_8.to_f64(), 1.5); + + let from_8_to_24 = scale8.convert_to::<24>(); + assert_near(from_8_to_24.to_f64(), 1.5); + + // Multiplication with different scales + let a8 = Fixed::<8>::from_f64(2.5); + let b8 = Fixed::<8>::from_f64(3.0); + let (result8, _) = a8 * b8; + assert_near(result8.to_f64(), 7.5); + + let a24 = Fixed::<24>::from_f64(2.5); + let b24 = Fixed::<24>::from_f64(3.0); + let (result24, _) = a24 * b24; + assert_near(result24.to_f64(), 7.5); } }