diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index cf660c4e6..c6c15e59d 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -12,6 +12,8 @@ const LOG_NUM_THREADS: usize = 5; /// The number of threads to spawn for parallel computations. const NUM_THREADS: usize = 1 << LOG_NUM_THREADS; +const LOG_BATCHED_TILE_SIZE: usize = 14; + /// Given `evals` = (α_1, ..., α_n), returns a multilinear polynomial P in n variables, /// defined on the boolean hypercube by: ∀ (x_1, ..., x_n) ∈ {0, 1}^n, /// P(x_1, ..., x_n) = Π_{i=1}^{n} (x_i.α_i + (1 - x_i).(1 - α_i)) @@ -371,6 +373,64 @@ pub fn compute_eval_eq_base_packed( } } +#[inline] +pub fn compute_eval_eq_base_packed_batched( + evals: &[MultilinearPoint], + out: &mut [EF::ExtensionPacking], + scalars: &[EF], +) where + F: Field, + EF: ExtensionField, +{ + assert_eq!(evals.len(), scalars.len()); + if evals.is_empty() { + return; + } + + let n = evals[0].len(); + let packing_width = F::Packing::WIDTH; + let log_packing_width = log2_strict_usize(packing_width); + assert!(log_packing_width <= n); + assert_eq!(out.len(), 1 << (n - log_packing_width)); + + let k = n.min(LOG_BATCHED_TILE_SIZE); + + if k <= log_packing_width || k >= n { + for (eval, &scalar) in evals.iter().zip(scalars) { + compute_eval_eq_base_packed::(eval, out, scalar); + } + return; + } + + let n_prefix_levels = n - k; + let tile_packed_size = 1 << (k - log_packing_width); + + let per_query: Vec<_> = evals + .iter() + .zip(scalars) + .map(|(eval, &scalar)| { + let middle = &eval[n_prefix_levels..n - log_packing_width]; + let eq_suffix = packed_eq_poly::(&eval[n - log_packing_width..], F::ONE); + let mut eq_prefix: Vec = unsafe { uninitialized_vec(1 << n_prefix_levels) }; + eval_eq_basic::(&eval[..n_prefix_levels], &mut eq_prefix, scalar); + (eq_prefix, middle, eq_suffix) + }) + .collect(); + + out.par_chunks_exact_mut(tile_packed_size) + .enumerate() + .for_each(|(tile_idx, out_tile)| { + for (eq_prefix, middle, eq_suffix) in &per_query { + base_eval_eq_packed_with_packed_output::( + middle, + out_tile, + *eq_suffix, + EF::ExtensionPacking::from(eq_prefix[tile_idx]), + ); + } + }); +} + /// Fills the `buffer` with evaluations of the equality polynomial /// of degree `points.len()` multiplied by the value at `buffer[0]`. /// diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index 91db24aea..a18792d34 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -85,6 +85,19 @@ impl Trace { } } + fn with_capacity(cycles: usize, table_rows: &BTreeMap) -> Self { + Self { + pcs: Vec::with_capacity(cycles), + fps: Vec::with_capacity(cycles), + tables: BTreeMap::from_iter((0..N_TABLES).map(|i| { + let cap = table_rows.get(&ALL_TABLES[i]).copied().unwrap_or(0); + (ALL_TABLES[i], TableTrace::with_column_capacity(&ALL_TABLES[i], cap)) + })), + counts: InstructionCounts::default(), + pending_deref_hints: Vec::new(), + } + } + fn merge(&mut self, other: Self) { self.pcs.extend(other.pcs); self.fps.extend(other.fps); @@ -423,6 +436,13 @@ fn handle_parallel_batch( let shared: &[Option] = &*left; let segment_slices: Vec<&mut [Option]> = right.chunks_mut(stride).take(n_par).collect(); + let seg0_table_rows: BTreeMap = trace + .tables + .iter() + .map(|(t, tt)| (*t, tt.columns.first().map_or(0, |c| c.len()))) + .collect(); + let seg0_cycles = trace.pcs.len(); + type SegResult = Result<(Trace, Vec<(usize, F)>), RunnerError>; let results: Vec = segment_slices .into_par_iter() @@ -431,7 +451,7 @@ fn handle_parallel_batch( let seg_start = split_at + i * stride; let mut seg_mem = SegmentMemory::new(shared, seg_slice, seg_start); let fp_i = batch.batch_fp + (i + 1) * stride; - let mut seg_trace = Trace::new(); + let mut seg_trace = Trace::with_capacity(seg0_cycles, &seg0_table_rows); let mut seg_pc = batch.batch_pc; let mut seg_fp = fp_i; let mut seg_ap = fp_i + batch.frame_size; diff --git a/crates/lean_vm/src/tables/table_trait.rs b/crates/lean_vm/src/tables/table_trait.rs index cbb773c61..8420fd176 100644 --- a/crates/lean_vm/src/tables/table_trait.rs +++ b/crates/lean_vm/src/tables/table_trait.rs @@ -61,6 +61,16 @@ impl TableTrace { log_n_rows: 0, // filled later } } + + pub fn with_column_capacity(air: &A, capacity: usize) -> Self { + Self { + columns: (0..air.n_columns_total()) + .map(|_| Vec::with_capacity(capacity)) + .collect(), + non_padded_n_rows: 0, + log_n_rows: 0, + } + } } pub fn sort_tables_by_height(tables_log_heights: &BTreeMap) -> Vec<(Table, usize)> { diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 9a09e7988..8b8b4031c 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -366,14 +366,11 @@ where assert_eq!(combination_randomness.len(), points.len()); assert_eq!(evaluations.len(), points.len()); - // Parallel update of weight buffer - - points - .iter() - .zip(combination_randomness.iter()) - .for_each(|(point, &rand)| { - compute_eval_eq_base_packed::<_, _, true>(point, self.weights.as_extension_packed_mut().unwrap(), rand); - }); + compute_eval_eq_base_packed_batched::, EF>( + points, + self.weights.as_extension_packed_mut().unwrap(), + combination_randomness, + ); // Accumulate the weighted sum (cheap, done sequentially) self.sum += combination_randomness