Skip to content

add delayed modular reduction for split-eq sumcheck#105

Merged
srinathsetty merged 9 commits intomicrosoft:mainfrom
wu-s-john:split-eq-dmr-sumcheck
Feb 25, 2026
Merged

add delayed modular reduction for split-eq sumcheck#105
srinathsetty merged 9 commits intomicrosoft:mainfrom
wu-s-john:split-eq-dmr-sumcheck

Conversation

@wu-s-john
Copy link
Contributor

@wu-s-john wu-s-john commented Feb 9, 2026

Delayed Modular Reduction for Split-Eq Sumcheck

NOTE: This is a chunk of the original Small-Value PR where we have implementation solutions for sumcheck and delayed modular reduction. This PR just isolates that change. Here we focus on implementing a new fundamental primitive which is delayed reduction and implemented the big num math for it.

Summary

This PR extracts the delayed modular reduction (DMR) optimization from the larger feat/procedure-9-accumulator branch, which implements Algorithm 6 from "Speeding Up Sum-Check Proving" (Bagad et al., ePrint 2024/1046). This change introduces a DelayedReduction trait and supporting infrastructure that reduces Montgomery reduction operations from O(N) to O(1) per inner-product computation in the sumcheck prover. The optimization specifically targets the split-eq sumcheck used in Spartan's outer sumcheck, achieving 1.09-1.35× speedup on the sumcheck phase (BN254 field).

Core Insight: Batching Modular Reductions

The fundamental operation in sumcheck is computing inner products of the form:

result = Σᵢ (aᵢ × bᵢ)

where aᵢ and bᵢ are field elements. In the standard approach, each multiplication triggers an immediate Montgomery reduction:

// Standard approach: N reductions for N products
let mut acc = F::ZERO;
for (a, b) in pairs {
    acc += a * b;  // Montgomery REDC happens here on each iteration
}

The key insight is that we can defer all reductions by accumulating unreduced products in a wider integer type, then reducing only once at the end:

// Delayed reduction: 1 reduction for N products
let mut acc = WideLimbs::<9>::default();  // 576-bit accumulator
for (a, b) in pairs {
    unreduced_multiply_accumulate(&mut acc, a, b);  // No reduction
}
let result = reduce(&acc);  // Single reduction here

Why this works: Field elements in Montgomery form satisfy (a×R) × (b×R) = ab×R². The R² factor is common across all products, so we accumulate unreduced 512-bit products in a 576-bit accumulator, then apply a single Montgomery reduction at the end.

Overflow safety: Each product contributes at most 1 carry to limb 8. With u64 limbs, we can safely accumulate 2⁶⁴ products. Sumcheck uses ≤2^{40}, so overflow is never approached.

Benchmarks

Measured on M1 Max MacBook Pro with criterion (10 samples, 20s measurement time).

Benchmark command:

sudo env BENCH_FIELD=bn254 nice -n -20 cargo bench --bench sumcheck -- bn254 --measurement-time 20

Speedup Summary (BN254, this PR vs base)

num_vars n Before (ms) After (ms) Speedup
16 65,536 5.96 4.66 1.28×
17 131,072 11.11 7.13 1.56×
18 262,144 15.33 11.40 1.35×
19 524,288 28.42 21.67 1.31×
20 1,048,576 51.90 40.66 1.28×
21 2,097,152 91.27 73.17 1.25×
22 4,194,304 171.10 137.96 1.24×
23 8,388,608 330.12 261.50 1.26×
24 16,777,216 646.67 509.90 1.27×
25 33,554,432 1305.87 1054.91 1.24×
26 67,108,864 2591.38 2066.17 1.25×

The optimization provides around a 1.20-1.27× speed.

Systems-Level Justification: Assembly Analysis

Analyzing the ARM64 assembly reveals why delayed reduction is faster:

Method Instructions/Iteration mul ops umulh ops Notes
Eager (base) 250 36 32 Includes Montgomery REDC
Delayed (DMR) 109 16 16 Wide accumulation only

Breakdown of the 250 eager instructions:

  • 16 multiplies for 4×4 product
  • 20 multiplies for Montgomery REDC
  • ~140 instructions (56%) spent on reduction per iteration

