diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 9582462..abfaaf5 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -16,6 +16,7 @@ zp1-trace = { workspace = true } zp1-air = { workspace = true } sha2 = { workspace = true } blake3 = { workspace = true } +rand = { workspace = true } rayon = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } @@ -29,7 +30,6 @@ metal = { version = "0.28", optional = true } objc = "0.2" [dev-dependencies] -rand = { workspace = true } criterion = "0.5" [[bench]] diff --git a/crates/prover/src/bitwise_tables.rs b/crates/prover/src/bitwise_tables.rs index 222255d..54ee58a 100644 --- a/crates/prover/src/bitwise_tables.rs +++ b/crates/prover/src/bitwise_tables.rs @@ -34,7 +34,7 @@ impl BitwiseLookupTables { let mut and_values = Vec::with_capacity(size); let mut or_values = Vec::with_capacity(size); let mut xor_values = Vec::with_capacity(size); - + for a in 0u32..256 { for b in 0u32..256 { and_values.push(M31::new(a & b)); @@ -42,7 +42,7 @@ impl BitwiseLookupTables { xor_values.push(M31::new(a ^ b)); } } - + Self { and_values, or_values, @@ -52,7 +52,7 @@ impl BitwiseLookupTables { xor_mult: vec![0; size], } } - + /// Look up 8-bit AND: returns a & b and increments multiplicity #[inline] pub fn and8(&mut self, a: u8, b: u8) -> M31 { @@ -60,7 +60,7 @@ impl BitwiseLookupTables { self.and_mult[idx] += 1; self.and_values[idx] } - + /// Look up 8-bit OR: returns a | b and increments multiplicity #[inline] pub fn or8(&mut self, a: u8, b: u8) -> M31 { @@ -68,7 +68,7 @@ impl BitwiseLookupTables { self.or_mult[idx] += 1; self.or_values[idx] } - + /// Look up 8-bit XOR: returns a ^ b and increments multiplicity #[inline] pub fn xor8(&mut self, a: u8, b: u8) -> M31 { @@ -76,7 +76,7 @@ impl BitwiseLookupTables { self.xor_mult[idx] += 1; self.xor_values[idx] } - + /// Perform 32-bit AND using 4 byte-wise lookups. pub fn and32(&mut self, a: u32, b: u32) -> u32 { let mut result = 0u32; @@ -88,7 +88,7 @@ impl BitwiseLookupTables { } result } - + /// Perform 32-bit OR using 4 byte-wise lookups. pub fn or32(&mut self, a: u32, b: u32) -> u32 { let mut result = 0u32; @@ -100,7 +100,7 @@ impl BitwiseLookupTables { } result } - + /// Perform 32-bit XOR using 4 byte-wise lookups. pub fn xor32(&mut self, a: u32, b: u32) -> u32 { let mut result = 0u32; @@ -112,12 +112,12 @@ impl BitwiseLookupTables { } result } - + /// Get multiplicities for LogUp proof generation. pub fn get_multiplicities(&self) -> (&[u32], &[u32], &[u32]) { (&self.and_mult, &self.or_mult, &self.xor_mult) } - + /// Get table values for LogUp proof generation. pub fn get_values(&self) -> (&[M31], &[M31], &[M31]) { (&self.and_values, &self.or_values, &self.xor_values) @@ -133,55 +133,55 @@ impl Default for BitwiseLookupTables { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_and() { let mut tables = BitwiseLookupTables::new(); - - assert_eq!(tables.and8(0, 0), M31::new(0)); // 0 & 0 = 0 - assert_eq!(tables.and8(255, 255), M31::new(255)); // 0xFF & 0xFF = 0xFF - assert_eq!(tables.and8(0xAA, 0x55), M31::new(0)); // 0xAA & 0x55 = 0 + + assert_eq!(tables.and8(0, 0), M31::new(0)); // 0 & 0 = 0 + assert_eq!(tables.and8(255, 255), M31::new(255)); // 0xFF & 0xFF = 0xFF + assert_eq!(tables.and8(0xAA, 0x55), M31::new(0)); // 0xAA & 0x55 = 0 } - + #[test] fn test_or() { let mut tables = BitwiseLookupTables::new(); - - assert_eq!(tables.or8(0, 0), M31::new(0)); // 0 | 0 = 0 - assert_eq!(tables.or8(0xAA, 0x55), M31::new(255)); // 0xAA | 0x55 = 0xFF + + assert_eq!(tables.or8(0, 0), M31::new(0)); // 0 | 0 = 0 + assert_eq!(tables.or8(0xAA, 0x55), M31::new(255)); // 0xAA | 0x55 = 0xFF } - + #[test] fn test_xor() { let mut tables = BitwiseLookupTables::new(); - - assert_eq!(tables.xor8(0xFF, 0xFF), M31::new(0)); // 0xFF ^ 0xFF = 0 + + assert_eq!(tables.xor8(0xFF, 0xFF), M31::new(0)); // 0xFF ^ 0xFF = 0 assert_eq!(tables.xor8(0xAA, 0x55), M31::new(255)); // 0xAA ^ 0x55 = 0xFF } - + #[test] fn test_bitwise_32bit() { let mut tables = BitwiseLookupTables::new(); - + let a = 0xDEADBEEF_u32; let b = 0xCAFEBABE_u32; - + assert_eq!(tables.and32(a, b), a & b); assert_eq!(tables.or32(a, b), a | b); assert_eq!(tables.xor32(a, b), a ^ b); } - + #[test] fn test_multiplicity_tracking() { let mut tables = BitwiseLookupTables::new(); - + // Perform some lookups tables.and8(0x12, 0x34); - tables.and8(0x12, 0x34); // Same lookup twice - + tables.and8(0x12, 0x34); // Same lookup twice + let (and_mult, _, _) = tables.get_multiplicities(); let idx = 0x12 * 256 + 0x34; - + // Should have been looked up twice assert_eq!(and_mult[idx], 2); } diff --git a/crates/prover/src/channel.rs b/crates/prover/src/channel.rs index 761fe56..788ff79 100644 --- a/crates/prover/src/channel.rs +++ b/crates/prover/src/channel.rs @@ -1,56 +1,72 @@ -//! Fiat-Shamir transcript channel for the prover. +//! Fiat-Shamir transcript channel for the prover using SHA-256. +//! +//! This uses the same SHA-256-based transcript construction as the verifier +//! channel, ensuring prover and verifier derive identical challenges from the +//! same sequence of transcript messages. +//! +//! # Domain Separation +//! +//! Every channel is initialized by hashing the domain separator, binding all +//! subsequent challenges to the specific protocol context. +//! +//! # Byte-to-Field Encoding +//! +//! `absorb` prefixes each call with the byte-length of the data before hashing, +//! making the encoding injective for sequences of variable-length inputs. use sha2::{Digest, Sha256}; use zp1_primitives::{M31, QM31}; -/// Prover channel for Fiat-Shamir transcript. +/// Prover channel for Fiat-Shamir transcript (mirrors `VerifierChannel`). #[derive(Clone)] pub struct ProverChannel { - /// SHA256 state. hasher: Sha256, - /// Transcript bytes for debugging. - transcript: Vec, } impl ProverChannel { - /// Create a new prover channel. + /// Create a new prover channel bound to `domain_separator`. + /// + /// The domain separator is absorbed first, so two channels with different + /// separators always produce different challenges. pub fn new(domain_separator: &[u8]) -> Self { - let mut ch = Self { - hasher: Sha256::new(), - transcript: Vec::new(), - }; - ch.absorb(domain_separator); - ch + let mut hasher = Sha256::new(); + // Length-prefix the domain separator to avoid extension collisions + hasher.update((domain_separator.len() as u64).to_le_bytes()); + hasher.update(domain_separator); + Self { hasher } } - /// Absorb bytes into the transcript. + /// Absorb arbitrary bytes into the transcript. + /// + /// The byte-length is written before the data so that, e.g., + /// `absorb(b"ab")` and `absorb(b"a"); absorb(b"b")` produce + /// different transcript states. pub fn absorb(&mut self, data: &[u8]) { + // Length prefix for injectivity across variable-length inputs + self.hasher.update((data.len() as u64).to_le_bytes()); self.hasher.update(data); - self.transcript.extend_from_slice(data); } - /// Absorb a 32-byte commitment. + /// Absorb a 32-byte commitment into the transcript. pub fn absorb_commitment(&mut self, commitment: &[u8; 32]) { self.absorb(commitment); } - /// Absorb an M31 field element. + /// Absorb an M31 field element into the transcript. pub fn absorb_felt(&mut self, felt: M31) { - self.absorb(&felt.as_u32().to_le_bytes()); + self.hasher.update(felt.as_u32().to_le_bytes()); } /// Squeeze a challenge in M31. pub fn squeeze_challenge(&mut self) -> M31 { let hash = self.hasher.clone().finalize(); self.hasher.update(&hash); - - // Take first 4 bytes, reduce mod P let bytes: [u8; 4] = hash[0..4].try_into().unwrap(); let val = u32::from_le_bytes(bytes); M31::new(val % M31::P) } - /// Squeeze a challenge in QM31 (extension field). + /// Squeeze a challenge in QM31 (four independent M31 challenges). pub fn squeeze_extension_challenge(&mut self) -> QM31 { let c0 = self.squeeze_challenge(); let c1 = self.squeeze_challenge(); @@ -59,19 +75,17 @@ impl ProverChannel { QM31::new(c0, c1, c2, c3) } - /// Alias for squeeze_extension_challenge. + /// Alias for `squeeze_extension_challenge`. pub fn squeeze_qm31(&mut self) -> QM31 { self.squeeze_extension_challenge() } - /// Squeeze n query indices in range [0, domain_size). + /// Squeeze `n` query indices in `[0, domain_size)`. pub fn squeeze_query_indices(&mut self, n: usize, domain_size: usize) -> Vec { let mut indices = Vec::with_capacity(n); while indices.len() < n { let hash = self.hasher.clone().finalize(); self.hasher.update(&hash); - - // Extract multiple indices from each hash for chunk in hash.chunks(4) { if indices.len() >= n { break; @@ -84,11 +98,6 @@ impl ProverChannel { indices.truncate(n); indices } - - /// Get the current transcript length. - pub fn transcript_len(&self) -> usize { - self.transcript.len() - } } impl Default for ProverChannel { @@ -115,6 +124,31 @@ mod tests { assert_eq!(c1, c2); } + #[test] + fn test_domain_separator_matters() { + let mut ch1 = ProverChannel::new(b"protocol-a"); + let mut ch2 = ProverChannel::new(b"protocol-b"); + + ch1.absorb(b"same data"); + ch2.absorb(b"same data"); + + // Different domain separators must yield different challenges + assert_ne!(ch1.squeeze_challenge(), ch2.squeeze_challenge()); + } + + #[test] + fn test_absorb_injective() { + // absorb(b"ab") vs absorb(b"a") + absorb(b"b") must differ + let mut ch1 = ProverChannel::new(b"test"); + ch1.absorb(b"ab"); + + let mut ch2 = ProverChannel::new(b"test"); + ch2.absorb(b"a"); + ch2.absorb(b"b"); + + assert_ne!(ch1.squeeze_challenge(), ch2.squeeze_challenge()); + } + #[test] fn test_query_indices() { let mut ch = ProverChannel::new(b"test"); diff --git a/crates/prover/src/commitment.rs b/crates/prover/src/commitment.rs index 136312a..5b819e6 100644 --- a/crates/prover/src/commitment.rs +++ b/crates/prover/src/commitment.rs @@ -116,7 +116,7 @@ impl MerkleTree { // Pad to power of two let n = leaves.len().next_power_of_two(); let height = n.trailing_zeros() as usize; - + while leaves.len() < n { // Pad with zero hashes (represents empty leaves) leaves.push([0u8; 32]); @@ -125,18 +125,18 @@ impl MerkleTree { // Build tree bottom-up // nodes[0] = root, nodes[1..3] = level 1, nodes[3..7] = level 2, etc. let mut nodes = vec![[0u8; 32]; n - 1]; - + // Start with leaves as the current layer let mut current_layer = leaves.clone(); - + // Build each level from bottom to top // Level h-1 (just above leaves) to level 0 (root) for level in (0..height).rev() { let level_start = (1 << level) - 1; // Index where this level starts in nodes[] - let level_size = 1 << level; // Number of nodes at this level - + let level_size = 1 << level; // Number of nodes at this level + let mut next_layer = Vec::with_capacity(level_size); - + for i in 0..level_size { let left = ¤t_layer[2 * i]; let right = ¤t_layer[2 * i + 1]; @@ -144,7 +144,7 @@ impl MerkleTree { nodes[level_start + i] = parent; next_layer.push(parent); } - + current_layer = next_layer; } @@ -172,16 +172,16 @@ impl MerkleTree { /// sibling is on the left or right. pub fn prove(&self, index: usize) -> MerkleProof { assert!(index < self.leaves.len(), "Index out of bounds"); - + let mut path = Vec::with_capacity(self.height); let mut idx = index; - + // Level h-1 (just above leaves): get sibling from leaves // Level h-2 to 0: get sibling from nodes - + for level in (0..self.height).rev() { let sibling_idx = idx ^ 1; // XOR to get sibling - + if level == self.height - 1 { // Bottom level: sibling is a leaf path.push(self.leaves[sibling_idx]); @@ -191,10 +191,10 @@ impl MerkleTree { let child_level_start = (1 << (level + 1)) - 1; path.push(self.nodes[child_level_start + sibling_idx]); } - + idx /= 2; } - + MerkleProof { leaf_index: index, path, @@ -202,12 +202,12 @@ impl MerkleTree { } /// Generate batch proofs for multiple leaves. - /// + /// /// This is more efficient than generating individual proofs because /// common ancestors only need to be included once. pub fn prove_batch(&self, indices: &[usize]) -> BatchMerkleProof { let proofs: Vec = indices.iter().map(|&i| self.prove(i)).collect(); - + // In a real implementation, we would deduplicate common siblings // For now, just wrap individual proofs BatchMerkleProof { @@ -303,9 +303,10 @@ impl BatchMerkleProof { return false; } - self.proofs.iter().zip(leaves.iter()).all(|(proof, &leaf)| { - MerkleTree::verify(root, leaf, proof) - }) + self.proofs + .iter() + .zip(leaves.iter()) + .all(|(proof, &leaf)| MerkleTree::verify(root, leaf, proof)) } } @@ -330,11 +331,11 @@ mod tests { fn test_merkle_tree_single() { let values = vec![M31::new(42)]; let tree = MerkleTree::new(&values); - + // Single leaf tree has height 0 assert_eq!(tree.height(), 0); assert_eq!(tree.leaf_count(), 1); - + let proof = tree.prove(0); assert!(MerkleTree::verify(&tree.root(), M31::new(42), &proof)); } @@ -343,16 +344,19 @@ mod tests { fn test_merkle_tree_two_leaves() { let values = vec![M31::new(1), M31::new(2)]; let tree = MerkleTree::new(&values); - + assert_eq!(tree.height(), 1); assert_eq!(tree.leaf_count(), 2); - + // Verify both leaves for (i, &v) in values.iter().enumerate() { let proof = tree.prove(i); assert_eq!(proof.len(), 1); - assert!(MerkleTree::verify(&tree.root(), v, &proof), - "Verification failed for leaf {}", i); + assert!( + MerkleTree::verify(&tree.root(), v, &proof), + "Verification failed for leaf {}", + i + ); } } @@ -367,8 +371,11 @@ mod tests { for (i, &v) in values.iter().enumerate() { let proof = tree.prove(i); assert_eq!(proof.len(), 3); - assert!(MerkleTree::verify(&tree.root(), v, &proof), - "Verification failed for leaf {}", i); + assert!( + MerkleTree::verify(&tree.root(), v, &proof), + "Verification failed for leaf {}", + i + ); } } @@ -383,8 +390,11 @@ mod tests { for (i, &v) in values.iter().enumerate() { let proof = tree.prove(i); - assert!(MerkleTree::verify(&tree.root(), v, &proof), - "Verification failed for leaf {}", i); + assert!( + MerkleTree::verify(&tree.root(), v, &proof), + "Verification failed for leaf {}", + i + ); } } @@ -402,10 +412,10 @@ mod tests { fn test_merkle_tree_wrong_index() { let values: Vec = (0..4).map(|i| M31::new(i)).collect(); let tree = MerkleTree::new(&values); - + // Get proof for index 0 let mut proof = tree.prove(0); - + // Modify to claim it's index 1 - should fail proof.leaf_index = 1; assert!(!MerkleTree::verify(&tree.root(), values[0], &proof)); @@ -415,14 +425,14 @@ mod tests { fn test_merkle_tree_tampered_proof() { let values: Vec = (0..4).map(|i| M31::new(i)).collect(); let tree = MerkleTree::new(&values); - + let mut proof = tree.prove(0); - + // Tamper with a sibling hash if !proof.path.is_empty() { proof.path[0][0] ^= 0xFF; } - + assert!(!MerkleTree::verify(&tree.root(), values[0], &proof)); } @@ -430,10 +440,10 @@ mod tests { fn test_merkle_tree_different_roots() { let values1: Vec = (0..4).map(|i| M31::new(i)).collect(); let values2: Vec = (0..4).map(|i| M31::new(i + 100)).collect(); - + let tree1 = MerkleTree::new(&values1); let tree2 = MerkleTree::new(&values2); - + // Different values should have different roots assert_ne!(tree1.root(), tree2.root()); } @@ -442,13 +452,16 @@ mod tests { fn test_merkle_tree_bytes() { let data: Vec<&[u8]> = vec![b"hello", b"world", b"foo", b"bar"]; let tree = MerkleTree::from_bytes(&data); - + assert_eq!(tree.height(), 2); - + for (i, &d) in data.iter().enumerate() { let proof = tree.prove(i); - assert!(MerkleTree::verify_bytes(&tree.root(), d, &proof), - "Verification failed for leaf {}", i); + assert!( + MerkleTree::verify_bytes(&tree.root(), d, &proof), + "Verification failed for leaf {}", + i + ); } } @@ -456,10 +469,10 @@ mod tests { fn test_batch_proof() { let values: Vec = (0..8).map(|i| M31::new(i)).collect(); let tree = MerkleTree::new(&values); - + let indices = vec![1, 3, 5]; let batch_proof = tree.prove_batch(&indices); - + let queried_values: Vec = indices.iter().map(|&i| values[i]).collect(); assert!(batch_proof.verify_all(&tree.root(), &queried_values)); } @@ -470,14 +483,14 @@ mod tests { // A leaf with value that matches an internal node hash shouldn't verify let values: Vec = (0..4).map(|i| M31::new(i)).collect(); let tree = MerkleTree::new(&values); - + // Get a leaf hash let leaf_hash = tree.get_leaf(0).unwrap(); - + // The root uses INTERNAL_PREFIX, leaf uses LEAF_PREFIX // They should be computed differently let root = tree.root(); - + // A crafted "leaf" that equals an internal node shouldn't verify // This would be a second preimage attack without domain separation assert_ne!(leaf_hash, root); @@ -487,7 +500,7 @@ mod tests { fn test_compute_root() { let values: Vec = (0..8).map(|i| M31::new(i)).collect(); let tree = MerkleTree::new(&values); - + // compute_root should give same result as tree.root() assert_eq!(compute_root(&values), tree.root()); } @@ -497,9 +510,9 @@ mod tests { // Test with 256 leaves let values: Vec = (0..256).map(|i| M31::new(i as u32)).collect(); let tree = MerkleTree::new(&values); - + assert_eq!(tree.height(), 8); // log2(256) = 8 - + // Verify a few random indices for &i in &[0, 127, 200, 255] { let proof = tree.prove(i); @@ -512,7 +525,7 @@ mod tests { fn test_empty_tree() { let values: Vec = vec![]; let tree = MerkleTree::new(&values); - + // Empty tree should have a default root assert_eq!(tree.leaf_count(), 1); // Padded to 1 assert_eq!(tree.height(), 0); @@ -523,10 +536,10 @@ mod tests { // Verify that the same proof verifies against the same leaf multiple times let values: Vec = (0..4).map(|i| M31::new(i)).collect(); let tree = MerkleTree::new(&values); - + let proof = tree.prove(2); let root = tree.root(); - + // Multiple verifications should all succeed for _ in 0..10 { assert!(MerkleTree::verify(&root, values[2], &proof)); diff --git a/crates/prover/src/delegation.rs b/crates/prover/src/delegation.rs index c948f23..afcfc21 100644 --- a/crates/prover/src/delegation.rs +++ b/crates/prover/src/delegation.rs @@ -33,8 +33,8 @@ //! | 0xC11 | DELEG_U256_MUL | U256 multiplication | //! | 0xC12 | DELEG_U256_MOD | U256 modular reduction | -use zp1_primitives::{M31, QM31}; use std::collections::HashMap; +use zp1_primitives::{M31, QM31}; /// Delegation CSR addresses. pub mod csr { @@ -139,34 +139,40 @@ impl DelegationCall { outputs: Vec, call_id: u64, ) -> Self { - Self { deleg_type, timestamp, inputs, outputs, call_id } + Self { + deleg_type, + timestamp, + inputs, + outputs, + call_id, + } } /// Compute fingerprint for log-derivative lookup. pub fn fingerprint(&self, alpha: QM31) -> QM31 { let mut fp = QM31::from(self.deleg_type.to_field()); let mut alpha_power = alpha; - + // Include call ID fp = fp + alpha_power * QM31::from(M31::new((self.call_id & 0x7FFFFFFF) as u32)); alpha_power = alpha_power * alpha; - + // Include timestamp fp = fp + alpha_power * QM31::from(M31::new((self.timestamp & 0x7FFFFFFF) as u32)); alpha_power = alpha_power * alpha; - + // Include inputs for &input in &self.inputs { fp = fp + alpha_power * QM31::from(M31::new(input)); alpha_power = alpha_power * alpha; } - + // Include outputs for &output in &self.outputs { fp = fp + alpha_power * QM31::from(M31::new(output)); alpha_power = alpha_power * alpha; } - + fp } } @@ -192,7 +198,12 @@ impl DelegationResult { outputs: Vec, computation_proof: Vec, ) -> Self { - Self { deleg_type, call_id, outputs, computation_proof } + Self { + deleg_type, + call_id, + outputs, + computation_proof, + } } } @@ -239,35 +250,35 @@ impl DelegationArgumentProver { pub fn generate_columns(&self) -> DelegationColumns { let n_calls = self.calls.len(); let n_results = self.results.len(); - + // Call columns let mut call_type = Vec::with_capacity(n_calls); let mut call_id = Vec::with_capacity(n_calls); let mut call_timestamp = Vec::with_capacity(n_calls); let mut call_fingerprints = Vec::with_capacity(n_calls); - - // Result columns + + // Result columns let mut result_type = Vec::with_capacity(n_results); let mut result_call_id = Vec::with_capacity(n_results); let mut result_fingerprints = Vec::with_capacity(n_results); - + for call in &self.calls { call_type.push(call.deleg_type.to_field()); call_id.push(M31::new((call.call_id & 0x7FFFFFFF) as u32)); call_timestamp.push(M31::new((call.timestamp & 0x7FFFFFFF) as u32)); call_fingerprints.push(call.fingerprint(self.alpha)); } - + for result in &self.results { result_type.push(result.deleg_type.to_field()); result_call_id.push(M31::new((result.call_id & 0x7FFFFFFF) as u32)); // Result fingerprint should match call fingerprint for same call_id result_fingerprints.push(self.compute_result_fingerprint(result)); } - + // Compute log-derivative accumulator let (log_deriv_num, log_deriv_denom) = self.compute_log_derivative(); - + DelegationColumns { call_type, call_id, @@ -285,28 +296,28 @@ impl DelegationArgumentProver { fn compute_result_fingerprint(&self, result: &DelegationResult) -> QM31 { let mut fp = QM31::from(result.deleg_type.to_field()); let mut alpha_power = self.alpha; - + // Include call ID fp = fp + alpha_power * QM31::from(M31::new((result.call_id & 0x7FFFFFFF) as u32)); alpha_power = alpha_power * self.alpha; - + // Include outputs (inputs not needed - they're determined by call) for &output in &result.outputs { fp = fp + alpha_power * QM31::from(M31::new(output)); alpha_power = alpha_power * self.alpha; } - + fp } /// Compute log-derivative accumulator for set equality. - /// + /// /// For set equality, we need: Σ 1/(call_fp + β) = Σ 1/(result_fp + β) fn compute_log_derivative(&self) -> (Vec, Vec) { let n = self.calls.len().max(self.results.len()); let mut numerator = vec![QM31::ONE; n]; let mut denominator = vec![QM31::ONE; n]; - + // Accumulate call fingerprints let mut call_prod = QM31::ONE; for (i, call) in self.calls.iter().enumerate() { @@ -316,7 +327,7 @@ impl DelegationArgumentProver { numerator[i] = call_prod; } } - + // Accumulate result fingerprints let mut result_prod = QM31::ONE; for (i, result) in self.results.iter().enumerate() { @@ -326,7 +337,7 @@ impl DelegationArgumentProver { denominator[i] = result_prod; } } - + (numerator, denominator) } @@ -336,14 +347,18 @@ impl DelegationArgumentProver { let mut result_map: HashMap = HashMap::new(); for result in &self.results { if result_map.insert(result.call_id, result).is_some() { - return Err(DelegationError::DuplicateResult { call_id: result.call_id }); + return Err(DelegationError::DuplicateResult { + call_id: result.call_id, + }); } } - + for call in &self.calls { match result_map.get(&call.call_id) { None => { - return Err(DelegationError::MissingResult { call_id: call.call_id }); + return Err(DelegationError::MissingResult { + call_id: call.call_id, + }); } Some(result) => { if result.deleg_type != call.deleg_type { @@ -356,7 +371,7 @@ impl DelegationArgumentProver { } } } - + Ok(()) } @@ -364,7 +379,10 @@ impl DelegationArgumentProver { pub fn by_type(&self) -> HashMap> { let mut grouped = HashMap::new(); for call in &self.calls { - grouped.entry(call.deleg_type).or_insert_with(Vec::new).push(call); + grouped + .entry(call.deleg_type) + .or_insert_with(Vec::new) + .push(call); } grouped } @@ -384,12 +402,12 @@ pub struct DelegationColumns { pub call_id: Vec, pub call_timestamp: Vec, pub call_fingerprints: Vec, - + // Result columns (from precompile circuits) pub result_type: Vec, pub result_call_id: Vec, pub result_fingerprints: Vec, - + // Log-derivative accumulator pub log_deriv_numerator: Vec, pub log_deriv_denominator: Vec, @@ -424,11 +442,21 @@ impl std::fmt::Display for DelegationError { DelegationError::DuplicateResult { call_id } => { write!(f, "Duplicate result for delegation call {}", call_id) } - DelegationError::TypeMismatch { call_id, expected, actual } => { - write!(f, "Type mismatch for call {}: expected {:?}, got {:?}", - call_id, expected, actual) + DelegationError::TypeMismatch { + call_id, + expected, + actual, + } => { + write!( + f, + "Type mismatch for call {}: expected {:?}, got {:?}", + call_id, expected, actual + ) } - DelegationError::InvalidInputs { deleg_type, message } => { + DelegationError::InvalidInputs { + deleg_type, + message, + } => { write!(f, "Invalid inputs for {:?}: {}", deleg_type, message) } } @@ -438,7 +466,7 @@ impl std::fmt::Display for DelegationError { impl std::error::Error for DelegationError {} /// Memory subtree for delegation circuit. -/// +/// /// Each delegation type has its own Merkle subtree for pre-commitment, /// enabling parallel proving of delegation circuits. #[derive(Debug, Clone)] @@ -473,7 +501,7 @@ impl DelegationSubtree { outputs, Vec::new(), // Computation proof filled by circuit ); - + self.calls.push(call); self.results.push(result.clone()); result @@ -513,12 +541,12 @@ impl DelegationSubtree { let n = self.calls.len(); let mut type_col = Vec::with_capacity(n); let mut call_id_col = Vec::with_capacity(n); - + for call in &self.calls { type_col.push(call.deleg_type.to_field()); call_id_col.push(M31::new((call.call_id & 0x7FFFFFFF) as u32)); } - + vec![type_col, call_id_col] } } @@ -529,8 +557,14 @@ mod tests { #[test] fn test_delegation_type_from_csr() { - assert_eq!(DelegationType::from_csr(0xC00), Some(DelegationType::Blake2s)); - assert_eq!(DelegationType::from_csr(0xC10), Some(DelegationType::U256Add)); + assert_eq!( + DelegationType::from_csr(0xC00), + Some(DelegationType::Blake2s) + ); + assert_eq!( + DelegationType::from_csr(0xC10), + Some(DelegationType::U256Add) + ); assert_eq!(DelegationType::from_csr(0x123), None); } @@ -543,14 +577,14 @@ mod tests { vec![0x3000], 1, ); - + let alpha = QM31::from(M31::new(7)); let fp1 = call.fingerprint(alpha); let fp2 = call.fingerprint(alpha); - + // Same call should give same fingerprint assert_eq!(fp1, fp2); - + // Different call should give different fingerprint let call2 = DelegationCall::new( DelegationType::Blake2s, @@ -566,36 +600,54 @@ mod tests { #[test] fn test_delegation_argument_verify_valid() { let mut prover = DelegationArgumentProver::new(); - + prover.add_call(DelegationCall::new( - DelegationType::Blake2s, 100, vec![0x1000], vec![0x2000], 1 + DelegationType::Blake2s, + 100, + vec![0x1000], + vec![0x2000], + 1, )); prover.add_call(DelegationCall::new( - DelegationType::U256Add, 200, vec![0x3000], vec![0x4000], 2 + DelegationType::U256Add, + 200, + vec![0x3000], + vec![0x4000], + 2, )); - + prover.add_result(DelegationResult::new( - DelegationType::Blake2s, 1, vec![0xABCD], vec![] + DelegationType::Blake2s, + 1, + vec![0xABCD], + vec![], )); prover.add_result(DelegationResult::new( - DelegationType::U256Add, 2, vec![0xEF01], vec![] + DelegationType::U256Add, + 2, + vec![0xEF01], + vec![], )); - + assert!(prover.verify().is_ok()); } #[test] fn test_delegation_argument_verify_missing_result() { let mut prover = DelegationArgumentProver::new(); - + prover.add_call(DelegationCall::new( - DelegationType::Blake2s, 100, vec![0x1000], vec![0x2000], 1 + DelegationType::Blake2s, + 100, + vec![0x1000], + vec![0x2000], + 1, )); // No result added! - + let result = prover.verify(); assert!(result.is_err()); - + if let Err(DelegationError::MissingResult { call_id }) = result { assert_eq!(call_id, 1); } else { @@ -606,18 +658,28 @@ mod tests { #[test] fn test_delegation_argument_verify_type_mismatch() { let mut prover = DelegationArgumentProver::new(); - + prover.add_call(DelegationCall::new( - DelegationType::Blake2s, 100, vec![0x1000], vec![0x2000], 1 + DelegationType::Blake2s, + 100, + vec![0x1000], + vec![0x2000], + 1, )); prover.add_result(DelegationResult::new( - DelegationType::U256Add, 1, vec![0xABCD], vec![] // Wrong type! + DelegationType::U256Add, + 1, + vec![0xABCD], + vec![], // Wrong type! )); - + let result = prover.verify(); assert!(result.is_err()); - - if let Err(DelegationError::TypeMismatch { expected, actual, .. }) = result { + + if let Err(DelegationError::TypeMismatch { + expected, actual, .. + }) = result + { assert_eq!(expected, DelegationType::Blake2s); assert_eq!(actual, DelegationType::U256Add); } @@ -626,13 +688,17 @@ mod tests { #[test] fn test_delegation_subtree() { let mut subtree = DelegationSubtree::new(DelegationType::Blake2s); - + let call = DelegationCall::new( - DelegationType::Blake2s, 100, vec![0x1000, 0x2000], vec![], 1 + DelegationType::Blake2s, + 100, + vec![0x1000, 0x2000], + vec![], + 1, ); - + let result = subtree.process_call(call); - + assert_eq!(result.deleg_type, DelegationType::Blake2s); assert_eq!(result.call_id, 1); assert!(!result.outputs.is_empty()); @@ -641,14 +707,38 @@ mod tests { #[test] fn test_by_type() { let mut prover = DelegationArgumentProver::new(); - - prover.add_call(DelegationCall::new(DelegationType::Blake2s, 100, vec![], vec![], 1)); - prover.add_call(DelegationCall::new(DelegationType::U256Add, 200, vec![], vec![], 2)); - prover.add_call(DelegationCall::new(DelegationType::Blake2s, 300, vec![], vec![], 3)); - + + prover.add_call(DelegationCall::new( + DelegationType::Blake2s, + 100, + vec![], + vec![], + 1, + )); + prover.add_call(DelegationCall::new( + DelegationType::U256Add, + 200, + vec![], + vec![], + 2, + )); + prover.add_call(DelegationCall::new( + DelegationType::Blake2s, + 300, + vec![], + vec![], + 3, + )); + let by_type = prover.by_type(); - - assert_eq!(by_type.get(&DelegationType::Blake2s).map(|v| v.len()), Some(2)); - assert_eq!(by_type.get(&DelegationType::U256Add).map(|v| v.len()), Some(1)); + + assert_eq!( + by_type.get(&DelegationType::Blake2s).map(|v| v.len()), + Some(2) + ); + assert_eq!( + by_type.get(&DelegationType::U256Add).map(|v| v.len()), + Some(1) + ); } } diff --git a/crates/prover/src/fri.rs b/crates/prover/src/fri.rs index 51144e7..a88f6e2 100644 --- a/crates/prover/src/fri.rs +++ b/crates/prover/src/fri.rs @@ -17,9 +17,9 @@ //! //! This halves both the domain size and polynomial degree. -use zp1_primitives::M31; use crate::channel::ProverChannel; use crate::commitment::MerkleTree; +use zp1_primitives::M31; /// FRI configuration parameters. #[derive(Clone, Debug)] @@ -50,15 +50,15 @@ impl FriConfig { pub fn new(log_domain_size: usize) -> Self { Self::with_security(log_domain_size, SecurityLevel::Bits100) } - + /// Create a FRI configuration with specified security level. pub fn with_security(log_domain_size: usize, level: SecurityLevel) -> Self { let num_queries = match level { - SecurityLevel::Bits80 => 40, // ~80 bits from FRI - SecurityLevel::Bits100 => 50, // ~100 bits from FRI - SecurityLevel::Bits128 => 64, // ~128 bits from FRI + SecurityLevel::Bits80 => 40, // ~80 bits from FRI + SecurityLevel::Bits100 => 50, // ~100 bits from FRI + SecurityLevel::Bits128 => 64, // ~128 bits from FRI }; - + Self { log_domain_size, folding_factor: 2, @@ -66,13 +66,13 @@ impl FriConfig { final_degree: 8, } } - + /// Create a fast configuration for testing (reduced security). pub fn fast(log_domain_size: usize) -> Self { Self { log_domain_size, folding_factor: 2, - num_queries: 10, // Fast but insecure + num_queries: 10, // Fast but insecure final_degree: 8, } } @@ -86,7 +86,7 @@ impl FriConfig { } (self.log_domain_size - log_final) / log_fold } - + /// Get domain size at a specific layer. pub fn layer_domain_size(&self, layer: usize) -> usize { let log_fold = (self.folding_factor as f64).log2() as usize; @@ -116,7 +116,7 @@ impl FriLayer { tree, } } - + /// Generate a Merkle proof for an index. pub fn prove(&self, index: usize) -> Vec<[u8; 32]> { self.tree.prove(index).path @@ -182,7 +182,7 @@ impl FriProver { evaluations.len() == 1 << self.config.log_domain_size, "Evaluations must match domain size" ); - + let mut layers = Vec::with_capacity(self.config.num_layers()); let mut current_evals = evaluations; @@ -204,10 +204,8 @@ impl FriProver { let final_poly = current_evals; // Generate query proofs - let query_indices = channel.squeeze_query_indices( - self.config.num_queries, - 1 << self.config.log_domain_size, - ); + let query_indices = channel + .squeeze_query_indices(self.config.num_queries, 1 << self.config.log_domain_size); let query_proofs = self.generate_query_proofs(&layers, &query_indices); let proof = FriProof { @@ -228,7 +226,7 @@ impl FriProver { /// This halves the domain size while maintaining the RS proximity property. fn fold_circle(&self, evals: &[M31], alpha: M31, layer: usize) -> Vec { use zp1_primitives::CirclePoint; - + let n = evals.len(); let half_n = n / 2; let mut folded = Vec::with_capacity(half_n); @@ -254,7 +252,7 @@ impl FriProver { // point_i = generator^i let point_i = generator.pow(i as u64); let y_i = point_i.y; - + // Proper Circle FRI folding formula: // f_folded = (sum / 2) + alpha * (diff / (2 * y_i)) // = (sum / 2) + alpha * diff * inv_two * y_i^(-1) @@ -266,7 +264,7 @@ impl FriProver { let y_inv = y_i.inv(); sum * inv_two + alpha * diff * inv_two * y_inv }; - + folded.push(folded_val); } @@ -274,11 +272,7 @@ impl FriProver { } /// Generate query proofs for all requested positions. - fn generate_query_proofs( - &self, - layers: &[FriLayer], - indices: &[usize], - ) -> Vec { + fn generate_query_proofs(&self, layers: &[FriLayer], indices: &[usize]) -> Vec { indices .iter() .map(|&initial_idx| { @@ -289,14 +283,14 @@ impl FriProver { let n = layer.evaluations.len(); // Ensure index is in range current_idx %= n; - + // Sibling is at index + n/2 (mod n) for twin-coset structure let sibling_idx = (current_idx + n / 2) % n; - + // Get values let value = layer.evaluations[current_idx]; let sibling_value = layer.evaluations[sibling_idx]; - + // Get Merkle proof let merkle_proof = layer.prove(current_idx); @@ -317,7 +311,7 @@ impl FriProver { }) .collect() } - + /// Verify a FRI proof (used by the verifier). pub fn verify( &self, @@ -327,29 +321,27 @@ impl FriProver { ) -> bool { // Absorb initial commitment channel.absorb_commitment(initial_commitment); - + // Collect challenges let mut challenges = Vec::with_capacity(proof.layer_commitments.len()); for commitment in &proof.layer_commitments { channel.absorb_commitment(commitment); challenges.push(channel.squeeze_challenge()); } - + // Verify each query - let query_indices = channel.squeeze_query_indices( - self.config.num_queries, - 1 << self.config.log_domain_size, - ); - + let query_indices = channel + .squeeze_query_indices(self.config.num_queries, 1 << self.config.log_domain_size); + for (query_idx, query_proof) in proof.query_proofs.iter().enumerate() { if query_proof.index != query_indices[query_idx] { return false; } - + // Verify folding consistency through layers let mut current_idx = query_proof.index; let mut expected_value = None; - + for (layer_idx, layer_proof) in query_proof.layer_proofs.iter().enumerate() { // If we have an expected value from previous folding, verify it if let Some(expected) = expected_value { @@ -357,21 +349,21 @@ impl FriProver { return false; } } - + // Verify Merkle proof // (In full implementation, would verify against layer commitment) - + // Compute expected folded value for next layer let alpha = challenges[layer_idx]; let inv_two = M31::new(2).inv(); let sum = layer_proof.value + layer_proof.sibling_value; let diff = layer_proof.value - layer_proof.sibling_value; let folded = sum * inv_two + alpha * diff * inv_two; - + expected_value = Some(folded); current_idx /= 2; } - + // Verify final value matches final polynomial evaluation if let Some(expected) = expected_value { let final_eval = evaluate_poly_at(&proof.final_poly, current_idx); @@ -380,7 +372,7 @@ impl FriProver { } } } - + true } } @@ -414,7 +406,7 @@ mod tests { let folded = prover.fold_circle(&evals, alpha, 0); assert_eq!(folded.len(), 8); - + // Verify folding is deterministic let folded2 = prover.fold_circle(&evals, alpha, 0); assert_eq!(folded, folded2); @@ -433,7 +425,7 @@ mod tests { assert!(!proof.final_poly.is_empty()); assert!(!proof.layer_commitments.is_empty()); assert!(!proof.query_proofs.is_empty()); - + // Verify query proofs have Merkle paths for query in &proof.query_proofs { for layer_proof in &query.layer_proofs { @@ -442,22 +434,22 @@ mod tests { } } } - + #[test] fn test_fri_layer() { let evals: Vec = (0..8).map(|i| M31::new(i)).collect(); let layer = FriLayer::new(evals.clone()); - + // Verify commitment is non-zero assert_ne!(layer.commitment, [0u8; 32]); - + // Verify Merkle proofs for i in 0..8 { let proof = layer.prove(i); assert!(!proof.is_empty()); } } - + #[test] fn test_fri_multiple_folds() { let config = FriConfig { @@ -468,21 +460,21 @@ mod tests { }; let prover = FriProver::new(config.clone()); let evals: Vec = (0..64).map(|i| M31::new(i)).collect(); - + let mut channel = ProverChannel::new(b"test"); let (layers, proof) = prover.commit(evals, &mut channel); - + // Should have multiple layers assert!(layers.len() >= 2); - + // Each layer should be half the size of previous for i in 1..layers.len() { assert_eq!( layers[i].evaluations.len(), - layers[i-1].evaluations.len() / 2 + layers[i - 1].evaluations.len() / 2 ); } - + // Final poly should be small assert!(proof.final_poly.len() <= config.final_degree * 2); } diff --git a/crates/prover/src/lde.rs b/crates/prover/src/lde.rs index c1249ec..0342a76 100644 --- a/crates/prover/src/lde.rs +++ b/crates/prover/src/lde.rs @@ -8,7 +8,7 @@ //! 2. Extend coefficients to larger domain //! 3. Evaluate on extended domain (FFT) -use zp1_primitives::{M31, CircleFFT, CirclePoint}; +use zp1_primitives::{CircleFFT, CirclePoint, M31}; /// LDE domain configuration. #[derive(Clone, Debug)] @@ -28,9 +28,12 @@ pub struct LdeDomain { impl LdeDomain { /// Create a new LDE domain. pub fn new(trace_len: usize, blowup: usize) -> Self { - assert!(trace_len.is_power_of_two(), "Trace length must be power of 2"); + assert!( + trace_len.is_power_of_two(), + "Trace length must be power of 2" + ); assert!(blowup.is_power_of_two(), "Blowup must be power of 2"); - + let log_trace_len = trace_len.trailing_zeros() as usize; let log_blowup = blowup.trailing_zeros() as usize; let log_domain_size = log_trace_len + log_blowup; @@ -70,11 +73,15 @@ impl LdeDomain { /// Perform LDE on a single column using Circle FFT. pub fn extend(&self, values: &[M31]) -> Vec { - assert_eq!(values.len(), self.trace_len(), "Input must match trace length"); - + assert_eq!( + values.len(), + self.trace_len(), + "Input must match trace length" + ); + // Step 1: iFFT to get coefficients let coeffs = self.trace_fft.ifft(values); - + // Step 2: FFT on extended domain (zero-padded coefficients) self.extended_fft.fft(&coeffs) } @@ -113,11 +120,11 @@ impl TraceLDE { /// Create a new trace LDE. pub fn new(trace_columns: &[Vec], blowup: usize) -> Self { assert!(!trace_columns.is_empty(), "Need at least one column"); - + let trace_len = trace_columns[0].len(); let domain = LdeDomain::new(trace_len, blowup); let columns = domain.extend_columns(trace_columns); - + Self { domain, columns } } @@ -177,10 +184,10 @@ mod tests { fn test_lde_single_column() { let domain = LdeDomain::new(8, 4); let values: Vec = (0..8).map(|i| M31::new(i)).collect(); - + let extended = domain.extend(&values); assert_eq!(extended.len(), 32); - + // The extended evaluations should match the original at trace points // (every 4th point since blowup=4) // Note: This depends on domain structure, simplified check here @@ -191,9 +198,9 @@ mod tests { fn test_trace_lde() { let col1: Vec = (0..8).map(|i| M31::new(i)).collect(); let col2: Vec = (0..8).map(|i| M31::new(i * 2)).collect(); - + let trace_lde = TraceLDE::new(&[col1, col2], 4); - + assert_eq!(trace_lde.num_columns(), 2); assert_eq!(trace_lde.domain_size(), 32); } @@ -202,15 +209,14 @@ mod tests { fn test_lde_preserves_low_degree() { // A low-degree polynomial should remain low-degree after extension let domain = LdeDomain::new(8, 4); - + // Linear function: f(i) = 3i + 7 let values: Vec = (0..8).map(|i| M31::new(3 * i as u32 + 7)).collect(); - + let extended = domain.extend(&values); assert_eq!(extended.len(), 32); - + // After LDE, interpolating through extended points should give // a polynomial of degree < trace_len (since original was low-degree) } } - diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index f0ca8ea..52c1670 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -2,6 +2,7 @@ //! //! CPU and GPU backends for commitment, LDE, constraint evaluation, and FRI. +pub mod bitwise_tables; pub mod channel; pub mod commitment; pub mod delegation; @@ -9,7 +10,6 @@ pub mod fri; pub mod gpu; pub mod lde; pub mod logup; -pub mod bitwise_tables; pub mod memory; pub mod parallel; pub mod ram; @@ -18,16 +18,22 @@ pub mod serialize; pub mod snark; pub mod stark; -pub use commitment::MerkleTree; pub use channel::ProverChannel; -pub use stark::{StarkConfig, StarkProver, StarkProof, QueryProof}; +pub use commitment::MerkleTree; +pub use delegation::{ + DelegationArgumentProver, DelegationCall, DelegationColumns, DelegationResult, + DelegationSubtree, DelegationType, +}; +pub use gpu::{detect_devices, DeviceType, GpuBackend, GpuDevice, GpuError}; pub use lde::{LdeDomain, TraceLDE}; -pub use logup::{LookupTable, LogUpProver, RangeCheck, PermutationArgument}; -pub use memory::{MemoryConsistencyProver, MemoryAccess, MemoryOp, MemoryColumns}; -pub use parallel::{ParallelConfig, parallel_lde, parallel_merkle_tree, parallel_fri_fold}; -pub use serialize::{SerializableProof, VerificationKey, ProofConfig}; -pub use gpu::{GpuBackend, GpuDevice, GpuError, DeviceType, detect_devices}; -pub use recursion::{RecursiveProver, RecursiveProof, RecursionConfig, SegmentedProver}; -pub use ram::{RamArgumentProver, RamAccess, RamOp, RamColumns, ChunkMemorySubtree}; -pub use delegation::{DelegationArgumentProver, DelegationCall, DelegationResult, DelegationType, DelegationColumns, DelegationSubtree}; -pub use snark::{SnarkWrapper, SnarkProof, SnarkVerifier, SnarkSystem, SnarkConfig, SnarkError, groth16_wrapper, plonk_wrapper, halo2_wrapper}; +pub use logup::{LogUpProver, LookupTable, PermutationArgument, RangeCheck}; +pub use memory::{MemoryAccess, MemoryColumns, MemoryConsistencyProver, MemoryOp}; +pub use parallel::{parallel_fri_fold, parallel_lde, parallel_merkle_tree, ParallelConfig}; +pub use ram::{ChunkMemorySubtree, RamAccess, RamArgumentProver, RamColumns, RamOp}; +pub use recursion::{RecursionConfig, RecursiveProof, RecursiveProver, SegmentedProver}; +pub use serialize::{ProofConfig, SerializableProof, VerificationKey}; +pub use snark::{ + groth16_wrapper, halo2_wrapper, plonk_wrapper, SnarkConfig, SnarkError, SnarkProof, + SnarkSystem, SnarkVerifier, SnarkWrapper, +}; +pub use stark::{QueryProof, StarkConfig, StarkProof, StarkProver}; diff --git a/crates/prover/src/logup.rs b/crates/prover/src/logup.rs index d91065e..2d63052 100644 --- a/crates/prover/src/logup.rs +++ b/crates/prover/src/logup.rs @@ -96,14 +96,18 @@ impl LookupTable { let mut index = HashMap::with_capacity(entries.len()); let mut values = Vec::with_capacity(entries.len()); let mut multiplicities = Vec::with_capacity(entries.len()); - + for (i, (v, m)) in entries.into_iter().enumerate() { index.insert(v.as_u32(), i); values.push(v); multiplicities.push(m); } - - Self { values, multiplicities, index } + + Self { + values, + multiplicities, + index, + } } /// Record a lookup of value v, incrementing its multiplicity. @@ -206,12 +210,12 @@ impl LogUpAccumulator { pub fn add(&mut self, v: M31) { let v_ext = QM31::from_base(v); let diff = self.alpha - v_ext; - + // Handle case where v = α (shouldn't happen with good challenges) if diff == QM31::ZERO { panic!("LogUp: lookup value equals challenge"); } - + let inv = diff.inv(); let prev = *self.running_sum.last().unwrap(); self.running_sum.push(prev + inv); @@ -227,18 +231,18 @@ impl LogUpAccumulator { self.terms.push(QM31::ZERO); return; } - + let v_ext = QM31::from_base(v); let diff = self.alpha - v_ext; - + if diff == QM31::ZERO { panic!("LogUp: table value equals challenge"); } - + let inv = diff.inv(); let mult = QM31::from_base(M31::new(multiplicity)); let term = mult * inv; - + let prev = *self.running_sum.last().unwrap(); self.running_sum.push(prev - term); self.terms.push(term); @@ -387,12 +391,12 @@ impl MultiColumnLogUp { pub fn combine(&self, values: &[M31]) -> QM31 { let mut result = QM31::ZERO; let mut beta_power = QM31::ONE; - + for &v in values { result = result + QM31::from_base(v) * beta_power; beta_power = beta_power * self.beta; } - + result } @@ -585,7 +589,12 @@ pub struct MemoryOp { impl MemoryOp { /// Create a new memory operation. pub fn new(addr: M31, value: M31, timestamp: u32, is_write: bool) -> Self { - Self { addr, value, timestamp, is_write } + Self { + addr, + value, + timestamp, + is_write, + } } /// Create a read operation. @@ -613,7 +622,9 @@ pub struct MemoryConsistency { impl MemoryConsistency { /// Create a new memory consistency checker. pub fn new() -> Self { - Self { operations: Vec::new() } + Self { + operations: Vec::new(), + } } /// Add a memory operation. @@ -649,11 +660,9 @@ impl MemoryConsistency { // Sort by (address, timestamp) let mut sorted = self.operations.clone(); - sorted.sort_by(|a, b| { - match a.addr.as_u32().cmp(&b.addr.as_u32()) { - std::cmp::Ordering::Equal => a.timestamp.cmp(&b.timestamp), - ord => ord, - } + sorted.sort_by(|a, b| match a.addr.as_u32().cmp(&b.addr.as_u32()) { + std::cmp::Ordering::Equal => a.timestamp.cmp(&b.timestamp), + ord => ord, }); // Check consistency: for each address, reads must match preceding writes @@ -685,14 +694,17 @@ impl MemoryConsistency { // Generate LogUp proof using multi-column approach let mc_logup = MultiColumnLogUp::new(alpha, beta); - + // Original order tuples (addr, value, timestamp) - let original_tuples: Vec> = self.operations.iter() + let original_tuples: Vec> = self + .operations + .iter() .map(|op| vec![op.addr, op.value, M31::new(op.timestamp)]) .collect(); // Sorted order tuples - let sorted_tuples: Vec> = sorted.iter() + let sorted_tuples: Vec> = sorted + .iter() .map(|op| vec![op.addr, op.value, M31::new(op.timestamp)]) .collect(); @@ -735,10 +747,8 @@ mod tests { #[test] fn test_lookup_table_basic() { - let table = LookupTable::new(vec![ - M31::new(10), M31::new(20), M31::new(30) - ]); - + let table = LookupTable::new(vec![M31::new(10), M31::new(20), M31::new(30)]); + assert_eq!(table.len(), 3); assert!(!table.is_empty()); assert!(table.contains(M31::new(10))); @@ -747,9 +757,7 @@ mod tests { #[test] fn test_lookup_table_lookup() { - let mut table = LookupTable::new(vec![ - M31::new(0), M31::new(1), M31::new(2), M31::new(3) - ]); + let mut table = LookupTable::new(vec![M31::new(0), M31::new(1), M31::new(2), M31::new(3)]); // Lookup existing values assert_eq!(table.lookup(M31::new(1)), Some(1)); @@ -769,7 +777,7 @@ mod tests { fn test_lookup_table_range() { let table = LookupTable::range_table(4); assert_eq!(table.len(), 16); - + for i in 0..16 { assert!(table.contains(M31::new(i))); } @@ -822,9 +830,7 @@ mod tests { let prover = LogUpProver::new(alpha); // Create table [0, 1, 2, 3] - let mut table = LookupTable::new(vec![ - M31::new(0), M31::new(1), M31::new(2), M31::new(3) - ]); + let mut table = LookupTable::new(vec![M31::new(0), M31::new(1), M31::new(2), M31::new(3)]); // Lookup values that are in the table let lookups = vec![M31::new(1), M31::new(2), M31::new(1)]; @@ -846,9 +852,7 @@ mod tests { let alpha = test_alpha(); let prover = LogUpProver::new(alpha); - let mut table = LookupTable::new(vec![ - M31::new(0), M31::new(1), M31::new(2) - ]); + let mut table = LookupTable::new(vec![M31::new(0), M31::new(1), M31::new(2)]); // Lookup value NOT in table (don't record it) let lookups = vec![M31::new(1), M31::new(99)]; // 99 not in table @@ -904,7 +908,7 @@ mod tests { #[test] fn test_range_check_proof() { let mut rc = RangeCheck::new(4); // [0, 16) - + let values = vec![M31::new(1), M31::new(5), M31::new(1), M31::new(15)]; assert!(rc.check_all(&values)); diff --git a/crates/prover/src/memory.rs b/crates/prover/src/memory.rs index bdd62f5..af5c925 100644 --- a/crates/prover/src/memory.rs +++ b/crates/prover/src/memory.rs @@ -590,10 +590,10 @@ impl MemoryAirConstraints { /// Constraint: same_addr * is_read * (value_curr - value_prev) = 0 pub fn consistency_constraint( &self, - same_addr: M31, // 1 if same address as previous - is_read: M31, // 1 if current op is read - value_curr: M31, // current value - value_prev: M31, // previous value + same_addr: M31, // 1 if same address as previous + is_read: M31, // 1 if current op is read + value_curr: M31, // current value + value_prev: M31, // previous value ) -> M31 { // same_addr * is_read * (value_curr - value_prev) should be 0 let value_diff = value_curr - value_prev; @@ -607,12 +607,7 @@ impl MemoryAirConstraints { /// where inverse proves ts_curr - ts_prev - 1 >= 0 /// /// Simplified: just check ts_curr - ts_prev is positive when same_addr=1 - pub fn timestamp_constraint( - &self, - same_addr: M31, - ts_curr: M31, - ts_prev: M31, - ) -> M31 { + pub fn timestamp_constraint(&self, same_addr: M31, ts_curr: M31, ts_prev: M31) -> M31 { // When same_addr=1, we need ts_curr > ts_prev // This is typically proven via range check, but here we return the difference // A full implementation would use a range check argument @@ -781,7 +776,10 @@ mod tests { let result = prover.verify_consistency(); assert!(result.is_err()); - if let Err(MemoryError::TimestampOrder { prev_ts, curr_ts, .. }) = result { + if let Err(MemoryError::TimestampOrder { + prev_ts, curr_ts, .. + }) = result + { assert_eq!(prev_ts, 10); assert_eq!(curr_ts, 10); } else { @@ -862,95 +860,75 @@ mod tests { #[test] fn test_air_constraints_fingerprint() { - let constraints = MemoryAirConstraints::new( - QM31::from(M31::new(5)), - QM31::from(M31::new(7)), - ); + let constraints = + MemoryAirConstraints::new(QM31::from(M31::new(5)), QM31::from(M31::new(7))); - let fp = constraints.fingerprint( - M31::new(0x1000), - M31::new(42), - M31::new(1), - M31::ONE, - ); + let fp = constraints.fingerprint(M31::new(0x1000), M31::new(42), M31::new(1), M31::ONE); // Fingerprint should be non-zero assert_ne!(fp, QM31::ZERO); // Same inputs should give same fingerprint - let fp2 = constraints.fingerprint( - M31::new(0x1000), - M31::new(42), - M31::new(1), - M31::ONE, - ); + let fp2 = constraints.fingerprint(M31::new(0x1000), M31::new(42), M31::new(1), M31::ONE); assert_eq!(fp, fp2); } #[test] fn test_consistency_constraint_satisfied() { - let constraints = MemoryAirConstraints::new( - QM31::from(M31::new(5)), - QM31::from(M31::new(7)), - ); + let constraints = + MemoryAirConstraints::new(QM31::from(M31::new(5)), QM31::from(M31::new(7))); // Same address, read, same value -> should be 0 let result = constraints.consistency_constraint( - M31::ONE, // same_addr = 1 - M31::ONE, // is_read = 1 - M31::new(42), // value_curr - M31::new(42), // value_prev (same) + M31::ONE, // same_addr = 1 + M31::ONE, // is_read = 1 + M31::new(42), // value_curr + M31::new(42), // value_prev (same) ); assert_eq!(result, M31::ZERO); } #[test] fn test_consistency_constraint_write_ok() { - let constraints = MemoryAirConstraints::new( - QM31::from(M31::new(5)), - QM31::from(M31::new(7)), - ); + let constraints = + MemoryAirConstraints::new(QM31::from(M31::new(5)), QM31::from(M31::new(7))); // Same address, write, different value -> should be 0 (write can change) let result = constraints.consistency_constraint( - M31::ONE, // same_addr = 1 - M31::ZERO, // is_read = 0 (it's a write) - M31::new(100), // value_curr - M31::new(42), // value_prev (different) + M31::ONE, // same_addr = 1 + M31::ZERO, // is_read = 0 (it's a write) + M31::new(100), // value_curr + M31::new(42), // value_prev (different) ); assert_eq!(result, M31::ZERO); } #[test] fn test_consistency_constraint_different_addr() { - let constraints = MemoryAirConstraints::new( - QM31::from(M31::new(5)), - QM31::from(M31::new(7)), - ); + let constraints = + MemoryAirConstraints::new(QM31::from(M31::new(5)), QM31::from(M31::new(7))); // Different address -> constraint doesn't apply let result = constraints.consistency_constraint( - M31::ZERO, // same_addr = 0 - M31::ONE, // is_read = 1 - M31::new(100), // value_curr - M31::new(42), // value_prev (different but OK) + M31::ZERO, // same_addr = 0 + M31::ONE, // is_read = 1 + M31::new(100), // value_curr + M31::new(42), // value_prev (different but OK) ); assert_eq!(result, M31::ZERO); } #[test] fn test_consistency_constraint_violated() { - let constraints = MemoryAirConstraints::new( - QM31::from(M31::new(5)), - QM31::from(M31::new(7)), - ); + let constraints = + MemoryAirConstraints::new(QM31::from(M31::new(5)), QM31::from(M31::new(7))); // Same address, read, different value -> should be non-zero (violation) let result = constraints.consistency_constraint( - M31::ONE, // same_addr = 1 - M31::ONE, // is_read = 1 - M31::new(100), // value_curr - M31::new(42), // value_prev (different!) + M31::ONE, // same_addr = 1 + M31::ONE, // is_read = 1 + M31::new(100), // value_curr + M31::new(42), // value_prev (different!) ); assert_ne!(result, M31::ZERO); } diff --git a/crates/prover/src/parallel.rs b/crates/prover/src/parallel.rs index 7fb57db..d441d89 100644 --- a/crates/prover/src/parallel.rs +++ b/crates/prover/src/parallel.rs @@ -38,23 +38,23 @@ pub fn parallel_lde(columns: &[Vec], blowup: usize) -> Vec> { fn extend_column(column: &[M31], blowup: usize) -> Vec { let n = column.len(); let extended_len = n * blowup; - + // Simple extension: interpolate then evaluate on extended domain // In production, use Circle FFT for proper LDE let mut extended = vec![M31::ZERO; extended_len]; - + // For now, copy original values at stride positions for (i, &val) in column.iter().enumerate() { extended[i * blowup] = val; } - + // Fill intermediate values using linear interpolation (placeholder) for i in 0..n { let start_idx = i * blowup; let _end_idx = ((i + 1) % n) * blowup; let start_val = column[i]; let end_val = column[(i + 1) % n]; - + for j in 1..blowup { let t = M31::new(j as u32); let inv_blowup = M31::new(blowup as u32).inv(); @@ -63,7 +63,7 @@ fn extend_column(column: &[M31], blowup: usize) -> Vec { extended[idx] = start_val + (end_val - start_val) * t * inv_blowup; } } - + extended } @@ -107,32 +107,32 @@ pub fn parallel_merkle_tree(values: &[M31]) -> (Vec<[u8; 32]>, [u8; 32]) { if values.is_empty() { return (vec![[0u8; 32]], [0u8; 32]); } - + let n = values.len().next_power_of_two(); - + // Hash leaves in parallel let mut leaves = parallel_hash_leaves(values); - + // Pad to power of two while leaves.len() < n { leaves.push([0u8; 32]); } - + // Build tree layers let mut layers = vec![leaves.clone()]; let mut current = leaves; - + while current.len() > 1 { current = parallel_merkle_layer(¤t); layers.push(current.clone()); } - + let root = if current.is_empty() { [0u8; 32] } else { current[0] }; - + (layers.into_iter().flatten().collect(), root) } @@ -159,17 +159,17 @@ where F: Fn(usize, &[M31], &[M31]) -> M31 + Sync, { let blowup = domain_size / trace_lde[0].len().max(1); - + (0..domain_size) .into_par_iter() .map(|i| { // Get current row values let row: Vec = trace_lde.iter().map(|col| col[i]).collect(); - + // Get next row values (with wraparound) let next_idx = (i + blowup) % domain_size; let next_row: Vec = trace_lde.iter().map(|col| col[next_idx]).collect(); - + evaluator(i, &row, &next_row) }) .collect() @@ -180,10 +180,10 @@ pub fn parallel_batch_inverse(values: &[M31]) -> Vec { if values.is_empty() { return vec![]; } - + // For large batches, split and process in parallel const CHUNK_SIZE: usize = 1024; - + if values.len() <= CHUNK_SIZE { batch_inverse_sequential(values) } else { @@ -200,17 +200,17 @@ fn batch_inverse_sequential(values: &[M31]) -> Vec { if n == 0 { return vec![]; } - + // Compute prefix products let mut prefix = vec![M31::ONE; n]; prefix[0] = values[0]; for i in 1..n { prefix[i] = prefix[i - 1] * values[i]; } - + // Compute inverse of product let mut inv_prod = prefix[n - 1].inv(); - + // Compute individual inverses let mut result = vec![M31::ZERO; n]; for i in (1..n).rev() { @@ -218,7 +218,7 @@ fn batch_inverse_sequential(values: &[M31]) -> Vec { inv_prod = inv_prod * values[i]; } result[0] = inv_prod; - + result } @@ -248,7 +248,7 @@ impl ParallelConfig { ..Default::default() } } - + /// Initialize Rayon thread pool. pub fn init_thread_pool(&self) -> Result<(), rayon::ThreadPoolBuildError> { if self.num_threads > 0 { @@ -269,9 +269,9 @@ mod tests { fn test_parallel_evaluate_poly() { let coeffs = vec![M31::new(1), M31::new(2), M31::new(3)]; // 1 + 2x + 3x^2 let points = vec![M31::ZERO, M31::ONE, M31::new(2)]; - + let results = parallel_evaluate_poly(&coeffs, &points); - + assert_eq!(results[0].as_u32(), 1); // p(0) = 1 assert_eq!(results[1].as_u32(), 6); // p(1) = 1 + 2 + 3 = 6 assert_eq!(results[2].as_u32(), 17); // p(2) = 1 + 4 + 12 = 17 @@ -281,7 +281,7 @@ mod tests { fn test_parallel_merkle_tree() { let values: Vec = (0..8).map(|i| M31::new(i)).collect(); let (_, root) = parallel_merkle_tree(&values); - + assert_ne!(root, [0u8; 32]); } @@ -289,9 +289,9 @@ mod tests { fn test_parallel_fri_fold() { let evals: Vec = (0..8).map(|i| M31::new(i)).collect(); let alpha = M31::new(3); - + let folded = parallel_fri_fold(&evals, alpha); - + assert_eq!(folded.len(), 4); // folded[0] = evals[0] + alpha * evals[4] = 0 + 3*4 = 12 assert_eq!(folded[0].as_u32(), 12); @@ -301,7 +301,7 @@ mod tests { fn test_batch_inverse() { let values = vec![M31::new(2), M31::new(3), M31::new(5), M31::new(7)]; let inverses = parallel_batch_inverse(&values); - + for (v, inv) in values.iter().zip(inverses.iter()) { let product = *v * *inv; assert_eq!(product.as_u32(), 1); @@ -310,12 +310,10 @@ mod tests { #[test] fn test_parallel_lde() { - let columns = vec![ - vec![M31::new(0), M31::new(1), M31::new(2), M31::new(3)], - ]; - + let columns = vec![vec![M31::new(0), M31::new(1), M31::new(2), M31::new(3)]]; + let extended = parallel_lde(&columns, 4); - + assert_eq!(extended.len(), 1); assert_eq!(extended[0].len(), 16); } diff --git a/crates/prover/src/ram.rs b/crates/prover/src/ram.rs index 72d4f9f..a21d87a 100644 --- a/crates/prover/src/ram.rs +++ b/crates/prover/src/ram.rs @@ -496,7 +496,8 @@ impl RamArgumentProver { /// - Sum of init fingerprints = Sum of final fingerprints (with chunk linking) pub fn verify_shuffle2(&self) -> bool { let (init_tuples, final_tuples) = self.extract_init_final(); - let (init_fps, final_fps) = self.compute_init_final_fingerprints(&init_tuples, &final_tuples); + let (init_fps, final_fps) = + self.compute_init_final_fingerprints(&init_tuples, &final_tuples); // For single chunk or complete execution, init and final should balance // In multi-chunk setting, would need to link across chunks @@ -602,7 +603,11 @@ pub struct RamProof { #[derive(Debug, Clone, PartialEq, Eq)] pub enum RamError { /// Timestamp ordering violation - TimestampOrder { address: u32, prev_ts: u64, curr_ts: u64 }, + TimestampOrder { + address: u32, + prev_ts: u64, + curr_ts: u64, + }, /// Read value mismatch ReadMismatch { address: u32, @@ -787,13 +792,7 @@ impl RamAirConstraints { } /// Compute fingerprint for a RAM tuple. - pub fn fingerprint( - &self, - address: M31, - value: M31, - timestamp: M31, - op: M31, - ) -> QM31 { + pub fn fingerprint(&self, address: M31, value: M31, timestamp: M31, op: M31) -> QM31 { let addr = QM31::from(address); let val = QM31::from(value); let ts = QM31::from(timestamp); @@ -1096,26 +1095,23 @@ mod tests { #[test] fn test_air_constraints() { - let constraints = RamAirConstraints::new( - QM31::from(M31::new(5)), - QM31::from(M31::new(7)), - ); + let constraints = RamAirConstraints::new(QM31::from(M31::new(5)), QM31::from(M31::new(7))); // Consistency constraint: same addr, read, same value -> 0 let result = constraints.consistency_constraint( - M31::ONE, // same_addr - M31::ONE, // is_read - M31::new(42), // value_curr - M31::new(42), // value_prev + M31::ONE, // same_addr + M31::ONE, // is_read + M31::new(42), // value_curr + M31::new(42), // value_prev ); assert_eq!(result, M31::ZERO); // Consistency constraint: same addr, read, different value -> non-zero let result = constraints.consistency_constraint( - M31::ONE, // same_addr - M31::ONE, // is_read - M31::new(100), // value_curr - M31::new(42), // value_prev + M31::ONE, // same_addr + M31::ONE, // is_read + M31::new(100), // value_curr + M31::new(42), // value_prev ); assert_ne!(result, M31::ZERO); } diff --git a/crates/prover/src/recursion.rs b/crates/prover/src/recursion.rs index bfc7eec..5ae48c6 100644 --- a/crates/prover/src/recursion.rs +++ b/crates/prover/src/recursion.rs @@ -8,9 +8,9 @@ #![allow(dead_code)] +use crate::stark::StarkProof; use blake3::Hasher; use zp1_primitives::M31; -use crate::stark::StarkProof; // M31 modulus const M31_MODULUS: u32 = (1 << 31) - 1; @@ -62,7 +62,7 @@ impl RecursiveProver { pub fn new(config: RecursionConfig) -> Self { Self { config } } - + /// Aggregate multiple proofs into one. /// /// The aggregated proof proves that all inner proofs are valid. @@ -70,17 +70,18 @@ impl RecursiveProver { if proofs.is_empty() { return Err(RecursionError::EmptyBatch); } - + if proofs.len() > self.config.max_batch_size { return Err(RecursionError::BatchTooLarge { size: proofs.len(), max: self.config.max_batch_size, }); } - + if self.config.verify_structure { for (i, proof) in proofs.iter().enumerate() { - self.validate_proof(proof).map_err(|msg| RecursionError::InvalidProof(format!("proof {}: {}", i, msg)))?; + self.validate_proof(proof) + .map_err(|msg| RecursionError::InvalidProof(format!("proof {}: {}", i, msg)))?; } } @@ -98,16 +99,16 @@ impl RecursiveProver { .collect() }) .collect(); - + // Compute commitment to all proofs let verifier_commitment = self.compute_batch_commitment(proofs); - + // Placeholder aggregated proof: we cannot build a real recursive proof here. // We retain the first proof's structure but bind both commitments to the batch hash. let mut aggregated = proofs[0].clone(); aggregated.trace_commitment = verifier_commitment; aggregated.composition_commitment = verifier_commitment; - + Ok(RecursiveProof { inner_proof: aggregated, num_aggregated: proofs.len(), @@ -115,7 +116,7 @@ impl RecursiveProver { verifier_commitment, }) } - + /// Compute commitment to a batch of proofs. fn compute_batch_commitment(&self, proofs: &[StarkProof]) -> [u8; 32] { let mut hasher = Hasher::new(); @@ -150,23 +151,23 @@ impl RecursiveProver { } Ok(()) } - + /// Flatten public outputs for aggregated proof. fn flatten_public_outputs(outputs: &[Vec]) -> Vec { let mut result = Vec::new(); - + // Add count of inner proofs result.push(M31::new(outputs.len() as u32)); - + // Add each proof's outputs prefixed by length for out in outputs { result.push(M31::new(out.len() as u32)); result.extend(out.iter().cloned()); } - + result } - + /// Recursively aggregate proofs in a tree structure. /// /// This allows aggregating more proofs than the batch size by @@ -175,26 +176,26 @@ impl RecursiveProver { if proofs.is_empty() { return Err(RecursionError::EmptyBatch); } - + if proofs.len() == 1 { // Base case: wrap single proof return self.aggregate(proofs); } - + // Recursively aggregate in batches let mut level_proofs = proofs.to_vec(); - + while level_proofs.len() > self.config.max_batch_size { let mut next_level = Vec::new(); - + for chunk in level_proofs.chunks(self.config.max_batch_size) { let aggregated = self.aggregate(chunk)?; next_level.push(aggregated.inner_proof); } - + level_proofs = next_level; } - + self.aggregate(&level_proofs) } } @@ -258,35 +259,32 @@ impl SegmentedProver { config, } } - + /// Add a segment proof. pub fn add_segment(&mut self, continuation: ProofContinuation) { self.segments.push(continuation); } - + /// Get accumulated segment count. pub fn num_segments(&self) -> usize { self.segments.len() } - + /// Finalize and aggregate all segments. pub fn finalize(self) -> Result { if self.segments.is_empty() { return Err(RecursionError::EmptyBatch); } - + // Verify segment chain for _i in 1..self.segments.len() { // In real implementation, verify state continuity // prev.final_state should match curr.initial_state } - + // Aggregate segment proofs - let proofs: Vec = self.segments - .into_iter() - .map(|s| s.segment_proof) - .collect(); - + let proofs: Vec = self.segments.into_iter().map(|s| s.segment_proof).collect(); + let prover = RecursiveProver::new(self.config); prover.tree_aggregate(&proofs) } @@ -305,7 +303,7 @@ impl ProofCompressor { pub fn new(target_size: usize) -> Self { Self { target_size } } - + /// Compress a proof by recursively verifying it. pub fn compress(&self, proof: &StarkProof) -> Result { // In a real implementation: @@ -314,11 +312,11 @@ impl ProofCompressor { // 3. Prove the verification with smaller parameters // // This reduces proof size at the cost of proving time - + // Placeholder: return proof unchanged Ok(proof.clone()) } - + /// Estimate compressed proof size. pub fn estimate_size(&self, proof: &StarkProof) -> usize { // Simplified estimate based on actual StarkProof structure @@ -335,7 +333,7 @@ mod tests { use super::*; use crate::fri::FriProof; use crate::stark::OodValues; - + fn mock_proof() -> StarkProof { StarkProof { trace_commitment: [1u8; 32], @@ -353,33 +351,33 @@ mod tests { query_proofs: vec![], } } - + #[test] fn test_aggregate_single() { let prover = RecursiveProver::new(RecursionConfig::default()); let proof = mock_proof(); - + let result = prover.aggregate(&[proof]).unwrap(); assert_eq!(result.num_aggregated, 1); } - + #[test] fn test_aggregate_multiple() { let prover = RecursiveProver::new(RecursionConfig::default()); let proofs = vec![mock_proof(), mock_proof(), mock_proof()]; - + let result = prover.aggregate(&proofs).unwrap(); assert_eq!(result.num_aggregated, 3); assert_eq!(result.public_outputs.len(), 3); } - + #[test] fn test_aggregate_empty() { let prover = RecursiveProver::new(RecursionConfig::default()); let result = prover.aggregate(&[]); assert!(matches!(result, Err(RecursionError::EmptyBatch))); } - + #[test] fn test_aggregate_too_large() { let config = RecursionConfig { @@ -388,11 +386,11 @@ mod tests { }; let prover = RecursiveProver::new(config); let proofs = vec![mock_proof(), mock_proof(), mock_proof()]; - + let result = prover.aggregate(&proofs); assert!(matches!(result, Err(RecursionError::BatchTooLarge { .. }))); } - + #[test] fn test_tree_aggregate() { let config = RecursionConfig { @@ -400,18 +398,18 @@ mod tests { ..Default::default() }; let prover = RecursiveProver::new(config); - + // Create 5 proofs (requires tree aggregation with batch size 2) let proofs: Vec<_> = (0..5).map(|_| mock_proof()).collect(); - + let result = prover.tree_aggregate(&proofs).unwrap(); assert!(result.num_aggregated > 0); } - + #[test] fn test_segmented_prover() { let mut prover = SegmentedProver::new(RecursionConfig::default()); - + // Add segments for i in 0..3 { prover.add_segment(ProofContinuation { @@ -421,25 +419,25 @@ mod tests { is_final: i == 2, }); } - + assert_eq!(prover.num_segments(), 3); - + let result = prover.finalize().unwrap(); assert!(result.num_aggregated > 0); } - + #[test] fn test_proof_compressor() { let compressor = ProofCompressor::new(1024); let proof = mock_proof(); - + let size = compressor.estimate_size(&proof); assert!(size > 0); - + let compressed = compressor.compress(&proof).unwrap(); assert_eq!(compressed.trace_commitment, proof.trace_commitment); } - + #[test] fn test_recursion_error_display() { let err = RecursionError::BatchTooLarge { size: 10, max: 4 }; diff --git a/crates/prover/src/serialize.rs b/crates/prover/src/serialize.rs index 07cffe3..468dd1e 100644 --- a/crates/prover/src/serialize.rs +++ b/crates/prover/src/serialize.rs @@ -2,7 +2,7 @@ //! //! Provides serialization/deserialization for proofs and verification keys. -use serde::{Serialize, Deserialize, Serializer, Deserializer}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use zp1_primitives::M31; /// Serialize M31 as u32. @@ -23,7 +23,9 @@ pub fn serialize_m31_vec(vals: &[M31], serializer: S) -> Result from Vec. -pub fn deserialize_m31_vec<'de, D: Deserializer<'de>>(deserializer: D) -> Result, D::Error> { +pub fn deserialize_m31_vec<'de, D: Deserializer<'de>>( + deserializer: D, +) -> Result, D::Error> { let u32_vals: Vec = Vec::deserialize(deserializer)?; Ok(u32_vals.into_iter().map(M31::new).collect()) } @@ -34,22 +36,25 @@ pub struct SerializableProof { /// Trace commitment (Merkle root). #[serde(with = "hex_array")] pub trace_commitment: [u8; 32], - + /// Composition polynomial commitment. #[serde(with = "hex_array")] pub composition_commitment: [u8; 32], - + /// FRI layer commitments. #[serde(with = "hex_vec")] pub fri_commitments: Vec<[u8; 32]>, - + /// FRI final polynomial coefficients. - #[serde(serialize_with = "serialize_m31_vec", deserialize_with = "deserialize_m31_vec")] + #[serde( + serialize_with = "serialize_m31_vec", + deserialize_with = "deserialize_m31_vec" + )] pub fri_final_poly: Vec, - + /// Query proofs. pub query_proofs: Vec, - + /// Configuration used for this proof. pub config: ProofConfig, } @@ -59,20 +64,26 @@ pub struct SerializableProof { pub struct SerializableQueryProof { /// Query index in the domain. pub index: usize, - + /// Trace values at query point. - #[serde(serialize_with = "serialize_m31_vec", deserialize_with = "deserialize_m31_vec")] + #[serde( + serialize_with = "serialize_m31_vec", + deserialize_with = "deserialize_m31_vec" + )] pub trace_values: Vec, - + /// Composition value at query point. #[serde(serialize_with = "serialize_m31", deserialize_with = "deserialize_m31")] pub composition_value: M31, - + /// Merkle authentication paths. pub merkle_paths: Vec, - + /// FRI layer values. - #[serde(serialize_with = "serialize_m31_vec", deserialize_with = "deserialize_m31_vec")] + #[serde( + serialize_with = "serialize_m31_vec", + deserialize_with = "deserialize_m31_vec" + )] pub fri_values: Vec, } @@ -116,13 +127,13 @@ pub struct VerificationKey { /// Hex serialization for fixed-size arrays. mod hex_array { - use serde::{Serializer, Deserializer, Deserialize}; use super::hex; - + use serde::{Deserialize, Deserializer, Serializer}; + pub fn serialize(bytes: &[u8; 32], serializer: S) -> Result { serializer.serialize_str(&hex::encode(bytes)) } - + pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<[u8; 32], D::Error> { let s = String::deserialize(deserializer)?; let bytes = hex::decode(&s).map_err(serde::de::Error::custom)?; @@ -137,15 +148,17 @@ mod hex_array { /// Hex serialization for vectors of fixed-size arrays. mod hex_vec { - use serde::{Serialize, Serializer, Deserializer, Deserialize}; use super::hex; - + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + pub fn serialize(bytes: &[[u8; 32]], serializer: S) -> Result { let strs: Vec = bytes.iter().map(|b| hex::encode(b)).collect(); strs.serialize(serializer) } - - pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result, D::Error> { + + pub fn deserialize<'de, D: Deserializer<'de>>( + deserializer: D, + ) -> Result, D::Error> { let strs: Vec = Vec::deserialize(deserializer)?; strs.into_iter() .map(|s| { @@ -166,7 +179,7 @@ pub mod hex { pub fn encode(bytes: &[u8]) -> String { bytes.iter().map(|b| format!("{:02x}", b)).collect() } - + pub fn decode(s: &str) -> Result, String> { if s.len() % 2 != 0 { return Err("Odd length hex string".into()); @@ -183,25 +196,25 @@ impl SerializableProof { pub fn to_json(&self) -> Result { serde_json::to_string_pretty(self) } - + /// Deserialize from JSON. pub fn from_json(json: &str) -> Result { serde_json::from_str(json) } - + /// Serialize to binary (bincode). pub fn to_bytes(&self) -> Vec { // Simple binary format: JSON for now // In production, use proper binary encoding self.to_json().unwrap_or_default().into_bytes() } - + /// Deserialize from binary. pub fn from_bytes(bytes: &[u8]) -> Result { let json = std::str::from_utf8(bytes).map_err(|e| e.to_string())?; Self::from_json(json).map_err(|e| e.to_string()) } - + /// Get proof size in bytes. pub fn size(&self) -> usize { self.to_bytes().len() @@ -213,7 +226,7 @@ impl VerificationKey { pub fn to_json(&self) -> Result { serde_json::to_string_pretty(self) } - + /// Deserialize from JSON. pub fn from_json(json: &str) -> Result { serde_json::from_str(json) @@ -234,24 +247,25 @@ mod tests { security_bits: 100, entry_point: 0x0, }; - + let json = serde_json::to_string(&config).unwrap(); let parsed: ProofConfig = serde_json::from_str(&json).unwrap(); - + assert_eq!(parsed.log_trace_len, 10); assert_eq!(parsed.security_bits, 100); } #[test] fn test_hex_roundtrip() { - let bytes = [0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, - 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, - 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, - 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89]; - + let bytes = [ + 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, + 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, + 0x23, 0x45, 0x67, 0x89, + ]; + let encoded = hex::encode(&bytes); let decoded = hex::decode(&encoded).unwrap(); - + assert_eq!(decoded, bytes.to_vec()); } @@ -269,10 +283,10 @@ mod tests { constraints_hash: [1u8; 32], public_inputs_hash: [2u8; 32], }; - + let json = vk.to_json().unwrap(); let parsed = VerificationKey::from_json(&json).unwrap(); - + assert_eq!(parsed.config.log_trace_len, 12); assert_eq!(parsed.constraints_hash, [1u8; 32]); } diff --git a/crates/prover/src/snark.rs b/crates/prover/src/snark.rs index ef41f51..d5fa3af 100644 --- a/crates/prover/src/snark.rs +++ b/crates/prover/src/snark.rs @@ -37,10 +37,10 @@ //! The wrapper generates a circuit that verifies STARK proofs and //! produces a succinct proof of that verification. -use zp1_primitives::M31; +use crate::recursion::{RecursionError, RecursiveProof}; use crate::stark::StarkProof; -use crate::recursion::{RecursiveProof, RecursionError}; -use sha2::{Sha256, Digest}; +use sha2::{Digest, Sha256}; +use zp1_primitives::M31; // ============================================================================ // BN254 Elliptic Curve Implementation @@ -73,44 +73,50 @@ pub struct Fr { impl Fr { /// Zero element. - pub const ZERO: Fr = Fr { limbs: [0, 0, 0, 0] }; - + pub const ZERO: Fr = Fr { + limbs: [0, 0, 0, 0], + }; + /// One element. - pub const ONE: Fr = Fr { limbs: [1, 0, 0, 0] }; - + pub const ONE: Fr = Fr { + limbs: [1, 0, 0, 0], + }; + /// Create from u64. pub fn from_u64(val: u64) -> Self { - let mut result = Fr { limbs: [val, 0, 0, 0] }; + let mut result = Fr { + limbs: [val, 0, 0, 0], + }; result.reduce(); result } - + /// Create from M31 field element. pub fn from_m31(m: M31) -> Self { Self::from_u64(m.value() as u64) } - + /// Create from bytes (little-endian). pub fn from_bytes(bytes: &[u8; 32]) -> Self { let mut limbs = [0u64; 4]; for i in 0..4 { let start = i * 8; - limbs[i] = u64::from_le_bytes(bytes[start..start+8].try_into().unwrap()); + limbs[i] = u64::from_le_bytes(bytes[start..start + 8].try_into().unwrap()); } let mut result = Fr { limbs }; result.reduce(); result } - + /// Convert to bytes (little-endian). pub fn to_bytes(&self) -> [u8; 32] { let mut bytes = [0u8; 32]; for i in 0..4 { - bytes[i*8..(i+1)*8].copy_from_slice(&self.limbs[i].to_le_bytes()); + bytes[i * 8..(i + 1) * 8].copy_from_slice(&self.limbs[i].to_le_bytes()); } bytes } - + /// Reduce modulo r. fn reduce(&mut self) { // Simple reduction: subtract r if >= r @@ -118,7 +124,7 @@ impl Fr { self.sub_assign_limbs(&BN254_R); } } - + /// Compare limbs: returns -1, 0, or 1 fn cmp_limbs(&self, other: &[u64; 4]) -> i32 { for i in (0..4).rev() { @@ -131,7 +137,7 @@ impl Fr { } 0 } - + /// Subtract limbs (assumes self >= other). fn sub_assign_limbs(&mut self, other: &[u64; 4]) { let mut borrow = 0u64; @@ -142,27 +148,27 @@ impl Fr { borrow = (b1 as u64) + (b2 as u64); } } - + /// Add two field elements. pub fn add(&self, other: &Fr) -> Fr { let mut result = Fr { limbs: [0; 4] }; let mut carry = 0u64; - + for i in 0..4 { let (sum, c1) = self.limbs[i].overflowing_add(other.limbs[i]); let (sum2, c2) = sum.overflowing_add(carry); result.limbs[i] = sum2; carry = (c1 as u64) + (c2 as u64); } - + result.reduce(); result } - + /// Subtract two field elements. pub fn sub(&self, other: &Fr) -> Fr { let mut result = *self; - + // If self < other, add r first if self.cmp_limbs(&other.limbs) < 0 { let mut carry = 0u64; @@ -173,31 +179,32 @@ impl Fr { carry = (c1 as u64) + (c2 as u64); } } - + result.sub_assign_limbs(&other.limbs); result } - + /// Multiply two field elements using schoolbook multiplication. pub fn mul(&self, other: &Fr) -> Fr { // Full 512-bit product let mut product = [0u64; 8]; - + for i in 0..4 { let mut carry = 0u128; for j in 0..4 { - let prod = (self.limbs[i] as u128) * (other.limbs[j] as u128) - + (product[i + j] as u128) + carry; + let prod = (self.limbs[i] as u128) * (other.limbs[j] as u128) + + (product[i + j] as u128) + + carry; product[i + j] = prod as u64; carry = prod >> 64; } product[i + 4] = carry as u64; } - + // Barrett reduction self.barrett_reduce(&product) } - + /// Barrett reduction for 512-bit value. fn barrett_reduce(&self, product: &[u64; 8]) -> Fr { // Simplified reduction: take mod r @@ -205,7 +212,7 @@ impl Fr { let mut result = Fr { limbs: [product[0], product[1], product[2], product[3]], }; - + // Handle high bits by reducing for i in 4..8 { if product[i] != 0 { @@ -220,26 +227,26 @@ impl Fr { break; } } - + result.reduce(); result } - + /// Compute modular inverse using extended Euclidean algorithm. pub fn inverse(&self) -> Option { if *self == Fr::ZERO { return None; } - + // Using Fermat's little theorem: a^(-1) = a^(r-2) mod r self.pow(&fr_minus_two()) } - + /// Modular exponentiation using square-and-multiply. pub fn pow(&self, exp: &Fr) -> Option { let mut result = Fr::ONE; let mut base = *self; - + for limb in &exp.limbs { let mut bits = *limb; for _ in 0..64 { @@ -250,16 +257,16 @@ impl Fr { bits >>= 1; } } - + Some(result) } - + /// Negate the field element. pub fn neg(&self) -> Fr { if *self == Fr::ZERO { return Fr::ZERO; } - + let r = Fr { limbs: BN254_R }; r.sub(self) } @@ -283,48 +290,58 @@ pub struct Fq { } impl Fq { - pub const ZERO: Fq = Fq { limbs: [0, 0, 0, 0] }; - pub const ONE: Fq = Fq { limbs: [1, 0, 0, 0] }; - + pub const ZERO: Fq = Fq { + limbs: [0, 0, 0, 0], + }; + pub const ONE: Fq = Fq { + limbs: [1, 0, 0, 0], + }; + pub fn from_u64(val: u64) -> Self { - let mut result = Fq { limbs: [val, 0, 0, 0] }; + let mut result = Fq { + limbs: [val, 0, 0, 0], + }; result.reduce(); result } - + pub fn from_bytes(bytes: &[u8; 32]) -> Self { let mut limbs = [0u64; 4]; for i in 0..4 { let start = i * 8; - limbs[i] = u64::from_le_bytes(bytes[start..start+8].try_into().unwrap()); + limbs[i] = u64::from_le_bytes(bytes[start..start + 8].try_into().unwrap()); } let mut result = Fq { limbs }; result.reduce(); result } - + pub fn to_bytes(&self) -> [u8; 32] { let mut bytes = [0u8; 32]; for i in 0..4 { - bytes[i*8..(i+1)*8].copy_from_slice(&self.limbs[i].to_le_bytes()); + bytes[i * 8..(i + 1) * 8].copy_from_slice(&self.limbs[i].to_le_bytes()); } bytes } - + fn reduce(&mut self) { if self.cmp_limbs(&BN254_Q) >= 0 { self.sub_assign_limbs(&BN254_Q); } } - + fn cmp_limbs(&self, other: &[u64; 4]) -> i32 { for i in (0..4).rev() { - if self.limbs[i] > other[i] { return 1; } - if self.limbs[i] < other[i] { return -1; } + if self.limbs[i] > other[i] { + return 1; + } + if self.limbs[i] < other[i] { + return -1; + } } 0 } - + fn sub_assign_limbs(&mut self, other: &[u64; 4]) { let mut borrow = 0u64; for i in 0..4 { @@ -334,22 +351,22 @@ impl Fq { borrow = (b1 as u64) + (b2 as u64); } } - + pub fn add(&self, other: &Fq) -> Fq { let mut result = Fq { limbs: [0; 4] }; let mut carry = 0u64; - + for i in 0..4 { let (sum, c1) = self.limbs[i].overflowing_add(other.limbs[i]); let (sum2, c2) = sum.overflowing_add(carry); result.limbs[i] = sum2; carry = (c1 as u64) + (c2 as u64); } - + result.reduce(); result } - + pub fn sub(&self, other: &Fq) -> Fq { let mut result = *self; if self.cmp_limbs(&other.limbs) < 0 { @@ -364,21 +381,22 @@ impl Fq { result.sub_assign_limbs(&other.limbs); result } - + pub fn mul(&self, other: &Fq) -> Fq { let mut product = [0u64; 8]; - + for i in 0..4 { let mut carry = 0u128; for j in 0..4 { - let prod = (self.limbs[i] as u128) * (other.limbs[j] as u128) - + (product[i + j] as u128) + carry; + let prod = (self.limbs[i] as u128) * (other.limbs[j] as u128) + + (product[i + j] as u128) + + carry; product[i + j] = prod as u64; carry = prod >> 64; } product[i + 4] = carry as u64; } - + // Simple reduction let mut result = Fq { limbs: [product[0], product[1], product[2], product[3]], @@ -386,11 +404,11 @@ impl Fq { result.reduce(); result } - + pub fn square(&self) -> Fq { self.mul(self) } - + pub fn neg(&self) -> Fq { if *self == Fq::ZERO { return Fq::ZERO; @@ -398,7 +416,7 @@ impl Fq { let q = Fq { limbs: BN254_Q }; q.sub(self) } - + pub fn inverse(&self) -> Option { if *self == Fq::ZERO { return None; @@ -406,11 +424,11 @@ impl Fq { // Fermat: a^(-1) = a^(q-2) Some(self.pow_u64(&fq_minus_two())) } - + fn pow_u64(&self, exp: &[u64; 4]) -> Fq { let mut result = Fq::ONE; let mut base = *self; - + for limb in exp { let mut bits = *limb; for _ in 0..64 { @@ -447,7 +465,7 @@ impl G1Affine { y: Fq::ZERO, infinity: true, }; - + /// Generator point. pub fn generator() -> Self { G1Affine { @@ -456,26 +474,30 @@ impl G1Affine { infinity: false, } } - + /// Create from coordinates. pub fn new(x: Fq, y: Fq) -> Self { - G1Affine { x, y, infinity: false } + G1Affine { + x, + y, + infinity: false, + } } - + /// Check if point is on curve. pub fn is_on_curve(&self) -> bool { if self.infinity { return true; } - + // y² = x³ + 3 let y2 = self.y.square(); let x3 = self.x.mul(&self.x).mul(&self.x); let rhs = x3.add(&Fq::from_u64(3)); - + y2 == rhs } - + /// Negate point. pub fn neg(&self) -> G1Affine { if self.infinity { @@ -487,7 +509,7 @@ impl G1Affine { infinity: false, } } - + /// Serialize to bytes (64 bytes: x || y). pub fn to_bytes(&self) -> [u8; 64] { let mut bytes = [0u8; 64]; @@ -497,16 +519,16 @@ impl G1Affine { } bytes } - + /// Deserialize from bytes. pub fn from_bytes(bytes: &[u8; 64]) -> Option { let x = Fq::from_bytes(bytes[..32].try_into().ok()?); let y = Fq::from_bytes(bytes[32..].try_into().ok()?); - + if x == Fq::ZERO && y == Fq::ZERO { return Some(G1Affine::INFINITY); } - + let point = G1Affine::new(x, y); if point.is_on_curve() { Some(point) @@ -530,7 +552,7 @@ impl G1Projective { y: Fq::ONE, z: Fq::ZERO, }; - + pub fn from_affine(p: &G1Affine) -> Self { if p.infinity { return G1Projective::INFINITY; @@ -541,52 +563,56 @@ impl G1Projective { z: Fq::ONE, } } - + pub fn to_affine(&self) -> G1Affine { if self.z == Fq::ZERO { return G1Affine::INFINITY; } - + let z_inv = self.z.inverse().unwrap(); let z_inv2 = z_inv.square(); let z_inv3 = z_inv2.mul(&z_inv); - + G1Affine { x: self.x.mul(&z_inv2), y: self.y.mul(&z_inv3), infinity: false, } } - + /// Point doubling. pub fn double(&self) -> Self { if self.z == Fq::ZERO { return *self; } - + // Using standard doubling formulas for short Weierstrass curves let a = self.x.square(); let b = self.y.square(); let c = b.square(); - + let d = self.x.add(&b).square().sub(&a).sub(&c); let d = d.add(&d); // 2 * d - + let e = a.add(&a).add(&a); // 3 * a let f = e.square(); - + let x3 = f.sub(&d).sub(&d); - + let eight_c = c.add(&c).add(&c).add(&c); let eight_c = eight_c.add(&eight_c); - + let y3 = e.mul(&d.sub(&x3)).sub(&eight_c); let z3 = self.y.mul(&self.z); let z3 = z3.add(&z3); - - G1Projective { x: x3, y: y3, z: z3 } + + G1Projective { + x: x3, + y: y3, + z: z3, + } } - + /// Point addition. pub fn add(&self, other: &G1Projective) -> Self { if self.z == Fq::ZERO { @@ -595,16 +621,16 @@ impl G1Projective { if other.z == Fq::ZERO { return *self; } - + let z1z1 = self.z.square(); let z2z2 = other.z.square(); - + let u1 = self.x.mul(&z2z2); let u2 = other.x.mul(&z1z1); - + let s1 = self.y.mul(&other.z).mul(&z2z2); let s2 = other.y.mul(&self.z).mul(&z1z1); - + if u1 == u2 { if s1 == s2 { return self.double(); @@ -612,28 +638,32 @@ impl G1Projective { return G1Projective::INFINITY; } } - + let h = u2.sub(&u1); let i = h.add(&h).square(); let j = h.mul(&i); - + let r = s2.sub(&s1); let r = r.add(&r); - + let v = u1.mul(&i); - + let x3 = r.square().sub(&j).sub(&v).sub(&v); let y3 = r.mul(&v.sub(&x3)).sub(&s1.mul(&j).add(&s1.mul(&j))); let z3 = self.z.add(&other.z).square().sub(&z1z1).sub(&z2z2).mul(&h); - - G1Projective { x: x3, y: y3, z: z3 } + + G1Projective { + x: x3, + y: y3, + z: z3, + } } - + /// Scalar multiplication. pub fn scalar_mul(&self, scalar: &Fr) -> Self { let mut result = G1Projective::INFINITY; let mut temp = *self; - + for limb in &scalar.limbs { let mut bits = *limb; for _ in 0..64 { @@ -644,7 +674,7 @@ impl G1Projective { bits >>= 1; } } - + result } } @@ -661,71 +691,77 @@ pub struct Fq2 { } impl Fq2 { - pub const ZERO: Fq2 = Fq2 { c0: Fq::ZERO, c1: Fq::ZERO }; - pub const ONE: Fq2 = Fq2 { c0: Fq::ONE, c1: Fq::ZERO }; - + pub const ZERO: Fq2 = Fq2 { + c0: Fq::ZERO, + c1: Fq::ZERO, + }; + pub const ONE: Fq2 = Fq2 { + c0: Fq::ONE, + c1: Fq::ZERO, + }; + pub fn new(c0: Fq, c1: Fq) -> Self { Fq2 { c0, c1 } } - + pub fn add(&self, other: &Fq2) -> Fq2 { Fq2 { c0: self.c0.add(&other.c0), c1: self.c1.add(&other.c1), } } - + pub fn sub(&self, other: &Fq2) -> Fq2 { Fq2 { c0: self.c0.sub(&other.c0), c1: self.c1.sub(&other.c1), } } - + pub fn mul(&self, other: &Fq2) -> Fq2 { // (a + bu)(c + du) = (ac - bd) + (ad + bc)u let ac = self.c0.mul(&other.c0); let bd = self.c1.mul(&other.c1); let ad = self.c0.mul(&other.c1); let bc = self.c1.mul(&other.c0); - + Fq2 { c0: ac.sub(&bd), c1: ad.add(&bc), } } - + pub fn square(&self) -> Fq2 { // (a + bu)² = (a² - b²) + 2abu let a2 = self.c0.square(); let b2 = self.c1.square(); let ab = self.c0.mul(&self.c1); - + Fq2 { c0: a2.sub(&b2), c1: ab.add(&ab), } } - + pub fn neg(&self) -> Fq2 { Fq2 { c0: self.c0.neg(), c1: self.c1.neg(), } } - + pub fn conjugate(&self) -> Fq2 { Fq2 { c0: self.c0, c1: self.c1.neg(), } } - + pub fn inverse(&self) -> Option { // 1/(a + bu) = (a - bu)/(a² + b²) let norm = self.c0.square().add(&self.c1.square()); let norm_inv = norm.inverse()?; - + Some(Fq2 { c0: self.c0.mul(&norm_inv), c1: self.c1.neg().mul(&norm_inv), @@ -747,11 +783,15 @@ impl G2Affine { y: Fq2::ZERO, infinity: true, }; - + pub fn new(x: Fq2, y: Fq2) -> Self { - G2Affine { x, y, infinity: false } + G2Affine { + x, + y, + infinity: false, + } } - + /// Serialize to bytes (128 bytes). pub fn to_bytes(&self) -> [u8; 128] { let mut bytes = [0u8; 128]; @@ -785,19 +825,19 @@ impl Gt { hasher.update(b"pairing"); hasher.update(&p.to_bytes()); hasher.update(&q.to_bytes()); - + let mut value = [0u8; 32]; value.copy_from_slice(&hasher.finalize()); Gt { value } } - + /// Multiply two Gt elements (pairing product). pub fn mul(&self, other: &Gt) -> Gt { let mut hasher = Sha256::new(); hasher.update(b"gt_mul"); hasher.update(&self.value); hasher.update(&other.value); - + let mut value = [0u8; 32]; value.copy_from_slice(&hasher.finalize()); Gt { value } @@ -818,17 +858,17 @@ impl LinearCombination { pub fn new() -> Self { LinearCombination { terms: Vec::new() } } - + pub fn add_term(&mut self, var: usize, coeff: Fr) { self.terms.push((var, coeff)); } - + pub fn one() -> Self { let mut lc = LinearCombination::new(); lc.add_term(0, Fr::ONE); // Variable 0 is always 1 lc } - + /// Evaluate linear combination with witness. pub fn evaluate(&self, witness: &[Fr]) -> Fr { let mut result = Fr::ZERO; @@ -874,31 +914,36 @@ impl R1csSystem { constraints: Vec::new(), } } - + /// Allocate a new private variable. pub fn alloc_private(&mut self) -> usize { let idx = 1 + self.num_public + self.num_private; self.num_private += 1; idx } - + /// Add a constraint A * B = C. - pub fn add_constraint(&mut self, a: LinearCombination, b: LinearCombination, c: LinearCombination) { + pub fn add_constraint( + &mut self, + a: LinearCombination, + b: LinearCombination, + c: LinearCombination, + ) { self.constraints.push(R1csConstraint { a, b, c }); } - + /// Total number of variables (1 + public + private). pub fn num_vars(&self) -> usize { 1 + self.num_public + self.num_private } - + /// Check if witness satisfies all constraints. pub fn is_satisfied(&self, witness: &[Fr]) -> bool { for constraint in &self.constraints { let a = constraint.a.evaluate(witness); let b = constraint.b.evaluate(witness); let c = constraint.c.evaluate(witness); - + if a.mul(&b) != c { return false; } @@ -971,15 +1016,15 @@ impl Groth16Proof { bytes[192..256].copy_from_slice(&self.c.to_bytes()); bytes } - + /// Serialize to compact 128 bytes (compressed format). pub fn to_bytes_compressed(&self) -> [u8; 128] { // In compressed form: A.x (32) + A.sign (1) + B.x (64) + B.sign (1) + C.x (32) + C.sign (1) // Simplified: use first 128 bytes of uncompressed let full = self.to_bytes(); let mut compressed = [0u8; 128]; - compressed[..32].copy_from_slice(&full[..32]); // A.x - compressed[32..64].copy_from_slice(&full[64..96]); // B.x.c0 + compressed[..32].copy_from_slice(&full[..32]); // A.x + compressed[32..64].copy_from_slice(&full[64..96]); // B.x.c0 compressed[64..96].copy_from_slice(&full[96..128]); // B.x.c1 compressed[96..128].copy_from_slice(&full[192..224]); // C.x compressed @@ -1003,37 +1048,38 @@ impl StarkVerificationCircuit { pub fn new(num_public_inputs: usize) -> Self { let r1cs = R1csSystem::new(num_public_inputs); let mut witness = vec![Fr::ONE]; // Variable 0 is always 1 - + // Add public input placeholders for _ in 0..num_public_inputs { witness.push(Fr::ZERO); } - + StarkVerificationCircuit { r1cs, witness } } - + /// Add constraint for Merkle path verification. - pub fn add_merkle_path_constraint(&mut self, - leaf_hash: Fr, - path: &[Fr], + pub fn add_merkle_path_constraint( + &mut self, + leaf_hash: Fr, + path: &[Fr], path_bits: &[bool], - root: Fr + root: Fr, ) { // Verify: hash(left, right) at each level let mut current = leaf_hash; - + for (i, (sibling, is_right)) in path.iter().zip(path_bits.iter()).enumerate() { let left = if *is_right { *sibling } else { current }; let right = if *is_right { current } else { *sibling }; - + // Allocate intermediate hash let hash_var = self.r1cs.alloc_private(); - + // Add hash constraint (simplified: a * b = c represents hash mixing) let mut a = LinearCombination::new(); let mut b = LinearCombination::new(); let mut c = LinearCombination::new(); - + // Constraint: (left + 1) * (right + 1) = intermediate_product let prod_var = self.r1cs.alloc_private(); a.add_term(0, Fr::ONE); // constant 1 @@ -1041,9 +1087,9 @@ impl StarkVerificationCircuit { b.add_term(0, Fr::ONE); b.add_term(hash_var, right); c.add_term(prod_var, Fr::ONE); - + self.r1cs.add_constraint(a, b, c); - + // Store witness let product = left.add(&Fr::ONE).mul(&right.add(&Fr::ONE)); while self.witness.len() <= prod_var { @@ -1051,169 +1097,168 @@ impl StarkVerificationCircuit { } self.witness[hash_var] = left.add(&right); // Simplified hash self.witness[prod_var] = product; - + current = self.witness[hash_var]; - + let _ = i; // Suppress unused warning } - + // Final constraint: current == root let mut a = LinearCombination::new(); let mut b = LinearCombination::new(); let mut c = LinearCombination::new(); - + let final_var = self.r1cs.alloc_private(); while self.witness.len() <= final_var { self.witness.push(Fr::ZERO); } self.witness[final_var] = current; - + a.add_term(final_var, Fr::ONE); b.add_term(0, Fr::ONE); // * 1 c.add_term(0, root); // = root - + self.r1cs.add_constraint(a, b, c); } - + /// Add constraint for FRI consistency check. - pub fn add_fri_folding_constraint(&mut self, - f_x: Fr, - f_neg_x: Fr, - alpha: Fr, - f_folded: Fr - ) { + pub fn add_fri_folding_constraint(&mut self, f_x: Fr, f_neg_x: Fr, alpha: Fr, f_folded: Fr) { // FRI folding: f_folded = (f(x) + f(-x))/2 + alpha * (f(x) - f(-x))/(2x) // Simplified constraint: f_folded = (f_x + f_neg_x) * inv2 + alpha * (f_x - f_neg_x) * inv2x - + let sum_var = self.r1cs.alloc_private(); let diff_var = self.r1cs.alloc_private(); let result_var = self.r1cs.alloc_private(); - + while self.witness.len() <= result_var { self.witness.push(Fr::ZERO); } - + let sum = f_x.add(&f_neg_x); let diff = f_x.sub(&f_neg_x); - + self.witness[sum_var] = sum; self.witness[diff_var] = diff; self.witness[result_var] = f_folded; - + // Constraint: sum * 1 = sum_var let mut a = LinearCombination::new(); let mut b = LinearCombination::new(); let mut c = LinearCombination::new(); - + a.add_term(0, f_x); a.add_term(0, f_neg_x); b.add_term(0, Fr::ONE); c.add_term(sum_var, Fr::ONE); - + self.r1cs.add_constraint(a, b, c); - + // Add constraint for folded result let mut a2 = LinearCombination::new(); let mut b2 = LinearCombination::new(); let mut c2 = LinearCombination::new(); - + a2.add_term(sum_var, Fr::from_u64(1)); // half a2.add_term(diff_var, alpha); b2.add_term(0, Fr::ONE); c2.add_term(result_var, Fr::ONE); - + self.r1cs.add_constraint(a2, b2, c2); } - + /// Add constraint for field element range check. pub fn add_range_constraint(&mut self, value: Fr, bits: usize) { // Decompose into bits and verify let mut bit_vars = Vec::with_capacity(bits); let mut current = value; - + for _ in 0..bits { let bit_var = self.r1cs.alloc_private(); bit_vars.push(bit_var); - + while self.witness.len() <= bit_var { self.witness.push(Fr::ZERO); } - + // Extract least significant bit - let bit = if current.limbs[0] & 1 == 1 { Fr::ONE } else { Fr::ZERO }; + let bit = if current.limbs[0] & 1 == 1 { + Fr::ONE + } else { + Fr::ZERO + }; self.witness[bit_var] = bit; - + // Boolean constraint: bit * (1 - bit) = 0 let mut a = LinearCombination::new(); let mut b = LinearCombination::new(); let mut c = LinearCombination::new(); - + a.add_term(bit_var, Fr::ONE); b.add_term(0, Fr::ONE); b.add_term(bit_var, Fr::ONE.neg()); c.add_term(0, Fr::ZERO); - + self.r1cs.add_constraint(a, b, c); - + // Shift right current.limbs[0] >>= 1; for i in 1..4 { let carry = current.limbs[i] & 1; current.limbs[i] >>= 1; - current.limbs[i-1] |= carry << 63; + current.limbs[i - 1] |= carry << 63; } } - + // Reconstruct and verify equals original let sum_var = self.r1cs.alloc_private(); while self.witness.len() <= sum_var { self.witness.push(Fr::ZERO); } self.witness[sum_var] = value; - + let mut a = LinearCombination::new(); let mut coeff = Fr::ONE; for &var in &bit_vars { a.add_term(var, coeff); coeff = coeff.add(&coeff); // 2^i } - + let b = LinearCombination::one(); let mut c = LinearCombination::new(); c.add_term(sum_var, Fr::ONE); - + self.r1cs.add_constraint(a, b, c); } - + /// Build circuit from STARK proof components. pub fn build_from_stark(stark_proof: &StarkProof, config: &CircuitDescription) -> Self { let mut circuit = StarkVerificationCircuit::new(config.num_stark_inputs); - + // Add trace commitment verification let trace_hash = Fr::from_bytes(&stark_proof.trace_commitment); circuit.witness[1] = trace_hash; // First public input - + // Add composition commitment verification let comp_hash = Fr::from_bytes(&stark_proof.composition_commitment); circuit.witness[2] = comp_hash; // Second public input - + // Add FRI layer verification constraints for (i, layer_commit) in stark_proof.fri_proof.layer_commitments.iter().enumerate() { if i >= config.num_fri_layers { break; } let layer_hash = Fr::from_bytes(layer_commit); - + // Range check the layer commitment circuit.add_range_constraint(layer_hash, 256); } - + // Add final polynomial constraints for coeff in &stark_proof.fri_proof.final_poly { let fr_coeff = Fr::from_m31(*coeff); circuit.add_range_constraint(fr_coeff, 31); // M31 is 31 bits } - + circuit } } @@ -1232,9 +1277,12 @@ pub struct Groth16Prover { impl Groth16Prover { pub fn new() -> Self { - Groth16Prover { pk: None, r1cs: None } + Groth16Prover { + pk: None, + r1cs: None, + } } - + /// Generate proving and verification keys from R1CS. pub fn setup(&mut self, r1cs: &R1csSystem) -> Groth16VerificationKey { // Toxic waste (in real implementation, this comes from MPC) @@ -1242,45 +1290,59 @@ impl Groth16Prover { let beta = Fr::from_u64(67890); let _gamma = Fr::from_u64(11111); let delta = Fr::from_u64(22222); - + let g1_gen = G1Affine::generator(); let g2_gen = G2Affine::new( Fq2::new(Fq::from_u64(1), Fq::from_u64(2)), Fq2::new(Fq::from_u64(3), Fq::from_u64(4)), ); - + let g1_proj = G1Projective::from_affine(&g1_gen); - + // Compute proving key elements let alpha_g1 = g1_proj.scalar_mul(&alpha).to_affine(); let beta_g1 = g1_proj.scalar_mul(&beta).to_affine(); let delta_g1 = g1_proj.scalar_mul(&delta).to_affine(); - + // IC commitments (one per public input + 1) let mut ic = Vec::with_capacity(r1cs.num_public + 1); for i in 0..=r1cs.num_public { let scalar = Fr::from_u64((i + 1) as u64); ic.push(g1_proj.scalar_mul(&scalar).to_affine()); } - + // A, B, H, L queries let num_vars = r1cs.num_vars(); let a_query: Vec<_> = (0..num_vars) - .map(|i| g1_proj.scalar_mul(&Fr::from_u64(i as u64 + 100)).to_affine()) + .map(|i| { + g1_proj + .scalar_mul(&Fr::from_u64(i as u64 + 100)) + .to_affine() + }) .collect(); let b_g1_query: Vec<_> = (0..num_vars) - .map(|i| g1_proj.scalar_mul(&Fr::from_u64(i as u64 + 200)).to_affine()) - .collect(); - let b_g2_query: Vec<_> = (0..num_vars) - .map(|_| g2_gen) + .map(|i| { + g1_proj + .scalar_mul(&Fr::from_u64(i as u64 + 200)) + .to_affine() + }) .collect(); + let b_g2_query: Vec<_> = (0..num_vars).map(|_| g2_gen).collect(); let h_query: Vec<_> = (0..r1cs.constraints.len()) - .map(|i| g1_proj.scalar_mul(&Fr::from_u64(i as u64 + 300)).to_affine()) + .map(|i| { + g1_proj + .scalar_mul(&Fr::from_u64(i as u64 + 300)) + .to_affine() + }) .collect(); let l_query: Vec<_> = (0..r1cs.num_private) - .map(|i| g1_proj.scalar_mul(&Fr::from_u64(i as u64 + 400)).to_affine()) + .map(|i| { + g1_proj + .scalar_mul(&Fr::from_u64(i as u64 + 400)) + .to_affine() + }) .collect(); - + let pk = Groth16ProvingKey { alpha_g1, beta_g1, @@ -1294,13 +1356,13 @@ impl Groth16Prover { h_query, l_query, }; - + self.pk = Some(pk); self.r1cs = Some(r1cs.clone()); - + // Compute verification key let alpha_beta = Gt::pairing(&alpha_g1, &g2_gen); - + Groth16VerificationKey { alpha_beta_miller: alpha_beta, gamma_g2: g2_gen, @@ -1308,22 +1370,22 @@ impl Groth16Prover { ic, } } - + /// Generate a Groth16 proof given witness. pub fn prove(&self, witness: &[Fr]) -> Result { let pk = self.pk.as_ref().ok_or(SnarkError::SetupRequired)?; let _r1cs = self.r1cs.as_ref().ok_or(SnarkError::SetupRequired)?; - + // Note: In production, we would verify witness satisfies constraints. // For now, we skip this check to allow testing with mock data. // Real implementation would do: if !r1cs.is_satisfied(witness) { return Err(...) } - + // Random blinding factors let r = Fr::from_u64(rand_u64()); let s = Fr::from_u64(rand_u64()); - + let g1_proj = G1Projective::from_affine(&G1Affine::generator()); - + // Compute A = alpha + sum(a_i * w_i) + r * delta let mut a_acc = G1Projective::from_affine(&pk.alpha_g1); for (i, w) in witness.iter().enumerate() { @@ -1333,10 +1395,10 @@ impl Groth16Prover { } } a_acc = a_acc.add(&G1Projective::from_affine(&pk.delta_g1).scalar_mul(&r)); - + // Compute B in G2 (simplified) let b = pk.beta_g2; - + // Compute C = sum(l_i * w_i) + A*s + B*r - r*s*delta let mut c_acc = G1Projective::INFINITY; let num_public = self.r1cs.as_ref().map(|r| r.num_public).unwrap_or(0); @@ -1347,7 +1409,7 @@ impl Groth16Prover { c_acc = c_acc.add(&term); } } - + // Add H contribution (QAP divisibility) for (i, h) in pk.h_query.iter().enumerate() { if i < self.r1cs.as_ref().map(|r| r.constraints.len()).unwrap_or(0) { @@ -1355,11 +1417,11 @@ impl Groth16Prover { c_acc = c_acc.add(&term); } } - + // Add blinding c_acc = c_acc.add(&a_acc.scalar_mul(&s)); c_acc = c_acc.add(&g1_proj.scalar_mul(&r.mul(&s).neg())); - + Ok(Groth16Proof { a: a_acc.to_affine(), b, @@ -1394,7 +1456,7 @@ impl Groth16ProofVerifier { pub fn new(vk: Groth16VerificationKey) -> Self { Groth16ProofVerifier { vk } } - + /// Verify a Groth16 proof with public inputs. pub fn verify(&self, proof: &Groth16Proof, public_inputs: &[Fr]) -> bool { // Compute public input commitment: IC[0] + sum(IC[i+1] * input[i]) @@ -1406,17 +1468,17 @@ impl Groth16ProofVerifier { } } let pub_commitment = acc.to_affine(); - + // Verify pairing equation: // e(A, B) = e(alpha, beta) * e(pub_commitment, gamma) * e(C, delta) let lhs = Gt::pairing(&proof.a, &proof.b); - + let rhs1 = &self.vk.alpha_beta_miller; let rhs2 = Gt::pairing(&pub_commitment, &self.vk.gamma_g2); let rhs3 = Gt::pairing(&proof.c, &self.vk.delta_g2); - + let rhs = rhs1.mul(&rhs2).mul(&rhs3); - + lhs == rhs } } @@ -1475,55 +1537,56 @@ impl SnarkProof { pub fn size(&self) -> usize { self.proof_data.len() } - + /// Expected proof size for each system. pub fn expected_size(system: SnarkSystem) -> usize { match system { - SnarkSystem::Groth16 => 128, // 2 G1 + 1 G2 point - SnarkSystem::Plonk => 400, // Multiple commitments - SnarkSystem::Halo2 => 500, // Accumulator + commitments + SnarkSystem::Groth16 => 128, // 2 G1 + 1 G2 point + SnarkSystem::Plonk => 400, // Multiple commitments + SnarkSystem::Halo2 => 500, // Accumulator + commitments } } - + /// Serialize proof to bytes. pub fn to_bytes(&self) -> Vec { let mut bytes = Vec::new(); - + // System identifier (1 byte) bytes.push(match self.system { SnarkSystem::Groth16 => 0, SnarkSystem::Plonk => 1, SnarkSystem::Halo2 => 2, }); - + // Proof data length (4 bytes, little-endian) bytes.extend_from_slice(&(self.proof_data.len() as u32).to_le_bytes()); - + // Proof data bytes.extend_from_slice(&self.proof_data); - + // Public inputs count (4 bytes) bytes.extend_from_slice(&(self.public_inputs.len() as u32).to_le_bytes()); - + // Public inputs (4 bytes each) for input in &self.public_inputs { bytes.extend_from_slice(&input.value().to_le_bytes()); } - + // VK hash bytes.extend_from_slice(&self.vk_hash); - + bytes } - + /// Deserialize proof from bytes. pub fn from_bytes(bytes: &[u8]) -> Result { - if bytes.len() < 41 { // minimum: 1 + 4 + 0 + 4 + 32 + if bytes.len() < 41 { + // minimum: 1 + 4 + 0 + 4 + 32 return Err(SnarkError::InvalidProofFormat("Proof too short".into())); } - + let mut offset = 0; - + // System let system = match bytes[offset] { 0 => SnarkSystem::Groth16, @@ -1532,39 +1595,40 @@ impl SnarkProof { _ => return Err(SnarkError::InvalidProofFormat("Unknown system".into())), }; offset += 1; - + // Proof data length - let proof_len = u32::from_le_bytes(bytes[offset..offset+4].try_into().unwrap()) as usize; + let proof_len = u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap()) as usize; offset += 4; - + if bytes.len() < offset + proof_len + 4 + 32 { return Err(SnarkError::InvalidProofFormat("Truncated proof".into())); } - + // Proof data - let proof_data = bytes[offset..offset+proof_len].to_vec(); + let proof_data = bytes[offset..offset + proof_len].to_vec(); offset += proof_len; - + // Public inputs count - let inputs_count = u32::from_le_bytes(bytes[offset..offset+4].try_into().unwrap()) as usize; + let inputs_count = + u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap()) as usize; offset += 4; - + if bytes.len() < offset + inputs_count * 4 + 32 { return Err(SnarkError::InvalidProofFormat("Truncated inputs".into())); } - + // Public inputs let mut public_inputs = Vec::with_capacity(inputs_count); for _ in 0..inputs_count { - let val = u32::from_le_bytes(bytes[offset..offset+4].try_into().unwrap()); + let val = u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap()); public_inputs.push(M31::new(val)); offset += 4; } - + // VK hash let mut vk_hash = [0u8; 32]; - vk_hash.copy_from_slice(&bytes[offset..offset+32]); - + vk_hash.copy_from_slice(&bytes[offset..offset + 32]); + Ok(Self { system, proof_data, @@ -1615,7 +1679,7 @@ impl SnarkVerificationKey { circuit_description: circuit, } } - + /// Compute hash of verification key. fn compute_hash(data: &[u8]) -> [u8; 32] { let mut hasher = Sha256::new(); @@ -1675,173 +1739,171 @@ pub struct SnarkWrapper { impl SnarkWrapper { /// Create a new SNARK wrapper. pub fn new(config: SnarkConfig) -> Self { - Self { - config, + Self { + config, vk: None, groth16_prover: None, groth16_vk: None, } } - + /// Perform setup and generate verification key. /// /// For Groth16, this requires a trusted setup ceremony. /// For PLONK, this uses a universal reference string. /// For Halo2, no setup is needed. - pub fn setup(&mut self, circuit_params: &CircuitDescription) -> Result { + pub fn setup( + &mut self, + circuit_params: &CircuitDescription, + ) -> Result { let vk_data = match self.config.system { SnarkSystem::Groth16 => self.groth16_setup(circuit_params)?, SnarkSystem::Plonk => self.plonk_setup(circuit_params)?, SnarkSystem::Halo2 => self.halo2_setup(circuit_params)?, }; - - let vk = SnarkVerificationKey::new( - self.config.system, - vk_data, - circuit_params.clone(), - ); - + + let vk = SnarkVerificationKey::new(self.config.system, vk_data, circuit_params.clone()); + self.vk = Some(vk.clone()); Ok(vk) } - + /// Generate Groth16 setup with real R1CS circuit. fn groth16_setup(&mut self, params: &CircuitDescription) -> Result, SnarkError> { // Build R1CS for STARK verification circuit let r1cs = self.build_stark_verification_r1cs(params); - + // Initialize Groth16 prover let mut prover = Groth16Prover::new(); let groth16_vk = prover.setup(&r1cs); - + // Serialize verification key let mut vk_bytes = Vec::with_capacity(512); - + // Alpha*beta pairing result vk_bytes.extend_from_slice(&groth16_vk.alpha_beta_miller.value); - + // Gamma G2 vk_bytes.extend_from_slice(&groth16_vk.gamma_g2.to_bytes()); - + // Delta G2 vk_bytes.extend_from_slice(&groth16_vk.delta_g2.to_bytes()); - + // IC commitments count vk_bytes.extend_from_slice(&(groth16_vk.ic.len() as u32).to_le_bytes()); - + // IC commitments for ic in &groth16_vk.ic { vk_bytes.extend_from_slice(&ic.to_bytes()); } - + self.groth16_prover = Some(prover); self.groth16_vk = Some(groth16_vk); - + Ok(vk_bytes) } - + /// Build R1CS for STARK verification. fn build_stark_verification_r1cs(&self, params: &CircuitDescription) -> R1csSystem { let mut r1cs = R1csSystem::new(params.num_stark_inputs); - + // Trace commitment verification constraints for _ in 0..8 { let _ = r1cs.alloc_private(); } - + // Composition commitment constraints for _ in 0..8 { let _ = r1cs.alloc_private(); } - + // FRI layer constraints for _ in 0..params.num_fri_layers { let a_var = r1cs.alloc_private(); let b_var = r1cs.alloc_private(); let c_var = r1cs.alloc_private(); - + // FRI folding constraint: a * b = c (simplified) let mut a = LinearCombination::new(); let mut b = LinearCombination::new(); let mut c = LinearCombination::new(); - + a.add_term(a_var, Fr::ONE); b.add_term(b_var, Fr::ONE); c.add_term(c_var, Fr::ONE); - + r1cs.add_constraint(a, b, c); } - + // Query verification constraints for _ in 0..params.num_queries { let query_var = r1cs.alloc_private(); let path_var = r1cs.alloc_private(); - + // Merkle path constraint let mut a = LinearCombination::new(); let mut b = LinearCombination::new(); let mut c = LinearCombination::new(); - + a.add_term(query_var, Fr::ONE); b.add_term(0, Fr::ONE); // constant 1 c.add_term(path_var, Fr::ONE); - + r1cs.add_constraint(a, b, c); } - + r1cs } - + /// Generate PLONK setup (placeholder). fn plonk_setup(&self, params: &CircuitDescription) -> Result, SnarkError> { // In real implementation: // 1. Build PLONK circuit for STARK verification // 2. Use universal SRS // 3. Generate circuit-specific verification key - + let mut hasher = Sha256::new(); hasher.update(b"plonk_vk"); hasher.update(¶ms.num_stark_inputs.to_le_bytes()); hasher.update(¶ms.num_fri_layers.to_le_bytes()); - + let hash = hasher.finalize(); let mut vk = vec![0u8; 512]; vk[..32].copy_from_slice(&hash); - + Ok(vk) } - + /// Generate Halo2 setup (placeholder). fn halo2_setup(&self, params: &CircuitDescription) -> Result, SnarkError> { // Halo2 doesn't require trusted setup // Generate parameters based on circuit size - + let mut hasher = Sha256::new(); hasher.update(b"halo2_params"); hasher.update(¶ms.num_stark_inputs.to_le_bytes()); - + let hash = hasher.finalize(); let mut vk = vec![0u8; 256]; vk[..32].copy_from_slice(&hash); - + Ok(vk) } - + /// Wrap a STARK proof in a SNARK proof. pub fn wrap(&self, stark_proof: &StarkProof) -> Result { - let vk = self.vk.as_ref() - .ok_or(SnarkError::SetupRequired)?; - + let vk = self.vk.as_ref().ok_or(SnarkError::SetupRequired)?; + // Extract public inputs from STARK proof let public_inputs = self.extract_public_inputs(stark_proof); - + // Generate SNARK proof let proof_data = match self.config.system { SnarkSystem::Groth16 => self.groth16_prove(stark_proof, &public_inputs)?, SnarkSystem::Plonk => self.plonk_prove(stark_proof, &public_inputs)?, SnarkSystem::Halo2 => self.halo2_prove(stark_proof, &public_inputs)?, }; - + Ok(SnarkProof { system: self.config.system, proof_data, @@ -1849,25 +1911,29 @@ impl SnarkWrapper { vk_hash: vk.hash, }) } - + /// Wrap a recursive proof in a SNARK proof. - pub fn wrap_recursive(&self, recursive_proof: &RecursiveProof) -> Result { - let vk = self.vk.as_ref() - .ok_or(SnarkError::SetupRequired)?; - + pub fn wrap_recursive( + &self, + recursive_proof: &RecursiveProof, + ) -> Result { + let vk = self.vk.as_ref().ok_or(SnarkError::SetupRequired)?; + // Extract public inputs from recursive proof let mut public_inputs = self.extract_public_inputs(&recursive_proof.inner_proof); - + // Add aggregation metadata public_inputs.push(M31::new(recursive_proof.num_aggregated as u32)); - + // Generate SNARK proof for the recursive verification let proof_data = match self.config.system { - SnarkSystem::Groth16 => self.groth16_prove(&recursive_proof.inner_proof, &public_inputs)?, + SnarkSystem::Groth16 => { + self.groth16_prove(&recursive_proof.inner_proof, &public_inputs)? + } SnarkSystem::Plonk => self.plonk_prove(&recursive_proof.inner_proof, &public_inputs)?, SnarkSystem::Halo2 => self.halo2_prove(&recursive_proof.inner_proof, &public_inputs)?, }; - + Ok(SnarkProof { system: self.config.system, proof_data, @@ -1875,43 +1941,49 @@ impl SnarkWrapper { vk_hash: vk.hash, }) } - + /// Extract public inputs from STARK proof. fn extract_public_inputs(&self, proof: &StarkProof) -> Vec { let mut inputs = Vec::new(); - + // Trace commitment (as field elements) for chunk in proof.trace_commitment.chunks(4) { let bytes: [u8; 4] = chunk.try_into().unwrap_or([0; 4]); inputs.push(M31::new(u32::from_le_bytes(bytes) & 0x7FFFFFFF)); } - + // Composition commitment for chunk in proof.composition_commitment.chunks(4) { let bytes: [u8; 4] = chunk.try_into().unwrap_or([0; 4]); inputs.push(M31::new(u32::from_le_bytes(bytes) & 0x7FFFFFFF)); } - + // FRI final polynomial (limited) let max_final = std::cmp::min(proof.fri_proof.final_poly.len(), 8); for i in 0..max_final { inputs.push(proof.fri_proof.final_poly[i]); } - + inputs } - + /// Generate Groth16 proof using real elliptic curve operations. - fn groth16_prove(&self, stark_proof: &StarkProof, public_inputs: &[M31]) -> Result, SnarkError> { - let prover = self.groth16_prover.as_ref() + fn groth16_prove( + &self, + stark_proof: &StarkProof, + public_inputs: &[M31], + ) -> Result, SnarkError> { + let prover = self + .groth16_prover + .as_ref() .ok_or(SnarkError::SetupRequired)?; - + // Build witness from STARK proof let circuit = StarkVerificationCircuit::build_from_stark( stark_proof, &self.vk.as_ref().unwrap().circuit_description, ); - + // Extend witness with public inputs let mut witness = circuit.witness; for (i, input) in public_inputs.iter().enumerate() { @@ -1919,51 +1991,59 @@ impl SnarkWrapper { witness[i + 1] = Fr::from_m31(*input); } } - + // Generate Groth16 proof let proof = prover.prove(&witness)?; - + // Return compressed proof bytes Ok(proof.to_bytes_compressed().to_vec()) } - + /// Generate PLONK proof (placeholder). - fn plonk_prove(&self, _stark_proof: &StarkProof, public_inputs: &[M31]) -> Result, SnarkError> { + fn plonk_prove( + &self, + _stark_proof: &StarkProof, + public_inputs: &[M31], + ) -> Result, SnarkError> { // In real implementation: // 1. Build PLONK witness // 2. Compute wire commitments // 3. Generate PLONK proof - + let mut hasher = Sha256::new(); hasher.update(b"plonk_proof"); for input in public_inputs { hasher.update(&input.value().to_le_bytes()); } - + let hash = hasher.finalize(); let mut proof = vec![0u8; 400]; proof[..32].copy_from_slice(&hash); - + Ok(proof) } - + /// Generate Halo2 proof (placeholder). - fn halo2_prove(&self, _stark_proof: &StarkProof, public_inputs: &[M31]) -> Result, SnarkError> { + fn halo2_prove( + &self, + _stark_proof: &StarkProof, + public_inputs: &[M31], + ) -> Result, SnarkError> { // In real implementation: // 1. Build Halo2 circuit witness // 2. Generate accumulator // 3. Produce proof with IPA - + let mut hasher = Sha256::new(); hasher.update(b"halo2_proof"); for input in public_inputs { hasher.update(&input.value().to_le_bytes()); } - + let hash = hasher.finalize(); let mut proof = vec![0u8; 500]; proof[..32].copy_from_slice(&hash); - + Ok(proof) } } @@ -1982,19 +2062,19 @@ impl SnarkVerifier { pub fn new(vk: SnarkVerificationKey) -> Self { Self { vk } } - + /// Verify a SNARK proof. pub fn verify(&self, proof: &SnarkProof) -> Result { // Check VK hash matches if proof.vk_hash != self.vk.hash { return Err(SnarkError::VerificationKeyMismatch); } - + // Check system matches if proof.system != self.vk.system { return Err(SnarkError::SystemMismatch); } - + // Verify based on system match proof.system { SnarkSystem::Groth16 => self.verify_groth16(proof), @@ -2002,13 +2082,15 @@ impl SnarkVerifier { SnarkSystem::Halo2 => self.verify_halo2(proof), } } - + /// Verify Groth16 proof using pairing equation. fn verify_groth16(&self, proof: &SnarkProof) -> Result { if proof.proof_data.len() < 128 { - return Err(SnarkError::InvalidProofFormat("Groth16 proof too short".into())); + return Err(SnarkError::InvalidProofFormat( + "Groth16 proof too short".into(), + )); } - + // Parse proof points from compressed format let mut a_bytes = [0u8; 64]; a_bytes[..32].copy_from_slice(&proof.proof_data[..32]); @@ -2017,10 +2099,9 @@ impl SnarkVerifier { hasher.update(&proof.proof_data[..32]); hasher.update(b"a_y"); a_bytes[32..64].copy_from_slice(&hasher.finalize()); - - let a = G1Affine::from_bytes(&a_bytes) - .unwrap_or(G1Affine::generator()); - + + let a = G1Affine::from_bytes(&a_bytes).unwrap_or(G1Affine::generator()); + // Parse B point (in G2) let b = G2Affine::new( Fq2::new( @@ -2029,7 +2110,7 @@ impl SnarkVerifier { ), Fq2::ONE, // Simplified ); - + // Parse C point let mut c_bytes = [0u8; 64]; c_bytes[..32].copy_from_slice(&proof.proof_data[96..128]); @@ -2037,75 +2118,80 @@ impl SnarkVerifier { hasher2.update(&proof.proof_data[96..128]); hasher2.update(b"c_y"); c_bytes[32..64].copy_from_slice(&hasher2.finalize()); - - let c = G1Affine::from_bytes(&c_bytes) - .unwrap_or(G1Affine::generator()); - + + let c = G1Affine::from_bytes(&c_bytes).unwrap_or(G1Affine::generator()); + let groth16_proof = Groth16Proof { a, b, c }; - + // Convert public inputs to Fr - let fr_inputs: Vec = proof.public_inputs + let fr_inputs: Vec = proof + .public_inputs .iter() .map(|m| Fr::from_m31(*m)) .collect(); - + // Deserialize verification key and verify // For now, use simplified pairing check let lhs = Gt::pairing(&groth16_proof.a, &groth16_proof.b); let _rhs = Gt::pairing(&groth16_proof.c, &groth16_proof.b); - + // Additional check: verify proof structure matches inputs let mut input_hash = Sha256::new(); for input in &fr_inputs { input_hash.update(&input.to_bytes()); } let input_commitment = input_hash.finalize(); - + // Proof is valid if pairing check passes and inputs are consistent let pairing_ok = lhs.value != [0u8; 32]; // Non-trivial pairing result let input_ok = proof.proof_data[..16] != [0u8; 16]; // Non-zero proof - - Ok(pairing_ok && input_ok && input_commitment[0] == input_commitment[0]) // Always true if reached + + Ok(pairing_ok && input_ok && input_commitment[0] == input_commitment[0]) + // Always true if reached } - + /// Verify PLONK proof (placeholder). fn verify_plonk(&self, proof: &SnarkProof) -> Result { if proof.proof_data.len() < 400 { - return Err(SnarkError::InvalidProofFormat("PLONK proof too short".into())); + return Err(SnarkError::InvalidProofFormat( + "PLONK proof too short".into(), + )); } - + let mut hasher = Sha256::new(); hasher.update(b"plonk_proof"); for input in &proof.public_inputs { hasher.update(&input.value().to_le_bytes()); } let expected_hash = hasher.finalize(); - + Ok(proof.proof_data[..32] == expected_hash[..]) } - + /// Verify Halo2 proof (placeholder). fn verify_halo2(&self, proof: &SnarkProof) -> Result { if proof.proof_data.len() < 500 { - return Err(SnarkError::InvalidProofFormat("Halo2 proof too short".into())); + return Err(SnarkError::InvalidProofFormat( + "Halo2 proof too short".into(), + )); } - + let mut hasher = Sha256::new(); hasher.update(b"halo2_proof"); for input in &proof.public_inputs { hasher.update(&input.value().to_le_bytes()); } let expected_hash = hasher.finalize(); - + Ok(proof.proof_data[..32] == expected_hash[..]) } - + /// Estimate verification gas cost for on-chain verification. pub fn estimate_gas_cost(&self) -> u64 { match self.vk.system { - SnarkSystem::Groth16 => 220_000, // ~220k gas on Ethereum - SnarkSystem::Plonk => 300_000, // ~300k gas - SnarkSystem::Halo2 => 500_000, // ~500k gas (more pairings) + SnarkSystem::Groth16 => 220_000, // ~220k gas on Ethereum + SnarkSystem::Plonk => 300_000, // ~300k gas + SnarkSystem::Halo2 => 500_000, // ~500k gas (more pairings) } } } @@ -2138,7 +2224,9 @@ impl std::fmt::Display for SnarkError { match self { SnarkError::SetupRequired => write!(f, "SNARK setup required before proving"), SnarkError::InvalidProofFormat(msg) => write!(f, "Invalid proof format: {}", msg), - SnarkError::VerificationKeyMismatch => write!(f, "Verification key does not match proof"), + SnarkError::VerificationKeyMismatch => { + write!(f, "Verification key does not match proof") + } SnarkError::SystemMismatch => write!(f, "SNARK system mismatch"), SnarkError::VerificationFailed(msg) => write!(f, "Verification failed: {}", msg), SnarkError::CircuitTooLarge { size, max } => { @@ -2171,7 +2259,7 @@ impl SolidityVerifierGenerator { pub fn new(system: SnarkSystem) -> Self { Self { system } } - + /// Generate Solidity verifier contract. pub fn generate(&self, vk: &SnarkVerificationKey) -> String { match self.system { @@ -2180,10 +2268,11 @@ impl SolidityVerifierGenerator { SnarkSystem::Halo2 => self.generate_halo2_verifier(vk), } } - + /// Generate Groth16 Solidity verifier. fn generate_groth16_verifier(&self, vk: &SnarkVerificationKey) -> String { - format!(r#"// SPDX-License-Identifier: MIT + format!( + r#"// SPDX-License-Identifier: MIT pragma solidity ^0.8.0; /// @title Groth16 Verifier for zp1 STARK proofs @@ -2239,10 +2328,11 @@ contract Groth16Verifier {{ hex_encode(&vk.hash) ) } - + /// Generate PLONK Solidity verifier. fn generate_plonk_verifier(&self, vk: &SnarkVerificationKey) -> String { - format!(r#"// SPDX-License-Identifier: MIT + format!( + r#"// SPDX-License-Identifier: MIT pragma solidity ^0.8.0; /// @title PLONK Verifier for zp1 STARK proofs @@ -2271,10 +2361,11 @@ contract PlonkVerifier {{ hex_encode(&vk.hash) ) } - + /// Generate Halo2 Solidity verifier. fn generate_halo2_verifier(&self, vk: &SnarkVerificationKey) -> String { - format!(r#"// SPDX-License-Identifier: MIT + format!( + r#"// SPDX-License-Identifier: MIT pragma solidity ^0.8.0; /// @title Halo2 Verifier for zp1 STARK proofs @@ -2342,7 +2433,7 @@ mod tests { use super::*; use crate::fri::FriProof; use crate::stark::OodValues; - + fn mock_stark_proof() -> StarkProof { StarkProof { trace_commitment: [1u8; 32], @@ -2360,7 +2451,7 @@ mod tests { query_proofs: vec![], } } - + fn circuit_description() -> CircuitDescription { CircuitDescription { num_stark_inputs: 16, @@ -2369,189 +2460,194 @@ mod tests { security_bits: 100, } } - + #[test] fn test_groth16_wrap() { let mut wrapper = groth16_wrapper(); let vk = wrapper.setup(&circuit_description()).unwrap(); - + assert_eq!(vk.system, SnarkSystem::Groth16); - + let stark_proof = mock_stark_proof(); let snark_proof = wrapper.wrap(&stark_proof).unwrap(); - + assert_eq!(snark_proof.system, SnarkSystem::Groth16); assert_eq!(snark_proof.proof_data.len(), 128); assert!(!snark_proof.public_inputs.is_empty()); } - + #[test] fn test_plonk_wrap() { let mut wrapper = plonk_wrapper(); let _vk = wrapper.setup(&circuit_description()).unwrap(); - + let stark_proof = mock_stark_proof(); let snark_proof = wrapper.wrap(&stark_proof).unwrap(); - + assert_eq!(snark_proof.system, SnarkSystem::Plonk); assert_eq!(snark_proof.proof_data.len(), 400); } - + #[test] fn test_halo2_wrap() { let mut wrapper = halo2_wrapper(); let _vk = wrapper.setup(&circuit_description()).unwrap(); - + let stark_proof = mock_stark_proof(); let snark_proof = wrapper.wrap(&stark_proof).unwrap(); - + assert_eq!(snark_proof.system, SnarkSystem::Halo2); assert_eq!(snark_proof.proof_data.len(), 500); } - + #[test] fn test_verify_groth16() { let mut wrapper = groth16_wrapper(); let vk = wrapper.setup(&circuit_description()).unwrap(); - + let stark_proof = mock_stark_proof(); let snark_proof = wrapper.wrap(&stark_proof).unwrap(); - + let verifier = SnarkVerifier::new(vk); assert!(verifier.verify(&snark_proof).unwrap()); } - + #[test] fn test_verify_plonk() { let mut wrapper = plonk_wrapper(); let vk = wrapper.setup(&circuit_description()).unwrap(); - + let stark_proof = mock_stark_proof(); let snark_proof = wrapper.wrap(&stark_proof).unwrap(); - + let verifier = SnarkVerifier::new(vk); assert!(verifier.verify(&snark_proof).unwrap()); } - + #[test] fn test_verify_halo2() { let mut wrapper = halo2_wrapper(); let vk = wrapper.setup(&circuit_description()).unwrap(); - + let stark_proof = mock_stark_proof(); let snark_proof = wrapper.wrap(&stark_proof).unwrap(); - + let verifier = SnarkVerifier::new(vk); assert!(verifier.verify(&snark_proof).unwrap()); } - + #[test] fn test_snark_proof_serialization() { let mut wrapper = groth16_wrapper(); let _vk = wrapper.setup(&circuit_description()).unwrap(); - + let stark_proof = mock_stark_proof(); let snark_proof = wrapper.wrap(&stark_proof).unwrap(); - + let bytes = snark_proof.to_bytes(); let recovered = SnarkProof::from_bytes(&bytes).unwrap(); - + assert_eq!(snark_proof.system, recovered.system); assert_eq!(snark_proof.proof_data, recovered.proof_data); assert_eq!(snark_proof.vk_hash, recovered.vk_hash); } - + #[test] fn test_expected_sizes() { assert_eq!(SnarkProof::expected_size(SnarkSystem::Groth16), 128); assert_eq!(SnarkProof::expected_size(SnarkSystem::Plonk), 400); assert_eq!(SnarkProof::expected_size(SnarkSystem::Halo2), 500); } - + #[test] fn test_setup_required_error() { let wrapper = groth16_wrapper(); let stark_proof = mock_stark_proof(); - + let result = wrapper.wrap(&stark_proof); assert!(matches!(result, Err(SnarkError::SetupRequired))); } - + #[test] fn test_vk_mismatch_error() { let mut wrapper1 = groth16_wrapper(); let mut wrapper2 = groth16_wrapper(); - + let vk1 = wrapper1.setup(&circuit_description()).unwrap(); - let _vk2 = wrapper2.setup(&CircuitDescription { - num_stark_inputs: 32, // Different! - ..circuit_description() - }).unwrap(); - + let _vk2 = wrapper2 + .setup(&CircuitDescription { + num_stark_inputs: 32, // Different! + ..circuit_description() + }) + .unwrap(); + let stark_proof = mock_stark_proof(); let snark_proof = wrapper2.wrap(&stark_proof).unwrap(); - + let verifier = SnarkVerifier::new(vk1); let result = verifier.verify(&snark_proof); assert!(matches!(result, Err(SnarkError::VerificationKeyMismatch))); } - + #[test] fn test_solidity_verifier_generation() { let mut wrapper = groth16_wrapper(); let vk = wrapper.setup(&circuit_description()).unwrap(); - + let generator = SolidityVerifierGenerator::new(SnarkSystem::Groth16); let solidity = generator.generate(&vk); - + assert!(solidity.contains("Groth16Verifier")); assert!(solidity.contains("function verify")); assert!(solidity.contains("pragma solidity")); } - + #[test] fn test_estimate_gas_cost() { let mut wrapper = groth16_wrapper(); let vk = wrapper.setup(&circuit_description()).unwrap(); let verifier = SnarkVerifier::new(vk); - + assert_eq!(verifier.estimate_gas_cost(), 220_000); } - + #[test] fn test_snark_error_display() { - let err = SnarkError::CircuitTooLarge { size: 1000, max: 500 }; + let err = SnarkError::CircuitTooLarge { + size: 1000, + max: 500, + }; let msg = format!("{}", err); assert!(msg.contains("1000")); assert!(msg.contains("500")); } - + // ======================================================================== // Cryptographic Tests // ======================================================================== - + #[test] fn test_fr_arithmetic() { let a = Fr::from_u64(12345); let b = Fr::from_u64(67890); - + // Addition let sum = a.add(&b); assert_ne!(sum, Fr::ZERO); - + // Subtraction let diff = b.sub(&a); assert_ne!(diff, Fr::ZERO); - + // Multiplication let prod = a.mul(&b); assert_ne!(prod, Fr::ZERO); - + // Identity properties assert_eq!(a.add(&Fr::ZERO), a); assert_eq!(a.mul(&Fr::ONE), a); } - + #[test] fn test_fr_inverse() { let a = Fr::from_u64(42); @@ -2561,59 +2657,59 @@ mod tests { assert_ne!(prod, Fr::ZERO); } } - + #[test] fn test_g1_operations() { let g = G1Affine::generator(); let scalar = Fr::from_u64(5); - + let g_proj = G1Projective::from_affine(&g); - + // Scalar multiplication let result = g_proj.scalar_mul(&scalar); let result_affine = result.to_affine(); - + assert!(!result_affine.infinity); } - + #[test] fn test_g1_double() { let g = G1Affine::generator(); let g_proj = G1Projective::from_affine(&g); - + let doubled = g_proj.double(); let doubled_affine = doubled.to_affine(); - + assert!(!doubled_affine.infinity); assert_ne!(doubled_affine, g); } - + #[test] fn test_g1_add() { let g = G1Affine::generator(); let g_proj = G1Projective::from_affine(&g); - + let sum = g_proj.add(&g_proj); let sum_affine = sum.to_affine(); - + assert!(!sum_affine.infinity); } - + #[test] fn test_fq2_arithmetic() { let a = Fq2::new(Fq::from_u64(1), Fq::from_u64(2)); let b = Fq2::new(Fq::from_u64(3), Fq::from_u64(4)); - + let sum = a.add(&b); assert_ne!(sum, Fq2::ZERO); - + let prod = a.mul(&b); assert_ne!(prod, Fq2::ZERO); - + let sq = a.square(); assert_ne!(sq, Fq2::ZERO); } - + #[test] fn test_pairing() { let g1 = G1Affine::generator(); @@ -2621,32 +2717,32 @@ mod tests { Fq2::new(Fq::from_u64(1), Fq::from_u64(2)), Fq2::new(Fq::from_u64(3), Fq::from_u64(4)), ); - + let result = Gt::pairing(&g1, &g2); - + // Pairing result should be non-trivial assert_ne!(result.value, [0u8; 32]); } - + #[test] fn test_r1cs_constraint() { let mut r1cs = R1csSystem::new(2); - + let x = r1cs.alloc_private(); let y = r1cs.alloc_private(); let z = r1cs.alloc_private(); - + // Add constraint: x * y = z let mut a = LinearCombination::new(); let mut b = LinearCombination::new(); let mut c = LinearCombination::new(); - + a.add_term(x, Fr::ONE); b.add_term(y, Fr::ONE); c.add_term(z, Fr::ONE); - + r1cs.add_constraint(a, b, c); - + // Test with satisfying witness: 1, pub1, pub2, 3, 4, 12 let witness = vec![ Fr::ONE, // constant 1 @@ -2656,40 +2752,40 @@ mod tests { Fr::from_u64(4), // y Fr::from_u64(12), // z = 3 * 4 ]; - + assert!(r1cs.is_satisfied(&witness)); } - + #[test] fn test_groth16_proof_structure() { let mut wrapper = groth16_wrapper(); let _vk = wrapper.setup(&circuit_description()).unwrap(); - + let stark_proof = mock_stark_proof(); let snark_proof = wrapper.wrap(&stark_proof).unwrap(); - + // Verify proof structure assert_eq!(snark_proof.proof_data.len(), 128); assert_eq!(snark_proof.system, SnarkSystem::Groth16); assert!(!snark_proof.public_inputs.is_empty()); - + // Verify vk_hash is set assert_ne!(snark_proof.vk_hash, [0u8; 32]); } - + #[test] fn test_linear_combination_evaluate() { let mut lc = LinearCombination::new(); - lc.add_term(0, Fr::ONE); // 1 * witness[0] - lc.add_term(1, Fr::from_u64(2)); // 2 * witness[1] - lc.add_term(2, Fr::from_u64(3)); // 3 * witness[2] - + lc.add_term(0, Fr::ONE); // 1 * witness[0] + lc.add_term(1, Fr::from_u64(2)); // 2 * witness[1] + lc.add_term(2, Fr::from_u64(3)); // 3 * witness[2] + let witness = vec![ - Fr::from_u64(10), // 10 - Fr::from_u64(20), // 20 - Fr::from_u64(30), // 30 + Fr::from_u64(10), // 10 + Fr::from_u64(20), // 20 + Fr::from_u64(30), // 30 ]; - + let result = lc.evaluate(&witness); // Expected: 1*10 + 2*20 + 3*30 = 10 + 40 + 90 = 140 assert_eq!(result, Fr::from_u64(140)); diff --git a/crates/prover/src/stark.rs b/crates/prover/src/stark.rs index 42d2477..bb4033b 100644 --- a/crates/prover/src/stark.rs +++ b/crates/prover/src/stark.rs @@ -77,12 +77,12 @@ use crate::{ channel::ProverChannel, - commitment::{MerkleTree, MerkleProof}, + commitment::{MerkleProof, MerkleTree}, + fri::{FriConfig, FriProof, FriProver}, lde::TraceLDE, - fri::{FriConfig, FriProver, FriProof}, }; -use zp1_primitives::{M31, QM31, CirclePoint}; -use zp1_air::{CpuTraceRow, ConstraintEvaluator as AirConstraintEvaluator}; +use zp1_air::{ConstraintEvaluator as AirConstraintEvaluator, CpuTraceRow}; +use zp1_primitives::{CirclePoint, M31, QM31}; /// Configuration for the STARK prover. #[derive(Clone, Debug)] @@ -132,7 +132,7 @@ impl StarkConfig { pub fn lde_domain_size(&self) -> usize { self.trace_len() * self.blowup_factor } - + /// Get log of LDE domain size. pub fn log_lde_domain_size(&self) -> usize { self.log_trace_len + self.blowup_factor.trailing_zeros() as usize @@ -212,7 +212,7 @@ impl StarkProver { } /// Enable range checks for 16-bit witness values. - /// + /// /// When enabled, verifies that witness columns (carry, borrow, mul_lo, etc.) /// are in the valid range [0, 2^16). pub fn enable_range_checks(&mut self) { @@ -220,7 +220,7 @@ impl StarkProver { } /// Enable parallel processing for improved performance. - /// + /// /// Uses rayon for parallel: /// - Constraint evaluation /// - Merkle tree construction @@ -246,7 +246,10 @@ impl StarkProver { let num_cols = trace_columns.len(); let trace_len = trace_columns[0].len(); - assert!(trace_len.is_power_of_two(), "Trace length must be power of 2"); + assert!( + trace_len.is_power_of_two(), + "Trace length must be power of 2" + ); assert_eq!(trace_len, self.config.trace_len(), "Trace length mismatch"); // ===== Phase 0: Bind Public Inputs ===== @@ -275,12 +278,12 @@ impl StarkProver { if let Err(e) = ram_prover.verify_consistency() { panic!("Memory consistency check failed: {}", e); } - + // Verify the permutation argument (execution order ↔ sorted order) if !ram_prover.verify_shuffle1() { panic!("RAM permutation argument failed"); } - + // Note: In a complete implementation, we would: // 1. Get RAM columns and add them to the trace // 2. Add RAM constraints to the composition polynomial @@ -292,9 +295,9 @@ impl StarkProver { // Verify 16-bit witness columns are in valid range [0, 2^16) if self.range_check_enabled && num_cols >= 77 { use crate::logup::RangeCheck; - + let mut range_checker = RangeCheck::new(16); // 16-bit range [0, 65536) - + // Column indices for 16-bit witness values (per trace column layout): // imm_lo (8), imm_hi (9), rd_val_lo (10), rd_val_hi (11), // rs1_val_lo (12), rs1_val_hi (13), rs2_val_lo (14), rs2_val_hi (15), @@ -302,18 +305,19 @@ impl StarkProver { // mul_lo (67), mul_hi (68), carry (69), borrow (70), // quotient_lo (71), quotient_hi (72), remainder_lo (73), remainder_hi (74) let witness_column_indices = [ - 8, 9, 10, 11, 12, 13, 14, 15, // Immediate and register values (16-bit limbs) - 62, 63, 64, 65, // Memory address/value limbs - 67, 68, 69, 70, 71, 72, 73, 74, // Multiply/carry/div witnesses + 8, 9, 10, 11, 12, 13, 14, 15, // Immediate and register values (16-bit limbs) + 62, 63, 64, 65, // Memory address/value limbs + 67, 68, 69, 70, 71, 72, 73, 74, // Multiply/carry/div witnesses ]; - + for &col_idx in &witness_column_indices { if col_idx < num_cols { for &value in &trace_columns[col_idx] { if !range_checker.check(value) { panic!( - "Range check failed: column {} value {} exceeds 16-bit range", - col_idx, value.as_u32() + "Range check failed: column {} value {} exceeds 16-bit range", + col_idx, + value.as_u32() ); } } @@ -347,14 +351,10 @@ impl StarkProver { // ===== Phase 3: DEEP Sampling ===== // Sample out-of-domain point let oods_point = self.channel.squeeze_qm31(); - + // Evaluate trace and composition at OOD point - let ood_values = self.evaluate_ood( - &trace_columns, - &composition_evals, - oods_point, - ); - + let ood_values = self.evaluate_ood(&trace_columns, &composition_evals, oods_point); + // Absorb OOD values for &v in &ood_values.trace_at_z { self.channel.absorb_felt(v); @@ -362,12 +362,20 @@ impl StarkProver { self.channel.absorb_felt(ood_values.composition_at_z); // ===== Phase 4: DEEP Quotient and FRI ===== - // Build DEEP quotient polynomial + // Squeeze DEEP combination alphas from transcript (one per trace column + composition) + // This MUST happen after absorbing OOD values and match the verifier's transcript order. + let num_deep_terms = trace_lde.num_columns() + 1; + let deep_alphas: Vec = (0..num_deep_terms) + .map(|_| self.channel.squeeze_challenge()) + .collect(); + + // Build DEEP quotient polynomial using transcript-derived alphas let deep_quotient = self.build_deep_quotient( &trace_lde, &composition_evals, &ood_values, oods_point, + &deep_alphas, ); // FRI commitment @@ -383,9 +391,7 @@ impl StarkProver { // ===== Phase 5: Query Phase ===== // Note: FRI commit already squeezed query indices internally, so we use those - let query_indices: Vec = fri_proof.query_proofs.iter() - .map(|q| q.index) - .collect(); + let query_indices: Vec = fri_proof.query_proofs.iter().map(|q| q.index).collect(); let query_proofs = self.generate_query_proofs( &query_indices, @@ -404,25 +410,34 @@ impl StarkProver { query_proofs, } } - /// Build Merkle tree for trace, committing to ALL columns via interleaving. - /// - /// Interleaves columns as: [col0[0], col1[0], ..., col0[1], col1[1], ...] - /// This ensures all columns are bound to the commitment (critical for soundness). + /// Build Merkle tree for trace, committing to ALL columns via per-row hashing. + /// + /// Each leaf is a Blake3 hash of all column values at that row, so the + /// commitment binds every column (critical for soundness). A single + /// Merkle proof per queried row then authenticates all trace values at + /// that row. fn build_trace_merkle_tree(&self, trace_lde: &TraceLDE) -> MerkleTree { + use blake3::Hasher as Blake3Hasher; + let domain_size = trace_lde.domain_size(); let num_cols = trace_lde.num_columns(); - - // Interleave all columns for commitment - let mut interleaved = Vec::with_capacity(domain_size * num_cols); - for row in 0..domain_size { - for col in 0..num_cols { - interleaved.push(trace_lde.get(col, row)); - } - } - - MerkleTree::new(&interleaved) + + // Hash each row into a single leaf + let row_hashes: Vec<[u8; 32]> = (0..domain_size) + .map(|row| { + let mut h = Blake3Hasher::new(); + // Domain prefix distinguishes trace row hashes from other Blake3 uses + h.update(b"zp1-trace-row"); + for col in 0..num_cols { + h.update(&trace_lde.get(col, row).as_u32().to_le_bytes()); + } + *h.finalize().as_bytes() + }) + .collect(); + + MerkleTree::from_leaf_hashes(row_hashes) } - + /// Squeeze random coefficients for combining constraints. fn squeeze_constraint_alphas(&mut self, num_cols: usize) -> Vec { // Generate enough alphas for boundary + transition constraints @@ -433,16 +448,12 @@ impl StarkProver { } /// Evaluate the composition polynomial at all LDE domain points. - /// + /// /// The composition polynomial combines all AIR constraints: /// C(x) = sum_i alpha_i * C_i(x) / Z_i(x) - /// + /// /// where C_i are constraint polynomials and Z_i are their zerofiers. - fn evaluate_composition_polynomial( - &self, - trace_lde: &TraceLDE, - alphas: &[M31], - ) -> Vec { + fn evaluate_composition_polynomial(&self, trace_lde: &TraceLDE, alphas: &[M31]) -> Vec { let domain_size = trace_lde.domain_size(); let blowup = self.config.blowup_factor; let trace_len = self.config.trace_len(); @@ -454,7 +465,7 @@ impl StarkProver { let trace_row: Vec = (0..trace_lde.num_columns()) .map(|c| trace_lde.get(c, i)) .collect(); - + // Get values at next row (with wraparound) let trace_row_next: Vec = (0..trace_lde.num_columns()) .map(|c| trace_lde.get(c, (i + blowup) % domain_size)) @@ -467,33 +478,33 @@ impl StarkProver { // Z_boundary(x) vanishes at i = 0, blowup, 2*blowup, ... (original trace positions) let is_trace_position = i % blowup == 0; let is_first_row = i < blowup; - + if is_first_row && is_trace_position { // Enforce boundary constraints at first execution step: // 1. PC must equal entry point // 2. next_pc must equal PC + 4 (sequential start) // 3. x0 register (rd=0) must be zero (rd_val_lo and rd_val_hi) - + let pc = trace_row[1]; // Column 1 is PC let next_pc = trace_row[2]; // Column 2 is next_pc let rd_val_lo = trace_row[10]; // Column 10 is rd_val_lo let rd_val_hi = trace_row[11]; // Column 11 is rd_val_hi - + let entry_pc = M31::new(self.config.entry_point & 0x7FFFFFFF); let four = M31::new(4); - + // Boundary constraint 1: PC = entry_point if alpha_idx < alphas.len() { constraint_sum += alphas[alpha_idx] * (pc - entry_pc); } alpha_idx += 1; - + // Boundary constraint 2: next_pc = pc + 4 if alpha_idx < alphas.len() { constraint_sum += alphas[alpha_idx] * (next_pc - pc - four); } alpha_idx += 1; - + // Boundary constraint 3: x0 = 0 (both limbs) // Note: This is enforced globally by x0_zero constraint, // but we add it here for explicit boundary checking @@ -501,7 +512,7 @@ impl StarkProver { constraint_sum += alphas[alpha_idx] * rd_val_lo; } alpha_idx += 1; - + if alpha_idx < alphas.len() { constraint_sum += alphas[alpha_idx] * rd_val_hi; } @@ -512,7 +523,7 @@ impl StarkProver { // Map columns to CpuTraceRow let row = CpuTraceRow::from_slice(&trace_row); let constraints = AirConstraintEvaluator::evaluate_all(&row); - + for c in constraints { if alpha_idx < alphas.len() { constraint_sum += alphas[alpha_idx] * c; @@ -522,14 +533,14 @@ impl StarkProver { // 2. Inter-row constraints (apply to non-last rows) let is_last_row = i >= (trace_len - 1) * blowup && i < trace_len * blowup; - + if !is_last_row { // pc' = next_pc // trace_row_next[1] (pc) == trace_row[2] (next_pc) let pc_next = trace_row_next[1]; let next_pc_curr = trace_row[2]; let pc_consistency = pc_next - next_pc_curr; - + if alpha_idx < alphas.len() { constraint_sum += alphas[alpha_idx] * pc_consistency; } @@ -543,7 +554,7 @@ impl StarkProver { } /// Parallel version of composition polynomial evaluation. - /// + /// /// Uses rayon for ~2-8x speedup on multi-core systems. fn evaluate_composition_polynomial_parallel( &self, @@ -551,7 +562,7 @@ impl StarkProver { alphas: &[M31], ) -> Vec { use rayon::prelude::*; - + let domain_size = trace_lde.domain_size(); let blowup = self.config.blowup_factor; let trace_len = self.config.trace_len(); @@ -563,10 +574,8 @@ impl StarkProver { .into_par_iter() .map(|i| { // Get values at current row - let trace_row: Vec = (0..num_columns) - .map(|c| trace_lde.get(c, i)) - .collect(); - + let trace_row: Vec = (0..num_columns).map(|c| trace_lde.get(c, i)).collect(); + // Get values at next row (with wraparound) let trace_row_next: Vec = (0..num_columns) .map(|c| trace_lde.get(c, (i + blowup) % domain_size)) @@ -578,31 +587,31 @@ impl StarkProver { // Boundary constraints let is_trace_position = i % blowup == 0; let is_first_row = i < blowup; - + if is_first_row && is_trace_position && trace_row.len() >= 12 { let pc = trace_row[1]; let next_pc = trace_row[2]; let rd_val_lo = trace_row[10]; let rd_val_hi = trace_row[11]; - + let entry_pc = M31::new(entry_point & 0x7FFFFFFF); let four = M31::new(4); - + if alpha_idx < alphas_vec.len() { constraint_sum += alphas_vec[alpha_idx] * (pc - entry_pc); } alpha_idx += 1; - + if alpha_idx < alphas_vec.len() { constraint_sum += alphas_vec[alpha_idx] * (next_pc - pc - four); } alpha_idx += 1; - + if alpha_idx < alphas_vec.len() { constraint_sum += alphas_vec[alpha_idx] * rd_val_lo; } alpha_idx += 1; - + if alpha_idx < alphas_vec.len() { constraint_sum += alphas_vec[alpha_idx] * rd_val_hi; } @@ -612,7 +621,7 @@ impl StarkProver { // Intra-row constraints let row = CpuTraceRow::from_slice(&trace_row); let constraints = AirConstraintEvaluator::evaluate_all(&row); - + for c in constraints { if alpha_idx < alphas_vec.len() { constraint_sum += alphas_vec[alpha_idx] * c; @@ -622,12 +631,12 @@ impl StarkProver { // Inter-row constraints let is_last_row = i >= (trace_len - 1) * blowup && i < trace_len * blowup; - + if !is_last_row && trace_row.len() >= 3 && trace_row_next.len() >= 2 { let pc_next = trace_row_next[1]; let next_pc_curr = trace_row[2]; let pc_consistency = pc_next - next_pc_curr; - + if alpha_idx < alphas_vec.len() { constraint_sum += alphas_vec[alpha_idx] * pc_consistency; } @@ -637,9 +646,9 @@ impl StarkProver { }) .collect() } - + /// Evaluate trace and composition at out-of-domain point. - /// + /// /// Uses full QM31 arithmetic for proper security (~128 bits). fn evaluate_ood( &self, @@ -648,30 +657,32 @@ impl StarkProver { z: QM31, ) -> OodValues { // Evaluate trace polynomials at z using full QM31 arithmetic - let trace_at_z: Vec = trace_columns.iter() + let trace_at_z: Vec = trace_columns + .iter() .map(|col| self.evaluate_poly_at_qm31(col, z)) .collect(); - + // For next row: multiply z by domain generator's x-coordinate // In circle domain, the generator advances us to the next point let generator = CirclePoint::generator(self.config.log_trace_len); let gen_x = QM31::from(generator.x); let z_next = z * gen_x; - - let trace_at_z_next: Vec = trace_columns.iter() + + let trace_at_z_next: Vec = trace_columns + .iter() .map(|col| self.evaluate_poly_at_qm31(col, z_next)) .collect(); - + // Composition at z let composition_at_z = self.evaluate_poly_at_qm31(composition_evals, z); - + OodValues { trace_at_z, trace_at_z_next, composition_at_z, } } - + /// Evaluate polynomial at a single point (Horner's method). fn evaluate_poly_at_point(&self, coeffs: &[M31], x: M31) -> M31 { let mut result = M31::ZERO; @@ -682,7 +693,7 @@ impl StarkProver { } /// Evaluate polynomial at a QM31 point using Horner's method. - /// + /// /// Uses full extension field arithmetic for proper security. /// Returns the c0 component (projection to base field). fn evaluate_poly_at_qm31(&self, coeffs: &[M31], x: QM31) -> M31 { @@ -693,11 +704,11 @@ impl StarkProver { // Return c0 component (projection to base field for constraint checking) result.c0 } - + /// Build the DEEP quotient polynomial. - /// + /// /// Q(x) = sum_i alpha_i * (f_i(x) - f_i(z)) / (x - z) - /// + /// /// This "lifts" the low-degree test to include the OOD values. /// Uses actual circle domain points for proper soundness. fn build_deep_quotient( @@ -706,25 +717,20 @@ impl StarkProver { composition_evals: &[M31], ood_values: &OodValues, z: QM31, + deep_alphas: &[M31], ) -> Vec { let domain_size = trace_lde.domain_size(); - - // Get DEEP combination alphas - let num_terms = trace_lde.num_columns() + 1; // trace columns + composition - let deep_alphas: Vec = (0..num_terms) - .map(|i| M31::new((i as u32 + 1) * 7919)) // Deterministic for testing - .collect(); - + let mut quotient = vec![M31::ZERO; domain_size]; - + for i in 0..domain_size { // Get ACTUAL circle domain point x-coordinate (critical for soundness!) let x_i = trace_lde.get_domain_x(i); let x_i_qm31 = QM31::from(x_i); - + // Compute (x_i - z)^(-1) in QM31 for proper security let denom = x_i_qm31 - z; - + // Handle singularity (when x_i == z, which is extremely rare) let denom_inv = if denom.is_zero() { // Skip this point in the quotient (contributes 0) @@ -732,26 +738,27 @@ impl StarkProver { } else { denom.inv() }; - + let mut sum = QM31::ZERO; - + // Add trace column contributions for (col_idx, &ood_val) in ood_values.trace_at_z.iter().enumerate() { let f_x = trace_lde.get(col_idx, i); let numerator = QM31::from(f_x - ood_val); sum = sum + QM31::from(deep_alphas[col_idx]) * numerator * denom_inv; } - + // Add composition contribution let comp_x = composition_evals[i]; let comp_z = ood_values.composition_at_z; let comp_numerator = QM31::from(comp_x - comp_z); - sum = sum + QM31::from(deep_alphas[trace_lde.num_columns()]) * comp_numerator * denom_inv; - + sum = + sum + QM31::from(deep_alphas[trace_lde.num_columns()]) * comp_numerator * denom_inv; + // Project result to base field quotient[i] = sum.c0; } - + quotient } @@ -766,17 +773,17 @@ impl StarkProver { deep_quotient: &[M31], ) -> Vec { let num_cols = trace_lde.num_columns(); - + indices .iter() .map(|&idx| { let trace_values = trace_lde.get_row(idx); - + // For interleaved commitment, prove starting index of the row // Row idx contains indices [idx*num_cols, idx*num_cols + 1, ..., idx*num_cols + num_cols-1] let interleaved_idx = idx * num_cols; let trace_proof = trace_tree.prove(interleaved_idx); - + let composition_value = composition_evals[idx]; let composition_proof = composition_tree.prove(idx); let deep_quotient_value = deep_quotient[idx]; @@ -846,32 +853,32 @@ impl StarkVerifier { pub fn new(config: StarkConfig) -> Self { Self { config } } - + /// Verify a STARK proof. pub fn verify(&self, proof: &StarkProof) -> bool { let mut channel = ProverChannel::new(b"zp1-stark-v1"); - + // Absorb trace commitment channel.absorb(&proof.trace_commitment); - + // Get constraint alphas (must match prover - use same count as num_cols * 2) let num_cols = proof.ood_values.trace_at_z.len(); let _constraint_alphas: Vec = (0..num_cols * 2) .map(|_| channel.squeeze_challenge()) .collect(); - + // Absorb composition commitment channel.absorb(&proof.composition_commitment); - + // Get OOD point let _oods_point = channel.squeeze_qm31(); - + // Absorb OOD values for &v in &proof.ood_values.trace_at_z { channel.absorb_felt(v); } channel.absorb_felt(proof.ood_values.composition_at_z); - + // Verify FRI proof using the same channel state as prover let fri_config = FriConfig { log_domain_size: self.config.log_lde_domain_size(), @@ -880,7 +887,7 @@ impl StarkVerifier { final_degree: 8, }; let _fri_prover = FriProver::new(fri_config); - + // FRI verification follows the prover's channel flow // Absorb layer commitments and squeeze challenges (matches FRI.commit) let mut fri_challenges = Vec::with_capacity(proof.fri_proof.layer_commitments.len()); @@ -888,20 +895,18 @@ impl StarkVerifier { channel.absorb_commitment(commitment); fri_challenges.push(channel.squeeze_challenge()); } - + // Get query indices - squeezed after all FRI layer commitments (matches FRI.commit) - let query_indices = channel.squeeze_query_indices( - self.config.num_queries, - self.config.lde_domain_size(), - ); - + let query_indices = + channel.squeeze_query_indices(self.config.num_queries, self.config.lde_domain_size()); + // Verify Merkle proofs for each query for (q_idx, query) in proof.query_proofs.iter().enumerate() { // Verify query index matches expected if query.index != query_indices[q_idx] { return false; } - + // Verify trace Merkle proof if !query.trace_values.is_empty() { let trace_valid = MerkleTree::verify( @@ -913,7 +918,7 @@ impl StarkVerifier { return false; } } - + // Verify composition Merkle proof let comp_valid = MerkleTree::verify( &proof.composition_commitment, @@ -924,17 +929,17 @@ impl StarkVerifier { return false; } } - + // Verify FRI query proofs for (query_idx, fri_query) in proof.fri_proof.query_proofs.iter().enumerate() { if fri_query.index != query_indices[query_idx] { return false; } - + // Verify folding consistency let mut expected_value: Option = None; let mut current_idx = fri_query.index; - + for (layer_idx, layer_proof) in fri_query.layer_proofs.iter().enumerate() { // Check expected value from previous layer if let Some(expected) = expected_value { @@ -942,18 +947,18 @@ impl StarkVerifier { return false; } } - + // Compute folded value for next layer let alpha = fri_challenges[layer_idx]; let inv_two = M31::new(2).inv(); let sum = layer_proof.value + layer_proof.sibling_value; let diff = layer_proof.value - layer_proof.sibling_value; let folded = sum * inv_two + alpha * diff * inv_two; - + expected_value = Some(folded); current_idx /= 2; } - + // Final polynomial check if let Some(expected) = expected_value { let final_idx = current_idx % proof.fri_proof.final_poly.len(); @@ -965,10 +970,9 @@ impl StarkVerifier { } } } - + // Basic structural checks - !proof.fri_proof.final_poly.is_empty() && - !proof.fri_proof.layer_commitments.is_empty() + !proof.fri_proof.final_poly.is_empty() && !proof.fri_proof.layer_commitments.is_empty() } } @@ -991,7 +995,9 @@ mod tests { // We need enough columns for CpuTraceRow (77 columns) let mut columns: Vec> = Vec::new(); for i in 0..77 { - let col: Vec = (0..trace_len).map(|j| M31::new((i + j as usize) as u32)).collect(); + let col: Vec = (0..trace_len) + .map(|j| M31::new((i + j as usize) as u32)) + .collect(); columns.push(col); } @@ -1013,40 +1019,57 @@ mod tests { assert_eq!(proof.composition_commitment.len(), 32); assert_eq!(proof.query_proofs.len(), 3); assert!(!proof.ood_values.trace_at_z.is_empty()); - + // Verify structural properties - assert!(!proof.fri_proof.layer_commitments.is_empty(), "Should have FRI layers"); - assert!(!proof.fri_proof.final_poly.is_empty(), "Should have final polynomial"); - assert_eq!(proof.fri_proof.query_proofs.len(), 3, "Should have 3 FRI query proofs"); - + assert!( + !proof.fri_proof.layer_commitments.is_empty(), + "Should have FRI layers" + ); + assert!( + !proof.fri_proof.final_poly.is_empty(), + "Should have final polynomial" + ); + assert_eq!( + proof.fri_proof.query_proofs.len(), + 3, + "Should have 3 FRI query proofs" + ); + // Each FRI query should have layer proofs for fri_query in &proof.fri_proof.query_proofs { - assert!(!fri_query.layer_proofs.is_empty(), "FRI query should have layer proofs"); + assert!( + !fri_query.layer_proofs.is_empty(), + "FRI query should have layer proofs" + ); } - + // Query proofs should have valid structure for query in &proof.query_proofs { assert!(query.index < 32, "Query index should be in domain"); assert!(!query.trace_values.is_empty(), "Should have trace values"); - assert!(!query.trace_proof.path.is_empty() || query.trace_proof.path.is_empty(), - "Trace proof should have valid structure"); + assert!( + !query.trace_proof.path.is_empty() || query.trace_proof.path.is_empty(), + "Trace proof should have valid structure" + ); } - + // Test verifier construction (verification may fail due to folding math) let verifier = StarkVerifier::new(config); let _ = verifier.verify(&proof); // Result doesn't matter for structural test } - + #[test] fn test_multi_column_proof() { let trace_len = 8; // We need enough columns for CpuTraceRow (77 columns) let mut columns: Vec> = Vec::new(); for i in 0..77 { - let col: Vec = (0..trace_len).map(|j| M31::new((i * j as usize) as u32)).collect(); + let col: Vec = (0..trace_len) + .map(|j| M31::new((i * j as usize) as u32)) + .collect(); columns.push(col); } - + let config = StarkConfig { log_trace_len: 3, blowup_factor: 4, @@ -1055,65 +1078,65 @@ mod tests { security_bits: 50, entry_point: 0x0, }; - + let mut prover = StarkProver::new(config.clone()); let public_inputs = vec![]; // No public inputs for this test let proof = prover.prove(columns, &public_inputs); - + assert_eq!(proof.ood_values.trace_at_z.len(), 77); assert_eq!(proof.query_proofs[0].trace_values.len(), 77); } - + #[test] fn test_constraint_evaluator() { let evaluator = ConstraintEvaluator::new(1, 2); - + let row = vec![M31::new(5)]; let row_next = vec![M31::new(6)]; let alphas = vec![M31::ONE, M31::ONE]; - + // Boundary constraint at first row let result = evaluator.evaluate(&row, &row_next, &alphas, true); // boundary = 5, transition = 6 - 5 - 1 = 0 assert_eq!(result, M31::new(5)); - + // Non-boundary let result2 = evaluator.evaluate(&row, &row_next, &alphas, false); // Only transition = 0 assert_eq!(result2, M31::ZERO); } - + #[test] fn test_boundary_constraint_entry_point() { // Test that boundary constraints enforce correct entry_point let trace_len = 8; let entry_point = 0x1000u32; - + // Create trace with correct entry point at first row let mut columns: Vec> = Vec::new(); - + // Column 0: clk columns.push((0..trace_len).map(|j| M31::new(j as u32)).collect()); - + // Column 1: PC (should start at entry_point) let mut pc_col = vec![M31::new(entry_point & 0x7FFFFFFF)]; for j in 1..trace_len { pc_col.push(M31::new((entry_point + (j * 4) as u32) & 0x7FFFFFFF)); } columns.push(pc_col); - + // Column 2: next_pc (should be PC + 4 at first row) let mut next_pc_col = vec![M31::new((entry_point + 4) & 0x7FFFFFFF)]; for j in 1..trace_len { next_pc_col.push(M31::new((entry_point + ((j + 1) * 4) as u32) & 0x7FFFFFFF)); } columns.push(next_pc_col); - + // Fill remaining columns (up to 77) with zeros for _ in 3..77 { columns.push(vec![M31::ZERO; trace_len]); } - + let config = StarkConfig { log_trace_len: 3, blowup_factor: 4, @@ -1122,64 +1145,66 @@ mod tests { security_bits: 50, entry_point, }; - + let mut prover = StarkProver::new(config.clone()); let public_inputs = vec![]; - + // Should succeed with correct entry_point let proof = prover.prove(columns.clone(), &public_inputs); assert_eq!(proof.trace_commitment.len(), 32); - + // Now test with WRONG entry point in config (should still generate proof, // but composition polynomial will be non-zero at boundary) let wrong_config = StarkConfig { entry_point: 0x2000u32, // Wrong entry point ..config }; - + let mut wrong_prover = StarkProver::new(wrong_config); let wrong_proof = wrong_prover.prove(columns, &public_inputs); - + // Proof still generates (prover doesn't check constraints) // But composition polynomial at boundary will be non-zero // This would be caught by the verifier assert_eq!(wrong_proof.trace_commitment.len(), 32); - + // The OOD composition value should be different when entry_point mismatches // (In a full implementation, verifier would reject this) - assert!(proof.ood_values.composition_at_z != wrong_proof.ood_values.composition_at_z - || proof.ood_values.composition_at_z == M31::ZERO); + assert!( + proof.ood_values.composition_at_z != wrong_proof.ood_values.composition_at_z + || proof.ood_values.composition_at_z == M31::ZERO + ); } - + #[test] fn test_boundary_constraint_x0_zero() { // Test that boundary constraints enforce x0 = 0 let trace_len = 8; let entry_point = 0x0u32; - + // Create trace where x0 (rd_val at first row) is non-zero let mut columns: Vec> = Vec::new(); - + // Columns 0-9: standard columns for i in 0..10 { columns.push((0..trace_len).map(|j| M31::new((i + j) as u32)).collect()); } - + // Column 10: rd_val_lo (should be 0 at first row for x0 constraint) let mut rd_val_lo = vec![M31::new(42)]; // Non-zero at first row for j in 1..trace_len { rd_val_lo.push(M31::new(j as u32)); } columns.push(rd_val_lo); - + // Column 11: rd_val_hi columns.push(vec![M31::ZERO; trace_len]); - + // Fill remaining columns for _ in 12..77 { columns.push(vec![M31::ZERO; trace_len]); } - + let config = StarkConfig { log_trace_len: 3, blowup_factor: 4, @@ -1188,10 +1213,10 @@ mod tests { security_bits: 50, entry_point, }; - + let mut prover = StarkProver::new(config); let proof = prover.prove(columns, &vec![]); - + // Proof generates but composition should be non-zero at boundary // (would be rejected by verifier) assert_eq!(proof.trace_commitment.len(), 32); diff --git a/crates/verifier/src/verify.rs b/crates/verifier/src/verify.rs index 4c0009a..4f53fa3 100644 --- a/crates/verifier/src/verify.rs +++ b/crates/verifier/src/verify.rs @@ -254,7 +254,7 @@ impl Verifier { } /// Verify a STARK proof. - /// + /// /// # Arguments /// * `proof` - The STARK proof to verify /// * `public_inputs` - Public inputs that the proof is bound to @@ -287,14 +287,14 @@ impl Verifier { // Step 4: Get DEEP/OODS sampling point let oods_point = channel.squeeze_extension_challenge(); - // Step 5: Absorb OOD values into transcript (must match prover exactly) - // CRITICAL: Prover only absorbs trace_at_z and composition_at_z, not trace_at_z_next - for v in &proof.ood_values.trace_at_z { - channel.absorb_felt(*v); - } - // Note: trace_at_z_next is NOT absorbed to match prover transcript - channel.absorb_felt(proof.ood_values.composition_at_z); - + // Step 5: Absorb OOD values into transcript (must match prover exactly) + // CRITICAL: Prover only absorbs trace_at_z and composition_at_z, not trace_at_z_next + for v in &proof.ood_values.trace_at_z { + channel.absorb_felt(*v); + } + // Note: trace_at_z_next is NOT absorbed to match prover transcript + channel.absorb_felt(proof.ood_values.composition_at_z); + // Generate DEEP combination alphas (for linear combination of quotients) // Need one alpha per trace column + one for composition let num_deep_terms = proof.ood_values.trace_at_z.len() + 1; @@ -310,10 +310,8 @@ impl Verifier { } // Step 7: Get query indices (must match prover's) - let query_indices = channel.squeeze_query_indices( - self.config.num_queries, - self.config.lde_domain_size(), - ); + let query_indices = + channel.squeeze_query_indices(self.config.num_queries, self.config.lde_domain_size()); // Step 8: Verify query count if proof.query_proofs.len() != self.config.num_queries { @@ -339,34 +337,36 @@ impl Verifier { // Verify trace Merkle proof (single-leaf commitment per row) if query_proof.trace_values.len() != trace_width { return Err(VerifyError::ConstraintError { - constraint: format!("Trace width mismatch: expected {}, got {}", trace_width, query_proof.trace_values.len()), + constraint: format!( + "Trace width mismatch: expected {}, got {}", + trace_width, + query_proof.trace_values.len() + ), }); } let trace_value = query_proof.trace_values[0]; - if !query_proof.trace_proof.verify(&proof.trace_commitment, trace_value) { + if !query_proof + .trace_proof + .verify(&proof.trace_commitment, trace_value) + { return Err(VerifyError::MerkleError { index: query_proof.index, }); } // Verify composition Merkle proof - if !query_proof.composition_proof.verify( - &proof.composition_commitment, - query_proof.composition_value, - ) { + if !query_proof + .composition_proof + .verify(&proof.composition_commitment, query_proof.composition_value) + { return Err(VerifyError::MerkleError { index: query_proof.index, }); } // Verify DEEP quotient (ensures trace/composition values are consistent with OOD samples) - self.verify_deep_quotient( - query_proof, - oods_point, - &proof.ood_values, - &deep_alphas, - )?; + self.verify_deep_quotient(query_proof, oods_point, &proof.ood_values, &deep_alphas)?; // Verify constraint consistency placeholder self.verify_constraint_consistency(query_proof, &oods_point)?; @@ -443,22 +443,22 @@ impl Verifier { // For simplicity, use query.index as M31 (real impl would use circle domain point) let domain_point_m31 = M31::new(query.index as u32); let domain_point = QM31::from(domain_point_m31); - + // Compute denominator: X - z let denom = domain_point - oods_point; - + // Check for division by zero (would indicate z is in LDE domain, which breaks soundness) if denom == QM31::ZERO { return Err(VerifyError::ConstraintError { constraint: "OODS point collides with query point (denominator is zero)".into(), }); } - + let denom_inv = denom.inv(); - + // Compute expected DEEP quotient: Σ α_i · (f_i(X) - f_i(z)) / (X - z) let mut expected_deep = QM31::ZERO; - + // Trace columns contribution for (col_idx, &trace_val_at_x) in query.trace_values.iter().enumerate() { if col_idx >= ood_values.trace_at_z.len() { @@ -470,7 +470,7 @@ impl Verifier { ), }); } - + if col_idx >= deep_alphas.len() { return Err(VerifyError::InvalidProof { reason: format!( @@ -480,17 +480,17 @@ impl Verifier { ), }); } - + let trace_val_at_z = ood_values.trace_at_z[col_idx]; - + // Numerator: f_i(X) - f_i(z) let numerator = QM31::from(trace_val_at_x) - QM31::from(trace_val_at_z); - + // Contribution: α_i · numerator / (X - z) let contribution = QM31::from(deep_alphas[col_idx]) * numerator * denom_inv; expected_deep = expected_deep + contribution; } - + // Composition polynomial contribution let comp_alpha_idx = query.trace_values.len(); if comp_alpha_idx >= deep_alphas.len() { @@ -502,35 +502,29 @@ impl Verifier { ), }); } - - let comp_numerator = QM31::from(query.composition_value) - - QM31::from(ood_values.composition_at_z); - let comp_contribution = QM31::from(deep_alphas[comp_alpha_idx]) - * comp_numerator * denom_inv; + + let comp_numerator = + QM31::from(query.composition_value) - QM31::from(ood_values.composition_at_z); + let comp_contribution = + QM31::from(deep_alphas[comp_alpha_idx]) * comp_numerator * denom_inv; expected_deep = expected_deep + comp_contribution; - + // Convert to M31 for comparison (taking real part of QM31) // Note: In a complete implementation, the DEEP quotient might be QM31, // but current proof structure stores it as M31 let expected_deep_m31 = expected_deep.c0; - + // Compare with claimed FRI value // Allow small numerical differences due to field arithmetic if expected_deep_m31 != query.deep_quotient_value { - return Err(VerifyError::DeepQuotientMismatch { - index: query.index, - }); + return Err(VerifyError::DeepQuotientMismatch { index: query.index }); } - + Ok(()) } /// Verify the FRI proof. - fn verify_fri( - &self, - fri_proof: &FriProof, - alphas: &[M31], - ) -> VerifyResult<()> { + fn verify_fri(&self, fri_proof: &FriProof, alphas: &[M31]) -> VerifyResult<()> { // Verify each query through the FRI layers for (query_idx, fri_query) in fri_proof.query_proofs.iter().enumerate() { self.verify_fri_query(fri_proof, fri_query, alphas, query_idx)?; @@ -538,7 +532,7 @@ impl Verifier { // Verify final polynomial is low-degree // (In a complete implementation, would evaluate final_poly at random points) - + Ok(()) } @@ -590,10 +584,15 @@ impl Verifier { } // Compute folded value for next layer (twin point folding) - let alpha = alphas.get(layer_idx).copied().ok_or_else(|| VerifyError::FriStructure { - reason: "Missing FRI alpha".into(), - })?; - let folded = fri_utils::compute_fold(layer_proof.value, layer_proof.sibling_value, alpha); + let alpha = + alphas + .get(layer_idx) + .copied() + .ok_or_else(|| VerifyError::FriStructure { + reason: "Missing FRI alpha".into(), + })?; + let folded = + fri_utils::compute_fold(layer_proof.value, layer_proof.sibling_value, alpha); expected_next = Some(folded); current_index /= 2; } @@ -601,7 +600,9 @@ impl Verifier { // Final polynomial check if let Some(expected) = expected_next { if fri_proof.final_poly.is_empty() { - return Err(VerifyError::FriStructure { reason: "Empty final polynomial".into() }); + return Err(VerifyError::FriStructure { + reason: "Empty final polynomial".into(), + }); } let final_idx = current_index % fri_proof.final_poly.len(); let final_val = fri_proof.final_poly[final_idx]; @@ -639,7 +640,7 @@ pub mod fri_utils { if coeffs.is_empty() { return M31::ZERO; } - + let mut result = coeffs[coeffs.len() - 1]; for i in (0..coeffs.len() - 1).rev() { result = result * x + coeffs[i]; @@ -679,11 +680,11 @@ mod tests { leaf_index: 0, path: vec![], }; - + // Single leaf tree - root equals leaf hash (with domain separation) let leaf = M31::new(42); let root = hash_leaf_m31(leaf); - + assert!(proof.verify(&root, leaf)); assert!(!proof.verify(&root, M31::new(43))); } @@ -693,7 +694,7 @@ mod tests { let even = M31::new(10); let odd = M31::new(20); let alpha = M31::new(3); - + let folded = fri_utils::compute_fold(even, odd, alpha); // (10+20)/2 + 3*(10-20)/2 = 15 + 3*(-10)/2 = 15 - 15 = 0 assert_eq!(folded.as_u32(), 0); @@ -703,13 +704,13 @@ mod tests { fn test_evaluate_poly() { // p(x) = 1 + 2x + 3x^2 let coeffs = vec![M31::new(1), M31::new(2), M31::new(3)]; - + // p(0) = 1 assert_eq!(fri_utils::evaluate_poly(&coeffs, M31::ZERO).as_u32(), 1); - + // p(1) = 1 + 2 + 3 = 6 assert_eq!(fri_utils::evaluate_poly(&coeffs, M31::ONE).as_u32(), 6); - + // p(2) = 1 + 4 + 12 = 17 assert_eq!(fri_utils::evaluate_poly(&coeffs, M31::new(2)).as_u32(), 17); }