Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ pub trait EvalFixedPoint: EvalAtRow {
self.add_constraint(remainder + aux - (divisor - Self::F::one()));
}

/// Evaluates remainder constraints for fixed-point numbers.
/// Constrains: dividend = quotient * divisor + remainder
/// This is essentially the same as eval_fixed_div_rem but semantically focused on remainder.
fn eval_fixed_rem(
&mut self,
dividend: Self::F,
divisor: Self::F,
quotient: Self::F,
remainder: Self::F,
) {
self.eval_fixed_div_rem(dividend, divisor, quotient, remainder);
}

/// Evaluates reciprocal constraints for fixed-point numbers.
/// Constrains: scale_factor * scale_factor = value * reciprocal + remainder
fn eval_fixed_recip(
Expand Down Expand Up @@ -123,6 +136,7 @@ mod tests {
Add,
Sub,
Mul,
Rem,
Recip,
Sqrt,
}
Expand Down Expand Up @@ -159,6 +173,13 @@ mod tests {
let rem = eval.next_trace_mask();
eval.eval_fixed_mul(lhs, rhs, scale_factor, out, rem)
}
Op::Rem => {
let dividend = eval.next_trace_mask();
let divisor = eval.next_trace_mask();
let quotient = eval.next_trace_mask();
let remainder = eval.next_trace_mask();
eval.eval_fixed_rem(dividend, divisor, quotient, remainder)
}
Op::Recip => {
let input = eval.next_trace_mask();
let out = eval.next_trace_mask();
Expand Down Expand Up @@ -321,6 +342,45 @@ mod tests {
}
}

#[test]
fn test_rem() {
let mut rng = StdRng::seed_from_u64(42);

// Test regular remainder cases
for _ in 0..50 {
let a = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let b = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);

// Skip cases where divisor is too close to zero
if b.to_f64().abs() < 0.1 {
continue;
}

let (quotient, remainder) = a.div_rem(b);

test_op_internal(Op::Rem, &[a, b], &[quotient, remainder], 2);
}

// Test special cases
let special_cases = vec![
(10.0, 3.0), // 10 % 3 = 1, quotient = 3
(7.5, 2.5), // 7.5 % 2.5 = 0, quotient = 3
(9.0, 4.0), // 9 % 4 = 1, quotient = 2
(8.0, 3.0), // 8 % 3 = 2, quotient = 2
(15.0, 4.0), // 15 % 4 = 3, quotient = 3
(20.0, 6.0), // 20 % 6 = 2, quotient = 3
(1.5, 0.5), // 1.5 % 0.5 = 0, quotient = 3
];

for (a, b) in special_cases {
let fixed_a = Fixed::<15>::from_f64(a);
let fixed_b = Fixed::<15>::from_f64(b);
let (quotient, remainder) = fixed_a.div_rem(fixed_b);

test_op_internal(Op::Rem, &[fixed_a, fixed_b], &[quotient, remainder], 2);
}
}

#[test]
fn test_recip() {
let mut rng = StdRng::seed_from_u64(42);
Expand Down
105 changes: 100 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use num_traits::Zero;
use serde::{Deserialize, Serialize};
use std::ops::{Add, Mul, Sub};
use std::ops::{Add, Mul, Rem, Sub};
use stwo_prover::core::fields::m31::{M31, P};

pub mod eval;
Expand Down Expand Up @@ -63,6 +63,17 @@ impl<const SCALE: u32> Fixed<SCALE> {
}
}

/// Computes both quotient and remainder for division.
/// Returns (quotient, remainder) where dividend = quotient * divisor + remainder
/// Note: The quotient is stored as an unscaled integer for constraint compatibility.
#[inline]
pub fn div_rem(self, rhs: Self) -> (Self, Self) {
assert!(rhs.0 != 0, "Division by zero");
let quotient = self.0 / rhs.0;
let remainder = self.0 % rhs.0;
(Self(quotient), Self(remainder))
}

/// Computes the reciprocal (1/x) of a fixed-point number
///
/// Returns a tuple of (quotient, remainder) where:
Expand Down Expand Up @@ -108,10 +119,7 @@ impl<const SCALE: u32> Fixed<SCALE> {
// Calculate remainder (input_scaled - sqrt_val^2)
let remainder = input_scaled - sqrt_val * sqrt_val;

(
Self(sqrt_val as i64),
Self(remainder as i64),
)
(Self(sqrt_val as i64), Self(remainder as i64))
}

/// Convert this Fixed value to a Fixed with a different scale
Expand Down Expand Up @@ -200,6 +208,16 @@ impl<const SCALE: u32> Mul for Fixed<SCALE> {
}
}

impl<const SCALE: u32> Rem for Fixed<SCALE> {
type Output = Self;

#[inline]
fn rem(self, rhs: Self) -> Self::Output {
assert!(rhs.0 != 0, "Division by zero in remainder operation");
Self(self.0 % rhs.0)
}
}

impl<const SCALE: u32> Zero for Fixed<SCALE> {
#[inline]
fn zero() -> Self {
Expand Down Expand Up @@ -348,6 +366,83 @@ mod tests {
}
}

#[test]
fn test_rem() {
let mut rng = StdRng::seed_from_u64(42);

// Test random cases
for _ in 0..100 {
let a = (rng.gen::<f64>() - 0.5) * 20.0;
let b = (rng.gen::<f64>() - 0.5) * 20.0;

// Skip cases where divisor is too close to zero
if b.abs() < 0.1 {
continue;
}

let fa = Fixed::<15>::from_f64(a);
let fb = Fixed::<15>::from_f64(b);

let remainder = fa % fb;
let expected = a % b;

assert_near(remainder.to_f64(), expected);
}

// Test specific cases
let test_cases = vec![
(10.0, 3.0), // 10 % 3 = 1
(7.5, 2.5), // 7.5 % 2.5 = 0
(9.0, 4.0), // 9 % 4 = 1
(-10.0, 3.0), // -10 % 3 = -1 (or 2, depending on implementation)
(10.0, -3.0), // 10 % -3 = 1 (or -2, depending on implementation)
(-10.0, -3.0), // -10 % -3 = -1 (or 2, depending on implementation)
];

for (a, b) in test_cases {
let fa = Fixed::<15>::from_f64(a);
let fb = Fixed::<15>::from_f64(b);
let remainder = fa % fb;
let expected = a % b;

assert_near(remainder.to_f64(), expected);
}
}

#[test]
fn test_div_rem() {
let mut rng = StdRng::seed_from_u64(42);

for _ in 0..100 {
let a = (rng.gen::<f64>() - 0.5) * 20.0;
let b = (rng.gen::<f64>() - 0.5) * 20.0;
println!("a {:?}", a);
println!("b {:?}", b);

// Skip cases where divisor is too close to zero
if b.abs() < 0.1 {
continue;
}

let fa = Fixed::<15>::from_f64(a);
let fb = Fixed::<15>::from_f64(b);

let (quotient, remainder) = fa.div_rem(fb);

// Verify: dividend = quotient * divisor + remainder
let reconstructed = quotient.0 * fb.0 + remainder.0;
assert_eq!(reconstructed, fa.0);

// Check individual results
let expected_quotient = (a / b).trunc();
let expected_remainder = a % b;

// The quotient from div_rem is stored as an unscaled integer
assert_eq!(quotient.0 as f64, expected_quotient);
assert_near(remainder.to_f64(), expected_remainder);
}
}

#[test]
fn test_different_scales() {
// Test with 15-bit scale
Expand Down