Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions crates/backend/poly/src/eq_mle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ const LOG_NUM_THREADS: usize = 5;
/// The number of threads to spawn for parallel computations.
const NUM_THREADS: usize = 1 << LOG_NUM_THREADS;

const LOG_BATCHED_TILE_SIZE: usize = 14;

/// Given `evals` = (α_1, ..., α_n), returns a multilinear polynomial P in n variables,
/// defined on the boolean hypercube by: ∀ (x_1, ..., x_n) ∈ {0, 1}^n,
/// P(x_1, ..., x_n) = Π_{i=1}^{n} (x_i.α_i + (1 - x_i).(1 - α_i))
Expand Down Expand Up @@ -371,6 +373,64 @@ pub fn compute_eval_eq_base_packed<F, EF, const INITIALIZED: bool>(
}
}

#[inline]
pub fn compute_eval_eq_base_packed_batched<F, EF>(
evals: &[MultilinearPoint<F>],
out: &mut [EF::ExtensionPacking],
scalars: &[EF],
) where
F: Field,
EF: ExtensionField<F>,
{
assert_eq!(evals.len(), scalars.len());
if evals.is_empty() {
return;
}

let n = evals[0].len();
let packing_width = F::Packing::WIDTH;
let log_packing_width = log2_strict_usize(packing_width);
assert!(log_packing_width <= n);
assert_eq!(out.len(), 1 << (n - log_packing_width));

let k = n.min(LOG_BATCHED_TILE_SIZE);

if k <= log_packing_width || k >= n {
for (eval, &scalar) in evals.iter().zip(scalars) {
compute_eval_eq_base_packed::<F, EF, true>(eval, out, scalar);
}
return;
}

let n_prefix_levels = n - k;
let tile_packed_size = 1 << (k - log_packing_width);

let per_query: Vec<_> = evals
.iter()
.zip(scalars)
.map(|(eval, &scalar)| {
let middle = &eval[n_prefix_levels..n - log_packing_width];
let eq_suffix = packed_eq_poly::<F, F>(&eval[n - log_packing_width..], F::ONE);
let mut eq_prefix: Vec<EF> = unsafe { uninitialized_vec(1 << n_prefix_levels) };
eval_eq_basic::<F, F, EF, false>(&eval[..n_prefix_levels], &mut eq_prefix, scalar);
(eq_prefix, middle, eq_suffix)
})
.collect();

out.par_chunks_exact_mut(tile_packed_size)
.enumerate()
.for_each(|(tile_idx, out_tile)| {
for (eq_prefix, middle, eq_suffix) in &per_query {
base_eval_eq_packed_with_packed_output::<F, EF, true>(
middle,
out_tile,
*eq_suffix,
EF::ExtensionPacking::from(eq_prefix[tile_idx]),
);
}
});
}

/// Fills the `buffer` with evaluations of the equality polynomial
/// of degree `points.len()` multiplied by the value at `buffer[0]`.
///
Expand Down
22 changes: 21 additions & 1 deletion crates/lean_vm/src/execution/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ impl Trace {
}
}

fn with_capacity(cycles: usize, table_rows: &BTreeMap<Table, usize>) -> Self {
Self {
pcs: Vec::with_capacity(cycles),
fps: Vec::with_capacity(cycles),
tables: BTreeMap::from_iter((0..N_TABLES).map(|i| {
let cap = table_rows.get(&ALL_TABLES[i]).copied().unwrap_or(0);
(ALL_TABLES[i], TableTrace::with_column_capacity(&ALL_TABLES[i], cap))
})),
counts: InstructionCounts::default(),
pending_deref_hints: Vec::new(),
}
}

fn merge(&mut self, other: Self) {
self.pcs.extend(other.pcs);
self.fps.extend(other.fps);
Expand Down Expand Up @@ -423,6 +436,13 @@ fn handle_parallel_batch(
let shared: &[Option<F>] = &*left;
let segment_slices: Vec<&mut [Option<F>]> = right.chunks_mut(stride).take(n_par).collect();

let seg0_table_rows: BTreeMap<Table, usize> = trace
.tables
.iter()
.map(|(t, tt)| (*t, tt.columns.first().map_or(0, |c| c.len())))
.collect();
let seg0_cycles = trace.pcs.len();

type SegResult = Result<(Trace, Vec<(usize, F)>), RunnerError>;
let results: Vec<SegResult> = segment_slices
.into_par_iter()
Expand All @@ -431,7 +451,7 @@ fn handle_parallel_batch(
let seg_start = split_at + i * stride;
let mut seg_mem = SegmentMemory::new(shared, seg_slice, seg_start);
let fp_i = batch.batch_fp + (i + 1) * stride;
let mut seg_trace = Trace::new();
let mut seg_trace = Trace::with_capacity(seg0_cycles, &seg0_table_rows);
let mut seg_pc = batch.batch_pc;
let mut seg_fp = fp_i;
let mut seg_ap = fp_i + batch.frame_size;
Expand Down
10 changes: 10 additions & 0 deletions crates/lean_vm/src/tables/table_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ impl TableTrace {
log_n_rows: 0, // filled later
}
}

pub fn with_column_capacity<A: TableT>(air: &A, capacity: usize) -> Self {
Self {
columns: (0..air.n_columns_total())
.map(|_| Vec::with_capacity(capacity))
.collect(),
non_padded_n_rows: 0,
log_n_rows: 0,
}
}
}

pub fn sort_tables_by_height(tables_log_heights: &BTreeMap<Table, usize>) -> Vec<(Table, usize)> {
Expand Down
13 changes: 5 additions & 8 deletions crates/whir/src/open.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,11 @@ where
assert_eq!(combination_randomness.len(), points.len());
assert_eq!(evaluations.len(), points.len());

// Parallel update of weight buffer

points
.iter()
.zip(combination_randomness.iter())
.for_each(|(point, &rand)| {
compute_eval_eq_base_packed::<_, _, true>(point, self.weights.as_extension_packed_mut().unwrap(), rand);
});
compute_eval_eq_base_packed_batched::<PF<EF>, EF>(
points,
self.weights.as_extension_packed_mut().unwrap(),
combination_randomness,
);

// Accumulate the weighted sum (cheap, done sequentially)
self.sum += combination_randomness
Expand Down
Loading