diff --git a/program-libs/batched-merkle-tree/src/errors.rs b/program-libs/batched-merkle-tree/src/errors.rs index e09b5bc21b..a322777757 100644 --- a/program-libs/batched-merkle-tree/src/errors.rs +++ b/program-libs/batched-merkle-tree/src/errors.rs @@ -51,8 +51,6 @@ pub enum BatchedMerkleTreeError { NonInclusionCheckFailed, #[error("Bloom filter must be zeroed prior to reusing a batch.")] BloomFilterNotZeroed, - #[error("Cannot zero out complete or more than complete root history.")] - CannotZeroCompleteRootHistory, #[error("Account error {0}")] AccountError(#[from] AccountError), } @@ -72,7 +70,6 @@ impl From for u32 { BatchedMerkleTreeError::TreeIsFull => 14310, BatchedMerkleTreeError::NonInclusionCheckFailed => 14311, BatchedMerkleTreeError::BloomFilterNotZeroed => 14312, - BatchedMerkleTreeError::CannotZeroCompleteRootHistory => 14313, BatchedMerkleTreeError::Hasher(e) => e.into(), BatchedMerkleTreeError::ZeroCopy(e) => e.into(), BatchedMerkleTreeError::MerkleTreeMetadata(e) => e.into(), diff --git a/program-libs/batched-merkle-tree/src/merkle_tree.rs b/program-libs/batched-merkle-tree/src/merkle_tree.rs index 7517684478..94aa28ec05 100644 --- a/program-libs/batched-merkle-tree/src/merkle_tree.rs +++ b/program-libs/batched-merkle-tree/src/merkle_tree.rs @@ -747,11 +747,7 @@ impl<'a> BatchedMerkleTreeAccount<'a> { /// - now all roots containing values nullified in the final B0 root update are zeroed /// - B0 is safe to clear /// - fn zero_out_roots( - &mut self, - sequence_number: u64, - first_safe_root_index: u32, - ) -> Result<(), BatchedMerkleTreeError> { + fn zero_out_roots(&mut self, sequence_number: u64, first_safe_root_index: u32) { // 1. Check whether overlapping roots exist. let overlapping_roots_exits = sequence_number > self.sequence_number; if overlapping_roots_exits { @@ -761,13 +757,10 @@ impl<'a> BatchedMerkleTreeAccount<'a> { // the update of the previous batch therfore allow anyone to prove // inclusion of values nullified in the previous batch. let num_remaining_roots = sequence_number - self.sequence_number; - if num_remaining_roots >= self.root_history.len() as u64 { - return Err(BatchedMerkleTreeError::CannotZeroCompleteRootHistory); - } // 2.2. Zero out roots oldest to first safe root index. // Skip one iteration we don't need to zero out // the first safe root. - for _ in 0..num_remaining_roots { + for _ in 1..num_remaining_roots { self.root_history[oldest_root_index] = [0u8; 32]; oldest_root_index += 1; oldest_root_index %= self.root_history.len(); @@ -778,7 +771,6 @@ impl<'a> BatchedMerkleTreeAccount<'a> { "Zeroing out roots failed." ); } - Ok(()) } /// Zero out bloom filter of previous batch if 50% of the @@ -816,34 +808,23 @@ impl<'a> BatchedMerkleTreeAccount<'a> { let current_batch_is_half_full = num_inserted_elements >= batch_size / 2; current_batch_is_half_full && current_batch_is_not_inserted }; - let sequence_number = self.sequence_number; - let root_history_len = self.metadata.root_history_capacity as u64; + let previous_pending_batch = self .queue_batches .batches .get_mut(previous_pending_batch_index) .ok_or(BatchedMerkleTreeError::InvalidBatchIndex)?; - let no_insert_since_last_batch_root = (previous_pending_batch - .sequence_number - .saturating_sub(root_history_len)) - == sequence_number; + let previous_batch_is_inserted = previous_pending_batch.get_state() == BatchState::Inserted; let previous_batch_is_ready = previous_batch_is_inserted && !previous_pending_batch.bloom_filter_is_zeroed(); // Current batch is at least half full, previous batch is inserted, and not zeroed. - if current_batch_is_half_full && previous_batch_is_ready && !no_insert_since_last_batch_root - { + if current_batch_is_half_full && previous_batch_is_ready { // 3.1. Mark bloom filter zeroed. previous_pending_batch.set_bloom_filter_to_zeroed(); let seq = previous_pending_batch.sequence_number; - // previous_pending_batch.root_index is the index the root - // of the last update of that batch was inserted at. - // This is the last unsafe root index. - // The next index is safe. - let first_safe_root_index = - (previous_pending_batch.root_index + 1) % self.metadata.root_history_capacity; - + let root_index = previous_pending_batch.root_index; // 3.2. Zero out bloom filter. { let bloom_filter = self @@ -856,7 +837,7 @@ impl<'a> BatchedMerkleTreeAccount<'a> { // which allows to prove inclusion of a value // that was inserted into the bloom filter just zeroed out. { - self.zero_out_roots(seq, first_safe_root_index)?; + self.zero_out_roots(seq, root_index); } } diff --git a/program-tests/batched-merkle-tree-test/tests/e2e_tests/shared.rs b/program-tests/batched-merkle-tree-test/tests/e2e_tests/shared.rs index 706f202fe5..513b7cde66 100644 --- a/program-tests/batched-merkle-tree-test/tests/e2e_tests/shared.rs +++ b/program-tests/batched-merkle-tree-test/tests/e2e_tests/shared.rs @@ -149,19 +149,13 @@ pub fn assert_merkle_tree_update( let is_half_full = input_queue_current_batch.get_num_inserted_elements() >= input_queue_current_batch.batch_size / 2 && input_queue_current_batch.get_state() != BatchState::Inserted; - let root_history_len = old_account.root_history.capacity() as u64; - let previous_batch = old_account.queue_batches.get_previous_batch(); - let no_insert_since_last_batch_root = (previous_batch - .sequence_number - .saturating_sub(root_history_len)) - == old_account.sequence_number; + if is_half_full && input_queue_previous_batch_state == BatchState::Inserted && !old_account .queue_batches .get_previous_batch() .bloom_filter_is_zeroed() - && !no_insert_since_last_batch_root { println!("Entering zeroing block for batch {}", previous_batch_index); println!( @@ -213,27 +207,45 @@ pub fn assert_merkle_tree_update( // inclusion of values nullified in the previous batch. let num_remaining_roots = sequence_number - old_account.sequence_number; // 2.2. Zero out roots oldest to first safe root index. - for _ in 0..num_remaining_roots { + // Skip one iteration we don't need to zero out + // the first safe root. + for _ in 1..num_remaining_roots { old_account.root_history[oldest_root_index] = [0u8; 32]; oldest_root_index += 1; oldest_root_index %= old_account.root_history.len(); } - // Assert that all unsafe roots from this batch are zeroed + // Assert that all unsafe roots except the last one are zeroed + // The last root (at root_index) is the first safe root and should remain let batch_key = previous_batch_index as u32; if let Some(unsafe_roots) = batch_roots.get_by_key(&batch_key) { - for unsafe_root in unsafe_roots { - assert!( - !old_account - .root_history - .iter() - .any(|x| *x == *unsafe_root), - "Unsafe root from batch {} should be zeroed: {:?} root history {:?}, unsafe roots {:?}", - previous_batch_index, - unsafe_root, - old_account.root_history, unsafe_roots - ); + // Check all roots except the last one are zeroed + for (idx, unsafe_root) in unsafe_roots.iter().enumerate() { + let is_last_root = idx == unsafe_roots.len() - 1; + + if is_last_root { + // The last root is the first safe root. + // Skip check if it's been rotated out of root_history + if old_account.root_history.iter().any(|x| *x == *unsafe_root) { + assert!( + *unsafe_root != [0u8; 32], + "Last root from batch {} should remain as first safe root: {:?}", + previous_batch_index, + unsafe_root + ); + } + } else { + // All other roots should NOT exist in root_history (either zeroed or rotated out) + assert!( + !old_account.root_history.iter().any(|x| *x == *unsafe_root), + "Unsafe root from batch {} should be zeroed: {:?} root history {:?}, unsafe roots {:?}", + previous_batch_index, + unsafe_root, + old_account.root_history, unsafe_roots + ); + } } + // Clear unsafe roots after verification - batch index will be reused if let Some(roots) = batch_roots.get_mut_by_key(&batch_key) { roots.clear(); diff --git a/program-tests/batched-merkle-tree-test/tests/e2e_tests/state.rs b/program-tests/batched-merkle-tree-test/tests/e2e_tests/state.rs index 234f1501ee..286981f207 100644 --- a/program-tests/batched-merkle-tree-test/tests/e2e_tests/state.rs +++ b/program-tests/batched-merkle-tree-test/tests/e2e_tests/state.rs @@ -206,10 +206,16 @@ async fn test_fill_state_queues_completely() { // Fill up complete input queue. let num_tx = NUM_BATCHES as u64 * params.input_queue_batch_size; let mut first_value = [0u8; 32]; + let mut first_batch_values = Vec::new(); + let mut counter = 0; for tx in 0..num_tx { println!("Input insert ----------------------------- {}", tx); let (_, leaf) = get_random_leaf(&mut rng, &mut mock_indexer.active_leaves); let leaf_index = mock_indexer.merkle_tree.get_leaf_index(&leaf).unwrap(); + if counter < params.input_queue_batch_size { + first_batch_values.push(leaf); + } + counter += 1; let mut pre_mt_account_data = mt_account_data.clone(); let pre_merkle_tree_account = @@ -320,6 +326,9 @@ async fn test_fill_state_queues_completely() { } // Root of the final batch of first input queue batch let mut first_input_batch_update_root_value = [0u8; 32]; + for value in first_batch_values.iter(){ + assert!(mock_indexer.merkle_tree.get_leaf_index(&value).is_some()); + } let num_updates = params.input_queue_batch_size / params.input_queue_zkp_batch_size * NUM_BATCHES as u64; for i in 0..num_updates { @@ -339,23 +348,44 @@ async fn test_fill_state_queues_completely() { .unwrap(); // after 5 updates the first batch is completely inserted - // As soon as we switch to inserting the second batch we zero out the first batch since - // the second batch is completely full. - if i >= 5 { + // At the 5th update (i=4), batch 0 is marked as inserted and the bloom filter is zeroed + // since the second batch is completely full. + if i >= 4 { let batch = merkle_tree_account.queue_batches.batches.first().unwrap(); assert!(batch.bloom_filter_is_zeroed()); + //zeroed out values are no longer on the tree + for value in first_batch_values.iter(){ + assert!(!mock_indexer.merkle_tree.get_leaf_index(&value).is_some()); + } - // Assert that none of the unsafe roots from batch 0 exist in root history + // Assert that all unsafe roots except the last one are zeroed + // The last root (at root_index) is the first safe root and should remain if let Some(unsafe_roots) = batch_roots.get_by_key(&0) { - for unsafe_root in unsafe_roots { - assert!( - !merkle_tree_account - .root_history - .iter() - .any(|x| *x == *unsafe_root), - "Unsafe root from batch 0 should be zeroed: {:?}", - unsafe_root - ); + // Only check roots that are still in root_history (not rotated out) + for (idx, unsafe_root) in unsafe_roots.iter().enumerate() { + let is_last_root = idx == unsafe_roots.len() - 1; + + if is_last_root { + // The last root should still exist in root_history as the first safe root + // Skip check if it's been rotated out + if merkle_tree_account.root_history.iter().any(|x| *x == *unsafe_root) { + assert!( + *unsafe_root != [0u8; 32], + "Last root from batch 0 should remain as first safe root: {:?}", + unsafe_root + ); + } + } else { + // All other roots should NOT exist in root_history (either zeroed or rotated out) + assert!( + !merkle_tree_account + .root_history + .iter() + .any(|x| *x == *unsafe_root), + "Unsafe root from batch 0 should be zeroed: {:?}", + unsafe_root + ); + } } } } else {