Skip to content

Commit 4cc57d4

Browse files
Faten848Barnadrot
andcommitted
avoid unnecessary allocation in initial Merkle tree
Co-Authored-By: Borna <94551425+Barnadrot@users.noreply.github.com>
1 parent 9b44b59 commit 4cc57d4

3 files changed

Lines changed: 47 additions & 100 deletions

File tree

crates/whir/src/commit.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ use crate::*;
99

1010
#[derive(Debug, Clone)]
1111
pub enum MerkleData<EF: ExtensionField<PF<EF>>> {
12-
Base(RoundMerkleTree<PF<EF>, PF<EF>>),
13-
Extension(RoundMerkleTree<PF<EF>, EF>),
12+
Base(RoundMerkleTree<PF<EF>>),
13+
Extension(RoundMerkleTree<PF<EF>>),
1414
}
1515

1616
impl<EF: ExtensionField<PF<EF>>> MerkleData<EF> {

crates/whir/src/matrix.rs

Lines changed: 1 addition & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
use std::{
44
borrow::{Borrow, BorrowMut},
5-
iter,
65
marker::PhantomData,
76
ops::Deref,
87
};
98

10-
use field::{ExtensionField, Field, PackedValue};
9+
use field::PackedValue;
1110
use itertools::Itertools;
1211

1312
pub trait Matrix<T: Send + Sync + Clone>: Send + Sync {
@@ -123,88 +122,6 @@ impl<T: Clone + Send + Sync, S: DenseStorage<T>> DenseMatrix<T, S> {
123122
}
124123
}
125124

126-
#[derive(Debug, Clone)]
127-
pub struct FlatMatrixView<F, EF, Inner>(Inner, PhantomData<(F, EF)>);
128-
129-
impl<F, EF, Inner> FlatMatrixView<F, EF, Inner> {
130-
pub const fn new(inner: Inner) -> Self {
131-
Self(inner, PhantomData)
132-
}
133-
}
134-
135-
impl<F, EF, Inner> Deref for FlatMatrixView<F, EF, Inner> {
136-
type Target = Inner;
137-
138-
fn deref(&self) -> &Self::Target {
139-
&self.0
140-
}
141-
}
142-
143-
impl<F, EF, Inner> Matrix<F> for FlatMatrixView<F, EF, Inner>
144-
where
145-
F: Field,
146-
EF: ExtensionField<F>,
147-
Inner: Matrix<EF>,
148-
{
149-
fn width(&self) -> usize {
150-
self.0.width() * EF::DIMENSION
151-
}
152-
153-
fn height(&self) -> usize {
154-
self.0.height()
155-
}
156-
157-
unsafe fn row_subseq_unchecked(
158-
&self,
159-
r: usize,
160-
start: usize,
161-
end: usize,
162-
) -> impl IntoIterator<Item = F, IntoIter = impl Iterator<Item = F> + Send + Sync> {
163-
// We can skip the first start / EF::DIMENSION elements in the row.
164-
let len = end - start;
165-
let inner_start = start / EF::DIMENSION;
166-
unsafe {
167-
// Safety: The caller must ensure that r < self.height(), start <= end and end < self.width().
168-
FlatIter {
169-
inner: self
170-
.0
171-
// We set end to be the width of the inner matrix and use take to ensure we get the right
172-
// number of elements.
173-
.row_subseq_unchecked(r, inner_start, self.0.width())
174-
.into_iter()
175-
.peekable(),
176-
idx: start,
177-
_phantom: PhantomData,
178-
}
179-
.take(len)
180-
}
181-
}
182-
}
183-
184-
pub struct FlatIter<F, I: Iterator> {
185-
inner: iter::Peekable<I>,
186-
idx: usize,
187-
_phantom: PhantomData<F>,
188-
}
189-
190-
impl<F, EF, I> Iterator for FlatIter<F, I>
191-
where
192-
F: Field,
193-
EF: ExtensionField<F>,
194-
I: Iterator<Item = EF>,
195-
{
196-
type Item = F;
197-
fn next(&mut self) -> Option<Self::Item> {
198-
if self.idx == EF::DIMENSION {
199-
self.idx = 0;
200-
self.inner.next();
201-
}
202-
let value = self.inner.peek()?.as_basis_coefficients_slice()[self.idx];
203-
self.idx += 1;
204-
Some(value)
205-
}
206-
}
207-
208125
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
209126
pub struct Dimensions {
210127
/// Number of columns in the matrix.

crates/whir/src/merkle.rs

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,58 +19,88 @@ use utils::log2_ceil_usize;
1919

2020
use crate::DenseMatrix;
2121
use crate::Dimensions;
22-
use crate::FlatMatrixView;
2322
use crate::Matrix;
2423
pub use symetric::DIGEST_ELEMS;
2524

26-
pub(crate) type RoundMerkleTree<F, EF> = WhirMerkleTree<F, FlatMatrixView<F, EF, DenseMatrix<EF>>, DIGEST_ELEMS>;
25+
pub(crate) type RoundMerkleTree<F> = WhirMerkleTree<F, DenseMatrix<F>, DIGEST_ELEMS>;
2726

2827
#[allow(clippy::missing_transmute_annotations)]
2928
pub(crate) fn merkle_commit<F: Field, EF: ExtensionField<F>>(
3029
matrix: DenseMatrix<EF>,
3130
full_n_cols: usize,
3231
effective_n_cols: usize,
33-
) -> ([F; DIGEST_ELEMS], RoundMerkleTree<F, EF>) {
34-
let perm = default_koalabear_poseidon1_16();
32+
) -> ([F; DIGEST_ELEMS], RoundMerkleTree<F>) {
3533
if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() {
3634
let matrix = unsafe { std::mem::transmute::<_, DenseMatrix<QuinticExtensionFieldKB>>(matrix) };
37-
let view = FlatMatrixView::new(matrix);
3835
let dim = <QuinticExtensionFieldKB as BasedVectorSpace<KoalaBear>>::DIMENSION;
36+
let dft_base_width = matrix.width * dim;
3937
let full_base_width = full_n_cols * dim;
4038
let effective_base_width = effective_n_cols * dim;
41-
let tree =
42-
WhirMerkleTree::new::<PFPacking<KoalaBear>, _, 16, 8>(&perm, view, full_base_width, effective_base_width);
39+
let base_values = QuinticExtensionFieldKB::flatten_to_base(matrix.values);
40+
let base_matrix = DenseMatrix::<KoalaBear>::new(base_values, dft_base_width);
41+
let tree = build_merkle_tree_koalabear(base_matrix, full_base_width, effective_base_width);
4342
let root: [_; DIGEST_ELEMS] = tree.root();
4443
let root = unsafe { std::mem::transmute_copy::<_, [F; DIGEST_ELEMS]>(&root) };
45-
let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree<F, EF>>(tree) };
44+
let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree<F>>(tree) };
4645
(root, tree)
4746
} else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() {
4847
let matrix = unsafe { std::mem::transmute::<_, DenseMatrix<KoalaBear>>(matrix) };
49-
let tree = WhirMerkleTree::new::<PFPacking<KoalaBear>, _, 16, 8>(&perm, matrix, full_n_cols, effective_n_cols);
48+
let tree = build_merkle_tree_koalabear(matrix, full_n_cols, effective_n_cols);
5049
let root: [_; DIGEST_ELEMS] = tree.root();
5150
let root = unsafe { std::mem::transmute_copy::<_, [F; DIGEST_ELEMS]>(&root) };
52-
let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree<F, EF>>(tree) };
51+
let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree<F>>(tree) };
5352
(root, tree)
5453
} else {
5554
unimplemented!()
5655
}
5756
}
5857

