diff --git a/Cargo.lock b/Cargo.lock index 074326c39..e09053b24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -995,6 +995,7 @@ version = "0.1.0" dependencies = [ "once_cell", "p3", + "rand", "rand_core", "serde", ] diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 213b1bdc3..15add3b51 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -286,7 +286,7 @@ pub(crate) fn infer_tower_product_witness( #[cfg(test)] mod tests { - use ff_ext::{FieldInto, GoldilocksExt2}; + use ff_ext::{BabyBearExt4, FieldInto, GoldilocksExt2}; use itertools::Itertools; use multilinear_extensions::{ commutative_op_mle_pair, @@ -302,7 +302,7 @@ mod tests { #[test] fn test_infer_tower_witness() { - type E = GoldilocksExt2; + type E = BabyBearExt4; let num_product_fanin = 2; let last_layer: Vec> = vec![ vec![E::ONE, E::from_canonical_u64(2u64)].into_mle(), @@ -454,7 +454,7 @@ mod tests { #[test] fn test_infer_tower_logup_witness() { - type E = GoldilocksExt2; + type E = BabyBearExt4; let num_vars = 2; let q: Vec> = vec![ vec![1, 2, 3, 4] diff --git a/ff_ext/Cargo.toml b/ff_ext/Cargo.toml index b3be82054..10af45207 100644 --- a/ff_ext/Cargo.toml +++ b/ff_ext/Cargo.toml @@ -14,6 +14,7 @@ once_cell = "1.21.3" p3.workspace = true rand_core.workspace = true serde.workspace = true +rand.workspace = true [features] nightly-features = ["p3/nightly-features"] diff --git a/ff_ext/src/babybear.rs b/ff_ext/src/babybear.rs index 44f02f743..f7fd9a1c4 100644 --- a/ff_ext/src/babybear.rs +++ b/ff_ext/src/babybear.rs @@ -208,4 +208,46 @@ pub mod impl_babybear { .collect() } } + + #[cfg(test)] + mod tests { + use p3::{ + babybear::BabyBear, + field::{FieldAlgebra, FieldExtensionAlgebra}, + }; + use rand::thread_rng; + + use crate::{BabyBearExt4, FromUniformBytes}; + + #[test] + fn test_ext_mul() { + for (a_limbs, b_limbs, c_limbs) in vec![ + (vec![0, 1, 0, 0], vec![0, 0, 0, 1], vec![11, 0, 0, 0]), // x*x^3 = 11 + (vec![0, 0, 1, 0], vec![0, 0, 0, 1], vec![0, 11, 0, 0]), // x^2*x^3 = 11*x + (vec![1, 2, 0, 0], vec![0, 0, 3, 4], vec![88, 0, 3, 10]), /* (1+2x)*(3x^2+4x^3) = 88 + 3x^2+10x^3 */ + (vec![1, 2, 3, 4], vec![5, 6, 7, 8], vec![676, 588, 386, 60]), + ] { + let a = BabyBearExt4::from_base_iter( + a_limbs.into_iter().map(BabyBear::from_canonical_u32), + ); + let b = BabyBearExt4::from_base_iter( + b_limbs.into_iter().map(BabyBear::from_canonical_u32), + ); + let c = a * b; + assert_eq!( + c, + BabyBearExt4::from_base_iter( + c_limbs.into_iter().map(BabyBear::from_canonical_u32) + ) + ); + } + + // print one random example + let mut rng = thread_rng(); + let a = BabyBearExt4::random(&mut rng); + let b = BabyBearExt4::random(&mut rng); + let c = a * b; + println!("a: {:?}, b: {:?}, c: {:?}", a, b, c) + } + } } diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index 52bf9ef2e..6feb70628 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -482,11 +482,20 @@ where #[cfg(test)] mod test { - use ff_ext::GoldilocksExt2; + use crate::util::codeword_fold_with_challenge; + use ff_ext::{BabyBearExt4, ExtensionField, FromUniformBytes, GoldilocksExt2, PoseidonField}; + use itertools::Itertools; + use p3::{ + babybear::BabyBear, + commit::Mmcs, + field::{Field, FieldAlgebra, TwoAdicField}, + matrix::{Matrix, bitrev::BitReversableMatrix, dense::RowMajorMatrix}, + }; use crate::{ - basefold::Basefold, + basefold::{Basefold, poseidon2_merkle_tree}, test_util::{run_batch_commit_open_verify, run_commit_open_verify}, + util::test::rand_vec, }; use super::BasefoldRSParams; @@ -509,4 +518,60 @@ mod test { run_batch_commit_open_verify::(20, 21, 64); } } + + #[test] + fn test_fri_fold() { + type E = BabyBearExt4; + type F = BabyBear; + + let mut rng = rand::thread_rng(); + // fold a codeword of length 2^16 using random challenge + let codeword_log2_size = 16; + let codeword = E::random_vec(1 << codeword_log2_size, &mut rng); + + let twiddle = F::GENERATOR.exp_power_of_2(F::TWO_ADICITY - codeword_log2_size); + let inv_2 = F::from_canonical_u64(2).inverse(); + + let challenge = E::random(&mut rng); + + codeword + .chunks(2) + .zip(twiddle.powers()) + .for_each(|(chunk, coeff)| { + codeword_fold_with_challenge(chunk, challenge, coeff, inv_2); + }) + } + + #[test] + fn test_bit_reverse() { + let v = (0..8).collect_vec(); + + let m = RowMajorMatrix::new(v, 1); + assert_eq!( + m.bit_reverse_rows().to_row_major_matrix().values, + vec![0b000, 0b100, 0b010, 0b110, 0b001, 0b101, 0b011, 0b111] + ); + } + + #[test] + fn test_poseidon2_mmcs() { + type E = BabyBearExt4; + + let mut rng = rand::thread_rng(); + let base_mmcs: <::BaseField as PoseidonField>::MMCS = + poseidon2_merkle_tree::(); + + // commit to two matrices whose layouts are (2^10, 4) and (2^14, 10) + let matrices = vec![(1 << 10, 4), (1 << 14, 10)] + .into_iter() + .map(|(rows, cols)| { + RowMajorMatrix::<::BaseField>::new( + rand_vec(rows * cols, &mut rng), + cols, + ) + }) + .collect_vec(); + + base_mmcs.commit(matrices); + } } diff --git a/mpcs/src/basefold/commit_phase.rs b/mpcs/src/basefold/commit_phase.rs index d563334c9..c7de4f0ed 100644 --- a/mpcs/src/basefold/commit_phase.rs +++ b/mpcs/src/basefold/commit_phase.rs @@ -479,3 +479,34 @@ where Some(next_challenge) } + +#[cfg(test)] +mod tests { + use ff_ext::{BabyBearExt4, FromUniformBytes}; + use itertools::Itertools; + use p3::{ + babybear::BabyBear, + matrix::{Matrix, dense::RowMajorMatrix}, + }; + use rand::thread_rng; + + type E = BabyBearExt4; + type F = BabyBear; + + #[test] + fn test_matrix_multiply_vector() { + let num_rows = 1 << 10; + let num_cols = 32; + + let mut rng = thread_rng(); + let matrix = RowMajorMatrix::new(F::random_vec(num_rows * num_cols, &mut rng), num_cols); + let v = E::random_vec(num_cols, &mut rng); + + // matrix multiply vector + // codeword[i] = sum_j matrix[i][j] * v[j] + let _codeword = matrix + .rows() + .map(|row| v.iter().zip(row).map(|(a, b)| *a * b).sum::()) + .collect_vec(); + } +} diff --git a/mpcs/src/basefold/encoding/rs.rs b/mpcs/src/basefold/encoding/rs.rs index 4d6ec071f..3db3d0bf6 100644 --- a/mpcs/src/basefold/encoding/rs.rs +++ b/mpcs/src/basefold/encoding/rs.rs @@ -292,18 +292,19 @@ where mod tests { use std::collections::VecDeque; - use ff_ext::GoldilocksExt2; + use ff_ext::{BabyBearExt4, GoldilocksExt2}; use itertools::izip; use p3::{ commit::{ExtensionMmcs, Mmcs}, goldilocks::Goldilocks, }; - use rand::rngs::OsRng; + use rand::{rngs::OsRng, thread_rng}; use transcript::BasicTranscript; use crate::{ - basefold::commit_phase::basefold_fri_round, util::merkle_tree::poseidon2_merkle_tree, + basefold::commit_phase::basefold_fri_round, + util::{merkle_tree::poseidon2_merkle_tree, test::rand_vec}, }; use super::*; @@ -369,4 +370,28 @@ mod tests { &codeword_from_folded_rmm.values ); } + + #[test] + fn test_lde_batch() { + type E = BabyBearExt4; + let mut rng = thread_rng(); + let dft: Radix2DitParallel<::BaseField> = Default::default(); + + let width = 10; + let added_bits = vec![1, 2, 3]; + for log2_n in 1..22 { + let matrix = DenseMatrix::new(rand_vec(width * (1 << log2_n), &mut rng), width); + for added_bit in added_bits.iter() { + let dur = std::time::Instant::now(); + dft.lde_batch(matrix.clone(), *added_bit); + println!( + "lde(matrix {}x{}, {}) took {:?}", + width, + 1 << log2_n, + added_bit, + dur.elapsed() + ); + } + } + } } diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 764a469ca..7f65107b0 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -1253,3 +1253,39 @@ macro_rules! commutative_op_mle_pair { commutative_op_mle_pair!(|$a, $b| $op, |out| out) }; } + +#[cfg(test)] +mod tests { + use crate::mle::{IntoMLE, MultilinearExtension}; + use ff_ext::{BabyBearExt4, FromUniformBytes}; + use itertools::Itertools; + use rand::thread_rng; + + type E = BabyBearExt4; + + #[test] + fn test_fix_var() { + let mut rng = thread_rng(); + let mle: MultilinearExtension<'_, E> = MultilinearExtension::random(3, &mut rng); + let mle_clone = mle.clone(); + let point = E::random(&mut rng); + + let m1 = mle.fix_variables(&[point]); + let m2 = mle_clone + .as_view() + .get_base_field_vec() + .chunks(2) + .map(|chunk| { + // eq(1,r)*f(1) + eq(0,r)*f(0) + // r*f(1) + (1-r)*f(0) + let a = chunk[0]; + let b = chunk[1]; + point * (b - a) + a + }) + .collect_vec() + .into_mle(); + + assert_eq!(m1.num_vars(), m2.num_vars()); + assert_eq!(m1, m2,); + } +}