diff --git a/README.md b/README.md index fb2c83d..5a79f83 100644 --- a/README.md +++ b/README.md @@ -3,20 +3,33 @@ A fixed-point arithmetic library providing constrained fixed-point operations for [Stwo](https://github.com/starkware-libs/stwo.git)-based circuits. The library implements fixed-point arithmetic using M31 field elements with configurable decimal precision. +## Features + +- Type-level scale parameter using Rust's const generics +- Fixed-point operations (addition, subtraction, multiplication, reciprocal, square root) +- Circuit constraints for all operations +- Zero memory overhead for scale information + ## Usage ### Basic Arithmetic ```rust - -// Create fixed-point numbers -let lhs = Fixed::from_f64(3.14); -let rhs = Fixed::from_f64(2.0); +// Using 15 bits of precision +let a = Fixed::<15>::from_f64(3.14); +let b = Fixed::<15>::from_f64(2.0); // Basic arithmetic operations -let sum = lhs + rhs; -let diff = lhs - rhs; -let (prod, rem) = lhs * rhs; +let sum = a + b; +let diff = a - b; +let (prod, rem) = a * b; + +// Using custom scales +let high_precision = Fixed::<24>::from_f64(0.12345678); +let low_precision = Fixed::<8>::from_f64(42.5); + +// Convert between scales +let converted = high_precision.convert_to::<8>(); ``` ### In Circuit Constraints @@ -24,7 +37,7 @@ let (prod, rem) = lhs * rhs; To use fixed-point operations in your Stwo Prover circuit, use the constraint evaluation functions: ```rust -use numerair::eval::{eval_add, eval_mul, eval_sub}; +use numerair::eval::EvalFixedPoint; // In your circuit component's evaluate function: fn evaluate(&self, mut eval: E) -> E { @@ -32,14 +45,27 @@ fn evaluate(&self, mut eval: E) -> E { let rhs = eval.next_trace_mask(); let rem = eval.next_trace_mask(); let res = eval.next_trace_mask(); + + // Get scale factor for the specific scale you're using + let scale_factor = E::F::from(M31::from_u32_unchecked(1 << 15)); // For Fixed<15> - // Constrain mul. - eval.eval_fixed_mul(lhs, rhs, SCALE_FACTOR.into(), res, rem) + // Constrain mul using EvalFixedPoint trait + eval.eval_fixed_mul(lhs, rhs, scale_factor, res, rem); eval } ``` +## How It Works + +The fixed-point representation uses a const generic parameter to determine the number of bits used for the fractional part: + +- `SCALE`: Type-level constant that defines the number of bits for decimal precision +- `SCALE_FACTOR`: The value 2^SCALE (automatically calculated), represents 1.0 in fixed-point +- `HALF_SCALE_FACTOR`: The value 2^(SCALE-1) (automatically calculated), used for rounding + +A value `x` in floating point is represented as `floor(x * 2^SCALE)` in fixed-point. + ## Contributing Contributions are welcome! Please submit pull requests with: diff --git a/src/eval.rs b/src/eval.rs index 857b3c2..acb73a5 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -1,4 +1,4 @@ -use crate::{HALF_P, SCALE_FACTOR}; +use crate::HALF_P; use num_traits::One; use stwo_prover::{constraint_framework::EvalAtRow, core::fields::m31::M31}; @@ -19,12 +19,12 @@ pub trait EvalFixedPoint: EvalAtRow { &mut self, a: Self::F, b: Self::F, - scale: Self::F, + scale_factor: Self::F, quotient: Self::F, remainder: Self::F, ) { let product = self.add_intermediate(a * b); - self.eval_fixed_div_rem(product, scale, quotient, remainder); + self.eval_fixed_div_rem(product, scale_factor, quotient, remainder); } /// Evaluates constraints for signed division with remainder. @@ -46,28 +46,35 @@ pub trait EvalFixedPoint: EvalAtRow { } /// Evaluates reciprocal constraints for fixed-point numbers. - /// Constrains: scale * scale = value * reciprocal + remainder + /// Constrains: scale_factor * scale_factor = value * reciprocal + remainder fn eval_fixed_recip( &mut self, value: Self::F, - scale: Self::F, + scale_factor: Self::F, reciprocal: Self::F, remainder: Self::F, ) { - let scale_squared = self.add_intermediate(scale.clone() * scale); + let scale_squared = self.add_intermediate(scale_factor.clone() * scale_factor); self.eval_fixed_div_rem(scale_squared, value, reciprocal, remainder); } /// Evaluates constraints for square root operations. /// Adds constraints to verify that: /// 1. The input is non-negative - /// 2. out^2 + rem = input * SCALE_FACTOR + /// 2. out^2 + rem = input * scale_factor /// /// # Parameters /// - `input`: The trace column value representing the scaled input. /// - `out`: The trace column value of the scaled square root. /// - `rem`: The trace column value of the remainder. - fn eval_fixed_sqrt(&mut self, input: Self::F, out: Self::F, rem: Self::F) { + /// - `scale_factor`: The scale_factor factor to use for fixed-point representation. + fn eval_fixed_sqrt( + &mut self, + input: Self::F, + out: Self::F, + rem: Self::F, + scale_factor: Self::F, + ) { // Constraint to ensure input is non-negative // For field elements, we check if input is in the range [0, HALF_P) // We need an auxiliary variable to ensure 0 <= input < HALF_P @@ -75,8 +82,8 @@ pub trait EvalFixedPoint: EvalAtRow { self.add_intermediate(Self::F::from(M31(HALF_P)) - Self::F::one() - input.clone()); self.add_constraint(input.clone() + aux - (Self::F::from(M31(HALF_P)) - Self::F::one())); - // Enforce the constraint: out^2 + rem = input * SCALE_FACTOR - self.add_constraint((out.clone() * out) + rem.clone() - (input * SCALE_FACTOR)); + // Enforce the constraint: out^2 + rem = input * scale_factor + self.add_constraint((out.clone() * out) + rem.clone() - (input * scale_factor)); } } @@ -85,7 +92,6 @@ impl EvalFixedPoint for T {} #[cfg(test)] mod tests { - use num_traits::Zero; use rand::{rngs::StdRng, Rng, SeedableRng}; use stwo_prover::{ @@ -93,7 +99,7 @@ mod tests { core::{ backend::{simd::SimdBackend, Col, Column}, fields::{ - m31::{BaseField, P}, + m31::{BaseField, M31, P}, qm31::SecureField, }, pcs::TreeVec, @@ -104,11 +110,10 @@ mod tests { }, }; - use crate::{Fixed, SCALE_FACTOR}; - + use crate::Fixed; use super::*; - struct TestEval { + struct TestEval { log_size: u32, op: Op, } @@ -122,7 +127,7 @@ mod tests { Sqrt, } - impl FrameworkEval for TestEval { + impl FrameworkEval for TestEval { fn log_size(&self) -> u32 { self.log_size } @@ -132,6 +137,8 @@ mod tests { } fn evaluate(&self, mut eval: E) -> E { + let scale_factor = E::F::from(M31::from_u32_unchecked(1 << SCALE)); + match self.op { Op::Add => { let lhs = eval.next_trace_mask(); @@ -150,19 +157,19 @@ mod tests { let rhs = eval.next_trace_mask(); let out = eval.next_trace_mask(); let rem = eval.next_trace_mask(); - eval.eval_fixed_mul(lhs, rhs, SCALE_FACTOR.into(), out, rem) + eval.eval_fixed_mul(lhs, rhs, scale_factor, out, rem) } Op::Recip => { let input = eval.next_trace_mask(); let out = eval.next_trace_mask(); let rem = eval.next_trace_mask(); - eval.eval_fixed_recip(input, SCALE_FACTOR.into(), out, rem) + eval.eval_fixed_recip(input, scale_factor, out, rem) } Op::Sqrt => { let input = eval.next_trace_mask(); let out = eval.next_trace_mask(); let rem = eval.next_trace_mask(); - eval.eval_fixed_sqrt(input, out, rem) + eval.eval_fixed_sqrt(input, out, rem, scale_factor) } } eval @@ -184,11 +191,11 @@ mod tests { .collect() } - fn test_op( + fn test_op_internal( op: Op, - inputs: Vec, - expected_outputs: Vec, - tamper_col_idx: usize, // The column to tamper + inputs: &[Fixed], + expected_outputs: &[Fixed], + tamper_col_idx: usize, ) { const LOG_SIZE: u32 = 4; let domain = CanonicCoset::new(LOG_SIZE); @@ -212,7 +219,7 @@ mod tests { let trace_polys = trace.map_cols(|c| c.interpolate()); - let component = TestEval { + let component = TestEval:: { log_size: LOG_SIZE, op, }; @@ -231,7 +238,9 @@ mod tests { let mut invalid_trace_cols = trace_cols; if let Some(col) = invalid_trace_cols.get_mut(tamper_col_idx) { for val in col.iter_mut() { - val.0 = (val.0 + SCALE_FACTOR.0) % P; + // Calculate scale factor for tampering + let scale_factor = M31::from_u32_unchecked(1 << SCALE); + val.0 = (val.0 + scale_factor.0) % P; } } @@ -259,10 +268,10 @@ mod tests { fn test_add() { let mut rng = StdRng::seed_from_u64(42); for _ in 0..100 { - let a = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); - let b = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); + let a = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + let b = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); - test_op(Op::Add, vec![a, b], vec![a + b], 2); + test_op_internal(Op::Add, &[a, b], &[a + b], 2); } } @@ -270,9 +279,10 @@ mod tests { fn test_sub() { let mut rng = StdRng::seed_from_u64(42); for _ in 0..100 { - let a = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); - let b = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); - test_op(Op::Sub, vec![a, b], vec![a - b], 2); + let a = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + let b = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + + test_op_internal(Op::Sub, &[a, b], &[a - b], 2); } } @@ -282,11 +292,11 @@ mod tests { // Test regular multiplication cases for _ in 0..100 { - let a = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); - let b = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); + let a = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + let b = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); let (expected, rem) = a * b; - test_op(Op::Mul, vec![a, b], vec![expected, rem], 2); + test_op_internal(Op::Mul, &[a, b], &[expected, rem], 2); } // Test special cases @@ -303,11 +313,11 @@ mod tests { ]; for (a, b) in special_cases { - let fixed_a = Fixed::from_f64(a); - let fixed_b = Fixed::from_f64(b); + let fixed_a = Fixed::<15>::from_f64(a); + let fixed_b = Fixed::<15>::from_f64(b); let (expected, rem) = fixed_a * fixed_b; - test_op(Op::Mul, vec![fixed_a, fixed_b], vec![expected, rem], 2); + test_op_internal(Op::Mul, &[fixed_a, fixed_b], &[expected, rem], 2); } } @@ -315,12 +325,16 @@ mod tests { fn test_recip() { let mut rng = StdRng::seed_from_u64(42); - // Test regular multiplication cases + // Test regular recip cases for _ in 0..100 { - let input = Fixed::from_f64((rng.gen::() - 0.5) * 200.0); + let input = Fixed::<15>::from_f64((rng.gen::() - 0.5) * 200.0); + if input.0 == 0 { + continue; // Skip division by zero + } + let (expected, rem) = input.recip(); - test_op(Op::Recip, vec![input], vec![expected, rem], 2); + test_op_internal(Op::Recip, &[input], &[expected, rem], 1); } // Test special cases @@ -334,29 +348,30 @@ mod tests { ]; for input in special_cases { - let fixed_input = Fixed::from_f64(input); + let fixed_input = Fixed::<15>::from_f64(input); let (expected, rem) = fixed_input.recip(); - test_op(Op::Recip, vec![fixed_input], vec![expected, rem], 1); + test_op_internal(Op::Recip, &[fixed_input], &[expected, rem], 1); } } #[test] - fn test_eval_sqrt() { + fn test_sqrt() { let test_cases = vec![1.0, 4.0, 9.0, 2.0, 0.5, 0.25, 0.0]; for input in test_cases { - let fixed_input = Fixed::from_f64(input); + let fixed_input = Fixed::<15>::from_f64(input); let (sqrt_out, rem) = fixed_input.sqrt(); - test_op(Op::Sqrt, vec![fixed_input], vec![sqrt_out, rem], 1); + test_op_internal(Op::Sqrt, &[fixed_input], &[sqrt_out, rem], 1); } let mut rng = StdRng::seed_from_u64(43); for _ in 0..50 { let input_val: f64 = rng.gen_range(0.0..100.0); - let fixed_input = Fixed::from_f64(input_val); + let fixed_input = Fixed::<15>::from_f64(input_val); let (sqrt_out, rem) = fixed_input.sqrt(); - test_op(Op::Sqrt, vec![fixed_input], vec![sqrt_out, rem], 1); + + test_op_internal(Op::Sqrt, &[fixed_input], &[sqrt_out, rem], 1); } } } diff --git a/src/lib.rs b/src/lib.rs index 67fba19..58f11da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,20 +5,16 @@ use stwo_prover::core::fields::m31::{M31, P}; pub mod eval; -// Number of bits used for decimal precision. -pub const DEFAULT_SCALE: u32 = 12; -pub const HALF_SCALE: u32 = 1 << (DEFAULT_SCALE - 1); -// Scale factor = 2^DEFAULT_SCALE, used for fixed-point arithmetic. -pub const SCALE_FACTOR: M31 = M31::from_u32_unchecked(1 << DEFAULT_SCALE); // Half the prime modulus. pub const HALF_P: u32 = P / 2; -/// Integer representation of fixed-point Basefield. +/// Integer representation of fixed-point Basefield with parametrized scale. #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] -pub struct Fixed(pub i64); +pub struct Fixed(pub i64); -impl Fixed { - const SCALE_FACTOR: i64 = 1 << DEFAULT_SCALE; +impl Fixed { + const SCALE_FACTOR: i64 = 1 << SCALE; + const HALF_SCALE_FACTOR: i64 = 1 << (SCALE - 1); #[inline] pub fn from_f64(value: f64) -> Self { @@ -76,9 +72,9 @@ impl Fixed { pub fn recip(self) -> (Self, Self) { assert!(self.0 != 0, "Division by zero"); - let scale_squared = Self::SCALE_FACTOR * Self::SCALE_FACTOR; - let quotient = scale_squared / self.0; - let remainder = scale_squared % self.0; + let scale_factor_squared = Self::SCALE_FACTOR * Self::SCALE_FACTOR; + let quotient = scale_factor_squared / self.0; + let remainder = scale_factor_squared % self.0; (Self(quotient), Self(remainder)) } @@ -104,7 +100,7 @@ impl Fixed { } // Calculate value to compute sqrt of: self * SCALE_FACTOR - let input_scaled = (self.0 as u64) << DEFAULT_SCALE; + let input_scaled = (self.0 as u64) << SCALE; // Compute integer square root let sqrt_val = int_sqrt(input_scaled); @@ -112,7 +108,26 @@ impl Fixed { // Calculate remainder (input_scaled - sqrt_val^2) let remainder = input_scaled - sqrt_val * sqrt_val; - (Self(sqrt_val as i64), Self(remainder as i64)) + ( + Self(sqrt_val as i64), + Self(remainder as i64), + ) + } + + /// Convert this Fixed value to a Fixed with a different scale + pub fn convert_to(self) -> Fixed { + if TARGET_SCALE == SCALE { + // Same scale, just change the type + Fixed(self.0) + } else if TARGET_SCALE > SCALE { + // Going to higher precision + let shift = TARGET_SCALE - SCALE; + Fixed(self.0 << shift) + } else { + // Going to lower precision + let shift = SCALE - TARGET_SCALE; + Fixed(self.0 >> shift) + } } } @@ -150,7 +165,7 @@ pub fn int_sqrt(n: u64) -> u64 { } } -impl Add for Fixed { +impl Add for Fixed { type Output = Self; #[inline] @@ -159,7 +174,7 @@ impl Add for Fixed { } } -impl Sub for Fixed { +impl Sub for Fixed { type Output = Self; #[inline] @@ -168,24 +183,24 @@ impl Sub for Fixed { } } -impl Mul for Fixed { +impl Mul for Fixed { type Output = (Self, Self); #[inline] fn mul(self, rhs: Self) -> Self::Output { let product = self.0 * rhs.0; - let quotient = (product + HALF_SCALE as i64) >> DEFAULT_SCALE; + let quotient = (product + Self::HALF_SCALE_FACTOR) >> SCALE; // Calculate remainder to maintain: product = quotient * scale + remainder - let scaled_quotient = quotient << DEFAULT_SCALE; + let scaled_quotient = quotient << SCALE; let remainder = product - scaled_quotient; (Self(quotient), Self(remainder)) } } -impl Zero for Fixed { +impl Zero for Fixed { #[inline] fn zero() -> Self { Self(0) @@ -211,11 +226,11 @@ mod tests { #[test] fn test_negative() { - let a = Fixed::from_f64(-3.5); - let b = Fixed::from_f64(2.0); + let a = Fixed::<15>::from_f64(-3.5); + let b = Fixed::<15>::from_f64(2.0); assert_near(a.to_f64(), -3.5); - assert_near((a + b.clone()).to_f64(), -1.5); + assert_near((a + b).to_f64(), -1.5); assert_near((a - b).to_f64(), -5.5); } @@ -227,8 +242,8 @@ mod tests { let a = (rng.gen::() - 0.5) * 200.0; let b = (rng.gen::() - 0.5) * 200.0; - let fa = Fixed::from_f64(a); - let fb = Fixed::from_f64(b); + let fa = Fixed::<15>::from_f64(a); + let fb = Fixed::<15>::from_f64(b); assert_near((fa + fb).to_f64(), a + b); } @@ -242,8 +257,8 @@ mod tests { let a = (rng.gen::() - 0.5) * 200.0; let b = (rng.gen::() - 0.5) * 200.0; - let fa = Fixed::from_f64(a); - let fb = Fixed::from_f64(b); + let fa = Fixed::<15>::from_f64(a); + let fb = Fixed::<15>::from_f64(b); assert_near((fa - fb).to_f64(), a - b); } @@ -257,8 +272,8 @@ mod tests { let a = (rng.gen::() - 0.5) * 10.0; let b = (rng.gen::() - 0.5) * 10.0; - let fa = Fixed::from_f64(a); - let fb = Fixed::from_f64(b); + let fa = Fixed::<15>::from_f64(a); + let fb = Fixed::<15>::from_f64(b); let (q, _) = fa * fb; let expected = a * b; @@ -277,12 +292,11 @@ mod tests { continue; } - let fixed_a = Fixed::from_f64(a); - let (_recip, _) = fixed_a.recip(); - let _expected = 1.0 / a; + let fixed_a = Fixed::<15>::from_f64(a); + let (recip, _) = fixed_a.recip(); + let expected = 1.0 / a; - // assert_near(recip.to_f64(), expected); - // TODO (@raphaelDkhn) uncomment when we will parametizing DEFAULT_SCALE. + assert_near(recip.to_f64(), expected); } // Test specific cases @@ -296,7 +310,7 @@ mod tests { ]; for (a, expected) in test_cases { - let fixed_a = Fixed::from_f64(a); + let fixed_a = Fixed::<15>::from_f64(a); let (recip, _) = fixed_a.recip(); assert_near(recip.to_f64(), expected); } @@ -319,7 +333,7 @@ mod tests { } for input in test_cases { - let fixed_input = Fixed::from_f64(input); + let fixed_input = Fixed::<15>::from_f64(input); if input < 0.0 { let (result, remainder) = fixed_input.sqrt(); @@ -333,4 +347,37 @@ mod tests { assert_near(result_f64, input.sqrt()); } } + + #[test] + fn test_different_scales() { + // Test with 15-bit scale + let scale_15 = Fixed::<15>::from_f64(1.5); + assert_near(scale_15.to_f64(), 1.5); + + // Test with 8-bit scale (less precision) + let scale8 = Fixed::<8>::from_f64(1.5); + assert_near(scale8.to_f64(), 1.5); + + // Test with 24-bit scale (more precision) + let scale24 = Fixed::<24>::from_f64(1.5); + assert_near(scale24.to_f64(), 1.5); + + // Test conversion between scales + let from_15_to_8 = scale_15.convert_to::<8>(); + assert_near(from_15_to_8.to_f64(), 1.5); + + let from_8_to_24 = scale8.convert_to::<24>(); + assert_near(from_8_to_24.to_f64(), 1.5); + + // Multiplication with different scales + let a8 = Fixed::<8>::from_f64(2.5); + let b8 = Fixed::<8>::from_f64(3.0); + let (result8, _) = a8 * b8; + assert_near(result8.to_f64(), 7.5); + + let a24 = Fixed::<24>::from_f64(2.5); + let b24 = Fixed::<24>::from_f64(3.0); + let (result24, _) = a24 * b24; + assert_near(result24.to_f64(), 7.5); + } }