58+
#[instrument(name = "build merkle tree", skip_all)]
59+
fn build_merkle_tree_koalabear(
60+
leaf: DenseMatrix<KoalaBear>,
61+
full_base_width: usize,
62+
effective_base_width: usize,
63+
) -> RoundMerkleTree<KoalaBear> {
64+
let perm = default_koalabear_poseidon1_16();
65+
let n_zero_suffix_rate_chunks = (full_base_width - effective_base_width) / 8;
66+
let first_layer = if n_zero_suffix_rate_chunks >= 2 {
67+
let scalar_state = symetric::precompute_zero_suffix_state::<KoalaBear, _, 16, 8, DIGEST_ELEMS>(
68+
&perm,
69+
n_zero_suffix_rate_chunks,
70+
);
71+
let packed_state: [PFPacking<KoalaBear>; 16] =
72+
std::array::from_fn(|i| PFPacking::<KoalaBear>::from_fn(|_| scalar_state[i]));
73+
first_digest_layer_with_initial_state::<PFPacking<KoalaBear>, _, _, DIGEST_ELEMS, 16, 8>(
74+
&perm,
75+
&leaf,
76+
&packed_state,
77+
effective_base_width,
78+
)
79+
} else {
80+
first_digest_layer::<PFPacking<KoalaBear>, _, _, DIGEST_ELEMS, 16, 8>(&perm, &leaf, full_base_width)
81+
};
82+
let tree = symetric::merkle::MerkleTree::from_first_layer::<PFPacking<KoalaBear>, _, 16>(&perm, first_layer);
83+
WhirMerkleTree {
84+
leaf,
85+
tree,
86+
full_leaf_base_width: full_base_width,
87+
}
88+
}
89+
5990
#[allow(clippy::missing_transmute_annotations)]
6091
pub(crate) fn merkle_open<F: Field, EF: ExtensionField<F>>(
61-
merkle_tree: &RoundMerkleTree<F, EF>,
92+
merkle_tree: &RoundMerkleTree<F>,
6293
index: usize,
6394
) -> (Vec<EF>, Vec<[F; DIGEST_ELEMS]>) {
6495
if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() {
65-
let merkle_tree =
66-
unsafe { std::mem::transmute::<_, &RoundMerkleTree<KoalaBear, QuinticExtensionFieldKB>>(merkle_tree) };
96+
let merkle_tree = unsafe { std::mem::transmute::<_, &RoundMerkleTree<KoalaBear>>(merkle_tree) };
6797
let (inner_leaf, proof) = merkle_tree.open(index);
6898
let leaf = QuinticExtensionFieldKB::reconstitute_from_base(inner_leaf);
6999
let leaf = unsafe { std::mem::transmute::<_, Vec<EF>>(leaf) };
70100
let proof = unsafe { std::mem::transmute::<_, Vec<[F; DIGEST_ELEMS]>>(proof) };
71101
(leaf, proof)
72102
} else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() {
73-
let merkle_tree = unsafe { std::mem::transmute::<_, &RoundMerkleTree<KoalaBear, KoalaBear>>(merkle_tree) };
103+
let merkle_tree = unsafe { std::mem::transmute::<_, &RoundMerkleTree<KoalaBear>>(merkle_tree) };
74104
let (inner_leaf, proof) = merkle_tree.open(index);
75105
let leaf = KoalaBear::reconstitute_from_base(inner_leaf);
76106
let leaf = unsafe { std::mem::transmute::<_, Vec<EF>>(leaf) };

0 commit comments

Comments
 (0)