Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ zp1-trace = { workspace = true }
zp1-air = { workspace = true }
sha2 = { workspace = true }
blake3 = { workspace = true }
rand = { workspace = true }
rayon = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
Expand All @@ -29,7 +30,6 @@ metal = { version = "0.28", optional = true }
objc = "0.2"

[dev-dependencies]
rand = { workspace = true }
criterion = "0.5"

[[bench]]
Expand Down
60 changes: 30 additions & 30 deletions crates/prover/src/bitwise_tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ impl BitwiseLookupTables {
let mut and_values = Vec::with_capacity(size);
let mut or_values = Vec::with_capacity(size);
let mut xor_values = Vec::with_capacity(size);

for a in 0u32..256 {
for b in 0u32..256 {
and_values.push(M31::new(a & b));
or_values.push(M31::new(a | b));
xor_values.push(M31::new(a ^ b));
}
}

Self {
and_values,
or_values,
Expand All @@ -52,31 +52,31 @@ impl BitwiseLookupTables {
xor_mult: vec![0; size],
}
}

/// Look up 8-bit AND: returns a & b and increments multiplicity
#[inline]
pub fn and8(&mut self, a: u8, b: u8) -> M31 {
let idx = (a as usize) * 256 + (b as usize);
self.and_mult[idx] += 1;
self.and_values[idx]
}

/// Look up 8-bit OR: returns a | b and increments multiplicity
#[inline]
pub fn or8(&mut self, a: u8, b: u8) -> M31 {
let idx = (a as usize) * 256 + (b as usize);
self.or_mult[idx] += 1;
self.or_values[idx]
}

/// Look up 8-bit XOR: returns a ^ b and increments multiplicity
#[inline]
pub fn xor8(&mut self, a: u8, b: u8) -> M31 {
let idx = (a as usize) * 256 + (b as usize);
self.xor_mult[idx] += 1;
self.xor_values[idx]
}

/// Perform 32-bit AND using 4 byte-wise lookups.
pub fn and32(&mut self, a: u32, b: u32) -> u32 {
let mut result = 0u32;
Expand All @@ -88,7 +88,7 @@ impl BitwiseLookupTables {
}
result
}

/// Perform 32-bit OR using 4 byte-wise lookups.
pub fn or32(&mut self, a: u32, b: u32) -> u32 {
let mut result = 0u32;
Expand All @@ -100,7 +100,7 @@ impl BitwiseLookupTables {
}
result
}

/// Perform 32-bit XOR using 4 byte-wise lookups.
pub fn xor32(&mut self, a: u32, b: u32) -> u32 {
let mut result = 0u32;
Expand All @@ -112,12 +112,12 @@ impl BitwiseLookupTables {
}
result
}

/// Get multiplicities for LogUp proof generation.
pub fn get_multiplicities(&self) -> (&[u32], &[u32], &[u32]) {
(&self.and_mult, &self.or_mult, &self.xor_mult)
}

/// Get table values for LogUp proof generation.
pub fn get_values(&self) -> (&[M31], &[M31], &[M31]) {
(&self.and_values, &self.or_values, &self.xor_values)
Expand All @@ -133,55 +133,55 @@ impl Default for BitwiseLookupTables {
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_and() {
let mut tables = BitwiseLookupTables::new();
assert_eq!(tables.and8(0, 0), M31::new(0)); // 0 & 0 = 0
assert_eq!(tables.and8(255, 255), M31::new(255)); // 0xFF & 0xFF = 0xFF
assert_eq!(tables.and8(0xAA, 0x55), M31::new(0)); // 0xAA & 0x55 = 0

assert_eq!(tables.and8(0, 0), M31::new(0)); // 0 & 0 = 0
assert_eq!(tables.and8(255, 255), M31::new(255)); // 0xFF & 0xFF = 0xFF
assert_eq!(tables.and8(0xAA, 0x55), M31::new(0)); // 0xAA & 0x55 = 0
}

#[test]
fn test_or() {
let mut tables = BitwiseLookupTables::new();
assert_eq!(tables.or8(0, 0), M31::new(0)); // 0 | 0 = 0
assert_eq!(tables.or8(0xAA, 0x55), M31::new(255)); // 0xAA | 0x55 = 0xFF

assert_eq!(tables.or8(0, 0), M31::new(0)); // 0 | 0 = 0
assert_eq!(tables.or8(0xAA, 0x55), M31::new(255)); // 0xAA | 0x55 = 0xFF
}

#[test]
fn test_xor() {
let mut tables = BitwiseLookupTables::new();
assert_eq!(tables.xor8(0xFF, 0xFF), M31::new(0)); // 0xFF ^ 0xFF = 0

assert_eq!(tables.xor8(0xFF, 0xFF), M31::new(0)); // 0xFF ^ 0xFF = 0
assert_eq!(tables.xor8(0xAA, 0x55), M31::new(255)); // 0xAA ^ 0x55 = 0xFF
}

#[test]
fn test_bitwise_32bit() {
let mut tables = BitwiseLookupTables::new();

let a = 0xDEADBEEF_u32;
let b = 0xCAFEBABE_u32;

assert_eq!(tables.and32(a, b), a & b);
assert_eq!(tables.or32(a, b), a | b);
assert_eq!(tables.xor32(a, b), a ^ b);
}

#[test]
fn test_multiplicity_tracking() {
let mut tables = BitwiseLookupTables::new();

// Perform some lookups
tables.and8(0x12, 0x34);
tables.and8(0x12, 0x34); // Same lookup twice
tables.and8(0x12, 0x34); // Same lookup twice

let (and_mult, _, _) = tables.get_multiplicities();
let idx = 0x12 * 256 + 0x34;

// Should have been looked up twice
assert_eq!(and_mult[idx], 2);
}
Expand Down
92 changes: 63 additions & 29 deletions crates/prover/src/channel.rs
Original file line number Diff line number Diff line change
@@ -1,56 +1,72 @@
//! Fiat-Shamir transcript channel for the prover.
//! Fiat-Shamir transcript channel for the prover using SHA-256.
//!
//! This uses the same SHA-256-based transcript construction as the verifier
//! channel, ensuring prover and verifier derive identical challenges from the
//! same sequence of transcript messages.
//!
//! # Domain Separation
//!
//! Every channel is initialized by hashing the domain separator, binding all
//! subsequent challenges to the specific protocol context.
//!
//! # Byte-to-Field Encoding
//!
//! `absorb` prefixes each call with the byte-length of the data before hashing,
//! making the encoding injective for sequences of variable-length inputs.

use sha2::{Digest, Sha256};
use zp1_primitives::{M31, QM31};

/// Prover channel for Fiat-Shamir transcript.
/// Prover channel for Fiat-Shamir transcript (mirrors `VerifierChannel`).
#[derive(Clone)]
pub struct ProverChannel {
/// SHA256 state.
hasher: Sha256,
/// Transcript bytes for debugging.
transcript: Vec<u8>,
}

impl ProverChannel {
/// Create a new prover channel.
/// Create a new prover channel bound to `domain_separator`.
///
/// The domain separator is absorbed first, so two channels with different
/// separators always produce different challenges.
pub fn new(domain_separator: &[u8]) -> Self {
let mut ch = Self {
hasher: Sha256::new(),
transcript: Vec::new(),
};
ch.absorb(domain_separator);
ch
let mut hasher = Sha256::new();
// Length-prefix the domain separator to avoid extension collisions
hasher.update((domain_separator.len() as u64).to_le_bytes());
hasher.update(domain_separator);
Self { hasher }
}

/// Absorb bytes into the transcript.
/// Absorb arbitrary bytes into the transcript.
///
/// The byte-length is written before the data so that, e.g.,
/// `absorb(b"ab")` and `absorb(b"a"); absorb(b"b")` produce
/// different transcript states.
pub fn absorb(&mut self, data: &[u8]) {
// Length prefix for injectivity across variable-length inputs
self.hasher.update((data.len() as u64).to_le_bytes());
self.hasher.update(data);
self.transcript.extend_from_slice(data);
}

/// Absorb a 32-byte commitment.
/// Absorb a 32-byte commitment into the transcript.
pub fn absorb_commitment(&mut self, commitment: &[u8; 32]) {
self.absorb(commitment);
}

/// Absorb an M31 field element.
/// Absorb an M31 field element into the transcript.
pub fn absorb_felt(&mut self, felt: M31) {
self.absorb(&felt.as_u32().to_le_bytes());
self.hasher.update(felt.as_u32().to_le_bytes());
}

/// Squeeze a challenge in M31.
pub fn squeeze_challenge(&mut self) -> M31 {
let hash = self.hasher.clone().finalize();
self.hasher.update(&hash);

// Take first 4 bytes, reduce mod P
let bytes: [u8; 4] = hash[0..4].try_into().unwrap();
let val = u32::from_le_bytes(bytes);
M31::new(val % M31::P)
}

/// Squeeze a challenge in QM31 (extension field).
/// Squeeze a challenge in QM31 (four independent M31 challenges).
pub fn squeeze_extension_challenge(&mut self) -> QM31 {
let c0 = self.squeeze_challenge();
let c1 = self.squeeze_challenge();
Expand All @@ -59,19 +75,17 @@ impl ProverChannel {
QM31::new(c0, c1, c2, c3)
}

/// Alias for squeeze_extension_challenge.
/// Alias for `squeeze_extension_challenge`.
pub fn squeeze_qm31(&mut self) -> QM31 {
self.squeeze_extension_challenge()
}

/// Squeeze n query indices in range [0, domain_size).
/// Squeeze `n` query indices in `[0, domain_size)`.
pub fn squeeze_query_indices(&mut self, n: usize, domain_size: usize) -> Vec<usize> {
let mut indices = Vec::with_capacity(n);
while indices.len() < n {
let hash = self.hasher.clone().finalize();
self.hasher.update(&hash);

// Extract multiple indices from each hash
for chunk in hash.chunks(4) {
if indices.len() >= n {
break;
Expand All @@ -84,11 +98,6 @@ impl ProverChannel {
indices.truncate(n);
indices
}

/// Get the current transcript length.
pub fn transcript_len(&self) -> usize {
self.transcript.len()
}
}

impl Default for ProverChannel {
Expand All @@ -115,6 +124,31 @@ mod tests {
assert_eq!(c1, c2);
}

#[test]
fn test_domain_separator_matters() {
let mut ch1 = ProverChannel::new(b"protocol-a");
let mut ch2 = ProverChannel::new(b"protocol-b");

ch1.absorb(b"same data");
ch2.absorb(b"same data");

// Different domain separators must yield different challenges
assert_ne!(ch1.squeeze_challenge(), ch2.squeeze_challenge());
}

#[test]
fn test_absorb_injective() {
// absorb(b"ab") vs absorb(b"a") + absorb(b"b") must differ
let mut ch1 = ProverChannel::new(b"test");
ch1.absorb(b"ab");

let mut ch2 = ProverChannel::new(b"test");
ch2.absorb(b"a");
ch2.absorb(b"b");

assert_ne!(ch1.squeeze_challenge(), ch2.squeeze_challenge());
}

#[test]
fn test_query_indices() {
let mut ch = ProverChannel::new(b"test");
Expand Down
Loading