Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 133 additions & 1 deletion crates/backend/poly/src/eq_mle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,105 @@ fn eval_eq_with_packed_output<F: Field, EF: ExtensionField<F>, const INITIALIZED
}
}

#[inline]
fn eval_eq_with_packed_output_dual<F: Field, EF: ExtensionField<F>>(
eval_a: &[EF],
eval_b: &[EF],
out: &mut [EF::ExtensionPacking],
scalar_a: EF::ExtensionPacking,
scalar_b: EF::ExtensionPacking,
) {
debug_assert_eq!(eval_a.len(), eval_b.len());
debug_assert_eq!(out.len(), 1 << eval_a.len());

match eval_a.len() {
0 => {
out[0] = scalar_a + scalar_b;
}
1 => {
let [a0, a1] = eval_eq_1(eval_a, scalar_a);
let [b0, b1] = eval_eq_1(eval_b, scalar_b);
out[0] = a0 + b0;
out[1] = a1 + b1;
}
2 => {
let eq_a = eval_eq_2(eval_a, scalar_a);
let eq_b = eval_eq_2(eval_b, scalar_b);
for i in 0..4 {
out[i] = eq_a[i] + eq_b[i];
}
}
3 => {
let eq_a = eval_eq_3(eval_a, scalar_a);
let eq_b = eval_eq_3(eval_b, scalar_b);
for i in 0..8 {
out[i] = eq_a[i] + eq_b[i];
}
}
_ => {
let (low, high) = out.split_at_mut(out.len() / 2);
let sa1 = scalar_a * eval_a[0];
let sa0 = scalar_a - sa1;
let sb1 = scalar_b * eval_b[0];
let sb0 = scalar_b - sb1;
eval_eq_with_packed_output_dual::<F, EF>(&eval_a[1..], &eval_b[1..], low, sa0, sb0);
eval_eq_with_packed_output_dual::<F, EF>(&eval_a[1..], &eval_b[1..], high, sa1, sb1);
}
}
}

pub fn compute_eval_eq_packed_dual<EF>(
eval_a: &[EF],
eval_b: &[EF],
out: &mut [EF::ExtensionPacking],
scalar_a: EF,
scalar_b: EF,
) where
EF: ExtensionField<PF<EF>>,
{
let packing_width = packing_width::<EF>();
let log_packing_width = log2_strict_usize(packing_width);

assert_eq!(eval_a.len(), eval_b.len());
assert!(log_packing_width <= eval_a.len());
assert_eq!(out.len(), 1 << (eval_a.len() - log_packing_width));

if eval_a.len() <= log_packing_width + 1 + LOG_NUM_THREADS {
let mut output_no_packing = EF::zero_vec(1 << eval_a.len());
eval_eq_basic::<_, _, _, false>(eval_a, &mut output_no_packing, scalar_a);
eval_eq_basic::<_, _, _, true>(eval_b, &mut output_no_packing, scalar_b);
out.par_iter_mut()
.zip(output_no_packing.par_chunks_exact(packing_width))
.for_each(|(out_elem, chunk)| {
*out_elem = EF::ExtensionPacking::from_ext_slice(chunk);
});
} else {
let eval_len_min_packing = eval_a.len() - log_packing_width;

let mut parallel_buffer_a = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED);
let mut parallel_buffer_b = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED);
let out_chunk_size = out.len() / NUM_THREADS_PADDED;

parallel_buffer_a[0] = packed_eq_poly(&eval_a[eval_len_min_packing..], scalar_a);
fill_buffer(eval_a[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer_a);

parallel_buffer_b[0] = packed_eq_poly(&eval_b[eval_len_min_packing..], scalar_b);
fill_buffer(eval_b[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer_b);

out.par_chunks_exact_mut(out_chunk_size)
.enumerate()
.for_each(|(i, out_chunk)| {
eval_eq_with_packed_output_dual::<PF<EF>, EF>(
&eval_a[LOG_NUM_THREADS..eval_len_min_packing],
&eval_b[LOG_NUM_THREADS..eval_len_min_packing],
out_chunk,
parallel_buffer_a[i],
parallel_buffer_b[i],
);
});
}
}

