add delayed modular reduction for split-eq sumcheck#105
Merged
srinathsetty merged 9 commits intomicrosoft:mainfrom Feb 25, 2026
Merged
add delayed modular reduction for split-eq sumcheck#105srinathsetty merged 9 commits intomicrosoft:mainfrom
srinathsetty merged 9 commits intomicrosoft:mainfrom
Conversation
Implements batched Montgomery reduction that defers O(N) reductions to a single reduction at the end of inner-product computations. Performance (BN254 split-eq sumcheck): - 1.09-1.35× speedup across all problem sizes - n=2²⁶: 2.64s → 2.04s (1.29× faster) Changes: - Add big_num module with WideLimbs accumulator and montgomery_reduce_9 - Update sumcheck to use two-phase accumulation with DMR - Add BN254 provider and criterion benchmarks
- Create macros.rs with impl_field_reduction_constants! and impl_montgomery_limbs! macros for zero-config field support - Unify limb operations in limbs.rs as const fn (usable at both compile-time and runtime, eliminating duplication) - Add P256 (secp256r1) field support via single-line macro invocations - Rename MAX_CANONICALIZE_SUBS to MAX_REDC_SUB_CORRECTIONS for clearer Montgomery REDC terminology - Rename compute_q to compute_max_redc_sub_corrections - Add test macros for consistent test generation across field types - Remove unused DelayedReduction import from sumcheck.rs
- Add test_first_round_spartan_sumcheck test to src/sumcheck.rs using existing start_span! tracing infrastructure for timing - Delete benches/sumcheck.rs criterion benchmark - Remove criterion dev-dependency from Cargo.toml - Make big_num, polys, sumcheck modules private in lib.rs - Re-export DelayedReduction and FieldReductionConstants from traits module for external access - Clean up unused zip_with imports in polys modules
wu-s-john
commented
Feb 23, 2026
Contributor
Author
|
So, I added delayed modular reduction to the different sumcheck-like functions (these are unlocked and revealed in later PRs), here are the results that I have Spartan Benchmark: main vs benching/sha256-is-fasterSHA256 Single Hash (msg=1024B)
SHA256 Chain (msg=32B, chain=1028)
SummaryConsistent improvements across both benchmarks:
Small benchmark (1024B single hash):
Large benchmark (1028-chain):
|
Contributor
Author
|
Here is the benching code I used: // Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: MIT
//! examples/sha256_chain.rs
//! Configurable SHA-256 chain benchmark for Spartan and SpartanZK.
//!
//! Run with:
//! RUST_LOG=info cargo run --release --example sha256_chain -- -m 32 -c 4 -e bn254
//! RUST_LOG=info cargo run --release --example sha256_chain -- -m 32 -c 4 -e bn254 --zk
#[cfg(feature = "jem")]
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
use bellpepper::gadgets::sha256::sha256;
use bellpepper_core::{
ConstraintSystem, SynthesisError,
boolean::{AllocatedBit, Boolean},
num::AllocatedNum,
};
use clap::Parser;
use ff::{Field, PrimeField, PrimeFieldBits};
use sha2::{Digest, Sha256};
use spartan2::{
provider::{Bn254Engine, P256HyraxEngine, PallasHyraxEngine, T256HyraxEngine, VestaHyraxEngine},
spartan::SpartanSNARK,
spartan_zk::SpartanZkSNARK,
timing::{
SPARTAN_PHASES, SPARTAN_ZK_PHASES, TimingLayer, clear_timings, print_single_table,
snapshot_timings,
},
traits::{Engine, circuit::SpartanCircuit, pcs::FoldingEngineTrait, snark::R1CSSNARKTrait},
};
use std::{marker::PhantomData, time::Instant};
use tracing::info;
use tracing_subscriber::{EnvFilter, Layer as _, layer::SubscriberExt, util::SubscriberInitExt};
#[derive(Parser)]
#[command(about = "Configurable SHA-256 chain benchmark")]
struct Args {
/// Initial message size in bytes
#[arg(short = 'm', long, default_value = "32")]
msg_bytes: usize,
/// Number of SHA256 iterations (chain length)
#[arg(short = 'c', long, default_value = "1")]
chain: usize,
/// Engine to use: bn254, pallas, vesta, p256, t256
#[arg(short = 'e', long, default_value = "bn254")]
engine: String,
/// Use SpartanZK instead of Spartan (adds zero-knowledge)
#[arg(short = 'z', long)]
zk: bool,
}
/// SHA256 chain circuit: computes H(H(H(...H(msg)))) for `chain_length` iterations.
#[derive(Clone, Debug)]
struct Sha256ChainCircuit<Scalar: PrimeField> {
preimage: Vec<u8>,
chain_length: usize,
_p: PhantomData<Scalar>,
}
impl<Scalar: PrimeField + PrimeFieldBits> Sha256ChainCircuit<Scalar> {
fn new(preimage: Vec<u8>, chain_length: usize) -> Self {
Self {
preimage,
chain_length,
_p: PhantomData,
}
}
/// Compute the final hash after `chain_length` iterations.
fn compute_final_hash(&self) -> [u8; 32] {
// First hash the preimage
let mut hasher = Sha256::new();
hasher.update(&self.preimage);
let mut current: [u8; 32] = hasher.finalize().into();
// Chain additional iterations (already did 1 above)
for _ in 1..self.chain_length {
let mut hasher = Sha256::new();
hasher.update(current);
current = hasher.finalize().into();
}
current
}
}
impl<E: Engine> SpartanCircuit<E> for Sha256ChainCircuit<E::Scalar> {
fn public_values(&self) -> Result<Vec<<E as Engine>::Scalar>, SynthesisError> {
let final_hash = self.compute_final_hash();
let hash_scalars: Vec<<E as Engine>::Scalar> = final_hash
.iter()
.flat_map(|&byte| {
(0..8).rev().map(move |i| {
if (byte >> i) & 1 == 1 {
E::Scalar::ONE
} else {
E::Scalar::ZERO
}
})
})
.collect();
Ok(hash_scalars)
}
fn shared<CS: ConstraintSystem<E::Scalar>>(
&self,
_: &mut CS,
) -> Result<Vec<AllocatedNum<E::Scalar>>, SynthesisError> {
Ok(vec![])
}
fn precommitted<CS: ConstraintSystem<E::Scalar>>(
&self,
cs: &mut CS,
_: &[AllocatedNum<E::Scalar>],
) -> Result<Vec<AllocatedNum<E::Scalar>>, SynthesisError> {
// Allocate preimage bits
let bit_values: Vec<_> = self
.preimage
.iter()
.flat_map(|&byte| (0..8).rev().map(move |i| (byte >> i) & 1 == 1))
.collect();
let preimage_bits: Vec<Boolean> = bit_values
.into_iter()
.enumerate()
.map(|(i, bit)| {
AllocatedBit::alloc(cs.namespace(|| format!("preimage bit {i}")), Some(bit))
.map(Boolean::from)
})
.collect::<Result<Vec<_>, _>>()?;
// First SHA256 on the preimage
let mut current_bits = sha256(cs.namespace(|| "sha256 round 0"), &preimage_bits)?;
// Chain additional SHA256 iterations
for round in 1..self.chain_length {
current_bits = sha256(cs.namespace(|| format!("sha256 round {round}")), ¤t_bits)?;
}
// Verify against expected final hash
let expected_hash = self.compute_final_hash();
let mut expected_bits = expected_hash
.iter()
.flat_map(|&byte| (0..8).rev().map(move |i| (byte >> i) & 1 == 1));
for b in ¤t_bits {
match b {
Boolean::Is(bit) => assert_eq!(expected_bits.next().unwrap(), bit.get_value().unwrap()),
Boolean::Not(bit) => assert_ne!(expected_bits.next().unwrap(), bit.get_value().unwrap()),
Boolean::Constant(_) => unreachable!(),
}
}
// Expose final hash as public input (256 bits)
for (i, bit) in current_bits.iter().enumerate() {
let n = AllocatedNum::alloc_input(cs.namespace(|| format!("output bit {i}")), || {
Ok(
if bit.get_value().ok_or(SynthesisError::AssignmentMissing)? {
E::Scalar::ONE
} else {
E::Scalar::ZERO
},
)
})?;
cs.enforce(
|| format!("bit == num {i}"),
|_| bit.lc(CS::one(), E::Scalar::ONE),
|lc| lc + CS::one(),
|lc| lc + n.get_variable(),
);
}
Ok(vec![])
}
fn num_challenges(&self) -> usize {
0
}
fn synthesize<CS: ConstraintSystem<E::Scalar>>(
&self,
_: &mut CS,
_: &[AllocatedNum<E::Scalar>],
_: &[AllocatedNum<E::Scalar>],
_: Option<&[E::Scalar]>,
) -> Result<(), SynthesisError> {
Ok(())
}
}
/// Run benchmark with Spartan (non-ZK)
fn run_spartan<E: Engine>(
msg_bytes: usize,
chain_length: usize,
engine_name: &str,
timing_data: &spartan2::timing::TimingData,
) {
let circuit = Sha256ChainCircuit::<E::Scalar>::new(vec![0u8; msg_bytes], chain_length);
eprintln!(
"\n======= SHA256 Chain: engine={}, msg={}B, chain={}, zk=false =======",
engine_name, msg_bytes, chain_length
);
// SETUP
let t0 = Instant::now();
let (pk, vk) = SpartanSNARK::<E>::setup(circuit.clone()).expect("setup failed");
let setup_ms = t0.elapsed().as_millis();
info!(elapsed_ms = setup_ms as u64, "setup");
clear_timings(timing_data);
// PREP_PROVE
let t0 = Instant::now();
let prep_snark =
SpartanSNARK::<E>::prep_prove(&pk, circuit.clone(), false).expect("prep_prove failed");
let prep_ms = t0.elapsed().as_millis();
info!(elapsed_ms = prep_ms as u64, "prep_prove");
// PROVE
let t0 = Instant::now();
let proof =
SpartanSNARK::<E>::prove(&pk, circuit, &prep_snark, false).expect("prove failed");
let prove_ms = t0.elapsed().as_millis();
info!(elapsed_ms = prove_ms as u64, "prove");
// VERIFY
let t0 = Instant::now();
proof.verify(&vk).expect("verify failed");
let verify_ms = t0.elapsed().as_millis();
info!(elapsed_ms = verify_ms as u64, "verify");
// Print timing table
let timings = snapshot_timings(timing_data, SPARTAN_PHASES);
let header = format!(
"setup={}ms, prep={}ms, prove={}ms, verify={}ms",
setup_ms, prep_ms, prove_ms, verify_ms
);
print_single_table(&header, SPARTAN_PHASES, &timings);
}
/// Run benchmark with SpartanZK (with zero-knowledge)
fn run_spartan_zk<E: Engine>(
msg_bytes: usize,
chain_length: usize,
engine_name: &str,
timing_data: &spartan2::timing::TimingData,
)
where
E::PCS: FoldingEngineTrait<E>,
{
let circuit = Sha256ChainCircuit::<E::Scalar>::new(vec![0u8; msg_bytes], chain_length);
eprintln!(
"\n======= SHA256 Chain: engine={}, msg={}B, chain={}, zk=true =======",
engine_name, msg_bytes, chain_length
);
// SETUP
let t0 = Instant::now();
let (pk, vk) = SpartanZkSNARK::<E>::setup(circuit.clone()).expect("setup failed");
let setup_ms = t0.elapsed().as_millis();
info!(elapsed_ms = setup_ms as u64, "setup");
clear_timings(timing_data);
// PREP_PROVE
let t0 = Instant::now();
let prep_snark =
SpartanZkSNARK::<E>::prep_prove(&pk, circuit.clone(), false).expect("prep_prove failed");
let prep_ms = t0.elapsed().as_millis();
info!(elapsed_ms = prep_ms as u64, "prep_prove");
// PROVE
let t0 = Instant::now();
let proof =
SpartanZkSNARK::<E>::prove(&pk, circuit, &prep_snark, false).expect("prove failed");
let prove_ms = t0.elapsed().as_millis();
info!(elapsed_ms = prove_ms as u64, "prove");
// VERIFY
let t0 = Instant::now();
proof.verify(&vk).expect("verify failed");
let verify_ms = t0.elapsed().as_millis();
info!(elapsed_ms = verify_ms as u64, "verify");
// Print timing table
let timings = snapshot_timings(timing_data, SPARTAN_ZK_PHASES);
let header = format!(
"setup={}ms, prep={}ms, prove={}ms, verify={}ms",
setup_ms, prep_ms, prove_ms, verify_ms
);
print_single_table(&header, SPARTAN_ZK_PHASES, &timings);
}
fn main() {
let args = Args::parse();
let (timing_layer, timing_data, _constraints_data) = TimingLayer::new();
tracing_subscriber::registry()
.with(timing_layer)
.with(
tracing_subscriber::fmt::layer()
.with_target(false)
.with_ansi(true)
.with_writer(std::io::stderr)
.with_filter(EnvFilter::from_default_env()),
)
.init();
let msg_bytes = args.msg_bytes;
let chain_length = args.chain;
let engine_name = args.engine.as_str();
match (engine_name, args.zk) {
("bn254", false) => run_spartan::<Bn254Engine>(msg_bytes, chain_length, engine_name, &timing_data),
("bn254", true) => run_spartan_zk::<Bn254Engine>(msg_bytes, chain_length, engine_name, &timing_data),
("pallas", false) => run_spartan::<PallasHyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
("pallas", true) => run_spartan_zk::<PallasHyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
("vesta", false) => run_spartan::<VestaHyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
("vesta", true) => run_spartan_zk::<VestaHyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
("p256", false) => run_spartan::<P256HyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
("p256", true) => run_spartan_zk::<P256HyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
("t256", false) => run_spartan::<T256HyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
("t256", true) => run_spartan_zk::<T256HyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
_ => {
eprintln!("Unknown engine: {}. Valid options: bn254, pallas, vesta, p256, t256", engine_name);
std::process::exit(1);
}
}
} |
- Dedupe limb functions with const generics: gte<N>, add<N>, sub<N>, shl<N>, clz<N> (kept mul_4_by_4 specialized to avoid unstable generic_const_exprs) - Add compile-time assertion enforcing MAX_REDC_SUB_CORRECTIONS <= 8 - Update capacity bound comment from 2^30 to 2^40 - Add clearer comments for modulus byte conversion and fold operation - Update two-phase accumulation comment to mention inner loop - Make perf tests faster for debug builds (16,18 + BN254 only)
Contributor
Author
|
Continuation of this PR where I've added Delayed Modular Reduction on Inner Sumcehck Happens here |
Contributor
Author
|
Continuation of this PR where I've added Delayed Modular Reduction on the outer sumcheck of NeutronNova ZK |
Contributor
Author
|
Continuation of this PR where I've added Delayed Modular Reduction on the NIFS sumcheck of NeutronNova: |
Contributor
srinathsetty
left a comment
There was a problem hiding this comment.
Looks good! Made some small comments around organization of modules.
srinathsetty
approved these changes
Feb 24, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Delayed Modular Reduction for Split-Eq Sumcheck
NOTE: This is a chunk of the original Small-Value PR where we have implementation solutions for sumcheck and delayed modular reduction. This PR just isolates that change. Here we focus on implementing a new fundamental primitive which is delayed reduction and implemented the big num math for it.
Summary
This PR extracts the delayed modular reduction (DMR) optimization from the larger
feat/procedure-9-accumulatorbranch, which implements Algorithm 6 from "Speeding Up Sum-Check Proving" (Bagad et al., ePrint 2024/1046). This change introduces aDelayedReductiontrait and supporting infrastructure that reduces Montgomery reduction operations from O(N) to O(1) per inner-product computation in the sumcheck prover. The optimization specifically targets the split-eq sumcheck used in Spartan's outer sumcheck, achieving 1.09-1.35× speedup on the sumcheck phase (BN254 field).Core Insight: Batching Modular Reductions
The fundamental operation in sumcheck is computing inner products of the form:
where aᵢ and bᵢ are field elements. In the standard approach, each multiplication triggers an immediate Montgomery reduction:
The key insight is that we can defer all reductions by accumulating unreduced products in a wider integer type, then reducing only once at the end:
Why this works: Field elements in Montgomery form satisfy
(a×R) × (b×R) = ab×R². The R² factor is common across all products, so we accumulate unreduced 512-bit products in a 576-bit accumulator, then apply a single Montgomery reduction at the end.Overflow safety: Each product contributes at most 1 carry to limb 8. With u64 limbs, we can safely accumulate 2⁶⁴ products. Sumcheck uses ≤2^{40}, so overflow is never approached.
Benchmarks
Measured on M1 Max MacBook Pro with criterion (10 samples, 20s measurement time).
Benchmark command:
Speedup Summary (BN254, this PR vs base)
The optimization provides around a 1.20-1.27× speed.
Systems-Level Justification: Assembly Analysis
Analyzing the ARM64 assembly reveals why delayed reduction is faster:
Breakdown of the 250 eager instructions:
With delayed reduction, we pay the reduction cost only once regardless of iteration count:
109N + 200(reduction once)250NSecurity Considerations
This implementation does not introduce new cryptographic primitives or weaken security.
Why This Is Safe
Bit-identical outputs: The
reduce()function produces exactly the same canonical field element as performing N separate Montgomery multiplications followed by additions. This is not an approximation—it is mathematically equivalent and verified by tests.Standard algorithm: Montgomery reduction (REDC) is a well-established algorithm from 1985, used in OpenSSL, libsecp256k1, arkworks, halo2curves, and virtually every production cryptographic library. We apply the same algorithm—just batched.
No new field arithmetic: The underlying prime field (BN254 Fr, Pallas Fp, etc.) is unchanged. Its security properties (embedding degree, DLP hardness, etc.) are unaffected.
Constant-time execution: The implementation uses fixed loop bounds (no data-dependent iteration counts) and branchless conditionals where possible, avoiding timing side-channel leakage.
Testable equivalence: Any implementation can verify correctness by checking:
What Could Go Wrong (and why it doesn't)
Test Plan
cargo test- All existing tests passcargo clippy- No warningscargo bench --bench sumcheck- Performance regression testsOptional: Reproduce Benchmarks
To verify the speedup claims, compare against the base branch:
Implementation
New Files
src/big_num/delayed_reduction.rs:DelayedReduction<V>trait defining:type Accumulator- wide integer accumulator typeunreduced_multiply_accumulate()- accumulate product without reductionreduce()- final Montgomery reductionsrc/big_num/limbs.rs:WideLimbs<N>stack-allocated wide integers with:mul_4_by_4()- 4×4 limb multiplication to 8-limb resultsrc/big_num/montgomery.rs:MontgomeryLimbstrait for limb access +montgomery_reduce_9()for 9→4 limb reductionsrc/big_num/field_reduction_constants.rs: PrecomputedMODULUS,R512_MOD,MONT_INVfor BN254, Pallas, Vesta, T256Modified Files
src/sumcheck.rs: Updatedeq_sumcheck::evaluation_points_cubic_with_three_inputs()to use two-phase accumulation with DMRRelation to feat/procedure-9-accumulator
This PR extracts the DMR infrastructure as a standalone improvement. The full
feat/procedure-9-accumulatorbranch additionally implements:sl_mulfor field×i32 and field×i64 productsThose optimizations achieve 1.8-2.7× total speedup but require more invasive changes. This PR provides a 1.09-1.35× speedup with minimal code changes, making it suitable for incremental adoption.
References
feat/procedure-9-accumulatorbranchAppendix: Mathematical Proofs
Constants and Notation
u64(base b = 2⁶⁴)MONT_INV: −p₀⁻¹ mod 2⁶⁴MOD[0..3]: modulus p limbsR2_MOD[0..3]: R² mod p (for folding the 9th limb)ONE[0..3]: R mod p (Montgomery encoding of 1)Field-specific Q values (with R = 2²⁵⁶):
Algorithm:
montgomery_reduce_8Input:
T[8]limbs representing integer T where 0 ≤ T < R²Output:
out[4]limbs = T × R⁻¹ mod p in canonical form [0, p)Algorithm:
montgomery_reduce_9Input:
C[9]limbs representing accumulated integer C = L + h·R², where L < R² and h = C[8]Output:
out[4]= C × R⁻¹ mod p in canonical formTheorem 1: Correctness of
montgomery_reduce_9Statement:
montgomery_reduce_9correctly computes C × R⁻¹ mod p.Proof:
Let C = L + h·R² where L = C[0..7] (as an integer) and h = C[8].
After folding:
low8 + c·R² = L + h·(R² mod p) + c·R²Since R² mod p ≡ R² (mod p), we have:
The folding step computes
low8 = (L + h·(R² mod p)) mod R²with carryc = ⌊(L + h·(R² mod p))/R²⌋.After
montgomery_reduce_8(low8), we getlow8 × R⁻¹ mod p.If c = 1, we need to add R² × R⁻¹ mod p = R mod p = ONE.
Therefore:
out = (low8 + c·R²) × R⁻¹ mod p = C × R⁻¹ mod p. ∎Theorem 2: Carry bound after folding
Statement: The carry c after folding satisfies c ∈ {0, 1}.
Proof:
We have:
Therefore:
So
⌊(L + h·(R² mod p))/R²⌋ < 2, meaning c ∈ {0, 1}. ∎Theorem 3: Bound after Montgomery elimination
Statement: After the Montgomery elimination loop in
reduce_8, the result satisfies r[4..8] < R + p.Proof:
Standard Montgomery REDC analysis: after k = 4 elimination steps on input T < R², the intermediate value satisfies:
where m is the sum of Montgomery multipliers. Since R + p < 2R (as p < R), the 5th limb r[8] is at most 1. ∎
Theorem 4: Optimal canonicalization bounds
Statement: Q = ⌊R/p⌋ conditional subtractions are necessary and sufficient for canonicalization.
Proof of sufficiency:
After the r[8] == 1 check and potential subtract, we have out < R.
The maximum value is R − 1 = (Q + f)·p + r where f = {R/p} (fractional part) and r < p.
Since f < 1, we have R − 1 < (Q + 1)·p, so at most Q subtracts bring out into [0, p).
Proof of necessity:
Consider the input T = (Q·p)·R which gives T × R⁻¹ mod p = 0.
During reduction, the intermediate value before canonicalization can be Q·p.
Each subtract reduces by p, so exactly Q subtracts are needed to reach 0.
For specific fields:
Why No Data-Dependent Loops
The optimized implementation has zero variable-length loops:
All loops have compile-time-constant bounds, enabling full unrolling for maximum ILP.