diff --git a/README.md b/README.md index d8a40144..cce64ba4 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ Minimal hash-based zkVM, targeting recursion and aggregation of hash-based signa

Documentation + zkDSL reference Python verifier

diff --git a/crates/lean_compiler/snark_lib.py b/crates/lean_compiler/snark_lib.py index f11c8138..bec59983 100644 --- a/crates/lean_compiler/snark_lib.py +++ b/crates/lean_compiler/snark_lib.py @@ -6,7 +6,7 @@ # Type annotations Mut = Any Const = Any -Imu = Any +Imm = Any # @inline decorator (does nothing in Python execution) diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs index 2598ed3d..18254526 100644 --- a/crates/lean_compiler/src/a_simplify_lang/mod.rs +++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs @@ -332,21 +332,14 @@ pub fn simplify_program(mut program: Program) -> Result { let mut array_manager = ArrayManager::default(); let mut mut_tracker = MutableVarTracker::default(); - // Register mutable arguments and capture their initial versioned names - // BEFORE simplifying the body + // All arguments are immutable; record them as assigned to detect illegal reassignment. let arguments: Vec = func .arguments .iter() .map(|arg| { assert!(!arg.is_const); - if arg.is_mutable { - mut_tracker.register_mutable(&arg.name); - // Capture the initial versioned name (version 0) - mut_tracker.current_name(&arg.name) - } else { - mut_tracker.assigned.insert(arg.name.clone()); - arg.name.clone() - } + mut_tracker.assigned.insert(arg.name.clone()); + arg.name.clone() }) .collect(); @@ -398,9 +391,6 @@ fn compile_time_transform_in_program( .collect(); for func in inlined_functions.values() { - if func.has_mutable_arguments() { - return Err("Inlined functions with mutable arguments are not supported yet".to_string()); - } if func.has_const_arguments() { return Err(format!( "Inlined function should not have \"Const\" arguments (function \"{}\")", diff --git a/crates/lean_compiler/src/grammar.pest b/crates/lean_compiler/src/grammar.pest index 16b041a6..6640d2de 100644 --- a/crates/lean_compiler/src/grammar.pest +++ b/crates/lean_compiler/src/grammar.pest @@ -49,9 +49,9 @@ return_statement = { "return" ~ (("(" ~ tuple_expression ~ ")") | tuple_expressi mut_keyword = @{ "mut" ~ !(ASCII_ALPHANUMERIC | "_") } mut_annotation = { ":" ~ "Mut" } -im_annotation = { ":" ~ "Imu" } +im_annotation = { ":" ~ "Imm" } -// Forward declaration: x: Imu or x: Mut (not followed by =) +// Forward declaration: x: Imm or x: Mut (not followed by =) forward_declaration = { identifier ~ (im_annotation | mut_annotation) ~ !("=") } // General assignment: LHS is optional, RHS is any expression diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index 94871b47..91fbd0e8 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -32,7 +32,6 @@ impl Program { pub struct FunctionArg { pub name: Var, pub is_const: bool, - pub is_mutable: bool, } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] @@ -48,9 +47,6 @@ impl Function { pub fn has_const_arguments(&self) -> bool { self.arguments.iter().any(|arg| arg.is_const) } - pub fn has_mutable_arguments(&self) -> bool { - self.arguments.iter().any(|arg| arg.is_mutable) - } } pub type Var = String; @@ -663,7 +659,7 @@ impl Line { if *is_mutable { format!("{var}: Mut") } else { - format!("{var}: Imu") + format!("{var}: Imm") } } Self::Statement { targets, value, .. } => { @@ -899,8 +895,6 @@ impl Display for Function { .map(|arg| { if arg.is_const { format!("const {}", arg.name) - } else if arg.is_mutable { - format!("mut {}", arg.name) } else { arg.name.to_string() } diff --git a/crates/lean_compiler/src/lib.rs b/crates/lean_compiler/src/lib.rs index 590d58de..d99d1c2e 100644 --- a/crates/lean_compiler/src/lib.rs +++ b/crates/lean_compiler/src/lib.rs @@ -150,7 +150,11 @@ pub fn compile_program(input: &ProgramSource) -> Bytecode { try_compile_program(input).unwrap() } -pub fn try_compile_and_run(input: &ProgramSource, public_input: &[F], profiler: bool) -> Result { +pub fn try_compile_and_run( + input: &ProgramSource, + public_input: &[F; PUBLIC_INPUT_LEN], + profiler: bool, +) -> Result { let bytecode = try_compile_program(input)?; let witness = ExecutionWitness::default(); let result = try_execute_bytecode(&bytecode, public_input, &witness, profiler)?; @@ -158,7 +162,7 @@ pub fn try_compile_and_run(input: &ProgramSource, public_input: &[F], profiler: Ok(result.metadata.display()) } -pub fn compile_and_run(input: &ProgramSource, public_input: &[F], profiler: bool) { +pub fn compile_and_run(input: &ProgramSource, public_input: &[F; PUBLIC_INPUT_LEN], profiler: bool) { let summary = try_compile_and_run(input, public_input, profiler).unwrap(); println!("{summary}"); } diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 374f75d4..0334f5dc 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -15,11 +15,14 @@ pub const RESERVED_FUNCTION_NAMES: &[&str] = &[ // Built-in functions "print", "Array", + "hint_witness", // Compile-time only functions "len", "log2_ceil", "next_multiple_of", "saturating_sub", + "div_ceil", + "div_floor", "range", "parallel_range", "match_range", @@ -155,22 +158,24 @@ impl Parse for ParameterParser { .into()); } - // Check for optional type annotation (: Const or : Mut) - let (is_const, is_mutable) = if let Some(annotation) = inner.next() { + // Check for optional type annotation (: Const). ': Mut' parameters are forbidden. + let is_const = if let Some(annotation) = inner.next() { match annotation.as_str().trim() { - ": Const" => (true, false), - ": Mut" => (false, true), + ": Const" => true, + ": Mut" => { + return Err(SemanticError::new(format!( + "Parameter '{name}' cannot be declared ': Mut'. Mutable parameters are not allowed; \ + introduce a local '{name}_mut: Mut = {name}' instead." + )) + .into()); + } other => return Err(SemanticError::new(format!("Invalid parameter annotation: {other}")).into()), } } else { - (false, false) + false }; - Ok(FunctionArg { - name, - is_const, - is_mutable, - }) + Ok(FunctionArg { name, is_const }) } } diff --git a/crates/lean_compiler/src/parser/parsers/statement.rs b/crates/lean_compiler/src/parser/parsers/statement.rs index edef389d..f0835697 100644 --- a/crates/lean_compiler/src/parser/parsers/statement.rs +++ b/crates/lean_compiler/src/parser/parsers/statement.rs @@ -287,7 +287,7 @@ impl Parse for AssertParser { } } -/// Parser for forward declarations: `x: Imu` or `x: Mut` +/// Parser for forward declarations: `x: Imm` or `x: Mut` pub struct ForwardDeclarationParser; impl Parse for ForwardDeclarationParser { @@ -297,7 +297,7 @@ impl Parse for ForwardDeclarationParser { // Parse variable name let var = next_inner_pair(&mut inner, "variable name")?.as_str().to_string(); - // Check for : Mut or : Imu annotation + // Check for : Mut or : Imm annotation let annotation = next_inner_pair(&mut inner, "type annotation")?; let is_mutable = annotation.as_rule() == Rule::mut_annotation; diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 04b3f70e..cc8ac473 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -4,27 +4,6 @@ use backend::{BasedVectorSpace, PrimeCharacteristicRing}; use lean_compiler::*; use lean_vm::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; -use utils::poseidon16_compress; - -#[test] -fn test_poseidon() { - let program = r#" -def main(): - a = 0 - b = a + 8 - c = Array(8) - poseidon16_compress(a, b, c) - - for i in range(0, 8): - cc = c[i] - print(cc) - return - "#; - let public_input: [F; 16] = (0..16).map(F::new).collect::>().try_into().unwrap(); - compile_and_run(&ProgramSource::Raw(program.to_string()), &public_input, false); - - let _ = dbg!(poseidon16_compress(public_input)); -} #[test] fn test_div_extension_field() { @@ -32,13 +11,16 @@ fn test_div_extension_field() { DIM = 5 def main(): - n = 0 - d = n + DIM - q = n + 2 * DIM + nd = Array(2 * DIM) + hint_witness("nd", nd) + n = nd + d = nd + DIM + expected_q = Array(DIM) + hint_witness("q", expected_q) computed_q_1 = div_ext_1(n, d) computed_q_2 = div_ext_2(n, d) - assert_eq_ext(computed_q_2, q) - assert_eq_ext(computed_q_1, q) + assert_eq_ext(computed_q_2, expected_q) + assert_eq_ext(computed_q_1, expected_q) return def assert_eq_ext(x, y): @@ -61,12 +43,19 @@ def div_ext_2(n, d): let n: EF = rng.random(); let d: EF = rng.random(); let q = n / d; - let mut public_input = vec![]; - public_input.extend(n.as_basis_coefficients_slice()); - public_input.extend(d.as_basis_coefficients_slice()); - public_input.extend(q.as_basis_coefficients_slice()); - public_input.resize(16, F::ZERO); - compile_and_run(&ProgramSource::Raw(program.to_string()), &public_input, false); + let mut nd_buf: Vec = Vec::new(); + nd_buf.extend(n.as_basis_coefficients_slice()); + nd_buf.extend(d.as_basis_coefficients_slice()); + let q_buf: Vec = q.as_basis_coefficients_slice().to_vec(); + let mut hints = std::collections::HashMap::new(); + hints.insert("nd".to_string(), vec![nd_buf]); + hints.insert("q".to_string(), vec![q_buf]); + let witness = ExecutionWitness { + hints, + ..ExecutionWitness::default() + }; + let bytecode = compile_program(&ProgramSource::Raw(program.to_string())); + try_execute_bytecode(&bytecode, &[F::ZERO; PUBLIC_INPUT_LEN], &witness, false).unwrap(); } fn test_data_dir() -> String { @@ -134,7 +123,7 @@ fn test_all_programs() { Ok(b) => b, Err(err) => panic!("Program {} failed to compile: {:?}", path, err), }; - if let Err(err) = try_execute_bytecode(&bytecode, &[], &witness, false) { + if let Err(err) = try_execute_bytecode(&bytecode, &[F::ZERO; PUBLIC_INPUT_LEN], &witness, false) { panic!("Program {} failed with error: {:?}", path, err); } } @@ -176,7 +165,13 @@ def func(a, b): return "#; let bytecode = compile_program(&ProgramSource::Raw(program.to_string())); - let n_cycles = execute_bytecode(&bytecode, &[], &ExecutionWitness::default(), false).n_cycles(); + let n_cycles = execute_bytecode( + &bytecode, + &[F::ZERO; PUBLIC_INPUT_LEN], + &ExecutionWitness::default(), + false, + ) + .n_cycles(); assert!(n_cycles < 1100); } @@ -205,10 +200,20 @@ def factorial(n): let compiled_parallel = compile_program(&ProgramSource::Raw(program.replace("loop", "parallel_range"))); let time_sequential = Instant::now(); - let exec_seq = execute_bytecode(&compiled_sequencial, &[], &ExecutionWitness::default(), false); + let exec_seq = execute_bytecode( + &compiled_sequencial, + &[F::ZERO; PUBLIC_INPUT_LEN], + &ExecutionWitness::default(), + false, + ); let duration_sequential = time_sequential.elapsed(); let time_parallel = Instant::now(); - let exec_par = execute_bytecode(&compiled_parallel, &[], &ExecutionWitness::default(), false); + let exec_par = execute_bytecode( + &compiled_parallel, + &[F::ZERO; PUBLIC_INPUT_LEN], + &ExecutionWitness::default(), + false, + ); let duration_parallel = time_parallel.elapsed(); assert_eq!(exec_seq.metadata.stdout, exec_par.metadata.stdout); @@ -249,7 +254,13 @@ fn test_soundness_suite() { ("soundness_5", &[3, 4, 7, 19, 49, 28, 1, 3], &[(0, 4), (1, 5), (2, 8), (3, 20), (4, 50), (5, 29), (6, 0), (6, 2), (7, 4)]), ]; - let to_input = |v: &[u32]| v.iter().copied().map(F::new).collect::>(); + let to_input = |v: &[u32]| -> [F; PUBLIC_INPUT_LEN] { + let mut out = [F::ZERO; PUBLIC_INPUT_LEN]; + for (slot, &x) in out.iter_mut().zip(v) { + *slot = F::new(x); + } + out + }; for &(name, valid, perturbations) in cases { let path = format!("{}/{}.py", test_data_dir(), name); diff --git a/crates/lean_compiler/tests/test_data/error_13.py b/crates/lean_compiler/tests/test_data/error_13.py index 0450a5aa..2224ac22 100644 --- a/crates/lean_compiler/tests/test_data/error_13.py +++ b/crates/lean_compiler/tests/test_data/error_13.py @@ -2,7 +2,7 @@ def main(): - a: Imu + a: Imm a = 0 a = a + 1 if a == 1: diff --git a/crates/lean_compiler/tests/test_data/error_7.py b/crates/lean_compiler/tests/test_data/error_7.py index 94c262a6..e4435908 100644 --- a/crates/lean_compiler/tests/test_data/error_7.py +++ b/crates/lean_compiler/tests/test_data/error_7.py @@ -1,7 +1,7 @@ from snark_lib import * -# Error: inline functions with parameters: Mut are not supported +# Error: function parameters cannot be declared ': Mut' def main(): return diff --git a/crates/lean_compiler/tests/test_data/program_100.py b/crates/lean_compiler/tests/test_data/program_100.py index 5eeac3b6..a81d0480 100644 --- a/crates/lean_compiler/tests/test_data/program_100.py +++ b/crates/lean_compiler/tests/test_data/program_100.py @@ -2,8 +2,8 @@ def main(): - x: Imu - y: Imu + x: Imm + y: Imm cond = 1 if cond == 1: diff --git a/crates/lean_compiler/tests/test_data/program_109.py b/crates/lean_compiler/tests/test_data/program_109.py index b04f2608..77e86702 100644 --- a/crates/lean_compiler/tests/test_data/program_109.py +++ b/crates/lean_compiler/tests/test_data/program_109.py @@ -9,10 +9,10 @@ def main(): def test_func(a, b): x = 1 - mut_x_2: Imu + mut_x_2: Imm match a: case 0: - mut_x_1: Imu + mut_x_1: Imm mut_x_1 = x + 2 match b: case 0: diff --git a/crates/lean_compiler/tests/test_data/program_110.py b/crates/lean_compiler/tests/test_data/program_110.py index f863f968..df47ee5c 100644 --- a/crates/lean_compiler/tests/test_data/program_110.py +++ b/crates/lean_compiler/tests/test_data/program_110.py @@ -79,7 +79,7 @@ def main(): result = complex_compute(3, 4, 5) assert result == 47 - fwd_val: Imu + fwd_val: Imm cond = 1 if cond == 0: fwd_val = 100 diff --git a/crates/lean_compiler/tests/test_data/program_111.py b/crates/lean_compiler/tests/test_data/program_111.py index e497afdb..9634df0f 100644 --- a/crates/lean_compiler/tests/test_data/program_111.py +++ b/crates/lean_compiler/tests/test_data/program_111.py @@ -78,13 +78,14 @@ def chain_compute(x, y): a2, b2, c2 = step_compute(a1, b1) return a2, b2, c1 + c2 -def nested_mut_params(base: Mut): +def nested_mut_params(base): + acc: Mut = base for i in unroll(0, 3): - base = base + i * 2 - return base + acc = acc + i * 2 + return acc def state_machine_step(current_state, phase): - result: Imu + result: Imm if phase == 0: if current_state == 0: result = 1 diff --git a/crates/lean_compiler/tests/test_data/program_112.py b/crates/lean_compiler/tests/test_data/program_112.py index 7cd0fe3b..51a9de2e 100644 --- a/crates/lean_compiler/tests/test_data/program_112.py +++ b/crates/lean_compiler/tests/test_data/program_112.py @@ -2,7 +2,7 @@ def main(): - result1: Imu + result1: Imm outer_sel = 1 match outer_sel: case 0: @@ -18,8 +18,8 @@ def main(): result1 = 456 assert result1 == 456 - counter: Imu - flag: Imu + counter: Imm + flag: Imm phase = 1 if phase == 0: @@ -40,8 +40,8 @@ def main(): assert counter2 == 15 assert flag2 == 400 - x: Imu - y: Imu + x: Imm + y: Imm init_sel = 0 if init_sel == 0: @@ -61,7 +61,7 @@ def main(): assert x2 == 220 assert y2 == 20 - outcome: Imu + outcome: Imm selector = 4 match selector: case 0: @@ -78,9 +78,9 @@ def main(): outcome = compute_outcome(5, 25) assert outcome == 84 - p: Imu - q: Imu - r: Imu + p: Imm + q: Imm + r: Imm s1 = 1 if s1 == 1: diff --git a/crates/lean_compiler/tests/test_data/program_114.py b/crates/lean_compiler/tests/test_data/program_114.py index 1d0b0b95..c92f7044 100644 --- a/crates/lean_compiler/tests/test_data/program_114.py +++ b/crates/lean_compiler/tests/test_data/program_114.py @@ -31,8 +31,8 @@ def main(): result4 = complex_nested_compute(2, 1, 3) assert result4 == 280 - fwd_x: Imu - fwd_y: Imu + fwd_x: Imm + fwd_y: Imm mode = 2 if mode == 0: @@ -90,7 +90,7 @@ def sum_array_func(arr, n: Const): def complex_nested_compute(outer, inner, depth): - result: Imu + result: Imm if outer == 0: result = 100 diff --git a/crates/lean_compiler/tests/test_data/program_143.py b/crates/lean_compiler/tests/test_data/program_143.py index 08be3b5d..1e16c510 100644 --- a/crates/lean_compiler/tests/test_data/program_143.py +++ b/crates/lean_compiler/tests/test_data/program_143.py @@ -71,7 +71,7 @@ def main(): assert sum == 10 # Test 13: Inline functions in if condition (comparison) - result13: Imu + result13: Imm if incr(incr(0)) == 2: result13 = 100 else: @@ -79,7 +79,7 @@ def main(): assert result13 == 100 # Test 14: Nested inline calls in both sides of if condition - result14: Imu + result14: Imm if double(3) == triple(2): result14 = 1 else: @@ -88,7 +88,7 @@ def main(): assert result14 == 1 # Test 15: Inline calls inside if/else branches - result15: Imu + result15: Imm if 1 == 1: result15 = incr(incr(incr(10))) else: @@ -96,7 +96,7 @@ def main(): assert result15 == 13 # Test 16: Multiple nested inline calls in if condition - result16: Imu + result16: Imm if incr(double(incr(1))) == 5: # incr(1) = 2, double(2) = 4, incr(4) = 5 result16 = 200 @@ -105,7 +105,7 @@ def main(): assert result16 == 200 # Test 17: Inline call with != comparison - result17: Imu + result17: Imm if incr(5) != 5: result17 = 300 else: @@ -152,7 +152,7 @@ def main(): assert sum23 == 21 # Test 24: Chained else-if with inline conditions - result24: Imu + result24: Imm x24 = 5 if incr(x24) == 4: result24 = 1 @@ -194,7 +194,7 @@ def double(x): @inline def triple(x): y: Mut = x - two: Imu + two: Imm match y - x + 1: case 0: assert False diff --git a/crates/lean_compiler/tests/test_data/program_144.py b/crates/lean_compiler/tests/test_data/program_144.py index f6b60dff..9ceb190f 100644 --- a/crates/lean_compiler/tests/test_data/program_144.py +++ b/crates/lean_compiler/tests/test_data/program_144.py @@ -8,7 +8,7 @@ def main(): # ========================================================================== # TEST 1: Basic - panic in else branch (the original bug case) # ========================================================================== - two: Imu + two: Imm if 1 == 1: two = 2 else: @@ -18,7 +18,7 @@ def main(): # ========================================================================== # TEST 2: panic in then branch # ========================================================================== - three: Imu + three: Imm if 1 != 1: assert False else: @@ -28,9 +28,9 @@ def main(): # ========================================================================== # TEST 3: Multiple mutable variables, panic in else # ========================================================================== - a: Imu - b: Imu - c: Imu + a: Imm + b: Imm + c: Imm if 1 == 1: a = 10 b = 20 @@ -44,7 +44,7 @@ def main(): # ========================================================================== # TEST 4: Nested if with panic in inner else # ========================================================================== - x: Imu + x: Imm if 1 == 1: if 2 == 2: x = 42 @@ -79,7 +79,7 @@ def main(): # ========================================================================== # TEST 7: Chain of else-if with panic in final else # ========================================================================== - result: Imu + result: Imm selector = 1 if selector == 0: result = 100 @@ -94,7 +94,7 @@ def main(): # ========================================================================== # TEST 8: Match with panic in one arm # ========================================================================== - matched: Imu + matched: Imm tag = 1 match tag: case 0: @@ -108,7 +108,7 @@ def main(): # ========================================================================== # TEST 9: Match where only one arm doesn't panic # ========================================================================== - only_valid: Imu + only_valid: Imm tag2 = 2 match tag2: case 0: @@ -124,7 +124,7 @@ def main(): # ========================================================================== # TEST 10: Panic in deeply nested structure # ========================================================================== - deep: Imu + deep: Imm if 1 == 1: if 1 == 1: if 1 == 1: @@ -151,7 +151,7 @@ def main(): # ========================================================================== # TEST 12: Forward declared with = None panic in branch # ========================================================================== - fwd: Imu + fwd: Imm cond = 1 if cond == 1: fwd = 777 @@ -162,8 +162,8 @@ def main(): # ========================================================================== # TEST 13: Both mutable and immutable forward decl with panic # ========================================================================== - imm: Imu - mtbl: Imu + imm: Imm + mtbl: Imm flag = 0 if flag == 0: imm = 100 @@ -206,7 +206,7 @@ def main(): # ========================================================================== # TEST 17: Nested match with panic # ========================================================================== - nested_match: Imu + nested_match: Imm outer = 1 match outer: case 0: @@ -223,7 +223,7 @@ def main(): # ========================================================================== # TEST 18: If inside match with panic # ========================================================================== - if_in_match: Imu + if_in_match: Imm m18_sel = 0 match m18_sel: case 0: @@ -239,7 +239,7 @@ def main(): # ========================================================================== # TEST 19: Match inside if with panic # ========================================================================== - match_in_if: Imu + match_in_if: Imm cond19 = 1 if cond19 == 1: tag19 = 1 @@ -255,7 +255,7 @@ def main(): # ========================================================================== # TEST 20: Panic after partial assignment # ========================================================================== - partial: Imu + partial: Imm check = 0 if check == 0: partial_tmp: Mut = 1 @@ -288,7 +288,7 @@ def main(): # ========================================================================== # TEST 23: Multiple levels - if/match/if with panics # ========================================================================== - multi_level: Imu + multi_level: Imm c1 = 1 if c1 == 1: s1 = 0 @@ -308,7 +308,7 @@ def main(): # ========================================================================== # TEST 24: Panic in both outer branches but inner assigns # ========================================================================== - inner_assigns: Imu + inner_assigns: Imm outer24 = 0 match outer24: case 0: @@ -324,9 +324,9 @@ def main(): # ========================================================================== # TEST 25: Complex - multiple vars, nested, with panics # ========================================================================== - va: Imu - vb: Imu - vc: Imu + va: Imm + vb: Imm + vc: Imm outer25 = 1 if outer25 == 1: @@ -353,7 +353,7 @@ def main(): # Helper function for TEST 14 def test_early_return(flag): - result: Imu + result: Imm if flag == 1: result = 10 else: @@ -363,8 +363,8 @@ def test_early_return(flag): # Helper function for TEST 15 def test_multi_return(flag): - a: Imu - b: Imu + a: Imm + b: Imm if flag == 1: a = 100 b = 200 @@ -374,9 +374,10 @@ def test_multi_return(flag): # Helper function for TEST 22 -def func_with_mut_param(x: Mut, flag): +def func_with_mut_param(x, flag): + y: Mut = x if flag == 1: - x = x * 10 + y = y * 10 else: assert False - return x + return y diff --git a/crates/lean_compiler/tests/test_data/program_15.py b/crates/lean_compiler/tests/test_data/program_15.py index 6ea149c2..55433c26 100644 --- a/crates/lean_compiler/tests/test_data/program_15.py +++ b/crates/lean_compiler/tests/test_data/program_15.py @@ -1,6 +1,6 @@ from snark_lib import * -ONE_EF_PTR = 1 # right after the (empty-public-input) zero-padded cell at memory[0] +ONE_EF_PTR = 8 # right after the 8-cell public input region def main(): diff --git a/crates/lean_compiler/tests/test_data/program_17.py b/crates/lean_compiler/tests/test_data/program_17.py index d9274d42..092e2fc2 100644 --- a/crates/lean_compiler/tests/test_data/program_17.py +++ b/crates/lean_compiler/tests/test_data/program_17.py @@ -7,7 +7,7 @@ def main(): def func(): - a: Imu + a: Imm if 0 == 0: a = aux() return a diff --git a/crates/lean_compiler/tests/test_data/program_170.py b/crates/lean_compiler/tests/test_data/program_170.py index 96c00012..62f5d2e4 100644 --- a/crates/lean_compiler/tests/test_data/program_170.py +++ b/crates/lean_compiler/tests/test_data/program_170.py @@ -15,7 +15,7 @@ def multi_return(a, b): def multi_line_params( a, - b: Mut, + b, c: Const, ): return a + b + c @@ -30,14 +30,14 @@ def main(): x = 5 y = 10 - z: Imu + z: Imm if x + y == 15: z = 1 else: z = 0 assert z == 1 - w: Imu + w: Imm if x + y * 2 == 25: w = 100 else: diff --git a/crates/lean_compiler/tests/test_data/program_171.py b/crates/lean_compiler/tests/test_data/program_171.py index bb3c47c6..4af97391 100644 --- a/crates/lean_compiler/tests/test_data/program_171.py +++ b/crates/lean_compiler/tests/test_data/program_171.py @@ -1,7 +1,7 @@ from snark_lib import * # Comprehensive test for inlining with mutable variables in branches -# Tests: @inline functions, Mut/Imu variables, match, if/else, loops, nesting +# Tests: @inline functions, Mut/Imm variables, match, if/else, loops, nesting # ============================================================================ # Simple inline functions with mutable variables @@ -118,7 +118,7 @@ def inline_with_if(x): @inline def inline_with_match(selector): """Inline function that itself contains match""" - out: Imu + out: Imm match selector: case 0: out = 1000 @@ -132,7 +132,7 @@ def inline_with_match(selector): @inline def inline_with_nested_branch(a, b): """Inline with nested if inside match""" - res: Imu + res: Imm match a: case 0: if b == 0: @@ -321,7 +321,7 @@ def main(): # TEST 1: Basic inline in match arms (different inlined vars per arm) # This was the original bug - each arm gets its own inlined variable names # ------------------------------------------------------------------- - res1: Imu + res1: Imm match 0: case 0: res1 = count_up(5) @@ -329,7 +329,7 @@ def main(): res1 = count_up(10) assert res1 == 5 - res2: Imu + res2: Imm match 1: case 0: res2 = count_up(5) @@ -340,7 +340,7 @@ def main(): # ------------------------------------------------------------------- # TEST 2: Different inline functions in different arms # ------------------------------------------------------------------- - res3: Imu + res3: Imm match 0: case 0: res3 = count_up(3) @@ -350,7 +350,7 @@ def main(): res3 = double_count(3) assert res3 == 3 - res4: Imu + res4: Imm match 1: case 0: res4 = count_up(3) @@ -360,7 +360,7 @@ def main(): res4 = double_count(3) assert res4 == 3 # 0+1+2 - res5: Imu + res5: Imm match 2: case 0: res5 = count_up(3) @@ -392,7 +392,7 @@ def main(): # ------------------------------------------------------------------- # TEST 4: Multiple inlines in same arm # ------------------------------------------------------------------- - multi: Imu + multi: Imm match 0: case 0: a = count_up(3) @@ -406,7 +406,7 @@ def main(): # ------------------------------------------------------------------- # TEST 5: Nested inline functions in match arms # ------------------------------------------------------------------- - nested1: Imu + nested1: Imm match 0: case 0: nested1 = outer_with_inner(4) @@ -416,7 +416,7 @@ def main(): # = 0 + 0 + 1 + 3 = 4 assert nested1 == 4 - nested2: Imu + nested2: Imm match 1: case 0: nested2 = outer_with_inner(4) @@ -428,7 +428,7 @@ def main(): # ------------------------------------------------------------------- # TEST 6: Deep nesting in match # ------------------------------------------------------------------- - deep1: Imu + deep1: Imm match 0: case 0: deep1 = deep_nested(3) @@ -442,14 +442,14 @@ def main(): # ------------------------------------------------------------------- # TEST 7: Inline in if/else branches # ------------------------------------------------------------------- - if_res1: Imu + if_res1: Imm if 1 == 1: if_res1 = count_up(7) else: if_res1 = count_up(3) assert if_res1 == 7 - if_res2: Imu + if_res2: Imm if 1 == 0: if_res2 = count_up(7) else: @@ -459,7 +459,7 @@ def main(): # ------------------------------------------------------------------- # TEST 8: Nested if/else with inlines # ------------------------------------------------------------------- - nested_if: Imu + nested_if: Imm if 1 == 1: if 2 == 2: nested_if = sum_range(0, 5) @@ -472,7 +472,7 @@ def main(): # ------------------------------------------------------------------- # TEST 9: Match inside if with inlines # ------------------------------------------------------------------- - mixed: Imu + mixed: Imm if 1 == 1: match 1: case 0: @@ -486,7 +486,7 @@ def main(): # ------------------------------------------------------------------- # TEST 10: If inside match with inlines # ------------------------------------------------------------------- - mixed2: Imu + mixed2: Imm match 0: case 0: if 1 == 1: @@ -500,7 +500,7 @@ def main(): # ------------------------------------------------------------------- # TEST 11: Complex mutable variables in inline # ------------------------------------------------------------------- - cx: Imu + cx: Imm match 0: case 0: cx = complex_muts(4) @@ -519,7 +519,7 @@ def main(): # TEST 12: Mix of Mut and immutable in branches with inlines # ------------------------------------------------------------------- outer_mut: Mut = 10 - inner_imu: Imu + inner_imu: Imm match 0: case 0: local_imm = with_immutable(3) @@ -535,7 +535,7 @@ def main(): # ------------------------------------------------------------------- # TEST 13: Inline inside unroll loop inside match # ------------------------------------------------------------------- - unroll_in_match: Imu + unroll_in_match: Imm match 0: case 0: acc: Mut = 0 @@ -550,10 +550,10 @@ def main(): # ------------------------------------------------------------------- # TEST 14: Multiple match levels with different inlines at each # ------------------------------------------------------------------- - multi_match: Imu + multi_match: Imm match 1: case 0: - inner: Imu + inner: Imm match 0: case 0: inner = count_up(2) @@ -561,7 +561,7 @@ def main(): inner = count_up(3) multi_match = inner case 1: - inner2: Imu + inner2: Imm match 1: case 0: inner2 = sum_range(0, 2) @@ -573,7 +573,7 @@ def main(): # ------------------------------------------------------------------- # TEST 15: Same inline function called multiple times in same arm # ------------------------------------------------------------------- - same_fn: Imu + same_fn: Imm match 0: case 0: r1 = count_up(3) @@ -607,7 +607,7 @@ def main(): # ------------------------------------------------------------------- # TEST 17: Variables declared inside only some branches # ------------------------------------------------------------------- - outside: Imu + outside: Imm match 0: case 0: local_only_here = count_up(5) @@ -622,7 +622,7 @@ def main(): # ------------------------------------------------------------------- # TEST 18: Very deeply nested structure # ------------------------------------------------------------------- - very_deep: Imu + very_deep: Imm if 1 == 1: match 0: case 0: @@ -668,7 +668,7 @@ def main(): # ------------------------------------------------------------------- # TEST 20: Inline result used immediately in arithmetic in branch # ------------------------------------------------------------------- - arith: Imu + arith: Imm match 0: case 0: arith = count_up(3) * 10 + sum_range(0, 3) * 100 @@ -684,7 +684,7 @@ def main(): # ------------------------------------------------------------------- # TEST 21: Inline containing if/else in different match arms # ------------------------------------------------------------------- - t21: Imu + t21: Imm match 0: case 0: t21 = inline_with_if(0) @@ -693,7 +693,7 @@ def main(): # inline_with_if(0): result=100, result=100+0=100 assert t21 == 100 - t21b: Imu + t21b: Imm match 1: case 0: t21b = inline_with_if(0) @@ -705,7 +705,7 @@ def main(): # ------------------------------------------------------------------- # TEST 22: Inline containing match in different branches # ------------------------------------------------------------------- - t22: Imu + t22: Imm match 0: case 0: t22 = inline_with_match(0) @@ -715,7 +715,7 @@ def main(): t22 = inline_with_match(2) assert t22 == 1000 - t22b: Imu + t22b: Imm match 2: case 0: t22b = inline_with_match(0) @@ -728,7 +728,7 @@ def main(): # ------------------------------------------------------------------- # TEST 23: Inline with nested branches called in nested branches # ------------------------------------------------------------------- - t23: Imu + t23: Imm match 0: case 0: if 1 == 1: @@ -740,7 +740,7 @@ def main(): # inline_with_nested_branch(0, 1): a=0 -> if b==0 else -> 20 assert t23 == 20 - t23b: Imu + t23b: Imm match 1: case 0: t23b = inline_with_nested_branch(0, 0) @@ -752,8 +752,8 @@ def main(): # ------------------------------------------------------------------- # TEST 24: Multi-return inline in match arms # ------------------------------------------------------------------- - t24a: Imu - t24b: Imu + t24a: Imm + t24b: Imm match 0: case 0: t24a, t24b = multi_return_inline(5) @@ -763,8 +763,8 @@ def main(): assert t24a == 5 assert t24b == 110 - t24c: Imu - t24d: Imu + t24c: Imm + t24d: Imm match 1: case 0: t24c, t24d = multi_return_inline(5) @@ -777,9 +777,9 @@ def main(): # ------------------------------------------------------------------- # TEST 25: Triple return inline in branches # ------------------------------------------------------------------- - t25a: Imu - t25b: Imu - t25c: Imu + t25a: Imm + t25b: Imm + t25c: Imm match 0: case 0: t25a, t25b, t25c = triple_return(10) @@ -793,7 +793,7 @@ def main(): # ------------------------------------------------------------------- # TEST 26: 4-level deep inline nesting in match arms # ------------------------------------------------------------------- - t26: Imu + t26: Imm match 0: case 0: t26 = level_a(1) @@ -809,7 +809,7 @@ def main(): # = (1+2) + 20 + 200 + 2000 = 2223 assert t26 == 2223 - t26b: Imu + t26b: Imm match 3: case 0: t26b = level_a(5) @@ -825,7 +825,7 @@ def main(): # ------------------------------------------------------------------- # TEST 27: Inline with Array in match arms # ------------------------------------------------------------------- - t27: Imu + t27: Imm match 0: case 0: t27 = inline_with_array(10) @@ -834,7 +834,7 @@ def main(): # inline_with_array(10): 10+11+12+13 = 46 assert t27 == 46 - t27b: Imu + t27b: Imm match 1: case 0: t27b = inline_with_array(10) @@ -846,7 +846,7 @@ def main(): # ------------------------------------------------------------------- # TEST 28: Inline modifying array in branches # ------------------------------------------------------------------- - t28: Imu + t28: Imm match 0: case 0: t28 = inline_modify_array(1) @@ -858,7 +858,7 @@ def main(): # ------------------------------------------------------------------- # TEST 29: Chained inline calls in match arms # ------------------------------------------------------------------- - t29: Imu + t29: Imm match 0: case 0: # chain_a(5)=7, chain_b(7)=28, chain_c(28)=48 @@ -867,7 +867,7 @@ def main(): t29 = chain_a(100) assert t29 == 48 - t29b: Imu + t29b: Imm match 1: case 0: t29b = chain_c(chain_b(chain_a(1))) @@ -879,7 +879,7 @@ def main(): # ------------------------------------------------------------------- # TEST 30: Different chain patterns in different arms # ------------------------------------------------------------------- - t30: Imu + t30: Imm match 0: case 0: t30 = chain_a(chain_a(chain_a(0))) @@ -890,7 +890,7 @@ def main(): # chain_a(0)=2, chain_a(2)=4, chain_a(4)=6 assert t30 == 6 - t30b: Imu + t30b: Imm match 1: case 0: t30b = chain_a(chain_a(chain_a(0))) @@ -904,7 +904,7 @@ def main(): # ------------------------------------------------------------------- # TEST 31: Stress test - many variables inline in match # ------------------------------------------------------------------- - t31: Imu + t31: Imm match 0: case 0: t31 = many_vars(0) @@ -918,7 +918,7 @@ def main(): # ------------------------------------------------------------------- # TEST 32: Multiple multi-return inlines in same arm # ------------------------------------------------------------------- - t32_sum: Imu + t32_sum: Imm match 0: case 0: a1, b1 = multi_return_inline(3) @@ -936,7 +936,7 @@ def main(): # ------------------------------------------------------------------- # TEST 33: 5-way match with all different inline types # ------------------------------------------------------------------- - t33: Imu + t33: Imm match 0: case 0: t33 = count_up(10) @@ -950,7 +950,7 @@ def main(): t33 = inline_with_array(1) assert t33 == 10 - t33b: Imu + t33b: Imm match 4: case 0: t33b = count_up(10) @@ -968,10 +968,10 @@ def main(): # ------------------------------------------------------------------- # TEST 34: Triple nested match with inlines at each level # ------------------------------------------------------------------- - t34: Imu + t34: Imm match 0: case 0: - inner1: Imu + inner1: Imm match 1: case 0: tmp34a = count_up(2) @@ -986,10 +986,10 @@ def main(): assert t34 == 1423 # Additional triple nesting test - without forward declaration inside innermost - t34b: Imu + t34b: Imm match 0: case 0: - mid1: Imu + mid1: Imm match 0: case 0: # Use inline directly without forward declaration @@ -1003,10 +1003,10 @@ def main(): assert t34b == 1105 # Test forward declaration with nested match and inline - t34c: Imu + t34c: Imm match 0: case 0: - val34c: Imu + val34c: Imm match 0: case 0: val34c = sum_range(0, 5) @@ -1041,12 +1041,12 @@ def main(): assert deep_mut == 210 # ------------------------------------------------------------------- - # TEST 36: Multiple forward-declared Imu assigned via inlines + # TEST 36: Multiple forward-declared Imm assigned via inlines # ------------------------------------------------------------------- - fwd1: Imu - fwd2: Imu - fwd3: Imu - fwd4: Imu + fwd1: Imm + fwd2: Imm + fwd3: Imm + fwd4: Imm match 0: case 0: fwd1 = count_up(1) @@ -1086,7 +1086,7 @@ def main(): # ------------------------------------------------------------------- # TEST 38: If-else-if chain with different inlines # ------------------------------------------------------------------- - t38: Imu + t38: Imm if 0 == 1: t38 = count_up(100) else: @@ -1126,7 +1126,7 @@ def main(): # ------------------------------------------------------------------- # TEST 40: Inline returning mutable at different states # ------------------------------------------------------------------- - t40: Imu + t40: Imm match 0: case 0: # complex_muts returns computation of interdependent muts @@ -1164,7 +1164,7 @@ def main(): # TEST 42: Deeply nested with mixed mutable tracking # ------------------------------------------------------------------- outer_m: Mut = 100 - t42: Imu + t42: Imm if 1 == 1: outer_m = outer_m + 50 match 0: @@ -1197,14 +1197,14 @@ def main(): # ------------------------------------------------------------------- # TEST 43: All arms have different nesting patterns # ------------------------------------------------------------------- - t43: Imu + t43: Imm match 0: case 0: # Flat t43 = count_up(5) case 1: # One level nested - if_inner: Imu + if_inner: Imm if 1 == 1: if_inner = sum_range(0, 10) else: @@ -1212,7 +1212,7 @@ def main(): t43 = if_inner case 2: # Two levels nested - m_inner: Imu + m_inner: Imm match 0: case 0: m_inner = level_a(1) @@ -1221,7 +1221,7 @@ def main(): t43 = m_inner case 3: # Three levels nested - deep_inner: Imu + deep_inner: Imm if 1 == 1: match 0: case 0: @@ -1271,7 +1271,7 @@ def main(): # ------------------------------------------------------------------- # TEST 45: Inline calling another inline that has internal branches # ------------------------------------------------------------------- - t45: Imu + t45: Imm match 0: case 0: # outer_with_inner calls inner_loop diff --git a/crates/lean_compiler/tests/test_data/program_172.py b/crates/lean_compiler/tests/test_data/program_172.py index da813966..ba54fc05 100644 --- a/crates/lean_compiler/tests/test_data/program_172.py +++ b/crates/lean_compiler/tests/test_data/program_172.py @@ -8,7 +8,7 @@ def helper_const(n: Const): def main(): - # Test 1: Basic match_range - no forward declaration needed (auto-generated as Imu) + # Test 1: Basic match_range - no forward declaration needed (auto-generated as Imm) x = 2 r1 = match_range(x, range(0, 4), lambda i: i * 100) assert r1 == 200 diff --git a/crates/lean_compiler/tests/test_data/program_174.py b/crates/lean_compiler/tests/test_data/program_174.py index 19be7961..94e1f4d1 100644 --- a/crates/lean_compiler/tests/test_data/program_174.py +++ b/crates/lean_compiler/tests/test_data/program_174.py @@ -40,7 +40,7 @@ def main(): def match_start_at_1(x): - result: Imu + result: Imm match x: case 1: result = 100 @@ -54,7 +54,7 @@ def match_start_at_1(x): def match_start_at_5(x): - result: Imu + result: Imm match x: case 5: result = 50 @@ -68,7 +68,7 @@ def match_start_at_5(x): def match_start_at_10(x): - result: Imu + result: Imm match x: case 10: result = 1000 @@ -95,7 +95,7 @@ def match_nonzero_mutable(x): def nested_nonzero_match(outer, inner): - result: Imu + result: Imm match outer: case 1: match inner: @@ -125,7 +125,7 @@ def nested_nonzero_match(outer, inner): def nonzero_match_in_if(cond, x): - result: Imu + result: Imm if cond == 0: result = 0 else: diff --git a/crates/lean_compiler/tests/test_data/program_179.py b/crates/lean_compiler/tests/test_data/program_179.py index 84d1f0b0..521d0af6 100644 --- a/crates/lean_compiler/tests/test_data/program_179.py +++ b/crates/lean_compiler/tests/test_data/program_179.py @@ -1,6 +1,6 @@ from snark_lib import * -ONE_EF_PTR = 1 # right after the (empty-public-input) zero-padded cell at memory[0] +ONE_EF_PTR = 8 # right after the 8-cell public input region def main(): diff --git a/crates/lean_compiler/tests/test_data/program_43.py b/crates/lean_compiler/tests/test_data/program_43.py index aed3a00b..d0199de4 100644 --- a/crates/lean_compiler/tests/test_data/program_43.py +++ b/crates/lean_compiler/tests/test_data/program_43.py @@ -7,7 +7,8 @@ def main(): return -def increment_twice(x: Mut): - x = x + 1 - x = x + 1 - return x +def increment_twice(x): + y: Mut = x + y = y + 1 + y = y + 1 + return y diff --git a/crates/lean_compiler/tests/test_data/program_57.py b/crates/lean_compiler/tests/test_data/program_57.py index 3d4965c4..4a96412e 100644 --- a/crates/lean_compiler/tests/test_data/program_57.py +++ b/crates/lean_compiler/tests/test_data/program_57.py @@ -10,19 +10,22 @@ def main(): return -def step1(n: Mut): - n = n * 2 - n = n + 1 - return n +def step1(n): + m: Mut = n + m = m * 2 + m = m + 1 + return m -def step2(n: Mut): - n = n * 3 - n = n + 2 - return n +def step2(n): + m: Mut = n + m = m * 3 + m = m + 2 + return m -def step3(n: Mut): - n = n * 4 - n = n + 3 - return n +def step3(n): + m: Mut = n + m = m * 4 + m = m + 3 + return m diff --git a/crates/lean_compiler/tests/test_data/program_67.py b/crates/lean_compiler/tests/test_data/program_67.py index ed80db93..5c91e661 100644 --- a/crates/lean_compiler/tests/test_data/program_67.py +++ b/crates/lean_compiler/tests/test_data/program_67.py @@ -2,7 +2,7 @@ def main(): - mut_a: Imu + mut_a: Imm mut_a = 5 assert mut_a == 5 return diff --git a/crates/lean_compiler/tests/test_data/program_68.py b/crates/lean_compiler/tests/test_data/program_68.py index 13eb50c6..a4797324 100644 --- a/crates/lean_compiler/tests/test_data/program_68.py +++ b/crates/lean_compiler/tests/test_data/program_68.py @@ -10,10 +10,10 @@ def main(): def test_func(a, b): x = 1 - mut_x_2: Imu + mut_x_2: Imm match a: case 0: - mut_x_1: Imu + mut_x_1: Imm mut_x_1 = x + 2 match b: case 0: diff --git a/crates/lean_compiler/tests/test_data/program_69.py b/crates/lean_compiler/tests/test_data/program_69.py index 4c19fda8..ecc0051e 100644 --- a/crates/lean_compiler/tests/test_data/program_69.py +++ b/crates/lean_compiler/tests/test_data/program_69.py @@ -15,20 +15,20 @@ def main(): def compute(a, b, c): base = 1000 - outer_val: Imu - mid_val: Imu - inner_val: Imu + outer_val: Imm + mid_val: Imm + inner_val: Imm match a: case 0: outer_val = 5 - local_a: Imu + local_a: Imm local_a = a + outer_val match b: case 0: mid_val = 3 - local_b: Imu + local_b: Imm local_b = local_a + mid_val match c: @@ -38,7 +38,7 @@ def compute(a, b, c): inner_val = base + local_b + c case 1: mid_val = 7 - local_b: Imu + local_b: Imm local_b = local_a + mid_val match c: @@ -48,13 +48,13 @@ def compute(a, b, c): inner_val = base + local_b + c case 1: outer_val = 15 - local_a: Imu + local_a: Imm local_a = a + outer_val match b: case 0: mid_val = 20 - local_b: Imu + local_b: Imm local_b = local_a + mid_val match c: @@ -64,7 +64,7 @@ def compute(a, b, c): inner_val = base + local_b + c case 1: mid_val = 30 - local_b: Imu + local_b: Imm local_b = local_a + mid_val match c: diff --git a/crates/lean_compiler/tests/test_data/program_83.py b/crates/lean_compiler/tests/test_data/program_83.py index aa423c9f..36bd35d3 100644 --- a/crates/lean_compiler/tests/test_data/program_83.py +++ b/crates/lean_compiler/tests/test_data/program_83.py @@ -2,7 +2,7 @@ def main(): - x: Imu + x: Imm cond = 1 if cond == 1: x = 10 diff --git a/crates/lean_compiler/tests/test_data/program_99.py b/crates/lean_compiler/tests/test_data/program_99.py index 40d61c71..66b2c119 100644 --- a/crates/lean_compiler/tests/test_data/program_99.py +++ b/crates/lean_compiler/tests/test_data/program_99.py @@ -7,7 +7,8 @@ def main(): return -def accumulate(x: Mut): +def accumulate(x): + acc: Mut = x for i in unroll(0, 3): - x = x + i - return x + acc = acc + i + return acc diff --git a/crates/lean_compiler/tests/test_data/soundness_2.py b/crates/lean_compiler/tests/test_data/soundness_2.py index 3fa2867b..630d756a 100644 --- a/crates/lean_compiler/tests/test_data/soundness_2.py +++ b/crates/lean_compiler/tests/test_data/soundness_2.py @@ -12,7 +12,7 @@ def main(): offset = p[6] total = p[7] - computed: Imu + computed: Imm match mode: case 0: computed = add_op(x, y) @@ -24,7 +24,7 @@ def main(): computed = combined(x, y) assert computed == expected - adjusted: Imu + adjusted: Imm if flag == 0: adjusted = bump(secondary, 1) elif flag == 1: diff --git a/crates/lean_compiler/tests/test_data/soundness_5.py b/crates/lean_compiler/tests/test_data/soundness_5.py index 3333ab2f..eaa31b7c 100644 --- a/crates/lean_compiler/tests/test_data/soundness_5.py +++ b/crates/lean_compiler/tests/test_data/soundness_5.py @@ -36,7 +36,7 @@ def main(): assert paired_sum(seed, n) == paired - chosen: Imu + chosen: Imm if flag == 1: chosen = seed else: diff --git a/crates/lean_compiler/tests/test_performance.rs b/crates/lean_compiler/tests/test_performance.rs index 723b9bd4..893883e2 100644 --- a/crates/lean_compiler/tests/test_performance.rs +++ b/crates/lean_compiler/tests/test_performance.rs @@ -1,3 +1,4 @@ +use backend::PrimeCharacteristicRing; use lean_compiler::*; use lean_vm::*; @@ -9,7 +10,13 @@ fn test_data_dir() -> String { /// Helper to get the number of cycles for a program file fn get_cycle_count(path: &str) -> usize { let bytecode = compile_program(&ProgramSource::Filepath(path.to_string())); - let result = try_execute_bytecode(&bytecode, &[], &ExecutionWitness::default(), false).unwrap(); + let result = try_execute_bytecode( + &bytecode, + &[F::ZERO; PUBLIC_INPUT_LEN], + &ExecutionWitness::default(), + false, + ) + .unwrap(); result.pcs.len() } diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 5ea8d46a..08d67c37 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -1,65 +1,102 @@ # zkDSL Language Reference -Warning: still under construction (i.e. it's messy). +The zkDSL is a Python-syntax language that compiles to leanVM bytecode (4 basic instructions and 2 special ones (precompile): poseidon / extension operations). For the underlying VM, and proving system, see [`minimal_zkVM.pdf`](../../minimal_zkVM.pdf). -## Program Structure +Source files use the `.py` extension. They are **not** currently runnable as +real Python, but the syntax is kept Python-compatible so that one day they +could be (TODO). +## Dev experience + +To recycle python tooling/linting on zkDSL files (which import [`snark_lib`](snark_lib.py)), point your editor at the compiler crate. With VSCode (for instance in `leanMultisig/.vscode/settings.json`): + +```json +{ + "python.analysis.extraPaths": [ + "./crates/lean_compiler" + ] +} ``` -from snark_lib import * # Python compatibility (ignored by compiler) -from dir.file import * # imports (optional, Python-style) -NAME = value # constants (optional, uppercase by convention) -def main(): # entry point (required) + +## Entrypoint + +Programs are organized as one or more `.py` files. The toplevel of each file is a +sequence of: + +1. `from import *` statements (optional) +2. Top-level constant declarations (optional) +3. Function definitions + +Execution starts at `def main(): ...`. + +```python +from snark_lib import * # only there to keep the Python linter happy; stripped by the zkDSL compiler +from utils import * # import other file + +X = 42 # constants must come before functions +# array constants (or arbitrary dimmensions: 1D, 2D, etc) +ARR_1D = [1, 2, 3] +ARR_2D = [[1, 2, 3], [], [10, 4]] +ARR_3D = [[[1, 2, 3], [7, 8], [9]], [], [[10], [10, 4]]] + +def main(): # required entry point ... -def helper(): # other functions (optional) + +def helper(): # other functions ... ``` -The `from snark_lib import *` line imports Python definitions for zkDSL primitives (Array, Mut, Const, etc.), allowing `.py` files to be executed as normal Python scripts for testing. The zkDSL compiler ignores this import line. +## Imports -To run zkDSL files as Python scripts, run from the file's directory with PYTHONPATH pointing to the lean_compiler crate (for snark_lib.py): -```bash -export PYTHONPATH=/path/to/repo/crates/lean_compiler -cd crates/lean_compiler/tests/test_data -python program_0.py +```python +from utils import * # imports utils.py (resolved from the import root) +from dir.subdir.file import * # nested module +from ..module import * # parent-directory import (relative to current file) ``` +Imports are wildcard-only (`import *`). Each module is loaded once even if imported +multiple times; circular imports are rejected. Constants with the same +name in two imported files cause a compile-time error. + ## Constants -Constants are declared at the top level (outside functions) using simple assignment. By convention, constant names are UPPERCASE. +Constants live at the top of the file, outside any function. -``` +```python X = 42 ARR = [1, 2, 3] NESTED = [[1, 2], [3]] ``` -### Multi-Dimensional Const Arrays - -Const arrays can be nested to any depth, and inner arrays can have different lengths (ragged arrays). All const array values are resolved at compile time. +### Nested (multi-dimensional, possibly ragged) constant arrays -``` -MATRIX = [[1, 2, 3], [4, 5], [6, 7, 8, 9]] # ragged 2D array -DEEP = [[[1, 2], [3]], [[4, 5, 6]]] # 3D array +```python +MATRIX = [[1, 2, 3], [4, 5], [6, 7, 8, 9]] +DEEP = [[[1, 2], [3]], [[4, 5, 6]]] ``` -**Accessing elements:** Use chained indexing with compile-time indices: -``` -x = MATRIX[0][2] # x = 3 -y = DEEP[1][0][1] # y = 5 +Indexed access uses chained subscripts at compile time: + +```python +x = MATRIX[0][2] # 3 +y = DEEP[1][0][1] # 5 ``` -**Using `len()` on inner arrays:** The `len()` function can be applied to any level of a nested const array, including inner arrays accessed by index. This is particularly useful for iterating over ragged arrays where each row has a different length: +`len()` works at every depth, including on a row addressed by a constant index: -``` -len(MATRIX) # 3 -len(MATRIX[0]) # 3 -len(DEEP[0][0]) # 2 +```python +len(MATRIX) # 3 +len(MATRIX[0]) # 3 +len(DEEP[0][0]) # 2 ``` -**Important:** When using `len()` on an inner array with a variable index (e.g., `len(ARR[i])`), the index must be a compile-time constant. This works inside `unroll` loops because the loop variable becomes a compile-time constant during unrolling. +When `len()` is applied with a variable index (`len(ARR[i])`), `i` must be a +compile-time constant. `: Const` parameters always qualify (see [Functions] +below), as do iterator variables of an `unroll` loop (see [For loops] below). -**Example: Iterating over a ragged 2D array:** -``` +Example: iterating a ragged 2D table: + +```python MATRIX = [[1, 2, 3], [4, 5], [6, 7, 8, 9]] def main(): @@ -67,17 +104,17 @@ def main(): for row in unroll(0, len(MATRIX)): for col in unroll(0, len(MATRIX[row])): total = total + MATRIX[row][col] - assert total == 45 # 1+2+3+4+5+6+7+8+9 + assert total == 45 return ``` ## Functions -``` -def add(a, b): # return count is inferred from return statements +```python +def add(a, b): return a + b -def swap(a, b): # multiple return values +def swap(a, b): return b, a def main(): @@ -85,122 +122,116 @@ def main(): return ``` -The number of return values is automatically inferred from the `return` statements. All return statements in a function must return the same number of values. +Every function must contain at least one `return`. The compiler infers the number +of returned values from the `return` statements; all `return`s in a function must +agree. A function that "returns nothing" uses a bare `return`. -### Parameter Modifiers -| Syntax | Meaning | -| ---------- | --------------------------------------------------------- | -| `x` | immutable parameter | -| `x: Const` | compile-time value (enables `unroll` with dynamic bounds) | -| `x: Mut` | mutable within function body only | +### Parameter types -**All parameters are pass-by-value.** The `: Mut` modifier allows reassignment within the function, but changes are not visible to the caller. Use return values to communicate results. +| Syntax | Meaning | +| ---------- | ------------------------------------ | +| `x` | normal (immutable) runtime parameter | +| `x: Const` | compile-time parameter | -``` -def repeat(n: Const): # Const enables unroll +```python +def repeat(n: Const): # Const enables unroll(0, n) sum: Mut = 0 for i in unroll(0, n): sum = sum + i return sum -def double(x: Mut): # Mut allows local reassignment - x = x * 2 # only affects local copy - return x # must return to pass result back +def double(x): # parameter is immutable; shadow with a local + y: Mut = x + y = y * 2 + return y ``` -### Inline Functions -Use the `@inline` decorator to mark functions for inlining at call sites: -``` +### Inline functions + +`@inline` expands a function at every call site instead of generating a JUMP +instruction to another part of the bytecode. Useful for performance (calling a function costs a few cycles). + +```python @inline def square(x): return x * x ``` -**Note:** Inline functions cannot have `: Mut` parameters. -**Note:** Inline functions support at most one `return`, and it must be at the -top level of the body — never nested inside an `if`, loop, or `match`. Early or -conditional returns are rejected by the compiler, because inlining expands each -`return` into a plain assignment with no control flow. Use a regular (non-inline) -function — with `: Const` parameters if you need compile-time specialization — -when you need a conditional return. +Constraints on inline functions (compiler limitations): Exactly one `return`, placed as the last statement of the body, not nested inside `if`, a loop, or `match`. Inlining rewrites the `return` into a plain assignment in place, so early or conditional returns cannot be expressed. ## Variables | Declaration | Mutability | Notes | | ------------- | ---------- | ---------------------------------------------- | | `x = 10` | immutable | cannot be reassigned | -| `x: Mut = 10` | mutable | can be reassigned | -| `x: Imu` | immutable | forward declaration, assign exactly once later | -| `x: Mut` | mutable | forward declaration for mutable variable | +| `x: Mut = 10` | mutable | reassignable | +| `x: Imm` | immutable | forward declaration; assign exactly once later | +| `x: Mut` | mutable | forward declaration; reassignable later | -### Forward Declarations +### Forward declarations -Use `x: Imu` when a variable must be assigned in different branches: +Use `x: Imm` when you want an immutable binding but the value comes from a +branch: -``` -result: Imu # immutable: assign exactly once +```python +result: Imm if cond == 1: result = 10 else: result = 20 -# result cannot be reassigned after this +# result is now immutable ``` -Use `x: Mut` when you need the variable to be mutable after assignment: +Use `x: Mut` when you want to keep mutating the variable after the branch: -``` +```python x: Mut if cond == 1: x = 10 else: x = 20 -x = x + 1 # OK: x was declared as mutable +x = x + 1 # OK: x is mutable ``` -### Tuple Assignments with Mutable Variables +### Mutability inside tuple assignments -When a function returns multiple values and some need to be mutable, use forward declarations: +To make a single component of a tuple-return mutable, forward-declare it: -``` -b: Mut # declare b as mutable +```python +b: Mut a, b, c = some_function() -# a and c are immutable, b is mutable -b = b + 1 # OK -# a = 5 # ERROR: a is immutable +b = b + 1 # OK +# a = 5 # ERROR: a is immutable ``` -This is useful when a function returns multiple values and only some need to be modified later. - -## Memory and Arrays +## Memory and arrays -``` -buffer = Array(16) # allocate 16 field elements +```python +buffer = Array(16) # allocate 16 field elements buffer[0] = 42 -x = buffer[5] +buffer[0] = 42 # Valid +# buffer[0] = 41 # ERROR: conflicting write (read only memory) +buffer[5] = 34 +x = buffer[5] # x = 34 -matrix = Array(64) # 2D via manual indexing +matrix = Array(64) # 2D via manual indexing matrix[row * 8 + col] = value -ptr2 = ptr + 5 # pointer arithmetic -ptr2[0] = 100 # same as ptr[5] = 100 +ptr2 = buffer + 5 # pointer arithmetic +ptr2[0] = 100 # same as buffer[5] = 100 ``` -**Memory is write-once.** Due to SSA constraints, each memory location can only hold one value. Writing to the same location multiple times is allowed, but all writes must produce the same value—otherwise a runtime error occurs. +`Array(n)` returns a pointer to a freshly allocated block of `n` field +elements. `n` may be a compile-time constant (more efficient, analogy: allocated on the stack) or a runtime +value (less efficient, analogy: allocated on the heap). Memory is **write-once**: a cell may be +written more than once only if all writes store the same value. -``` -arr = Array(3) -arr[0] = 10 # OK: first write -arr[0] = 10 # OK: same value -arr[0] = 20 # ERROR: different value at same location -``` +## Control flow -Use `mut` variables when you need mutability, the compiler cannot handle mutability on hand-written allocated memory ("Array(...)"). +### `if` / `elif` / `else` -## Control Flow - -### If/Else -``` +```python if x == 0: y = 1 elif x == 1: @@ -208,30 +239,42 @@ elif x == 1: else: y = 3 ``` -Comparison operators: `==`, `!=` -### Match -Patterns must be consecutive integers: -``` +Comparison operators on conditions: `==`, `!=`, `<`, `<=`. There is **no** `>` +or `>=` (flip the operands to get the same effect). + +### `match` + +Patterns must be a set of integers of the form [n, n+1, n + 2, ...]: + +```python match value: case 5: result = 500 + do_stuf() case 6: result = 600 + do_other_stuf() case 7: result = 700 + ... ``` -### match_range +The matched value must lie inside the listed range; out-of-range values produce +undefined behaviour: **It's the responsability of the program to ensure this** (no checks added by the compiler). Letting a prover-controlled value escape the range in a `range` is a critical vulnerability. -Compile-time construct that expands into a match statement, useful for dispatching to functions with const parameters based on runtime values. Results are always immutable. +### `match_range` -``` +`match_range` enables to automatically generate a `match` with repeated arms. + +```python result = match_range(n, range(1, 5), lambda i: compute(i)) ``` -Expands to: -``` -result: Imu # auto-generated forward declaration (always immutable) + +is expanded by the compiler to: + +```python +result: Imm match n: case 1: result = compute(1) case 2: result = compute(2) @@ -239,323 +282,471 @@ match n: case 4: result = compute(4) ``` -**Multiple continuous ranges** with different lambdas: -``` +It's possible to chain several `(range, lambda)` pairs, provided the ranges are +**contiguous** (the end of one is the start of the next): + +```python result = match_range(n, range(0, 1), lambda i: special_case(), range(1, 8), lambda i: normal_case(i)) ``` -Expands to a match where case 0 uses `special_case()` and cases 1-7 use `normal_case(i)`. -Ranges must be continuous (end of one equals start of next). +Multiple return values are supported via tuple unpacking. The bindings produced +by `match_range` are always immutable. Forward-declare with `: Mut` (and then +reassign) if you need them mutable later: -**Multiple return values:** -``` +```python +a: Mut a, b = match_range(n, range(0, 4), lambda i: two_values(i)) +a += 1 ``` -**Common use case:** Dispatching runtime values to const-parameter functions: -``` +Idiomatic use: enables to dispatch a runtime value to a const-parameter function. + +```python def helper_const(n: Const): - # function that requires compile-time n return n * n def compute(value): - result = match_range(value, range(0, 10), lambda i: helper_const(i)) - return result + assert value < 10 + return match_range(value, range(0, 10), lambda i: helper_const(i)) ``` +Similar to `match`, range validity of the matched value is the responsibility of the program, not the compiler. Letting a prover-controlled value escape the range in a `match_range` is a critical vulnerability. -**IMPORTANT:** For both `match` and `match_range`, the programmer must ensure the value is within the specified range. Out-of-range values cause undefined behavior. Use `debug_assert` to validate: -``` -debug_assert(n < 10) -debug_assert(0 < n) -result = match_range(n, range(1, 10), lambda i: compute(i)) -``` +### For loops -### For Loops -``` -for i in range(0, 10): # standard loop - ... -for i in parallel_range(0, n): # iterations executed in parallel (see below) - ... -for i in unroll(0, 4): # unrolled at compile time - ... -``` -Use `unroll` when bounds are const or compile-time expansion is needed. +Three loop forms, all written `for i in (start, end):`. The +iterator visits `start, start + 1, ..., end - 1`. -**`parallel_range`** executes iterations concurrently using rayon. The produced bytecode is identical to `range`. Constraints: -- The loop body must be **iteration-independent**: no `Mut` variables carried - across iterations. Each iteration may only write to its own frame and to - external addresses that do not affect other iterations . -- The memory footprint (i.e. total memory usage) must be the same across iterations -- XMSS / Merkle hint consumption must be the same across iterations +Restrictions shared by all three forms: -**Mutable variables in non-unrolled loops:** Mutable variables can be modified inside non-unrolled loops. The compiler automatically transforms these into buffer-based implementations: +- No `break` or `continue` (not in the grammar). -``` +#### `range(a, b)`: runtime loop + +The general-purpose runtime loop. `a` and `b` may be runtime values. The +compiler lowers the loop to a recursive function. + +```python sum: Mut = 0 for i in range(1, 11): sum += i assert sum == 55 ``` -Loops limitations: -- no "continue" or "break" are supported yet -- the "return" keyword is not supported inside the body of a normal (non-unrolled) loop (because under the hood normal loops are transformed into recursive functions) +Mutable variables carried across iterations are supported transparently. + +*Under the hood: the compiler inserts a buffer array, stores the per-iteration value into it, and reads the final value back after the loop.* + +Restrictions: No `return` inside the body + +*Under the hood: because the loop is lowered to a recursive function.* + +#### `unroll(a, b)`: compile-time unrolling + +The loop is expanded at compile time: the body is duplicated once per iteration +with `i` substituted by its concrete value. Both `a` and `b` must be +compile-time constants. + +```python +for i in unroll(0, 4): + buffer[i] = i * i +``` + +#### `parallel_range(a, b)` — parallel runtime loop + +**`parallel_range` compiles to exactly the same bytecode as `range`.** It +differs only in the runner's scheduling policy: iterations are dispatched +concurrently across worker threads rather than evaluated in sequence. The only advantage is faster witness generation. +Iteration `a` is executed first, in isolation, to determine the per-iteration +memory footprint; the remaining iterations are then evaluated in parallel +without inter-iteration synchronization. + +```python +for i in parallel_range(0, n): + process(i, inputs[i], outputs[i]) +``` + +Because there is no synchronization, the loop body must be +iteration-independent: + +- No `Mut` variables carried across iterations (each iteration writes only to + its own call frame and to addresses disjoint from every other iteration). +- Identical memory footprint per iteration. +- Identical hint consumption per iteration (witness hints, XMSS-specific + decomposition hints, Merkle hints, etc.). + +These constraints are **not** checked at compile time. Violating them produces +silently wrong proofs. + +### Statements without effect are rejected + +Every line must either be a declaration, an assignment, a control-flow form, an +assertion, a `return`, or a side-effecting call (`hint_witness`, precompile, +`print`, or a function call). A bare expression like `x + 1` on its own line is +a compile error. ## Expressions ### Arithmetic -- `+`, `-`, `*`, `/` (field operations): allowed at runtime -- `%` (modulo), `**` (exponentiation): only allowed at compile time -### Compound Assignment -Syntactic sugar for updating mutable variables: -``` +`+`, `-`, `*`, `/` are field operations and work at runtime. + +`%` (modulo) and `**` (exponentiation) are **compile-time only** — both operands +must be constants known at compile time. + +### Compound assignment + +```python x: Mut = 10 -x += 5 # equivalent to: x = x + 5 -x -= 3 # equivalent to: x = x - 3 -x *= 2 # equivalent to: x = x * 2 -x /= 4 # equivalent to: x = x / 4 +x += 5 # x = x + 5 +x -= 3 # x = x - 3 +x *= 2 # x = x * 2 +x /= 4 # x = x / 4 ``` -### Built-in Functions -Only allowed at compile time: +Only a single target is allowed on the LHS of a compound assignment. -``` -log2_ceil(x) # ceiling of log2 -next_multiple_of(x, n) # smallest multiple of n >= x -div_ceil(a, b) # ceiling division: (a + b - 1) // b -div_floor(a, b) # floor division: a // b +### Compile-time built-ins + +These functions are evaluated at compile time only — their arguments must be +constants: + +```python +log2_ceil(x) # ceil(log2(x)) +next_multiple_of(x, n) # smallest multiple of n that is >= x +div_ceil(a, b) # (a + b - 1) // b +div_floor(a, b) # a // b saturating_sub(a, b) # max(0, a - b) -len(array) # length of const array or vector +len(array) # length of a constant array (any depth) ``` -## Assertions +### `_` (the discard target) + +Inside a tuple-unpacking LHS, `_` discards the value at that position. The +compiler rewrites each `_` to a fresh anonymous name so they don't collide. + +```python +_, b = swap(a, b) # only keep b +_ = compute() # discard a single return value ``` -# constraint in proof + +## Assertions + +The zkDSL provides two assertion forms with very different semantics: + +| Form | Enforced by | Use for | +| -------------- | --------------------------------------- | ------------------------------------------------------------------- | +| `assert` | The proof system | Invariants the verifier must check | +| `debug_assert` | The prover only (at witness generation) | Sanity checks; preconditions the verifier does not need to re-check | + +### `assert`: proof-enforced constraint + +```python assert x == y assert x != y -assert x < y +assert x < y assert x <= y -# unconditional failure (panic) +``` + +The four supported comparison operators are `==`, `!=`, `<`, `<=` (no `>` or +`>=`; flip the operands). + + +### Range checks: `assert a < b` and `assert a <= b` + +**The program must ensure `b <= 2^16`.** The compiler does not check this +(`b` may be a runtime value). Violating the bound is a critical soundness +vulnerability. + +*Under the hood: the compiler proves `a < b` by emitting two DEREF instructions, +which check that `a` and `b - 1 - a` are both valid memory addresses. An +address is valid iff it is `< M`, where `M` is the memory size. To stay sound +for every admissible memory size, the construction relies on the smallest one, +`M_min = 2^16` (= `2^MIN_LOG_MEMORY_SIZE`), giving the bound `b <= 2^16`.* + +#### Explicit panic + +`assert False` is the unconditional failure form. It compiles to a Panic and +accepts an optional message: + +```python assert False -assert False, "error message" -# runtime check only (not constrained by the snark) -debug_assert(x == y) -debug_assert(x != y) +assert False, "human-readable message" +``` + +### `debug_assert`: sanity checks at witness generation + +```python debug_assert(x < y) -debug_assert(x <= y) ``` +`debug_assert` accepts the same four comparison operators. It is evaluated by +the prover at trace-generation time and does **not** emit any constraint, so +the verifier never re-checks it. Use it for invariants the prover is expected +to maintain but that the verifier can take for granted — typically the +range-validity preconditions of `match` / `match_range` dispatches. + ## Comments -``` -# Single-line comment +```python +# single-line comment """ -Multi-line comment -can span multiple lines +block comment """ ``` -## Imports - -``` -from utils import * # imports utils.py (relative to import root) -from dir.subdir.file import * # imports dir/subdir/file.py -``` +## Line continuation -## Memory Layout +As in Python: -The runner places the program's memory as: +- **Implicit** continuation inside `(...)` or `[...]`. +- **Explicit** continuation with `\` at end of line. -``` -[ public_input | preamble_memory | runtime ] +```python +result = function_call(arg1, + arg2, + arg3) # implicit continuation inside parens +y = 1 + 2 + \ + 3 + 4 # explicit continuation with backslash ``` -- `public_input` lives at `memory[0..public_input.len()]` (zero-padded to a power of two by the runner so it can be evaluated as a multilinear polynomial). -- `preamble_memory` is a region the runner reserves but does not initialize. The guest program is responsible for writing any constants it needs (e.g. `ZERO_VEC_PTR`, `ONE_EF_PTR`, etc.) in this area. +## Hints (prover-supplied data) -Prover-supplied witness data is fetched on demand with `hint_witness("name", ptr)`, where the string literal -names an entry in the witness's `hints: HashMap>>` map and -`ptr` is a caller-allocated buffer. Each call writes the next unused `Vec` -under that name (per-name running index) into the buffer at `ptr`. The guest -is responsible for allocating `ptr` with enough room; the witness's length is -trusted. -`hint_witness` +A hint is data the *prover* writes into memory without adding any constraint — +the program must still constrain the written value if it wants the verifier to +believe anything about it. There are two flavours of hint: -``` -data_buf = Array(64) -hint_witness("input_data", data_buf) # writes next `input_data` entry into data_buf -n = data_buf[0] -# ... -``` +### `hint_witness("name", ptr)` -### Built-in Hints +Writes the next buffer queued under the label `name` into memory starting at +`ptr`. The guest must allocate `ptr` large enough to hold the data; no length +is checked at runtime. -hints = prover-supplied values at runtime (without adding snark constraints). Like `hint_witness`, they are bare statements (no return value) — the caller allocates any destination memory and is responsible for constraining the written values. +The buffer comes from the host (Rust side), not from the guest. Before +running the program, the host fills `ExecutionWitness::hints` with one queue +of buffers per label; each `hint_witness("name", ptr)` call pops the next +buffer from `hints["name"]`. -| Hint | Signature | Writes | -| --------------------------------- | --------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | -| `hint_decompose_bits` | `(to_decompose, ptr, num_bits, endianness)` | `num_bits` field elements at `ptr` (the 0/1 bit decomposition of `to_decompose`); `endianness` is `0` for big-endian, `1` for little-endian | -| `hint_less_than` | `(a, b, result_ptr)` | `1` at `result_ptr` if `a < b` else `0` | -| `hint_log2_ceil` | `(n, result_ptr)` | `ceil(log2(n))` at `result_ptr` | -| `hint_div_floor` | `(a, b, q_ptr, r_ptr)` | `floor(a/b)` at `q_ptr` and `a mod b` at `r_ptr` (requires `b != 0`) | -| `hint_decompose_bits_xmss` | `(decomposed_ptr, remaining_ptr, to_decompose_ptr, num_to_decompose, chunk_size)` | XMSS-specific decomposition (see `crates/lean_vm/src/isa/hint.rs`) | -| `hint_decompose_bits_merkle_whir` | `(decomposed_ptr, remaining_ptr, value, chunk_size)` | Merkle/WHIR-specific decomposition | +`ExecutionWitness` lives in `crates/lean_vm/src/execution/runner.rs`: -Hints only *suggest* a value; the guest must add appropriate constraints to bind that value to its specification. +```rust +pub struct ExecutionWitness { + ... + pub hints: HashMap>>, + ... +} +``` +Each map key is a label; the value is the **ordered list of buffers** the +guest will consume under that label. The N-th `hint_witness("name", ptr)` call +the guest executes pops the N-th `Vec` from `hints["name"]` and writes it +at `ptr`. -## Precompiles +For example, the guest below issues three `hint_witness` calls — two against +`"input_data"` and one against `"other_stuff"`: -### poseidon16_compress -Always in "compression" mode -``` -poseidon16_compress(left, right, output) +```python +data_buf_1 = Array(64) +hint_witness("input_data", data_buf_1) +n = data_buf_1[0] + +data_buf_2 = Array(64) +hint_witness("input_data", data_buf_2) +m = data_buf_2[3] +assert n == m + 8 + +data_buf_3 = Array(10) +hint_witness("other_stuff", data_buf_3) +... ``` -- `left`: pointer to 8 field elements -- `right`: pointer to 8 field elements -- `res`: pointer to result (8 elements) -### Extension Operations +The matching Rust side must register two buffers under `"input_data"` (in +the order the guest will read them) and one under `"other_stuff"`: -Six built-in functions route through a single `extension_op` precompile table. Each combines an element-wise operation with an accumulation over `length` element pairs. +```rust +let mut hints: HashMap>> = HashMap::new(); +hints.insert( + "input_data".to_string(), + vec![ + first_input_buffer, // consumed by the first hint_witness("input_data", ...) + second_input_buffer, // consumed by the second hint_witness("input_data", ...) + ], +); +hints.insert("other_stuff".to_string(), vec![other_buffer]); +let witness = ExecutionWitness { hints, ..Default::default() }; ``` -func(ptr_a, ptr_b, ptr_result) # length defaults to 1 -func(ptr_a, ptr_b, ptr_result, length) # explicit length (N elements) -``` - -**Operand types (suffix):** -- `_ee`: both `ptr_a` and `ptr_b` point to extension field elements (5 consecutive field elements each, stride = DIM) -- `_be`: `ptr_a` points to base field elements (stride 1), `ptr_b` points to extension field elements (stride DIM) -`ptr_result` always points to a single extension field element (DIM=5 field elements). +A missing label, or running out of buffers under a label, is a runner-side +panic: each call requires its corresponding entry to exist. -**Operations:** +### Custom hints -| Function | Element-wise | Accumulation | -| ----------------------------------- | --------------------------------- | -------------------- | -| `add_ee` / `add_be` | `e_i = a_i + b_i` | `result = sum(e_i)` | -| `dot_product_ee` / `dot_product_be` | `e_i = a_i * b_i` | `result = sum(e_i)` | -| `poly_eq_ee` / `poly_eq_be` | `e_i = a_i*b_i + (1-a_i)*(1-b_i)` | `result = prod(e_i)` | +Custom hints are a fixed set of built-in calls the prover uses to compute +values that would be expensive to derive in-circuit — bit +decompositions, comparisons, integer division, etc. Each is invoked like an +ordinary function and writes its result into a caller-supplied memory +location. -**Note:** `length` must be a compile-time constant. For runtime-known lengths, use `match_range` to dispatch (see example below). +Like every hint, **the result is unconstrained**: the verifier checks +nothing about the hinted value. The guest program must add its own +constraints binding the hinted bits / quotient / remainder / boolean to the +original input — otherwise a malicious prover can substitute any value. The +typical pattern is "hint, then assert the relationship": +```python +# hint the bits... +bits = Array(8) +hint_decompose_bits(value, bits, 8) +# ...then constrain them to actually equal `value` +acc: Mut = 0 +for i in unroll(0, 8): + assert bits[i] * (bits[i] - 1) == 0 # boolean + acc = acc * 2 + bits[i] +assert acc == value ``` -# Multiply two extension field elements (length=1, default) -dot_product_ee(x, y, z) # z = x * y -# Copy extension element (multiply by [1,0,0,0,0]). -# `ONE_EF_PTR` is a guest-program constant that the program must materialize -# in its preamble memory at startup; see `crates/rec_aggregation/zkdsl_implem/utils.py` -# for an example (`build_preamble_memory`). -dot_product_ee(src, ONE_EF_PTR, dst) +The full list: -# Dot product of N extension field elements -dot_product_ee(coeffs, basis, result, N) +| Hint | Arguments | Effect | +| --------------------------------- | ------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------- | +| `hint_decompose_bits` | `(value, ptr, n_bits)` | Writes `n_bits` big-endian 0/1 field elements at `ptr` (MSB at `ptr[0]`). Requires `n_bits <= 31`. | +| `hint_decompose_bits_merkle_whir` | `(decomposed_ptr, value, chunk_size)` | Writes `24 / chunk_size` little-endian `chunk_size`-bit chunks of `value` at `decomposed_ptr` (`chunk_size` must divide 24). | +| `hint_decompose_bits_xmss` | `(decomposed_ptr, to_decompose_ptr, num_to_decompose, chunk_size)` | For each of `num_to_decompose` values at `to_decompose_ptr[..]`, writes its `24 / chunk_size` little-endian chunks at `decomposed_ptr`. | +| `hint_less_than` | `(a, b, result_ptr)` | `1` at `result_ptr` if `a < b` (canonical integer compare), else `0`. | +| `hint_log2_ceil` | `(n, result_ptr)` | `ceil(log2(n))` at `result_ptr`. | +| `hint_div_floor` | `(a, b, q_ptr, r_ptr)` | `floor(a / b)` at `q_ptr`, `a mod b` at `r_ptr` (requires `b != 0`). | -# Dot product with base-field scalars -dot_product_be(alpha_powers, coeffs, result, N) +## Precompiles -# Extension field addition: c = a + b -add_ee(a, b, c) +Precompiles are special instructions in the leanVM ISA, alongside the four +basic ones (ADD, MUL, DEREF, JUMP). The zkDSL exposes them as built-in +functions. There are two families: Poseidon hashing and extension-field +operations. -# Extension field subtraction via constraint: c = a - b <=> b + c = a -add_ee(b, c, a) +### Poseidon16 family -# Equality polynomial: eq(a, b) = a*b + (1-a)*(1-b) -poly_eq_ee(a, b, eq_result) +The variants are as follows: -# Multi-point equality polynomial: prod_{i=0}^{n-1} eq(a[i], b[i]) -poly_eq_ee(a, b, result, n) +- **compress vs. permute** — `compress` applies the feed-forward addition + (`Poseidon(L || R) + L`); `permute` is the raw 16-cell permutation. +- **full vs. half output** — `_half` constrains only the first 4 output cells + (the rest are unconstrained); useful when the consumer only cares about + half a digest. +- **hardcoded-left** — `_hardcoded_left` reads the first 4 cells of the left + input from a compile-time address instead of from `m[L..L+4]`; the last 4 + cells of the left input still come from memory. -# Runtime-known length via match_range -def dot_product_ee_dynamic(a, b, res, n): - debug_assert(n <= 256) - match_range(n, range(1, 257), lambda i: dot_product_ee(a, b, res, i)) -``` +Common arguments: `L`, `R` are 8-cell input buffers; `O` is the output +buffer; `off` (where present) is a compile-time address. -## Debugging +| Function | Cells written to `O` | Notes | +| ------------------------------------------------------- | -------------------- | ----------------------------------------- | +| `poseidon16_compress(L, R, O)` | `O[0..8]` | `Poseidon(L \|\| R) + L` | +| `poseidon16_compress_half(L, R, O)` | `O[0..4]` | `O[4..8]` is unconstrained | +| `poseidon16_compress_hardcoded_left(L, R, O, off)` | `O[0..8]` | left = `m[off..off+4] \|\| m[L..L+4]` | +| `poseidon16_compress_half_hardcoded_left(L, R, O, off)` | `O[0..4]` | half-output + hardcoded-left composition | +| `poseidon16_permute(L, R, O)` | `O[0..16]` | raw Poseidon permutation, no feed-forward | -``` -print(value) -print(a, b, c) -``` +### Extension-field operations -## Example +Six built-in functions, each reading two length-`n` vectors `a` and `b` and +writing one extension-field element to `result`. `n` defaults to `1` and must +be a compile-time constant when given. +```python +add_ee(a, b, result, n=1) # result = sum_i (a[i] + b[i]) +dot_product_ee(a, b, result, n=1) # result = sum_i a[i] * b[i] +poly_eq_ee(a, b, result, n=1) # result = prod_i (a[i]*b[i] + (1-a[i])*(1-b[i])) ``` -SIZE = 8 -def main(): - arr = Array(SIZE) - for i in unroll(0, SIZE): - arr[i] = i * i - sum = compute_sum(arr, SIZE) - assert sum == 140 - return +The `_ee` suffix means both `a` and `b` are vectors of *extension*-field +elements (each occupying `DIM = 5` consecutive cells). The `_be` variants +(`add_be`, `dot_product_be`, `poly_eq_be`) are identical except `a` is a +vector of *base*-field elements (1 cell each); `b` and `result` are still +extension-field. -def compute_sum(ptr, n: Const): - acc: Mut = 0 - for i in unroll(0, n): - acc = acc + ptr[i] - return acc -``` +`result` always points to a single extension-field element (5 cells). -## Line Continuation +For a runtime `n`, dispatch through `match_range`: + +```python +def dot_product_ee_dynamic(a, b, res, n): + debug_assert(n <= 256) + match_range(n, range(1, 257), lambda i: dot_product_ee(a, b, res, i)) +``` -Like Python, lines can be continued in two ways: +Common idioms: -### Implicit continuation (inside parentheses/brackets/braces) +```python +# Multiply two extension elements (n defaults to 1) +dot_product_ee(x, y, z) # z = x * y -Expressions inside `()`, `[]`, or `{}` can span multiple lines without any special syntax: +# Copy an extension element by multiplying by 1 +# (ONE_EF_PTR is a constant materialized in the preamble) +dot_product_ee(src, ONE_EF_PTR, dst) +# Extension subtraction: write-once memory turns "c = a + b" into +# the constraint "b + c = a", i.e. c = a - b +add_ee(b, c, a) # c = a - b ``` -result = function_call( - arg1, - arg2, - arg3 -) -ARR = [ - 1, - 2, - 3, -] +## Debugging + +```python +print(value) +print(a, b, c) ``` -### Explicit continuation with backslash +`print` flushes its output during execution; **a Rust-side panic mid-program drops +buffered prints**. When you need a print to survive a panic, temporarily change +the print hint in `lean_vm/src/isa/hint.rs (Self::Print)` to `eprint!` directly. -Long lines can also be split using `\` at the end of a line: +## Memory layout -``` -x = very_long_function_name(arg1, \ - arg2, \ - arg3) +The runner lays out memory as -y = 1 + 2 + \ - 3 + 4 +```python +[ public_input (PUBLIC_INPUT_LEN cells) | preamble_memory | runtime ] ``` -The `\` and following newline are replaced with a single space. Any whitespace after `\` and before the newline is ignored. +- `public_input` is fixed at `PUBLIC_INPUT_LEN = DIGEST_LEN = 8` cells (a hash + digest), occupying `memory[0..8]`. +- `preamble_memory` is a region of `witness.preamble_memory_len` cells the + runner reserves immediately after the public input but does **not** + initialize. The guest program is expected to fill this region with whatever + helper constants it relies on (e.g. a vector of zeros for + `dot_product_ee`-as-copy, an extension-field one for multiply-by-one tricks, + a vector of ones for batched accumulations, …) at the start of `main`. The + names and offsets of these constants are not enshrined within leanVM. See + `crates/rec_aggregation/zkdsl_implem/utils.py (build_preamble_memory)` for + a concrete example. +- The runtime region holds the program's stack frames, working memory, and any + prover-supplied witness data, all governed by the write-once rule. ## Tips -1. Use `unroll` for small, fixed-size loops -2. Use `const` parameters when loop bounds depend on arguments -3. Use `mut` sparingly - immutable is easier to verify -4. Use `x: Imu` or `x: Mut` for forward-declaring variables that will be assigned in branches -5. Match patterns must be consecutive integers (can start from any value) +1. Prefer `unroll` over `range` for small, fixed-size loops. +2. Reach for `: Const` parameters when the function body needs `unroll` over the + parameter. +3. `if` / `elif` branches that assign to the same outer variable should + forward-declare it (`x: Imm` or `x: Mut`) before the branch. +7. Function parameters are always immutable. To mutate a parameter's value + inside a function, introduce a local `: Mut` alias at the top of the body + (e.g. `y: Mut = x`). -## Example: From high level syntactic sugar to minimal ISA, with read-only memory +## Example -Take the following program: +Look at the recursive aggregation program (to aggregate XMSS) at its entrypoint [main.py](../rec_aggregation/zkdsl_implem/main.py). -``` +## Compilation step-by-step: zkDSL -> ISA + +Starting program: + +```python def main(): x: Mut = 0 y: Mut = 3 @@ -571,9 +762,10 @@ def main(): return ``` -First, we use buffers to handle mutable variables across (non-unrolled) loops. +Step 1 — the compiler replaces mutable-across-loop variables with index buffers, since memory +is write-once: -``` +```python def main(): x: Mut = 0 y: Mut = 3 @@ -602,10 +794,9 @@ def main(): return ``` -Then, use auxiliary variables to transform it into SSA form (Static Single-Assignment): - +Step 2 — SSA-rename all reassignments to fresh names: -``` +```python def main(): x = 0 y = 3 @@ -634,9 +825,9 @@ def main(): return ``` -Finally, transform the loop into a recursive function: +Step 3 — lower the runtime loop to a recursive function: -``` +```python def main(): x = 0 y = 3 @@ -647,14 +838,14 @@ def main(): x_buff[0] = x2 y_buff = Array(size + 1) y_buff[0] = y2 - loop(4, x_buff, y_buff) + loop_helper(4, x_buff, y_buff) x3 = x_buff[size] y3 = y_buff[size] assert x3 == 35 assert y3 == 40 return -def loop(i, x_buff, y_buff): +def loop_helper(i, x_buff, y_buff): if i == 6: return else: @@ -668,20 +859,7 @@ def loop(i, x_buff, y_buff): next_idx = buff_idx + 1 x_buff[next_idx] = x_body3 y_buff[next_idx] = y_body3 - loop(i + 1, x_buff, y_buff) + loop_helper(i + 1, x_buff, y_buff) return ``` -## Dev experience - -If using VScode, add the following to your local settings `.vscode/settings.json` : - -```json -{ - "python.analysis.extraPaths": [ - "./crates/lean_compiler" - ], -} -``` - -(you will get better linting for the zkDSL files starting with `from snark_lib import *`, since it will expose zkDSL special functions from `crates/lean_compiler/snark_lib.py`). \ No newline at end of file diff --git a/crates/lean_prover/src/lib.rs b/crates/lean_prover/src/lib.rs index a9ee93df..c112c5a7 100644 --- a/crates/lean_prover/src/lib.rs +++ b/crates/lean_prover/src/lib.rs @@ -67,10 +67,6 @@ pub enum ProverError { Runner(RunnerError), UnknownMessage, MultipleMessages, - InvalidPublicInputSize { - expected: usize, - actual: usize, - }, InvalidChildProof(ProofError), LimitExceeded { what: &'static str, @@ -104,9 +100,6 @@ impl Display for ProverError { Self::Runner(e) => write!(f, "{}", e), Self::UnknownMessage => write!(f, "Unknown message, not part of the type2"), Self::MultipleMessages => write!(f, "Multiple common messages in the type2"), - Self::InvalidPublicInputSize { expected, actual } => { - write!(f, "Invalid public input size: expected {}, actual {}", expected, actual) - } Self::InvalidChildProof(e) => write!(f, "Invalid child proof: {}", e), Self::LimitExceeded { what, actual, max } => { write!(f, "Too many {}: {} (max {})", what, actual, max) diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index 1f985908..fb8fe1e3 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -19,7 +19,7 @@ pub struct ExecutionProof { pub fn prove_execution( bytecode: &Bytecode, - public_input: &[F], + public_input: &[F; PUBLIC_INPUT_LEN], witness: &ExecutionWitness, whir_config: &WhirConfigBuilder, vm_profiler: bool, @@ -27,15 +27,8 @@ pub fn prove_execution( check_rate(whir_config.starting_log_inv_rate) .map_err(|err| panic!("{err}")) .unwrap(); - if public_input.len() != PUBLIC_INPUT_LEN { - return Err(ProverError::InvalidPublicInputSize { - expected: PUBLIC_INPUT_LEN, - actual: public_input.len(), - }); - } let ExecutionTrace { traces, - public_memory_size, mut memory, // padded with zeros to next power of two metadata, } = info_span!("Witness generation").in_scope(|| -> Result<_, ProverError> { @@ -232,8 +225,8 @@ pub fn prove_execution( committed_statements.get_mut(table).unwrap().push(claim); } - let public_memory_random_point = MultilinearPoint(prover_state.sample_vec(log2_strict_usize(public_memory_size))); - let public_memory_eval = (&memory[..public_memory_size]).evaluate(&public_memory_random_point); + let public_memory_random_point = MultilinearPoint(prover_state.sample_vec(log2_strict_usize(PUBLIC_INPUT_LEN))); + let public_memory_eval = (&memory[..PUBLIC_INPUT_LEN]).evaluate(&public_memory_random_point); let previous_statements = vec![ SparseStatement::new( diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index 91b3f76b..767a2756 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -97,7 +97,7 @@ fn all_precompiles_flags(loop_iters: usize) -> CompilationFlags { } } -fn all_precompiles_witness() -> (Vec, ExecutionWitness) { +fn all_precompiles_witness() -> ([F; PUBLIC_INPUT_LEN], ExecutionWitness) { let mut rng = StdRng::seed_from_u64(0); let mut scratch = F::zero_vec(8192); @@ -194,7 +194,7 @@ fn all_precompiles_witness() -> (Vec, ExecutionWitness) { .fold(EF::ONE, |acc, x| acc * x); scratch[1300..][..DIMENSION].copy_from_slice(poly_eq_ee_result.as_basis_coefficients_slice()); - let mut public_input = vec![F::ZERO; PUBLIC_INPUT_LEN]; + let mut public_input = [F::ZERO; PUBLIC_INPUT_LEN]; public_input[..4].copy_from_slice(&hardcoded_prefix); let mut hints = std::collections::HashMap::new(); @@ -326,7 +326,7 @@ def fibonacci_const(a, b, n: Const): ); } -fn test_zk_vm_helper(program_str: &str, public_input: &[F]) { +fn test_zk_vm_helper(program_str: &str, public_input: &[F; PUBLIC_INPUT_LEN]) { test_zk_vm_helper_with_witness( program_str, public_input, @@ -337,7 +337,7 @@ fn test_zk_vm_helper(program_str: &str, public_input: &[F]) { fn test_zk_vm_helper_with_witness( program_str: &str, - public_input: &[F], + public_input: &[F; PUBLIC_INPUT_LEN], witness: ExecutionWitness, flags: CompilationFlags, ) { diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index 45fc3057..5bee4a97 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -6,7 +6,6 @@ use utils::{ToUsize, get_poseidon_16_of_zero, transposed_par_iter_mut}; #[derive(Debug)] pub struct ExecutionTrace { pub traces: BTreeMap, - pub public_memory_size: usize, pub memory: Vec, // of length a multiple of public_memory_size pub metadata: ExecutionMetadata, } @@ -171,7 +170,6 @@ pub fn get_execution_trace( ExecutionTrace { traces, - public_memory_size: execution_result.public_memory_size, memory: memory_padded, metadata: execution_result.metadata, } diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index c1886acd..173d2a8d 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -14,7 +14,7 @@ pub struct ProofVerificationDetails { pub fn verify_execution( bytecode: &Bytecode, - public_input: &[F], + public_input: &[F; PUBLIC_INPUT_LEN], proof: Proof, ) -> Result<(ProofVerificationDetails, RawProof), ProofError> { if bytecode.log_size() > MAX_BYTECODE_LOG_SIZE { @@ -23,9 +23,6 @@ pub fn verify_execution( max_log_size: MAX_BYTECODE_LOG_SIZE, }); } - if public_input.len() != PUBLIC_INPUT_LEN { - return Err(ProofError::InvalidProof); - } let mut verifier_state = VerifierState::::new(proof, get_poseidon16().clone(), fiat_shamir_domain_sep(bytecode))?; verifier_state.observe_scalars(public_input); @@ -58,8 +55,6 @@ pub fn verify_execution( return Err(ProofError::InvalidProof); } - let public_memory = padd_with_zero_to_next_power_of_two(public_input); - if !(MIN_LOG_MEMORY_SIZE..=MAX_LOG_MEMORY_SIZE).contains(&log_memory) { return Err(ProofError::InvalidProof); } @@ -175,9 +170,8 @@ pub fn verify_execution( return Err(ProofError::InvalidProof); } - let public_memory_random_point = - MultilinearPoint(verifier_state.sample_vec(log2_strict_usize(public_memory.len()))); - let public_memory_eval = public_memory.evaluate(&public_memory_random_point); + let public_memory_random_point = MultilinearPoint(verifier_state.sample_vec(log2_strict_usize(public_input.len()))); + let public_memory_eval = public_input.evaluate(&public_memory_random_point); let previous_statements = vec![ SparseStatement::new( diff --git a/crates/lean_vm/src/diagnostics/exec_result.rs b/crates/lean_vm/src/diagnostics/exec_result.rs index dcb1ae0c..2024fa08 100644 --- a/crates/lean_vm/src/diagnostics/exec_result.rs +++ b/crates/lean_vm/src/diagnostics/exec_result.rs @@ -72,7 +72,6 @@ impl ExecutionMetadata { #[derive(Debug)] pub struct ExecutionResult { pub runtime_memory_size: usize, - pub public_memory_size: usize, pub memory: Memory, pub pcs: Vec, pub fps: Vec, diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index 4cc5d721..50ff5429 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -1,6 +1,6 @@ //! VM execution runner -use crate::core::{DIMENSION, F}; +use crate::core::{DIMENSION, F, PUBLIC_INPUT_LEN}; use crate::diagnostics::{ExecutionMetadata, ExecutionResult, RunnerError}; use crate::execution::memory::MemoryAccess; use crate::execution::{ExecutionHistory, Memory}; @@ -10,7 +10,7 @@ use crate::isa::instruction::{InstructionContext, InstructionCounts}; use crate::{ALL_TABLES, CodeAddress, HintExecutionContext, MemOrConstant, N_TABLES, STARTING_PC, Table, TableTrace}; use backend::*; use std::collections::{BTreeMap, BTreeSet, HashMap}; -use utils::{ToUsize, padd_with_zero_to_next_power_of_two}; +use utils::ToUsize; use super::memory::SegmentMemory; @@ -27,7 +27,7 @@ pub struct ExecutionWitness { pub fn try_execute_bytecode( bytecode: &Bytecode, - public_input: &[F], + public_input: &[F; PUBLIC_INPUT_LEN], witness: &ExecutionWitness, profiling: bool, ) -> Result { @@ -58,7 +58,7 @@ pub fn try_execute_bytecode( pub fn execute_bytecode( bytecode: &Bytecode, - public_input: &[F], + public_input: &[F; PUBLIC_INPUT_LEN], witness: &ExecutionWitness, profiling: bool, ) -> ExecutionResult { @@ -239,7 +239,7 @@ fn resolve_deref_hints(memory: &mut Memory, pending: &[(usize, usize)]) -> Resul #[allow(clippy::too_many_arguments)] fn execute_bytecode_helper( bytecode: &Bytecode, - public_input: &[F], + public_input: &[F; PUBLIC_INPUT_LEN], witness: &ExecutionWitness, std_out: &mut String, instruction_history: &mut ExecutionHistory, @@ -250,10 +250,9 @@ fn execute_bytecode_helper( .iter() .map(|(name, entries)| (name.clone(), NamedHintCursor::new(entries))) .collect(); - let public_memory = padd_with_zero_to_next_power_of_two(public_input); - let public_memory_size = public_memory.len(); + let public_memory = public_input.to_vec(); let mut memory = Memory::new(public_memory); - let mut fp = public_memory_size + witness.preamble_memory_len; + let mut fp = PUBLIC_INPUT_LEN + witness.preamble_memory_len; fp = fp.next_multiple_of(DIMENSION); let initial_ap = fp + bytecode.starting_frame_memory; let mut pc = STARTING_PC; @@ -327,7 +326,7 @@ fn execute_bytecode_helper( } else { None }; - let runtime_memory_size = memory.0.len() - public_memory_size - witness.preamble_memory_len; + let runtime_memory_size = memory.0.len() - PUBLIC_INPUT_LEN - witness.preamble_memory_len; let used_memory_cells = memory.0.par_iter().filter(|&&x| x.is_some()).count(); let metadata = ExecutionMetadata { cycles: trace.pcs.len(), @@ -335,7 +334,7 @@ fn execute_bytecode_helper( n_poseidons: trace.tables[&Table::poseidon16()].columns[0].len(), n_extension_ops: trace.tables[&Table::extension_op()].columns[0].len(), bytecode_size: bytecode.code.len(), - public_input_size: public_input.len(), + public_input_size: PUBLIC_INPUT_LEN, runtime_memory: runtime_memory_size, memory_usage_percent: used_memory_cells as f64 / memory.0.len() as f64 * 100.0, stdout: std::mem::take(std_out), @@ -343,7 +342,6 @@ fn execute_bytecode_helper( }; Ok(ExecutionResult { runtime_memory_size: no_vec_runtime_memory, - public_memory_size, memory, pcs: trace.pcs, fps: trace.fps, diff --git a/crates/rec_aggregation/src/type_1_aggregation.rs b/crates/rec_aggregation/src/type_1_aggregation.rs index da3455f7..2f9dd09d 100644 --- a/crates/rec_aggregation/src/type_1_aggregation.rs +++ b/crates/rec_aggregation/src/type_1_aggregation.rs @@ -288,7 +288,7 @@ pub(crate) fn aggregate_type_1_with_min_padding( &reduced_claims.final_claim_flat(), bytecode, ); - let public_input = poseidon_compress_slice(&pub_input_data).to_vec(); + let public_input = poseidon_compress_slice(&pub_input_data); let mut claimed: HashSet = HashSet::new(); let mut dup_pub_keys: Vec = Vec::new(); diff --git a/crates/rec_aggregation/src/type_2_aggregation.rs b/crates/rec_aggregation/src/type_2_aggregation.rs index bf5706d8..6e93f299 100644 --- a/crates/rec_aggregation/src/type_2_aggregation.rs +++ b/crates/rec_aggregation/src/type_2_aggregation.rs @@ -112,7 +112,7 @@ pub fn merge_many_type_1( let digests: Vec<[F; DIGEST_LEN]> = verified_children.iter().map(|v| v.input_data_hash).collect(); let pub_input_data = build_type2_input_data(&digests, &reduced_claims.final_claim_flat()); - let public_input_digest = poseidon_compress_slice(&pub_input_data).to_vec(); + let public_input_digest = poseidon_compress_slice(&pub_input_data); let bytecode_value_hint_blobs: Vec> = verified_children .iter() diff --git a/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py b/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py index 9e92af76..d2395529 100644 --- a/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py +++ b/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py @@ -177,8 +177,8 @@ def fs_receive_ef_inlined(fs, n): def fs_receive_ef_by_log_dynamic(fs, log_n, min_value: Const, max_value: Const): debug_assert(log_n < max_value) debug_assert(min_value <= log_n) - new_fs: Imu - ef_ptr: Imu + new_fs: Imm + ef_ptr: Imm new_fs, ef_ptr = match_range(log_n, range(min_value, max_value), lambda ln: fs_receive_ef(fs, 2**ln)) return new_fs, ef_ptr diff --git a/crates/rec_aggregation/zkdsl_implem/hashing.py b/crates/rec_aggregation/zkdsl_implem/hashing.py index 5146ee82..00b8a700 100644 --- a/crates/rec_aggregation/zkdsl_implem/hashing.py +++ b/crates/rec_aggregation/zkdsl_implem/hashing.py @@ -111,8 +111,8 @@ def euclidian_div_runtime(a, b): # Requires: # 1 <= b < 2^14 # floor(a / b) < 2^16 (so that q*b + r stays well below p) - q: Imu - r: Imu + q: Imm + r: Imm hint_div_floor(a, b, q, r) assert r < b assert q < 2 ** 16 diff --git a/crates/rec_aggregation/zkdsl_implem/recursion.py b/crates/rec_aggregation/zkdsl_implem/recursion.py index 3b949c44..af081d4c 100644 --- a/crates/rec_aggregation/zkdsl_implem/recursion.py +++ b/crates/rec_aggregation/zkdsl_implem/recursion.py @@ -610,7 +610,8 @@ def fingerprint_n(domsep, data_evals, n, logup_alphas_eq_poly): return res -def verify_gkr_quotient(fs: Mut, n_vars): +def verify_gkr_quotient(prev_fs, n_vars): + fs: Mut = prev_fs fs, nums = fs_receive_ef_inlined(fs, LOGUP_GKR_N_COEFFS_SENT) fs, denoms = fs_receive_ef_inlined(fs, LOGUP_GKR_N_COEFFS_SENT) @@ -653,13 +654,16 @@ def verify_gkr_quotient(fs: Mut, n_vars): ) -def verify_gkr_quotient_step(fs: Mut, n_vars, point, claim_num, claim_den): +def verify_gkr_quotient_step(prev_fs, n_vars, point, claim_num, claim_den): + fs: Mut = prev_fs fs = fs_duplex(fs) fs, alpha = fs_sample_ef(fs) alpha_mul_claim_den = mul_extension_ret(alpha, claim_den) num_plus_alpha_mul_claim_den = add_extension_ret(claim_num, alpha_mul_claim_den) postponed_point = Array((n_vars + 1) * DIM) - fs, postponed_value = sumcheck_verify_reversed_helper(fs, n_vars, num_plus_alpha_mul_claim_den, 3, postponed_point) + fs, postponed_value = sumcheck_verify_reversed_helper( + fs, n_vars, num_plus_alpha_mul_claim_den, 3, postponed_point + ) fs, inner_evals = fs_receive_ef_inlined(fs, 4) a_num = inner_evals b_num = inner_evals + DIM @@ -705,7 +709,7 @@ def compute_total_gkr_n_vars(log_memory, log_bytecode_padded, tables_heights): def evaluate_air_constraints(table_index, inner_evals, air_alpha_powers, logup_alphas_eq_poly): - res: Imu + res: Imm debug_assert(table_index < N_TABLES) match table_index: case 0: diff --git a/crates/rec_aggregation/zkdsl_implem/utils.py b/crates/rec_aggregation/zkdsl_implem/utils.py index b95eedfa..d4b2e983 100644 --- a/crates/rec_aggregation/zkdsl_implem/utils.py +++ b/crates/rec_aggregation/zkdsl_implem/utils.py @@ -236,7 +236,7 @@ def mle_of_01234567_etc(point, n): @inline def checked_less_than(a, b): - res: Imu + res: Imm hint_less_than(a, b, res) assert res * (1 - res) == 0 if res == 1: @@ -249,7 +249,7 @@ def checked_less_than(a, b): @inline def maximum(a, b): is_a_less_than_b = checked_less_than(a, b) - res: Imu + res: Imm if is_a_less_than_b == 1: res = b else: @@ -809,7 +809,7 @@ def _verify_log2_large(n, log2: Const): def log2_ceil_runtime(n): # requires: 2 < n <= 2^30 - log2: Imu + log2: Imm hint_log2_ceil(n, log2) assert log2 < 31 if two_exp(log2) != n: diff --git a/crates/rec_aggregation/zkdsl_implem/whir.py b/crates/rec_aggregation/zkdsl_implem/whir.py index 3124f253..d14a10ef 100644 --- a/crates/rec_aggregation/zkdsl_implem/whir.py +++ b/crates/rec_aggregation/zkdsl_implem/whir.py @@ -16,14 +16,17 @@ def whir_open( - fs: Mut, + prev_fs, n_vars, initial_log_inv_rate, - root: Mut, + prev_root, ood_points_commit, combination_randomness_powers_0, - claimed_sum: Mut, + prev_claimed_sum, ): + fs: Mut = prev_fs + root: Mut = prev_root + claimed_sum: Mut = prev_claimed_sum n_rounds, n_final_vars, num_queries, num_oods, query_grinding_bits, folding_grinding = get_whir_params( n_vars, initial_log_inv_rate ) @@ -39,7 +42,7 @@ def whir_open( domain_sz: Mut = n_vars + initial_log_inv_rate for r in range(0, n_rounds): - is_first_round: Imu + is_first_round: Imm if r == 0: is_first_round = 1 else: @@ -175,13 +178,15 @@ def whir_open( return fs, folding_randomness_global, s, final_value, end_sum -def sumcheck_verify(fs: Mut, n_steps, claimed_sum, degree: Const): +def sumcheck_verify(fs, n_steps, claimed_sum, degree: Const): challenges = Array(n_steps * DIM) - fs, new_claimed_sum = sumcheck_verify_helper(fs, n_steps, claimed_sum, degree, challenges) - return fs, challenges, new_claimed_sum + new_fs, new_claimed_sum = sumcheck_verify_helper(fs, n_steps, claimed_sum, degree, challenges) + return new_fs, challenges, new_claimed_sum -def sumcheck_verify_helper(fs: Mut, n_steps, claimed_sum: Mut, degree: Const, challenges): +def sumcheck_verify_helper(prev_fs, n_steps, prev_claimed_sum, degree: Const, challenges): + fs: Mut = prev_fs + claimed_sum: Mut = prev_claimed_sum for sc_round in range(0, n_steps): fs, poly = fs_receive_ef_inlined(fs, degree + 1) polynomial_sum_at_0_and_1(poly, degree, claimed_sum) @@ -192,10 +197,10 @@ def sumcheck_verify_helper(fs: Mut, n_steps, claimed_sum: Mut, degree: Const, ch return fs, claimed_sum -def sumcheck_verify_reversed(fs: Mut, n_steps, claimed_sum: Mut, degree: Const): +def sumcheck_verify_reversed(fs, n_steps, claimed_sum, degree: Const): challenges = Array(n_steps * DIM) - fs, new_claimed_sum = sumcheck_verify_reversed_helper(fs, n_steps, claimed_sum, degree, challenges) - return fs, challenges, new_claimed_sum + new_fs, final_claimed_sum = sumcheck_verify_reversed_helper(fs, n_steps, claimed_sum, degree, challenges) + return new_fs, challenges, final_claimed_sum def sumcheck_verify_reversed_helper(fs, n_steps, claimed_sum, degree: Const, challenges): @@ -208,7 +213,9 @@ def sumcheck_verify_reversed_helper(fs, n_steps, claimed_sum, degree: Const, cha return new_fd, final_sum -def sumcheck_verify_reversed_helper_const(fs: Mut, n_steps: Const, claimed_sum: Mut, degree: Const, challenges): +def sumcheck_verify_reversed_helper_const(prev_fs, n_steps: Const, prev_claimed_sum, degree: Const, challenges): + fs: Mut = prev_fs + claimed_sum: Mut = prev_claimed_sum for sc_round in unroll(0, n_steps): fs, poly = fs_receive_ef_inlined(fs, degree + 1) polynomial_sum_at_0_and_1(poly, degree, claimed_sum) @@ -219,7 +226,9 @@ def sumcheck_verify_reversed_helper_const(fs: Mut, n_steps: Const, claimed_sum: return fs, claimed_sum -def sumcheck_verify_with_grinding(fs: Mut, n_steps, claimed_sum: Mut, degree: Const, folding_grinding_bits): +def sumcheck_verify_with_grinding(prev_fs, n_steps, prev_claimed_sum, degree: Const, folding_grinding_bits): + fs: Mut = prev_fs + claimed_sum: Mut = prev_claimed_sum challenges = Array(n_steps * DIM) for sc_round in range(0, n_steps): fs, poly = fs_receive_ef_inlined(fs, degree + 1) @@ -285,7 +294,7 @@ def decompose_and_verify_merkle_batch_const( def sample_stir_indexes_and_fold( - fs: Mut, + prev_fs, num_queries, merkle_leaves_in_basefield, folding_factor, @@ -295,6 +304,7 @@ def sample_stir_indexes_and_fold( folding_randomness, query_grinding_bits, ): + fs: Mut = prev_fs folded_domain_size = domain_size - folding_factor fs = fs_grinding(fs, query_grinding_bits) @@ -303,7 +313,7 @@ def sample_stir_indexes_and_fold( merkle_leaves = Array(num_queries) circle_values = Array(num_queries) - n_chunks_per_answer: Imu + n_chunks_per_answer: Imm # the number of chunk of 8 field elements per merkle leaf opened if merkle_leaves_in_basefield == 1: n_chunks_per_answer = two_pow_folding_factor @@ -335,7 +345,7 @@ def sample_stir_indexes_and_fold( def whir_round( - fs: Mut, + prev_fs, prev_root, folding_factor, two_pow_folding_factor, @@ -347,6 +357,7 @@ def whir_round( num_ood, folding_grinding_bits, ): + fs: Mut = prev_fs fs, folding_randomness, new_claimed_sum_a = sumcheck_verify_with_grinding( fs, folding_factor, claimed_sum, 2, folding_grinding_bits ) @@ -398,21 +409,24 @@ def polynomial_sum_at_0_and_1(coeffs, degree, dst): return -def parse_commitment(fs: Mut, num_ood): - root: Imu - ood_points: Imu - ood_evals: Imu +def parse_commitment(fs, num_ood): + root: Imm + ood_points: Imm + ood_evals: Imm debug_assert(num_ood < 5) debug_assert(num_ood != 0) - fs, root, ood_points, ood_evals = match_range(num_ood, range(1, 5), lambda n: parse_whir_commitment_const(fs, n)) - return fs, root, ood_points, ood_evals + new_fs, root, ood_points, ood_evals = match_range( + num_ood, range(1, 5), lambda n: parse_whir_commitment_const(fs, n) + ) + return new_fs, root, ood_points, ood_evals -def parse_whir_commitment_const(fs: Mut, num_ood: Const): - fs, root = fs_receive_chunks(fs, 1) - fs, ood_points = fs_sample_many_ef(fs, num_ood) - fs, ood_evals = fs_receive_ef_inlined(fs, num_ood) - return fs, root, ood_points, ood_evals +def parse_whir_commitment_const(fs, num_ood: Const): + new_fs: Mut + new_fs, root = fs_receive_chunks(fs, 1) + new_fs, ood_points = fs_sample_many_ef(new_fs, num_ood) + new_fs, ood_evals = fs_receive_ef_inlined(new_fs, num_ood) + return new_fs, root, ood_points, ood_evals @inline @@ -427,15 +441,15 @@ def get_whir_params(n_vars, log_inv_rate): debug_assert(MIN_WHIR_LOG_INV_RATE <= log_inv_rate) debug_assert(log_inv_rate <= MAX_WHIR_LOG_INV_RATE) - num_queries: Imu + num_queries: Imm num_queries = get_num_queries(log_inv_rate, n_vars) - query_grinding_bits: Imu + query_grinding_bits: Imm query_grinding_bits = get_query_grinding_bits(log_inv_rate, n_vars) num_oods = get_num_oods(log_inv_rate, n_vars) - folding_grinding: Imu + folding_grinding: Imm folding_grinding = get_folding_grinding(log_inv_rate, n_vars) return n_rounds, final_vars, num_queries, num_oods, query_grinding_bits, folding_grinding