Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,69 @@
A fixed-point arithmetic library providing constrained fixed-point operations for [Stwo](https://github.com/starkware-libs/stwo.git)-based circuits.
The library implements fixed-point arithmetic using M31 field elements with configurable decimal precision.

## Features

- Type-level scale parameter using Rust's const generics
- Fixed-point operations (addition, subtraction, multiplication, reciprocal, square root)
- Circuit constraints for all operations
- Zero memory overhead for scale information

## Usage

### Basic Arithmetic

```rust

// Create fixed-point numbers
let lhs = Fixed::from_f64(3.14);
let rhs = Fixed::from_f64(2.0);
// Using 15 bits of precision
let a = Fixed::<15>::from_f64(3.14);
let b = Fixed::<15>::from_f64(2.0);

// Basic arithmetic operations
let sum = lhs + rhs;
let diff = lhs - rhs;
let (prod, rem) = lhs * rhs;
let sum = a + b;
let diff = a - b;
let (prod, rem) = a * b;

// Using custom scales
let high_precision = Fixed::<24>::from_f64(0.12345678);
let low_precision = Fixed::<8>::from_f64(42.5);

// Convert between scales
let converted = high_precision.convert_to::<8>();
```

### In Circuit Constraints

To use fixed-point operations in your Stwo Prover circuit, use the constraint evaluation functions:

```rust
use numerair::eval::{eval_add, eval_mul, eval_sub};
use numerair::eval::EvalFixedPoint;

// In your circuit component's evaluate function:
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let lhs = eval.next_trace_mask();
let rhs = eval.next_trace_mask();
let rem = eval.next_trace_mask();
let res = eval.next_trace_mask();

// Get scale factor for the specific scale you're using
let scale_factor = E::F::from(M31::from_u32_unchecked(1 << 15)); // For Fixed<15>

// Constrain mul.
eval.eval_fixed_mul(lhs, rhs, SCALE_FACTOR.into(), res, rem)
// Constrain mul using EvalFixedPoint trait
eval.eval_fixed_mul(lhs, rhs, scale_factor, res, rem);

eval
}
```

## How It Works

The fixed-point representation uses a const generic parameter to determine the number of bits used for the fractional part:

- `SCALE`: Type-level constant that defines the number of bits for decimal precision
- `SCALE_FACTOR`: The value 2^SCALE (automatically calculated), represents 1.0 in fixed-point
- `HALF_SCALE_FACTOR`: The value 2^(SCALE-1) (automatically calculated), used for rounding

A value `x` in floating point is represented as `floor(x * 2^SCALE)` in fixed-point.

## Contributing

Contributions are welcome! Please submit pull requests with:
Expand Down
109 changes: 62 additions & 47 deletions src/eval.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{HALF_P, SCALE_FACTOR};
use crate::HALF_P;
use num_traits::One;
use stwo_prover::{constraint_framework::EvalAtRow, core::fields::m31::M31};

Expand All @@ -19,12 +19,12 @@ pub trait EvalFixedPoint: EvalAtRow {
&mut self,
a: Self::F,
b: Self::F,
scale: Self::F,
scale_factor: Self::F,
quotient: Self::F,
remainder: Self::F,
) {
let product = self.add_intermediate(a * b);
self.eval_fixed_div_rem(product, scale, quotient, remainder);
self.eval_fixed_div_rem(product, scale_factor, quotient, remainder);
}

/// Evaluates constraints for signed division with remainder.
Expand All @@ -46,37 +46,44 @@ pub trait EvalFixedPoint: EvalAtRow {
}

/// Evaluates reciprocal constraints for fixed-point numbers.
/// Constrains: scale * scale = value * reciprocal + remainder
/// Constrains: scale_factor * scale_factor = value * reciprocal + remainder
fn eval_fixed_recip(
&mut self,
value: Self::F,
scale: Self::F,
scale_factor: Self::F,
reciprocal: Self::F,
remainder: Self::F,
) {
let scale_squared = self.add_intermediate(scale.clone() * scale);
let scale_squared = self.add_intermediate(scale_factor.clone() * scale_factor);
self.eval_fixed_div_rem(scale_squared, value, reciprocal, remainder);
}

/// Evaluates constraints for square root operations.
/// Adds constraints to verify that:
/// 1. The input is non-negative
/// 2. out^2 + rem = input * SCALE_FACTOR
/// 2. out^2 + rem = input * scale_factor
///
/// # Parameters
/// - `input`: The trace column value representing the scaled input.
/// - `out`: The trace column value of the scaled square root.
/// - `rem`: The trace column value of the remainder.
fn eval_fixed_sqrt(&mut self, input: Self::F, out: Self::F, rem: Self::F) {
/// - `scale_factor`: The scale_factor factor to use for fixed-point representation.
fn eval_fixed_sqrt(
&mut self,
input: Self::F,
out: Self::F,
rem: Self::F,
scale_factor: Self::F,
) {
// Constraint to ensure input is non-negative
// For field elements, we check if input is in the range [0, HALF_P)
// We need an auxiliary variable to ensure 0 <= input < HALF_P
let aux =
self.add_intermediate(Self::F::from(M31(HALF_P)) - Self::F::one() - input.clone());
self.add_constraint(input.clone() + aux - (Self::F::from(M31(HALF_P)) - Self::F::one()));

// Enforce the constraint: out^2 + rem = input * SCALE_FACTOR
self.add_constraint((out.clone() * out) + rem.clone() - (input * SCALE_FACTOR));
// Enforce the constraint: out^2 + rem = input * scale_factor
self.add_constraint((out.clone() * out) + rem.clone() - (input * scale_factor));
}
}

Expand All @@ -85,15 +92,14 @@ impl<T: EvalAtRow> EvalFixedPoint for T {}

#[cfg(test)]
mod tests {

use num_traits::Zero;
use rand::{rngs::StdRng, Rng, SeedableRng};
use stwo_prover::{
constraint_framework::{self, preprocessed_columns::IsFirst, FrameworkEval},
core::{
backend::{simd::SimdBackend, Col, Column},
fields::{
m31::{BaseField, P},
m31::{BaseField, M31, P},
qm31::SecureField,
},
pcs::TreeVec,
Expand All @@ -104,11 +110,10 @@ mod tests {
},
};

use crate::{Fixed, SCALE_FACTOR};

use crate::Fixed;
use super::*;

struct TestEval {
struct TestEval<const SCALE: u32 = 15> {
log_size: u32,
op: Op,
}
Expand All @@ -122,7 +127,7 @@ mod tests {
Sqrt,
}

impl FrameworkEval for TestEval {
impl<const SCALE: u32> FrameworkEval for TestEval<SCALE> {
fn log_size(&self) -> u32 {
self.log_size
}
Expand All @@ -132,6 +137,8 @@ mod tests {
}

fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let scale_factor = E::F::from(M31::from_u32_unchecked(1 << SCALE));

match self.op {
Op::Add => {
let lhs = eval.next_trace_mask();
Expand All @@ -150,19 +157,19 @@ mod tests {
let rhs = eval.next_trace_mask();
let out = eval.next_trace_mask();
let rem = eval.next_trace_mask();
eval.eval_fixed_mul(lhs, rhs, SCALE_FACTOR.into(), out, rem)
eval.eval_fixed_mul(lhs, rhs, scale_factor, out, rem)
}
Op::Recip => {
let input = eval.next_trace_mask();
let out = eval.next_trace_mask();
let rem = eval.next_trace_mask();
eval.eval_fixed_recip(input, SCALE_FACTOR.into(), out, rem)
eval.eval_fixed_recip(input, scale_factor, out, rem)
}
Op::Sqrt => {
let input = eval.next_trace_mask();
let out = eval.next_trace_mask();
let rem = eval.next_trace_mask();
eval.eval_fixed_sqrt(input, out, rem)
eval.eval_fixed_sqrt(input, out, rem, scale_factor)
}
}
eval
Expand All @@ -184,11 +191,11 @@ mod tests {
.collect()
}

fn test_op(
fn test_op_internal<const SCALE: u32>(
op: Op,
inputs: Vec<Fixed>,
expected_outputs: Vec<Fixed>,
tamper_col_idx: usize, // The column to tamper
inputs: &[Fixed<SCALE>],
expected_outputs: &[Fixed<SCALE>],
tamper_col_idx: usize,
) {
const LOG_SIZE: u32 = 4;
let domain = CanonicCoset::new(LOG_SIZE);
Expand All @@ -212,7 +219,7 @@ mod tests {

let trace_polys = trace.map_cols(|c| c.interpolate());

let component = TestEval {
let component = TestEval::<SCALE> {
log_size: LOG_SIZE,
op,
};
Expand All @@ -231,7 +238,9 @@ mod tests {
let mut invalid_trace_cols = trace_cols;
if let Some(col) = invalid_trace_cols.get_mut(tamper_col_idx) {
for val in col.iter_mut() {
val.0 = (val.0 + SCALE_FACTOR.0) % P;
// Calculate scale factor for tampering
let scale_factor = M31::from_u32_unchecked(1 << SCALE);
val.0 = (val.0 + scale_factor.0) % P;
}
}

Expand Down Expand Up @@ -259,20 +268,21 @@ mod tests {
fn test_add() {
let mut rng = StdRng::seed_from_u64(42);
for _ in 0..100 {
let a = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let b = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let a = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let b = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);

test_op(Op::Add, vec![a, b], vec![a + b], 2);
test_op_internal(Op::Add, &[a, b], &[a + b], 2);
}
}

#[test]
fn test_sub() {
let mut rng = StdRng::seed_from_u64(42);
for _ in 0..100 {
let a = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let b = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
test_op(Op::Sub, vec![a, b], vec![a - b], 2);
let a = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let b = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);

test_op_internal(Op::Sub, &[a, b], &[a - b], 2);
}
}

Expand All @@ -282,11 +292,11 @@ mod tests {

// Test regular multiplication cases
for _ in 0..100 {
let a = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let b = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let a = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let b = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let (expected, rem) = a * b;

test_op(Op::Mul, vec![a, b], vec![expected, rem], 2);
test_op_internal(Op::Mul, &[a, b], &[expected, rem], 2);
}

// Test special cases
Expand All @@ -303,24 +313,28 @@ mod tests {
];

for (a, b) in special_cases {
let fixed_a = Fixed::from_f64(a);
let fixed_b = Fixed::from_f64(b);
let fixed_a = Fixed::<15>::from_f64(a);
let fixed_b = Fixed::<15>::from_f64(b);
let (expected, rem) = fixed_a * fixed_b;

test_op(Op::Mul, vec![fixed_a, fixed_b], vec![expected, rem], 2);
test_op_internal(Op::Mul, &[fixed_a, fixed_b], &[expected, rem], 2);
}
}

#[test]
fn test_recip() {
let mut rng = StdRng::seed_from_u64(42);

// Test regular multiplication cases
// Test regular recip cases
for _ in 0..100 {
let input = Fixed::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
let input = Fixed::<15>::from_f64((rng.gen::<f64>() - 0.5) * 200.0);
if input.0 == 0 {
continue; // Skip division by zero
}

let (expected, rem) = input.recip();

test_op(Op::Recip, vec![input], vec![expected, rem], 2);
test_op_internal(Op::Recip, &[input], &[expected, rem], 1);
}

// Test special cases
Expand All @@ -334,29 +348,30 @@ mod tests {
];

for input in special_cases {
let fixed_input = Fixed::from_f64(input);
let fixed_input = Fixed::<15>::from_f64(input);
let (expected, rem) = fixed_input.recip();

test_op(Op::Recip, vec![fixed_input], vec![expected, rem], 1);
test_op_internal(Op::Recip, &[fixed_input], &[expected, rem], 1);
}
}

#[test]
fn test_eval_sqrt() {
fn test_sqrt() {
let test_cases = vec![1.0, 4.0, 9.0, 2.0, 0.5, 0.25, 0.0];
for input in test_cases {
let fixed_input = Fixed::from_f64(input);
let fixed_input = Fixed::<15>::from_f64(input);
let (sqrt_out, rem) = fixed_input.sqrt();

test_op(Op::Sqrt, vec![fixed_input], vec![sqrt_out, rem], 1);
test_op_internal(Op::Sqrt, &[fixed_input], &[sqrt_out, rem], 1);
}

let mut rng = StdRng::seed_from_u64(43);
for _ in 0..50 {
let input_val: f64 = rng.gen_range(0.0..100.0);
let fixed_input = Fixed::from_f64(input_val);
let fixed_input = Fixed::<15>::from_f64(input_val);
let (sqrt_out, rem) = fixed_input.sqrt();
test_op(Op::Sqrt, vec![fixed_input], vec![sqrt_out, rem], 1);

test_op_internal(Op::Sqrt, &[fixed_input], &[sqrt_out, rem], 1);
}
}
}
Loading
Loading