/// Computes the equality polynomial evaluations via a simple recursive algorithm.
///
/// Unlike [`eval_eq_basic`], this function makes heavy use of packed values to speed up computations.
Expand Down Expand Up @@ -968,10 +1067,19 @@ fn base_eval_eq_packed_with_packed_output<F, EF, const INITIALIZED: bool>(
F: Field,
EF: ExtensionField<F>,
{
// Ensure that the output buffer size is correct:
// It should be of size `2^n`, where `n` is the number of variables.
let width = F::Packing::WIDTH;
let log_packing_width = log2_strict_usize(width);
debug_assert_eq!(out.len(), 1 << eval_points.len());
debug_assert!(log_packing_width <= eval_points.len());

match eval_points.len() {
0 => unreachable!(),
0 => {
debug_assert_eq!(F::Packing::WIDTH, 1);
let base_vals = F::Packing::pack_slice(eq_evals.as_slice());
scale_and_add_pf::<F, EF, INITIALIZED>(out, base_vals, packed_scalar);
}
1 => {
let eq_evaluations = eval_eq_1(eval_points, eq_evals);
scale_and_add_pf::<F, EF, INITIALIZED>(out, eq_evaluations.as_slice(), packed_scalar);
Expand Down Expand Up @@ -1248,4 +1356,28 @@ mod tests {
}
}
}

#[test]
fn test_compute_eval_eq_packed_dual() {
let packing_width = <F as Field>::Packing::WIDTH;
let log_packing_width = log2_strict_usize(packing_width);
let mut rng = StdRng::seed_from_u64(42);

for n_vars in log_packing_width..22 {
let eval_a: Vec<EF> = (0..n_vars).map(|_| rng.random()).collect();
let eval_b: Vec<EF> = (0..n_vars).map(|_| rng.random()).collect();
let scalar_a: EF = rng.random();
let scalar_b: EF = rng.random();

let packed_len = 1 << (n_vars - log_packing_width);
let mut out_dual = EFPacking::<EF>::zero_vec(packed_len);
compute_eval_eq_packed_dual::<EF>(&eval_a, &eval_b, &mut out_dual, scalar_a, scalar_b);

let mut out_separate = EFPacking::<EF>::zero_vec(packed_len);
compute_eval_eq_packed::<EF, false>(&eval_a, &mut out_separate, scalar_a);
compute_eval_eq_packed::<EF, true>(&eval_b, &mut out_separate, scalar_b);

assert_eq!(out_dual, out_separate, "Mismatch at n_vars={}", n_vars);
}
}
}
6 changes: 1 addition & 5 deletions crates/backend/sumcheck/src/product_computation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@ pub fn run_product_sumcheck<EF: ExtensionField<PF<EF>>>(
assert!(n_rounds >= 1);
let first_sumcheck_poly = match (pol_a, pol_b) {
(MleRef::BasePacked(evals), MleRef::ExtensionPacked(weights)) => {
if EF::DIMENSION == 5 {
compute_product_sumcheck_polynomial_base_ext_packed::<5, _, _, _, EF>(evals, weights, sum)
} else {
unimplemented!()
}
compute_product_sumcheck_polynomial(evals, weights, sum, |e| EFPacking::<EF>::to_ext_iter([e]).collect())
}
(MleRef::ExtensionPacked(evals), MleRef::ExtensionPacked(weights)) => {
compute_product_sumcheck_polynomial(evals, weights, sum, |e| EFPacking::<EF>::to_ext_iter([e]).collect())
Expand Down
2 changes: 1 addition & 1 deletion crates/lean_prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub const SECURITY_BITS: usize = 124; // TODO 128 bits security

pub const GRINDING_BITS: usize = 16;
pub const MAX_NUM_VARIABLES_TO_SEND_COEFFS: usize = 8;
pub const WHIR_INITIAL_FOLDING_FACTOR: usize = 7;
pub const WHIR_INITIAL_FOLDING_FACTOR: usize = 8;
pub const WHIR_SUBSEQUENT_FOLDING_FACTOR: usize = 5;
pub const RS_DOMAIN_INITIAL_REDUCTION_FACTOR: usize = 5;

Expand Down
22 changes: 8 additions & 14 deletions crates/lean_prover/src/trace_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,31 +112,25 @@ pub fn get_execution_trace(
let poseidon_trace = traces.get_mut(&Table::poseidon16()).unwrap();
fill_trace_poseidon_16(&mut poseidon_trace.columns);

// For permute=0 rows, override unconstrained output columns with memory values
// so the lookup matches. Same when half_output=1.
// For half_output=1 rows: override outputs_left[4..7] with memory values
// so the lookup passes (the lookup reads memory[res+4..7] which matches).
// outputs_right removed — no override needed for those.
{
let split = POSEIDON_16_COL_OUTPUT_LEFT + HALF_DIGEST_LEN;
let (left, right) = poseidon_trace.columns.split_at_mut(split);
let half_output_col = &left[POSEIDON_16_COL_FLAG_HALF_OUTPUT];
let permute_col = &left[POSEIDON_16_COL_FLAG_PERMUTE];
let res_col = &left[POSEIDON_16_COL_INDEX_INPUT_RES];
const N: usize = HALF_DIGEST_LEN + DIGEST_LEN;
const N: usize = HALF_DIGEST_LEN;
let cols: &mut [Vec<F>; N] = (&mut right[..N]).try_into().unwrap();

transposed_par_iter_mut(cols)
.zip(half_output_col)
.zip(permute_col)
.zip(res_col)
.for_each(|(((row, &half), &permute), &res)| {
if permute == F::ZERO {
.for_each(|((row, &half), &res)| {
if half == F::ONE {
let base = res.to_usize();
if half == F::ONE {
for j in 0..HALF_DIGEST_LEN {
*row[j] = memory_padded[base + HALF_DIGEST_LEN + j];
}
}
for j in 0..DIGEST_LEN {
*row[HALF_DIGEST_LEN + j] = memory_padded[base + DIGEST_LEN + j];
for j in 0..HALF_DIGEST_LEN {
*row[j] = memory_padded[base + HALF_DIGEST_LEN + j];
}
}
});
Expand Down
2 changes: 1 addition & 1 deletion crates/lean_vm/src/core/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,6 @@ mod tests {
for (table, max_log_n_rows) in MAX_LOG_N_ROWS_PER_TABLE {
max_surface += (table.n_columns() as u64) << (max_log_n_rows as u64);
}
assert!(max_surface <= 1 << 30); // Maximum data we can commit via WHIR using an initial folding factor of 7, and rate = 1/2
assert!(max_surface <= 1 << 31); // Maximum data we can commit via WHIR using an initial folding factor of 8, and rate = 1/2
}
}
12 changes: 3 additions & 9 deletions crates/lean_vm/src/tables/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ pub const POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_FIRST: ColIndex = 6;
pub const POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_SECOND: ColIndex = 7;
pub const POSEIDON_16_COL_FLAG_PERMUTE: ColIndex = 8;
pub const POSEIDON_16_COL_INPUT_START: ColIndex = 9;
pub const POSEIDON_16_COL_OUTPUT_LEFT: ColIndex = num_cols_poseidon_16() - 16;
pub const POSEIDON_16_COL_OUTPUT_RIGHT: ColIndex = num_cols_poseidon_16() - 8;
pub const POSEIDON_16_COL_OUTPUT_LEFT: ColIndex = num_cols_poseidon_16() - 8;
/// Non-committed columns ("virtual"):
pub const POSEIDON_16_COL_INDEX_INPUT_LEFT: ColIndex = num_cols_poseidon_16();
pub const POSEIDON_16_COL_DOMAINSEP: ColIndex = num_cols_poseidon_16() + 1;
Expand Down Expand Up @@ -171,7 +170,7 @@ impl<const BUS: bool> TableT for Poseidon16Precompile<BUS> {
buses.extend(memory_lookups_consecutive(
POSEIDON_16_COL_INDEX_INPUT_RES,
POSEIDON_16_COL_OUTPUT_LEFT,
DIGEST_LEN * 2,
DIGEST_LEN, // was DIGEST_LEN * 2 (included outputs_right)
));
buses
}
Expand All @@ -193,7 +192,6 @@ impl<const BUS: bool> TableT for Poseidon16Precompile<BUS> {
*perm.effective_index_left_first = F::from_usize(zero_vec_ptr);
*perm.effective_index_left_second = F::from_usize(zero_vec_ptr + HALF_DIGEST_LEN);
*perm.flag_permute = F::ZERO;
perm.outputs_right.iter_mut().for_each(|x| **x = F::ZERO);
row[POSEIDON_16_COL_INDEX_INPUT_LEFT] = F::from_usize(zero_vec_ptr);
row[POSEIDON_16_COL_DOMAINSEP] = F::from_usize(POSEIDON_DOMAINSEP_BASE);

Expand Down Expand Up @@ -308,7 +306,7 @@ impl<const BUS: bool> Air for Poseidon16Precompile<BUS> {
0
}
fn n_constraints(&self) -> usize {
2 * BUS as usize + 99
2 * BUS as usize + 91 // was 99, removed 8 flag_permute * (state[i+8] - outputs_right[i]) constraints
}
fn eval<AB: AirBuilder>(&self, builder: &mut AB, extra_data: &Self::ExtraData) {
let cols: Poseidon1Cols16<AB::IF> = {
Expand Down Expand Up @@ -378,7 +376,6 @@ pub(super) struct Poseidon1Cols16<T> {
pub partial_rounds: [T; PARTIAL_ROUNDS],
pub ending_full_rounds: [[T; WIDTH]; HALF_FINAL_FULL_ROUNDS - 1],
pub outputs_left: [T; WIDTH / 2],
pub outputs_right: [T; WIDTH / 2],
}

fn eval_poseidon1_16<AB: AirBuilder>(builder: &mut AB, local: &Poseidon1Cols16<AB::IF>) {
Expand Down Expand Up @@ -438,7 +435,6 @@ fn eval_poseidon1_16<AB: AirBuilder>(builder: &mut AB, local: &Poseidon1Cols16<A
&local.inputs,
&mut state,
&local.outputs_left,
&local.outputs_right,
&final_constants[2 * (HALF_FINAL_FULL_ROUNDS - 1)],
&final_constants[2 * (HALF_FINAL_FULL_ROUNDS - 1) + 1],
local.flag_half_output,
Expand Down Expand Up @@ -486,7 +482,6 @@ fn eval_last_2_full_rounds_16<AB: AirBuilder>(
initial_state: &[AB::IF; WIDTH],
state: &mut [AB::IF; WIDTH],
outputs_left: &[AB::IF; WIDTH / 2],
outputs_right: &[AB::IF; WIDTH / 2],
round_constants_1: &[F; WIDTH],
round_constants_2: &[F; WIDTH],
flag_half_output: AB::IF,
Expand All @@ -513,7 +508,6 @@ fn eval_last_2_full_rounds_16<AB: AirBuilder>(
};
builder.assert_zero(compression_gate * (state[i] + initial_state[i] - outputs_left[i]));
builder.assert_zero(flag_permute * (state[i] - outputs_left[i]));
builder.assert_zero(flag_permute * (state[i + WIDTH / 2] - outputs_right[i]));
}
}

Expand Down
4 changes: 1 addition & 3 deletions crates/lean_vm/src/tables/poseidon/trace_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ pub(super) fn generate_trace_rows_for_perm<F: Algebra<KoalaBear> + Copy>(perm: &
&mut state,
&inputs,
&mut perm.outputs_left,
&mut perm.outputs_right,
flag_permute,
&poseidon1_final_constants()[2 * n_ending_full_rounds],
&poseidon1_final_constants()[2 * n_ending_full_rounds + 1],
Expand Down Expand Up @@ -140,7 +139,6 @@ fn generate_last_2_full_rounds<F: Algebra<KoalaBear> + Copy>(
state: &mut [F; WIDTH],
inputs: &[F; WIDTH],
outputs_left: &mut [&mut F; WIDTH / 2],
outputs_right: &mut [&mut F; WIDTH / 2],
flag_permute: F,
round_constants_1: &[KoalaBear; WIDTH],
round_constants_2: &[KoalaBear; WIDTH],
Expand All @@ -160,6 +158,6 @@ fn generate_last_2_full_rounds<F: Algebra<KoalaBear> + Copy>(
for i in 0..(WIDTH / 2) {
let compression_value = state[i] + inputs[i];
*outputs_left[i] = (F::ONE - flag_permute) * compression_value + flag_permute * state[i];
*outputs_right[i] = flag_permute * state[i + WIDTH / 2];
// outputs_right removed — only outputs_left is committed
}
}
2 changes: 1 addition & 1 deletion crates/rec_aggregation/src/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ fn compile_main_program_self_referential() -> Bytecode {
if actual_log_size == log_size_guess {
return bytecode;
}
println!(
eprintln!(
"Wrong guess at `compile_main_program_self_referential` (log_size {log_size_guess}->{actual_log_size})"
);
log_size_guess = actual_log_size;
Expand Down
3 changes: 3 additions & 0 deletions crates/rec_aggregation/zkdsl_implem/whir.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ def decompose_and_verify_merkle_batch_with_height(
if num_chunks == 5:
decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height, 5, circle_values, answers)
return
if num_chunks == 32:
decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height, 32, circle_values, answers)
return
print(num_chunks)
assert False, "decompose_and_verify_merkle_batch called with unsupported num_chunks"

Expand Down
6 changes: 2 additions & 4 deletions crates/sub_protocols/src/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ pub fn prove_generic_logup(
let memory_domainsep_packed = PFPacking::<EF>::from(F::from_usize(LOGUP_MEMORY_DOMAINSEP));
let bytecode_domainsep_packed = PFPacking::<EF>::from(F::from_usize(LOGUP_BYTECODE_DOMAINSEP));

let min_section_log = log_bytecode.min(tables_log_heights_sorted.last().unwrap().1);
if min_section_log < ENDIANNESS_PIVOT_GKR {
tracing::info!("TODO: suboptimal GKR pivot (could be improved).");
}
let log_bytecode_section = log_bytecode.max(tables_log_heights_sorted[0].1);
let min_section_log = log_bytecode_section.min(tables_log_heights_sorted.last().unwrap().1);
let pivot = ENDIANNESS_PIVOT_GKR.min(min_section_log);
let chunk_size = 1usize << pivot;
let chunk_shift = usize::BITS as usize - pivot;
Expand Down
Loading
Loading