With delayed reduction, we pay the reduction cost only once regardless of iteration count:

  • DMR total cost: 109N + 200 (reduction once)
  • Eager total cost: 250N
  • Crossover at N ≈ 1.5, meaning DMR wins for any realistic workload

Security Considerations

This implementation does not introduce new cryptographic primitives or weaken security.

Why This Is Safe

  1. Bit-identical outputs: The reduce() function produces exactly the same canonical field element as performing N separate Montgomery multiplications followed by additions. This is not an approximation—it is mathematically equivalent and verified by tests.

  2. Standard algorithm: Montgomery reduction (REDC) is a well-established algorithm from 1985, used in OpenSSL, libsecp256k1, arkworks, halo2curves, and virtually every production cryptographic library. We apply the same algorithm—just batched.

  3. No new field arithmetic: The underlying prime field (BN254 Fr, Pallas Fp, etc.) is unchanged. Its security properties (embedding degree, DLP hardness, etc.) are unaffected.

  4. Constant-time execution: The implementation uses fixed loop bounds (no data-dependent iteration counts) and branchless conditionals where possible, avoiding timing side-channel leakage.

  5. Testable equivalence: Any implementation can verify correctness by checking:

    let eager = a.iter().zip(b).map(|(x, y)| *x * *y).sum::<F>();
    let delayed = DelayedReduction::reduce(&accumulator);
    assert_eq!(eager, delayed);

