Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 27 additions & 25 deletions src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,10 @@ mod tests {
}
}

struct TestEval<const SCALE: u32 = 15> {
struct TestEval {
log_size: u32,
op: Op,
scale: u32,
}

#[derive(Clone, Copy)]
Expand All @@ -162,7 +163,7 @@ mod tests {
Sqrt,
}

impl<const SCALE: u32> FrameworkEval for TestEval<SCALE> {
impl FrameworkEval for TestEval {
fn log_size(&self) -> u32 {
self.log_size
}
Expand All @@ -172,7 +173,7 @@ mod tests {
}

fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let scale_factor = E::F::from(M31::from_u32_unchecked(1 << SCALE));
let scale_factor = E::F::from(M31::from_u32_unchecked(1 << self.scale));

match self.op {
Op::Add => {
Expand Down Expand Up @@ -233,10 +234,10 @@ mod tests {
.collect()
}

fn test_op_internal<const SCALE: u32>(
fn test_op_internal(
op: Op,
inputs: &[Fixed<SCALE>],
expected_outputs: &[Fixed<SCALE>],
inputs: &[Fixed],
expected_outputs: &[Fixed],
tamper_col_idx: usize,
) {
const LOG_SIZE: u32 = 4;
Expand All @@ -261,9 +262,10 @@ mod tests {

let trace_polys = trace.map_cols(|c| c.interpolate());

let component = TestEval::<SCALE> {
let component = TestEval {
log_size: LOG_SIZE,
op,
scale: 15, // Default scale for tests
};

// Test valid trace
Expand All @@ -281,7 +283,7 @@ mod tests {
if let Some(col) = invalid_trace_cols.get_mut(tamper_col_idx) {
for val in col.iter_mut() {
// Calculate scale factor for tampering
let scale_factor = M31::from_u32_unchecked(1 << SCALE);
let scale_factor = M31::from_u32_unchecked(1 << 15); // Default scale
val.0 = (val.0 + scale_factor.0) % P;
}
}
Expand Down Expand Up @@ -310,8 +312,8 @@ mod tests {
fn test_add() {
let mut rng = StdRng::seed_from_u64(42);
for _ in 0..100 {
let a = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let b = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let a = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0, 15);
let b = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0, 15);

test_op_internal(Op::Add, &[a, b], &[a + b], 2);
}
Expand All @@ -321,8 +323,8 @@ mod tests {
fn test_sub() {
let mut rng = StdRng::seed_from_u64(42);
for _ in 0..100 {
let a = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let b = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let a = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0, 15);
let b = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0, 15);

test_op_internal(Op::Sub, &[a, b], &[a - b], 2);
}
Expand All @@ -334,8 +336,8 @@ mod tests {

// Test regular multiplication cases
for _ in 0..100 {
let a = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let b = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let a = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0, 15);
let b = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0, 15);
let (expected, rem) = a * b;

test_op_internal(Op::Mul, &[a, b], &[expected, rem], 2);
Expand All @@ -355,8 +357,8 @@ mod tests {
];

for (a, b) in special_cases {
let fixed_a = Fixed::<15>::from_f64(a);
let fixed_b = Fixed::<15>::from_f64(b);
let fixed_a = Fixed::from_f64(a, 15);
let fixed_b = Fixed::from_f64(b, 15);
let (expected, rem) = fixed_a * fixed_b;

test_op_internal(Op::Mul, &[fixed_a, fixed_b], &[expected, rem], 2);
Expand All @@ -369,8 +371,8 @@ mod tests {

// Test regular remainder cases
for _ in 0..50 {
let a = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let b = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let a = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0, 15);
let b = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0, 15);

// Skip cases where divisor is too close to zero
if b.to_f64().abs() < 0.1 {
Expand All @@ -394,8 +396,8 @@ mod tests {
];

for (a, b) in special_cases {
let fixed_a = Fixed::<15>::from_f64(a);
let fixed_b = Fixed::<15>::from_f64(b);
let fixed_a = Fixed::from_f64(a, 15);
let fixed_b = Fixed::from_f64(b, 15);
let (quotient, remainder) = fixed_a.div_rem(fixed_b);

test_op_internal(Op::Rem, &[fixed_a, fixed_b], &[quotient, remainder], 2);
Expand All @@ -408,8 +410,8 @@ mod tests {

// Test regular recip cases
for _ in 0..100 {
let input = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
if input.0 == 0 {
let input = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0, 15);
if input.value == 0 {
continue; // Skip division by zero
}

Expand All @@ -429,7 +431,7 @@ mod tests {
];

for input in special_cases {
let fixed_input = Fixed::<15>::from_f64(input);
let fixed_input = Fixed::from_f64(input, 15);
let (expected, rem) = fixed_input.recip();

test_op_internal(Op::Recip, &[fixed_input], &[expected, rem], 1);
Expand All @@ -440,7 +442,7 @@ mod tests {
fn test_sqrt() {
let test_cases = vec![1.0, 4.0, 9.0, 2.0, 0.5, 0.25, 0.0];
for input in test_cases {
let fixed_input = Fixed::<15>::from_f64(input);
let fixed_input = Fixed::from_f64(input, 15);
let (sqrt_out, rem) = fixed_input.sqrt();

test_op_internal(Op::Sqrt, &[fixed_input], &[sqrt_out, rem], 1);
Expand All @@ -449,7 +451,7 @@ mod tests {
let mut rng = StdRng::seed_from_u64(43);
for _ in 0..50 {
let input_val: f64 = rng.gen_range(0.0..100.0);
let fixed_input = Fixed::<15>::from_f64(input_val);
let fixed_input = Fixed::from_f64(input_val, 15);
let (sqrt_out, rem) = fixed_input.sqrt();

test_op_internal(Op::Sqrt, &[fixed_input], &[sqrt_out, rem], 1);
Expand Down
Loading
Loading