What Could Go Wrong (and why it doesn't)

Concern Why It's Not an Issue
Overflow in accumulator Proven bounded: 9 limbs (576 bits) safely holds 2⁶⁴ products; sumcheck uses ≤2³⁰
Incorrect reduction Theorems 1-3 prove correctness; tests verify against halo2curves reference
Timing leaks All loops have compile-time-constant bounds; no secret-dependent branches
Field element non-canonical Canonicalization uses exactly Q subtracts (Theorem 4); result always in [0, p)

Test Plan

  • cargo test - All existing tests pass
  • cargo clippy - No warnings
  • cargo bench --bench sumcheck - Performance regression tests
  • Verified identical sumcheck outputs with/without DMR (deterministic protocol)

Optional: Reproduce Benchmarks

To verify the speedup claims, compare against the base branch:

# Benchmark base (without DMR)
git fetch origin wu-s-john:benching-sumcheck-base
git checkout benching-sumcheck-base
cargo bench --bench sumcheck

# Benchmark this PR (with DMR)
git checkout split-eq-dmr-sumcheck
cargo bench --bench sumcheck

Implementation

New Files

  1. src/big_num/delayed_reduction.rs: DelayedReduction<V> trait defining:

    • type Accumulator - wide integer accumulator type
    • unreduced_multiply_accumulate() - accumulate product without reduction
    • reduce() - final Montgomery reduction
  2. src/big_num/limbs.rs: WideLimbs<N> stack-allocated wide integers with:

    • mul_4_by_4() - 4×4 limb multiplication to 8-limb result
    • Helper functions for carry propagation and subtraction
  3. src/big_num/montgomery.rs: MontgomeryLimbs trait for limb access + montgomery_reduce_9() for 9→4 limb reduction

  4. src/big_num/field_reduction_constants.rs: Precomputed MODULUS, R512_MOD, MONT_INV for BN254, Pallas, Vesta, T256

Modified Files

  • src/sumcheck.rs: Updated eq_sumcheck::evaluation_points_cubic_with_three_inputs() to use two-phase accumulation with DMR

Relation to feat/procedure-9-accumulator

This PR extracts the DMR infrastructure as a standalone improvement. The full feat/procedure-9-accumulator branch additionally implements:

  1. Small-value sumcheck (Algorithm 6): First ℓ₀ rounds use native i32/i64 arithmetic
  2. Lagrange accumulator construction (Procedure 9): Precomputes accumulators for small-value rounds
  3. Field×small multiplication: sl_mul for field×i32 and field×i64 products

Those optimizations achieve 1.8-2.7× total speedup but require more invasive changes. This PR provides a 1.09-1.35× speedup with minimal code changes, making it suitable for incremental adoption.

References


Appendix: Mathematical Proofs

Constants and Notation

  • Limbs: u64 (base b = 2⁶⁴)
  • Field elements: k = 4 limbs, so R = b⁴ = 2²⁵⁶
  • Arrays: little-endian (limb 0 = least significant)
  • MONT_INV: −p₀⁻¹ mod 2⁶⁴
  • MOD[0..3]: modulus p limbs
  • R2_MOD[0..3]: R² mod p (for folding the 9th limb)
  • ONE[0..3]: R mod p (Montgomery encoding of 1)
  • Q = ⌊R/p⌋: determines canonicalization subtract count

Field-specific Q values (with R = 2²⁵⁶):

Field Q Max subtracts in canonicalization
Pallas Fp 3 4 (1 + 3)
Vesta Fq 3 4 (1 + 3)
BN254 Fr 5 6 (1 + 5)
T256 Fq 1 2 (1 + 1)

Algorithm: montgomery_reduce_8

Input: T[8] limbs representing integer T where 0 ≤ T < R²
Output: out[4] limbs = T × R⁻¹ mod p in canonical form [0, p)

montgomery_reduce_8(T[8], MOD[4], MONT_INV, Q):
  // Working buffer with one extra limb for carry
  r[9] = [T[0], T[1], T[2], T[3], T[4], T[5], T[6], T[7], 0]

  // Word-by-word Montgomery elimination (EXACTLY 4 iterations)
  for i in 0..3:
    q = (r[i] × MONT_INV) mod 2⁶⁴

    // Fused: r[i..i+4] += q × MOD[0..3] with carry propagation
    carry = 0
    for j in 0..3:
      prod = (u128)q × MOD[j] + r[i+j] + carry
      r[i+j] = LOW64(prod)
      carry = HIGH64(prod)

    // Propagate carry into higher limbs (fixed bound: at most to r[8])
    for k in (i+4)..8:
      sum = r[k] + carry
      r[k] = LOW64(sum)
      carry = HIGH64(sum)
      if carry == 0: break
    r[8] += carry

  // Result candidate is in top 5 limbs: x5 = [r[4], r[5], r[6], r[7], r[8]]
  // At this point: x5 < R + p, so x5[4] ∈ {0, 1}

  // If high limb is 1, subtract p once
  if r[8] == 1:
    SUB_5_BY_4_INPLACE(r[4..8], MOD)  // now r[8] = 0

  out = [r[4], r[5], r[6], r[7]]  // now out ∈ [0, R), NOT [0, p)

  // KEY INSIGHT: Standard REDC produces a value in [0, R), not [0, p).
  // Since R = 2²⁵⁶ > p for all 256-bit primes, we need additional subtractions.
  // Specifically, Q = ⌊R/p⌋ conditional subtracts canonicalize [0, R) → [0, p).
  repeat Q times:
    if GE_4(out, MOD):
      SUB_4_INPLACE(out, MOD)

  return out  // now out ∈ [0, p) ✓

Algorithm: montgomery_reduce_9

Input: C[9] limbs representing accumulated integer C = L + h·R², where L < R² and h = C[8]
Output: out[4] = C × R⁻¹ mod p in canonical form

montgomery_reduce_9(C[9], MOD[4], MONT_INV, R2_MOD[4], ONE[4], Q):
  h = C[8]              // top limb (u64)
  low8 = C[0..7]        // 8 limbs

  // Fold: low8 += h × (R² mod p) via fused multiply-add
  carry = 0
  for j in 0..3:
    prod = (u128)h × R2_MOD[j] + low8[j] + carry
    low8[j] = LOW64(prod)
    carry = HIGH64(prod)

  // Propagate into low8[4..7]
  for j in 4..7:
    sum = low8[j] + carry
    low8[j] = LOW64(sum)
    carry = HIGH64(sum)

  c = carry  // CRITICAL: c ∈ {0, 1} (proven below)

  // Reduce the folded 8-limb value
  out = montgomery_reduce_8(low8, MOD, MONT_INV, Q)

  // If c = 1, we had overflow: add ONE (Montgomery form of 1)
  if c == 1:
    out = out + ONE
    if GE_4(out, MOD):
      SUB_4_INPLACE(out, MOD)

  return out

Theorem 1: Correctness of montgomery_reduce_9

Statement: montgomery_reduce_9 correctly computes C × R⁻¹ mod p.

Proof:

Let C = L + h·R² where L = C[0..7] (as an integer) and h = C[8].

After folding: low8 + c·R² = L + h·(R² mod p) + c·R²

Since R² mod p ≡ R² (mod p), we have:

L + h·R² ≡ L + h·(R² mod p) + h·⌊R²/p⌋·p ≡ low8 + c·R² (mod p)

The folding step computes low8 = (L + h·(R² mod p)) mod R² with carry c = ⌊(L + h·(R² mod p))/R²⌋.

After montgomery_reduce_8(low8), we get low8 × R⁻¹ mod p.

If c = 1, we need to add R² × R⁻¹ mod p = R mod p = ONE.

Therefore: out = (low8 + c·R²) × R⁻¹ mod p = C × R⁻¹ mod p. ∎

Theorem 2: Carry bound after folding

Statement: The carry c after folding satisfies c ∈ {0, 1}.

Proof:

We have:

  • L < R² (8-limb input bound)
  • h < 2⁶⁴ (single limb)
  • R² mod p < p < R (since p is a 256-bit prime less than R)

Therefore:

L + h·(R² mod p) < R² + 2⁶⁴·R < R² + R·R/2⁶⁴·2⁶⁴ = 2R²

So ⌊(L + h·(R² mod p))/R²⌋ < 2, meaning c ∈ {0, 1}. ∎

Theorem 3: Bound after Montgomery elimination

Statement: After the Montgomery elimination loop in reduce_8, the result satisfies r[4..8] < R + p.

Proof:

Standard Montgomery REDC analysis: after k = 4 elimination steps on input T < R², the intermediate value satisfies:

(T + m·p) / R < (R² + R·p) / R = R + p

where m is the sum of Montgomery multipliers. Since R + p < 2R (as p < R), the 5th limb r[8] is at most 1. ∎

Theorem 4: Optimal canonicalization bounds

Statement: Q = ⌊R/p⌋ conditional subtractions are necessary and sufficient for canonicalization.

Proof of sufficiency:

After the r[8] == 1 check and potential subtract, we have out < R.
The maximum value is R − 1 = (Q + f)·p + r where f = {R/p} (fractional part) and r < p.
Since f < 1, we have R − 1 < (Q + 1)·p, so at most Q subtracts bring out into [0, p).

Proof of necessity:

Consider the input T = (Q·p)·R which gives T × R⁻¹ mod p = 0.
During reduction, the intermediate value before canonicalization can be Q·p.
Each subtract reduces by p, so exactly Q subtracts are needed to reach 0.

For specific fields:

  • BN254 Fr (p ≈ 2²⁵⁴): R/p ≈ 4, so Q = ⌊2²⁵⁶/p⌋ = 5
  • Pallas/Vesta (p ≈ 2²⁵⁵): R/p ≈ 2, so Q = 3
  • T256 (p very close to 2²⁵⁶): Q = 1 ∎

Why No Data-Dependent Loops

The optimized implementation has zero variable-length loops:

  1. Montgomery elimination: Exactly 4 iterations (one per limb to cancel)
  2. Carry propagation: Fixed chain through remaining limbs (≤4 steps)
  3. Canonicalization: Exactly Q conditional subtracts (field-specific constant)
  4. Folding in reduce_9: Single pass through 8 limbs

All loops have compile-time-constant bounds, enabling full unrolling for maximum ILP.

Implements batched Montgomery reduction that defers O(N) reductions to a
single reduction at the end of inner-product computations.

Performance (BN254 split-eq sumcheck):
- 1.09-1.35× speedup across all problem sizes
- n=2²⁶: 2.64s → 2.04s (1.29× faster)

Changes:
- Add big_num module with WideLimbs accumulator and montgomery_reduce_9
- Update sumcheck to use two-phase accumulation with DMR
- Add BN254 provider and criterion benchmarks
- Create macros.rs with impl_field_reduction_constants! and
  impl_montgomery_limbs! macros for zero-config field support
- Unify limb operations in limbs.rs as const fn (usable at both
  compile-time and runtime, eliminating duplication)
- Add P256 (secp256r1) field support via single-line macro invocations
- Rename MAX_CANONICALIZE_SUBS to MAX_REDC_SUB_CORRECTIONS for clearer
  Montgomery REDC terminology
- Rename compute_q to compute_max_redc_sub_corrections
- Add test macros for consistent test generation across field types
- Remove unused DelayedReduction import from sumcheck.rs
- Add test_first_round_spartan_sumcheck test to src/sumcheck.rs using
  existing start_span! tracing infrastructure for timing
- Delete benches/sumcheck.rs criterion benchmark
- Remove criterion dev-dependency from Cargo.toml
- Make big_num, polys, sumcheck modules private in lib.rs
- Re-export DelayedReduction and FieldReductionConstants from traits
  module for external access
- Clean up unused zip_with imports in polys modules
@wu-s-john
Copy link
Contributor Author

wu-s-john commented Feb 23, 2026

So, I added delayed modular reduction to the different sumcheck-like functions (these are unlocked and revealed in later PRs), here are the results that I have

Spartan Benchmark: main vs benching/sha256-is-faster

SHA256 Single Hash (msg=1024B)

Metric main optimized Δ Speedup
prove_total 184ms 168ms -16ms 1.10x
setup 1318ms 1261ms -57ms 1.05x
prep 125ms 109ms -16ms 1.15x
verify 52ms 47ms -5ms 1.11x
Phase main optimized Δ Speedup
outer_sc 45ms 33ms -12ms 1.36x
eval_rx 5ms 11ms +6ms 0.45x
inner_sc 37ms 33ms -4ms 1.12x
mat_vec 10ms 9ms -1ms 1.11x
eval_sparse 30ms 30ms 0ms 1.00x
poly_ABC 12ms 10ms -2ms 1.20x
poly_z 1ms 1ms 0ms 1.00x
pcs 22ms 20ms -2ms 1.10x
synth_pre 123ms 107ms -16ms 1.15x
commit_pre 27ms 18ms -9ms 1.50x

SHA256 Chain (msg=32B, chain=1028)

Metric main optimized Δ Speedup
prove_total 8770ms 8751ms -19ms 1.00x
setup 144807ms 144832ms +25ms -
prep 6281ms 6850ms +569ms 0.92x
verify 1973ms 2024ms +51ms 0.98x
Phase main optimized Δ Speedup
outer_sc 1294ms 1048ms -246ms 1.23x
eval_rx 415ms 244ms -171ms 1.70x
inner_sc 1221ms 1101ms -120ms 1.11x
mat_vec 974ms 995ms +21ms 0.98x
eval_sparse 2378ms 2381ms +3ms 1.00x
poly_ABC 726ms 865ms +139ms 0.84x
poly_z 150ms 226ms +76ms 0.66x
pcs 546ms 641ms +95ms 0.85x
synth_pre 6146ms 6709ms +563ms 0.92x
commit_pre 1183ms 1188ms +5ms 1.00x

Summary

Consistent improvements across both benchmarks:

  • outer_sc: 23-36% faster
  • inner_sc: 11-12% faster

Small benchmark (1024B single hash):

  • Overall 10% faster proving time (184ms → 168ms)
  • Improvements across most phases

Large benchmark (1028-chain):

  • Sumcheck optimizations save ~537ms
  • Regressions in synth_pre, poly_ABC, poly_z, pcs offset gains
  • Net effect: ~0% change in total proving time

@wu-s-john
Copy link
Contributor Author

Here is the benching code I used:

// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: MIT

//! examples/sha256_chain.rs
//! Configurable SHA-256 chain benchmark for Spartan and SpartanZK.
//!
//! Run with:
//!   RUST_LOG=info cargo run --release --example sha256_chain -- -m 32 -c 4 -e bn254
//!   RUST_LOG=info cargo run --release --example sha256_chain -- -m 32 -c 4 -e bn254 --zk

#[cfg(feature = "jem")]
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;

use bellpepper::gadgets::sha256::sha256;
use bellpepper_core::{
  ConstraintSystem, SynthesisError,
  boolean::{AllocatedBit, Boolean},
  num::AllocatedNum,
};
use clap::Parser;
use ff::{Field, PrimeField, PrimeFieldBits};
use sha2::{Digest, Sha256};
use spartan2::{
  provider::{Bn254Engine, P256HyraxEngine, PallasHyraxEngine, T256HyraxEngine, VestaHyraxEngine},
  spartan::SpartanSNARK,
  spartan_zk::SpartanZkSNARK,
  timing::{
    SPARTAN_PHASES, SPARTAN_ZK_PHASES, TimingLayer, clear_timings, print_single_table,
    snapshot_timings,
  },
  traits::{Engine, circuit::SpartanCircuit, pcs::FoldingEngineTrait, snark::R1CSSNARKTrait},
};
use std::{marker::PhantomData, time::Instant};
use tracing::info;
use tracing_subscriber::{EnvFilter, Layer as _, layer::SubscriberExt, util::SubscriberInitExt};

#[derive(Parser)]
#[command(about = "Configurable SHA-256 chain benchmark")]
struct Args {
  /// Initial message size in bytes
  #[arg(short = 'm', long, default_value = "32")]
  msg_bytes: usize,

  /// Number of SHA256 iterations (chain length)
  #[arg(short = 'c', long, default_value = "1")]
  chain: usize,

  /// Engine to use: bn254, pallas, vesta, p256, t256
  #[arg(short = 'e', long, default_value = "bn254")]
  engine: String,

  /// Use SpartanZK instead of Spartan (adds zero-knowledge)
  #[arg(short = 'z', long)]
  zk: bool,
}

/// SHA256 chain circuit: computes H(H(H(...H(msg)))) for `chain_length` iterations.
#[derive(Clone, Debug)]
struct Sha256ChainCircuit<Scalar: PrimeField> {
  preimage: Vec<u8>,
  chain_length: usize,
  _p: PhantomData<Scalar>,
}

impl<Scalar: PrimeField + PrimeFieldBits> Sha256ChainCircuit<Scalar> {
  fn new(preimage: Vec<u8>, chain_length: usize) -> Self {
    Self {
      preimage,
      chain_length,
      _p: PhantomData,
    }
  }

  /// Compute the final hash after `chain_length` iterations.
  fn compute_final_hash(&self) -> [u8; 32] {
    // First hash the preimage
    let mut hasher = Sha256::new();
    hasher.update(&self.preimage);
    let mut current: [u8; 32] = hasher.finalize().into();

    // Chain additional iterations (already did 1 above)
    for _ in 1..self.chain_length {
      let mut hasher = Sha256::new();
      hasher.update(current);
      current = hasher.finalize().into();
    }
    current
  }
}

impl<E: Engine> SpartanCircuit<E> for Sha256ChainCircuit<E::Scalar> {
  fn public_values(&self) -> Result<Vec<<E as Engine>::Scalar>, SynthesisError> {
    let final_hash = self.compute_final_hash();
    let hash_scalars: Vec<<E as Engine>::Scalar> = final_hash
      .iter()
      .flat_map(|&byte| {
        (0..8).rev().map(move |i| {
          if (byte >> i) & 1 == 1 {
            E::Scalar::ONE
          } else {
            E::Scalar::ZERO
          }
        })
      })
      .collect();
    Ok(hash_scalars)
  }

  fn shared<CS: ConstraintSystem<E::Scalar>>(
    &self,
    _: &mut CS,
  ) -> Result<Vec<AllocatedNum<E::Scalar>>, SynthesisError> {
    Ok(vec![])
  }

  fn precommitted<CS: ConstraintSystem<E::Scalar>>(
    &self,
    cs: &mut CS,
    _: &[AllocatedNum<E::Scalar>],
  ) -> Result<Vec<AllocatedNum<E::Scalar>>, SynthesisError> {
    // Allocate preimage bits
    let bit_values: Vec<_> = self
      .preimage
      .iter()
      .flat_map(|&byte| (0..8).rev().map(move |i| (byte >> i) & 1 == 1))
      .collect();

    let preimage_bits: Vec<Boolean> = bit_values
      .into_iter()
      .enumerate()
      .map(|(i, bit)| {
        AllocatedBit::alloc(cs.namespace(|| format!("preimage bit {i}")), Some(bit))
          .map(Boolean::from)
      })
      .collect::<Result<Vec<_>, _>>()?;

    // First SHA256 on the preimage
    let mut current_bits = sha256(cs.namespace(|| "sha256 round 0"), &preimage_bits)?;

    // Chain additional SHA256 iterations
    for round in 1..self.chain_length {
      current_bits = sha256(cs.namespace(|| format!("sha256 round {round}")), &current_bits)?;
    }

    // Verify against expected final hash
    let expected_hash = self.compute_final_hash();
    let mut expected_bits = expected_hash
      .iter()
      .flat_map(|&byte| (0..8).rev().map(move |i| (byte >> i) & 1 == 1));

    for b in &current_bits {
      match b {
        Boolean::Is(bit) => assert_eq!(expected_bits.next().unwrap(), bit.get_value().unwrap()),
        Boolean::Not(bit) => assert_ne!(expected_bits.next().unwrap(), bit.get_value().unwrap()),
        Boolean::Constant(_) => unreachable!(),
      }
    }

    // Expose final hash as public input (256 bits)
    for (i, bit) in current_bits.iter().enumerate() {
      let n = AllocatedNum::alloc_input(cs.namespace(|| format!("output bit {i}")), || {
        Ok(
          if bit.get_value().ok_or(SynthesisError::AssignmentMissing)? {
            E::Scalar::ONE
          } else {
            E::Scalar::ZERO
          },
        )
      })?;

      cs.enforce(
        || format!("bit == num {i}"),
        |_| bit.lc(CS::one(), E::Scalar::ONE),
        |lc| lc + CS::one(),
        |lc| lc + n.get_variable(),
      );
    }

    Ok(vec![])
  }

  fn num_challenges(&self) -> usize {
    0
  }

  fn synthesize<CS: ConstraintSystem<E::Scalar>>(
    &self,
    _: &mut CS,
    _: &[AllocatedNum<E::Scalar>],
    _: &[AllocatedNum<E::Scalar>],
    _: Option<&[E::Scalar]>,
  ) -> Result<(), SynthesisError> {
    Ok(())
  }
}

/// Run benchmark with Spartan (non-ZK)
fn run_spartan<E: Engine>(
  msg_bytes: usize,
  chain_length: usize,
  engine_name: &str,
  timing_data: &spartan2::timing::TimingData,
) {
  let circuit = Sha256ChainCircuit::<E::Scalar>::new(vec![0u8; msg_bytes], chain_length);

  eprintln!(
    "\n======= SHA256 Chain: engine={}, msg={}B, chain={}, zk=false =======",
    engine_name, msg_bytes, chain_length
  );

  // SETUP
  let t0 = Instant::now();
  let (pk, vk) = SpartanSNARK::<E>::setup(circuit.clone()).expect("setup failed");
  let setup_ms = t0.elapsed().as_millis();
  info!(elapsed_ms = setup_ms as u64, "setup");

  clear_timings(timing_data);

  // PREP_PROVE
  let t0 = Instant::now();
  let prep_snark =
    SpartanSNARK::<E>::prep_prove(&pk, circuit.clone(), false).expect("prep_prove failed");
  let prep_ms = t0.elapsed().as_millis();
  info!(elapsed_ms = prep_ms as u64, "prep_prove");

  // PROVE
  let t0 = Instant::now();
  let proof =
    SpartanSNARK::<E>::prove(&pk, circuit, &prep_snark, false).expect("prove failed");
  let prove_ms = t0.elapsed().as_millis();
  info!(elapsed_ms = prove_ms as u64, "prove");

  // VERIFY
  let t0 = Instant::now();
  proof.verify(&vk).expect("verify failed");
  let verify_ms = t0.elapsed().as_millis();
  info!(elapsed_ms = verify_ms as u64, "verify");

  // Print timing table
  let timings = snapshot_timings(timing_data, SPARTAN_PHASES);
  let header = format!(
    "setup={}ms, prep={}ms, prove={}ms, verify={}ms",
    setup_ms, prep_ms, prove_ms, verify_ms
  );
  print_single_table(&header, SPARTAN_PHASES, &timings);
}

/// Run benchmark with SpartanZK (with zero-knowledge)
fn run_spartan_zk<E: Engine>(
  msg_bytes: usize,
  chain_length: usize,
  engine_name: &str,
  timing_data: &spartan2::timing::TimingData,
)
where
  E::PCS: FoldingEngineTrait<E>,
{
  let circuit = Sha256ChainCircuit::<E::Scalar>::new(vec![0u8; msg_bytes], chain_length);

  eprintln!(
    "\n======= SHA256 Chain: engine={}, msg={}B, chain={}, zk=true =======",
    engine_name, msg_bytes, chain_length
  );

  // SETUP
  let t0 = Instant::now();
  let (pk, vk) = SpartanZkSNARK::<E>::setup(circuit.clone()).expect("setup failed");
  let setup_ms = t0.elapsed().as_millis();
  info!(elapsed_ms = setup_ms as u64, "setup");

  clear_timings(timing_data);

  // PREP_PROVE
  let t0 = Instant::now();
  let prep_snark =
    SpartanZkSNARK::<E>::prep_prove(&pk, circuit.clone(), false).expect("prep_prove failed");
  let prep_ms = t0.elapsed().as_millis();
  info!(elapsed_ms = prep_ms as u64, "prep_prove");

  // PROVE
  let t0 = Instant::now();
  let proof =
    SpartanZkSNARK::<E>::prove(&pk, circuit, &prep_snark, false).expect("prove failed");
  let prove_ms = t0.elapsed().as_millis();
  info!(elapsed_ms = prove_ms as u64, "prove");

  // VERIFY
  let t0 = Instant::now();
  proof.verify(&vk).expect("verify failed");
  let verify_ms = t0.elapsed().as_millis();
  info!(elapsed_ms = verify_ms as u64, "verify");

  // Print timing table
  let timings = snapshot_timings(timing_data, SPARTAN_ZK_PHASES);
  let header = format!(
    "setup={}ms, prep={}ms, prove={}ms, verify={}ms",
    setup_ms, prep_ms, prove_ms, verify_ms
  );
  print_single_table(&header, SPARTAN_ZK_PHASES, &timings);
}

fn main() {
  let args = Args::parse();

  let (timing_layer, timing_data, _constraints_data) = TimingLayer::new();

  tracing_subscriber::registry()
    .with(timing_layer)
    .with(
      tracing_subscriber::fmt::layer()
        .with_target(false)
        .with_ansi(true)
        .with_writer(std::io::stderr)
        .with_filter(EnvFilter::from_default_env()),
    )
    .init();

  let msg_bytes = args.msg_bytes;
  let chain_length = args.chain;
  let engine_name = args.engine.as_str();

  match (engine_name, args.zk) {
    ("bn254", false) => run_spartan::<Bn254Engine>(msg_bytes, chain_length, engine_name, &timing_data),
    ("bn254", true) => run_spartan_zk::<Bn254Engine>(msg_bytes, chain_length, engine_name, &timing_data),
    ("pallas", false) => run_spartan::<PallasHyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
    ("pallas", true) => run_spartan_zk::<PallasHyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
    ("vesta", false) => run_spartan::<VestaHyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
    ("vesta", true) => run_spartan_zk::<VestaHyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
    ("p256", false) => run_spartan::<P256HyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
    ("p256", true) => run_spartan_zk::<P256HyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
    ("t256", false) => run_spartan::<T256HyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
    ("t256", true) => run_spartan_zk::<T256HyraxEngine>(msg_bytes, chain_length, engine_name, &timing_data),
    _ => {
      eprintln!("Unknown engine: {}. Valid options: bn254, pallas, vesta, p256, t256", engine_name);
      std::process::exit(1);
    }
  }
}

- Dedupe limb functions with const generics: gte<N>, add<N>, sub<N>,
  shl<N>, clz<N> (kept mul_4_by_4 specialized to avoid unstable
  generic_const_exprs)
- Add compile-time assertion enforcing MAX_REDC_SUB_CORRECTIONS <= 8
- Update capacity bound comment from 2^30 to 2^40
- Add clearer comments for modulus byte conversion and fold operation
- Update two-phase accumulation comment to mention inner loop
- Make perf tests faster for debug builds (16,18 + BN254 only)
@wu-s-john
Copy link
Contributor Author

Continuation of this PR where I've added Delayed Modular Reduction on Inner Sumcehck Happens here

wu-s-john#1

@wu-s-john
Copy link
Contributor Author

Continuation of this PR where I've added Delayed Modular Reduction on the outer sumcheck of NeutronNova ZK

wu-s-john#2

@wu-s-john
Copy link
Contributor Author

Continuation of this PR where I've added Delayed Modular Reduction on the NIFS sumcheck of NeutronNova:

wu-s-john#3

Copy link
Contributor

@srinathsetty srinathsetty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Made some small comments around organization of modules.

@srinathsetty srinathsetty merged commit df22cb9 into microsoft:main Feb 25, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants