diff --git a/README.md b/README.md
index d8a40144..cce64ba4 100644
--- a/README.md
+++ b/README.md
@@ -8,6 +8,7 @@ Minimal hash-based zkVM, targeting recursion and aggregation of hash-based signa
+
diff --git a/crates/lean_compiler/snark_lib.py b/crates/lean_compiler/snark_lib.py
index f11c8138..bec59983 100644
--- a/crates/lean_compiler/snark_lib.py
+++ b/crates/lean_compiler/snark_lib.py
@@ -6,7 +6,7 @@
# Type annotations
Mut = Any
Const = Any
-Imu = Any
+Imm = Any
# @inline decorator (does nothing in Python execution)
diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs
index 2598ed3d..18254526 100644
--- a/crates/lean_compiler/src/a_simplify_lang/mod.rs
+++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs
@@ -332,21 +332,14 @@ pub fn simplify_program(mut program: Program) -> Result {
let mut array_manager = ArrayManager::default();
let mut mut_tracker = MutableVarTracker::default();
- // Register mutable arguments and capture their initial versioned names
- // BEFORE simplifying the body
+ // All arguments are immutable; record them as assigned to detect illegal reassignment.
let arguments: Vec = func
.arguments
.iter()
.map(|arg| {
assert!(!arg.is_const);
- if arg.is_mutable {
- mut_tracker.register_mutable(&arg.name);
- // Capture the initial versioned name (version 0)
- mut_tracker.current_name(&arg.name)
- } else {
- mut_tracker.assigned.insert(arg.name.clone());
- arg.name.clone()
- }
+ mut_tracker.assigned.insert(arg.name.clone());
+ arg.name.clone()
})
.collect();
@@ -398,9 +391,6 @@ fn compile_time_transform_in_program(
.collect();
for func in inlined_functions.values() {
- if func.has_mutable_arguments() {
- return Err("Inlined functions with mutable arguments are not supported yet".to_string());
- }
if func.has_const_arguments() {
return Err(format!(
"Inlined function should not have \"Const\" arguments (function \"{}\")",
diff --git a/crates/lean_compiler/src/grammar.pest b/crates/lean_compiler/src/grammar.pest
index 16b041a6..6640d2de 100644
--- a/crates/lean_compiler/src/grammar.pest
+++ b/crates/lean_compiler/src/grammar.pest
@@ -49,9 +49,9 @@ return_statement = { "return" ~ (("(" ~ tuple_expression ~ ")") | tuple_expressi
mut_keyword = @{ "mut" ~ !(ASCII_ALPHANUMERIC | "_") }
mut_annotation = { ":" ~ "Mut" }
-im_annotation = { ":" ~ "Imu" }
+im_annotation = { ":" ~ "Imm" }
-// Forward declaration: x: Imu or x: Mut (not followed by =)
+// Forward declaration: x: Imm or x: Mut (not followed by =)
forward_declaration = { identifier ~ (im_annotation | mut_annotation) ~ !("=") }
// General assignment: LHS is optional, RHS is any expression
diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs
index 94871b47..91fbd0e8 100644
--- a/crates/lean_compiler/src/lang.rs
+++ b/crates/lean_compiler/src/lang.rs
@@ -32,7 +32,6 @@ impl Program {
pub struct FunctionArg {
pub name: Var,
pub is_const: bool,
- pub is_mutable: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
@@ -48,9 +47,6 @@ impl Function {
pub fn has_const_arguments(&self) -> bool {
self.arguments.iter().any(|arg| arg.is_const)
}
- pub fn has_mutable_arguments(&self) -> bool {
- self.arguments.iter().any(|arg| arg.is_mutable)
- }
}
pub type Var = String;
@@ -663,7 +659,7 @@ impl Line {
if *is_mutable {
format!("{var}: Mut")
} else {
- format!("{var}: Imu")
+ format!("{var}: Imm")
}
}
Self::Statement { targets, value, .. } => {
@@ -899,8 +895,6 @@ impl Display for Function {
.map(|arg| {
if arg.is_const {
format!("const {}", arg.name)
- } else if arg.is_mutable {
- format!("mut {}", arg.name)
} else {
arg.name.to_string()
}
diff --git a/crates/lean_compiler/src/lib.rs b/crates/lean_compiler/src/lib.rs
index 590d58de..d99d1c2e 100644
--- a/crates/lean_compiler/src/lib.rs
+++ b/crates/lean_compiler/src/lib.rs
@@ -150,7 +150,11 @@ pub fn compile_program(input: &ProgramSource) -> Bytecode {
try_compile_program(input).unwrap()
}
-pub fn try_compile_and_run(input: &ProgramSource, public_input: &[F], profiler: bool) -> Result {
+pub fn try_compile_and_run(
+ input: &ProgramSource,
+ public_input: &[F; PUBLIC_INPUT_LEN],
+ profiler: bool,
+) -> Result {
let bytecode = try_compile_program(input)?;
let witness = ExecutionWitness::default();
let result = try_execute_bytecode(&bytecode, public_input, &witness, profiler)?;
@@ -158,7 +162,7 @@ pub fn try_compile_and_run(input: &ProgramSource, public_input: &[F], profiler:
Ok(result.metadata.display())
}
-pub fn compile_and_run(input: &ProgramSource, public_input: &[F], profiler: bool) {
+pub fn compile_and_run(input: &ProgramSource, public_input: &[F; PUBLIC_INPUT_LEN], profiler: bool) {
let summary = try_compile_and_run(input, public_input, profiler).unwrap();
println!("{summary}");
}
diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs
index 374f75d4..0334f5dc 100644
--- a/crates/lean_compiler/src/parser/parsers/function.rs
+++ b/crates/lean_compiler/src/parser/parsers/function.rs
@@ -15,11 +15,14 @@ pub const RESERVED_FUNCTION_NAMES: &[&str] = &[
// Built-in functions
"print",
"Array",
+ "hint_witness",
// Compile-time only functions
"len",
"log2_ceil",
"next_multiple_of",
"saturating_sub",
+ "div_ceil",
+ "div_floor",
"range",
"parallel_range",
"match_range",
@@ -155,22 +158,24 @@ impl Parse for ParameterParser {
.into());
}
- // Check for optional type annotation (: Const or : Mut)
- let (is_const, is_mutable) = if let Some(annotation) = inner.next() {
+ // Check for optional type annotation (: Const). ': Mut' parameters are forbidden.
+ let is_const = if let Some(annotation) = inner.next() {
match annotation.as_str().trim() {
- ": Const" => (true, false),
- ": Mut" => (false, true),
+ ": Const" => true,
+ ": Mut" => {
+ return Err(SemanticError::new(format!(
+ "Parameter '{name}' cannot be declared ': Mut'. Mutable parameters are not allowed; \
+ introduce a local '{name}_mut: Mut = {name}' instead."
+ ))
+ .into());
+ }
other => return Err(SemanticError::new(format!("Invalid parameter annotation: {other}")).into()),
}
} else {
- (false, false)
+ false
};
- Ok(FunctionArg {
- name,
- is_const,
- is_mutable,
- })
+ Ok(FunctionArg { name, is_const })
}
}
diff --git a/crates/lean_compiler/src/parser/parsers/statement.rs b/crates/lean_compiler/src/parser/parsers/statement.rs
index edef389d..f0835697 100644
--- a/crates/lean_compiler/src/parser/parsers/statement.rs
+++ b/crates/lean_compiler/src/parser/parsers/statement.rs
@@ -287,7 +287,7 @@ impl Parse for AssertParser {
}
}
-/// Parser for forward declarations: `x: Imu` or `x: Mut`
+/// Parser for forward declarations: `x: Imm` or `x: Mut`
pub struct ForwardDeclarationParser;
impl Parse for ForwardDeclarationParser {
@@ -297,7 +297,7 @@ impl Parse for ForwardDeclarationParser {
// Parse variable name
let var = next_inner_pair(&mut inner, "variable name")?.as_str().to_string();
- // Check for : Mut or : Imu annotation
+ // Check for : Mut or : Imm annotation
let annotation = next_inner_pair(&mut inner, "type annotation")?;
let is_mutable = annotation.as_rule() == Rule::mut_annotation;
diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs
index 04b3f70e..cc8ac473 100644
--- a/crates/lean_compiler/tests/test_compiler.rs
+++ b/crates/lean_compiler/tests/test_compiler.rs
@@ -4,27 +4,6 @@ use backend::{BasedVectorSpace, PrimeCharacteristicRing};
use lean_compiler::*;
use lean_vm::*;
use rand::{RngExt, SeedableRng, rngs::StdRng};
-use utils::poseidon16_compress;
-
-#[test]
-fn test_poseidon() {
- let program = r#"
-def main():
- a = 0
- b = a + 8
- c = Array(8)
- poseidon16_compress(a, b, c)
-
- for i in range(0, 8):
- cc = c[i]
- print(cc)
- return
- "#;
- let public_input: [F; 16] = (0..16).map(F::new).collect::>().try_into().unwrap();
- compile_and_run(&ProgramSource::Raw(program.to_string()), &public_input, false);
-
- let _ = dbg!(poseidon16_compress(public_input));
-}
#[test]
fn test_div_extension_field() {
@@ -32,13 +11,16 @@ fn test_div_extension_field() {
DIM = 5
def main():
- n = 0
- d = n + DIM
- q = n + 2 * DIM
+ nd = Array(2 * DIM)
+ hint_witness("nd", nd)
+ n = nd
+ d = nd + DIM
+ expected_q = Array(DIM)
+ hint_witness("q", expected_q)
computed_q_1 = div_ext_1(n, d)
computed_q_2 = div_ext_2(n, d)
- assert_eq_ext(computed_q_2, q)
- assert_eq_ext(computed_q_1, q)
+ assert_eq_ext(computed_q_2, expected_q)
+ assert_eq_ext(computed_q_1, expected_q)
return
def assert_eq_ext(x, y):
@@ -61,12 +43,19 @@ def div_ext_2(n, d):
let n: EF = rng.random();
let d: EF = rng.random();
let q = n / d;
- let mut public_input = vec![];
- public_input.extend(n.as_basis_coefficients_slice());
- public_input.extend(d.as_basis_coefficients_slice());
- public_input.extend(q.as_basis_coefficients_slice());
- public_input.resize(16, F::ZERO);
- compile_and_run(&ProgramSource::Raw(program.to_string()), &public_input, false);
+ let mut nd_buf: Vec = Vec::new();
+ nd_buf.extend(n.as_basis_coefficients_slice());
+ nd_buf.extend(d.as_basis_coefficients_slice());
+ let q_buf: Vec = q.as_basis_coefficients_slice().to_vec();
+ let mut hints = std::collections::HashMap::new();
+ hints.insert("nd".to_string(), vec![nd_buf]);
+ hints.insert("q".to_string(), vec![q_buf]);
+ let witness = ExecutionWitness {
+ hints,
+ ..ExecutionWitness::default()
+ };
+ let bytecode = compile_program(&ProgramSource::Raw(program.to_string()));
+ try_execute_bytecode(&bytecode, &[F::ZERO; PUBLIC_INPUT_LEN], &witness, false).unwrap();
}
fn test_data_dir() -> String {
@@ -134,7 +123,7 @@ fn test_all_programs() {
Ok(b) => b,
Err(err) => panic!("Program {} failed to compile: {:?}", path, err),
};
- if let Err(err) = try_execute_bytecode(&bytecode, &[], &witness, false) {
+ if let Err(err) = try_execute_bytecode(&bytecode, &[F::ZERO; PUBLIC_INPUT_LEN], &witness, false) {
panic!("Program {} failed with error: {:?}", path, err);
}
}
@@ -176,7 +165,13 @@ def func(a, b):
return
"#;
let bytecode = compile_program(&ProgramSource::Raw(program.to_string()));
- let n_cycles = execute_bytecode(&bytecode, &[], &ExecutionWitness::default(), false).n_cycles();
+ let n_cycles = execute_bytecode(
+ &bytecode,
+ &[F::ZERO; PUBLIC_INPUT_LEN],
+ &ExecutionWitness::default(),
+ false,
+ )
+ .n_cycles();
assert!(n_cycles < 1100);
}
@@ -205,10 +200,20 @@ def factorial(n):
let compiled_parallel = compile_program(&ProgramSource::Raw(program.replace("loop", "parallel_range")));
let time_sequential = Instant::now();
- let exec_seq = execute_bytecode(&compiled_sequencial, &[], &ExecutionWitness::default(), false);
+ let exec_seq = execute_bytecode(
+ &compiled_sequencial,
+ &[F::ZERO; PUBLIC_INPUT_LEN],
+ &ExecutionWitness::default(),
+ false,
+ );
let duration_sequential = time_sequential.elapsed();
let time_parallel = Instant::now();
- let exec_par = execute_bytecode(&compiled_parallel, &[], &ExecutionWitness::default(), false);
+ let exec_par = execute_bytecode(
+ &compiled_parallel,
+ &[F::ZERO; PUBLIC_INPUT_LEN],
+ &ExecutionWitness::default(),
+ false,
+ );
let duration_parallel = time_parallel.elapsed();
assert_eq!(exec_seq.metadata.stdout, exec_par.metadata.stdout);
@@ -249,7 +254,13 @@ fn test_soundness_suite() {
("soundness_5", &[3, 4, 7, 19, 49, 28, 1, 3], &[(0, 4), (1, 5), (2, 8), (3, 20), (4, 50), (5, 29), (6, 0), (6, 2), (7, 4)]),
];
- let to_input = |v: &[u32]| v.iter().copied().map(F::new).collect::>();
+ let to_input = |v: &[u32]| -> [F; PUBLIC_INPUT_LEN] {
+ let mut out = [F::ZERO; PUBLIC_INPUT_LEN];
+ for (slot, &x) in out.iter_mut().zip(v) {
+ *slot = F::new(x);
+ }
+ out
+ };
for &(name, valid, perturbations) in cases {
let path = format!("{}/{}.py", test_data_dir(), name);
diff --git a/crates/lean_compiler/tests/test_data/error_13.py b/crates/lean_compiler/tests/test_data/error_13.py
index 0450a5aa..2224ac22 100644
--- a/crates/lean_compiler/tests/test_data/error_13.py
+++ b/crates/lean_compiler/tests/test_data/error_13.py
@@ -2,7 +2,7 @@
def main():
- a: Imu
+ a: Imm
a = 0
a = a + 1
if a == 1:
diff --git a/crates/lean_compiler/tests/test_data/error_7.py b/crates/lean_compiler/tests/test_data/error_7.py
index 94c262a6..e4435908 100644
--- a/crates/lean_compiler/tests/test_data/error_7.py
+++ b/crates/lean_compiler/tests/test_data/error_7.py
@@ -1,7 +1,7 @@
from snark_lib import *
-# Error: inline functions with parameters: Mut are not supported
+# Error: function parameters cannot be declared ': Mut'
def main():
return
diff --git a/crates/lean_compiler/tests/test_data/program_100.py b/crates/lean_compiler/tests/test_data/program_100.py
index 5eeac3b6..a81d0480 100644
--- a/crates/lean_compiler/tests/test_data/program_100.py
+++ b/crates/lean_compiler/tests/test_data/program_100.py
@@ -2,8 +2,8 @@
def main():
- x: Imu
- y: Imu
+ x: Imm
+ y: Imm
cond = 1
if cond == 1:
diff --git a/crates/lean_compiler/tests/test_data/program_109.py b/crates/lean_compiler/tests/test_data/program_109.py
index b04f2608..77e86702 100644
--- a/crates/lean_compiler/tests/test_data/program_109.py
+++ b/crates/lean_compiler/tests/test_data/program_109.py
@@ -9,10 +9,10 @@ def main():
def test_func(a, b):
x = 1
- mut_x_2: Imu
+ mut_x_2: Imm
match a:
case 0:
- mut_x_1: Imu
+ mut_x_1: Imm
mut_x_1 = x + 2
match b:
case 0:
diff --git a/crates/lean_compiler/tests/test_data/program_110.py b/crates/lean_compiler/tests/test_data/program_110.py
index f863f968..df47ee5c 100644
--- a/crates/lean_compiler/tests/test_data/program_110.py
+++ b/crates/lean_compiler/tests/test_data/program_110.py
@@ -79,7 +79,7 @@ def main():
result = complex_compute(3, 4, 5)
assert result == 47
- fwd_val: Imu
+ fwd_val: Imm
cond = 1
if cond == 0:
fwd_val = 100
diff --git a/crates/lean_compiler/tests/test_data/program_111.py b/crates/lean_compiler/tests/test_data/program_111.py
index e497afdb..9634df0f 100644
--- a/crates/lean_compiler/tests/test_data/program_111.py
+++ b/crates/lean_compiler/tests/test_data/program_111.py
@@ -78,13 +78,14 @@ def chain_compute(x, y):
a2, b2, c2 = step_compute(a1, b1)
return a2, b2, c1 + c2
-def nested_mut_params(base: Mut):
+def nested_mut_params(base):
+ acc: Mut = base
for i in unroll(0, 3):
- base = base + i * 2
- return base
+ acc = acc + i * 2
+ return acc
def state_machine_step(current_state, phase):
- result: Imu
+ result: Imm
if phase == 0:
if current_state == 0:
result = 1
diff --git a/crates/lean_compiler/tests/test_data/program_112.py b/crates/lean_compiler/tests/test_data/program_112.py
index 7cd0fe3b..51a9de2e 100644
--- a/crates/lean_compiler/tests/test_data/program_112.py
+++ b/crates/lean_compiler/tests/test_data/program_112.py
@@ -2,7 +2,7 @@
def main():
- result1: Imu
+ result1: Imm
outer_sel = 1
match outer_sel:
case 0:
@@ -18,8 +18,8 @@ def main():
result1 = 456
assert result1 == 456
- counter: Imu
- flag: Imu
+ counter: Imm
+ flag: Imm
phase = 1
if phase == 0:
@@ -40,8 +40,8 @@ def main():
assert counter2 == 15
assert flag2 == 400
- x: Imu
- y: Imu
+ x: Imm
+ y: Imm
init_sel = 0
if init_sel == 0:
@@ -61,7 +61,7 @@ def main():
assert x2 == 220
assert y2 == 20
- outcome: Imu
+ outcome: Imm
selector = 4
match selector:
case 0:
@@ -78,9 +78,9 @@ def main():
outcome = compute_outcome(5, 25)
assert outcome == 84
- p: Imu
- q: Imu
- r: Imu
+ p: Imm
+ q: Imm
+ r: Imm
s1 = 1
if s1 == 1:
diff --git a/crates/lean_compiler/tests/test_data/program_114.py b/crates/lean_compiler/tests/test_data/program_114.py
index 1d0b0b95..c92f7044 100644
--- a/crates/lean_compiler/tests/test_data/program_114.py
+++ b/crates/lean_compiler/tests/test_data/program_114.py
@@ -31,8 +31,8 @@ def main():
result4 = complex_nested_compute(2, 1, 3)
assert result4 == 280
- fwd_x: Imu
- fwd_y: Imu
+ fwd_x: Imm
+ fwd_y: Imm
mode = 2
if mode == 0:
@@ -90,7 +90,7 @@ def sum_array_func(arr, n: Const):
def complex_nested_compute(outer, inner, depth):
- result: Imu
+ result: Imm
if outer == 0:
result = 100
diff --git a/crates/lean_compiler/tests/test_data/program_143.py b/crates/lean_compiler/tests/test_data/program_143.py
index 08be3b5d..1e16c510 100644
--- a/crates/lean_compiler/tests/test_data/program_143.py
+++ b/crates/lean_compiler/tests/test_data/program_143.py
@@ -71,7 +71,7 @@ def main():
assert sum == 10
# Test 13: Inline functions in if condition (comparison)
- result13: Imu
+ result13: Imm
if incr(incr(0)) == 2:
result13 = 100
else:
@@ -79,7 +79,7 @@ def main():
assert result13 == 100
# Test 14: Nested inline calls in both sides of if condition
- result14: Imu
+ result14: Imm
if double(3) == triple(2):
result14 = 1
else:
@@ -88,7 +88,7 @@ def main():
assert result14 == 1
# Test 15: Inline calls inside if/else branches
- result15: Imu
+ result15: Imm
if 1 == 1:
result15 = incr(incr(incr(10)))
else:
@@ -96,7 +96,7 @@ def main():
assert result15 == 13
# Test 16: Multiple nested inline calls in if condition
- result16: Imu
+ result16: Imm
if incr(double(incr(1))) == 5:
# incr(1) = 2, double(2) = 4, incr(4) = 5
result16 = 200
@@ -105,7 +105,7 @@ def main():
assert result16 == 200
# Test 17: Inline call with != comparison
- result17: Imu
+ result17: Imm
if incr(5) != 5:
result17 = 300
else:
@@ -152,7 +152,7 @@ def main():
assert sum23 == 21
# Test 24: Chained else-if with inline conditions
- result24: Imu
+ result24: Imm
x24 = 5
if incr(x24) == 4:
result24 = 1
@@ -194,7 +194,7 @@ def double(x):
@inline
def triple(x):
y: Mut = x
- two: Imu
+ two: Imm
match y - x + 1:
case 0:
assert False
diff --git a/crates/lean_compiler/tests/test_data/program_144.py b/crates/lean_compiler/tests/test_data/program_144.py
index f6b60dff..9ceb190f 100644
--- a/crates/lean_compiler/tests/test_data/program_144.py
+++ b/crates/lean_compiler/tests/test_data/program_144.py
@@ -8,7 +8,7 @@ def main():
# ==========================================================================
# TEST 1: Basic - panic in else branch (the original bug case)
# ==========================================================================
- two: Imu
+ two: Imm
if 1 == 1:
two = 2
else:
@@ -18,7 +18,7 @@ def main():
# ==========================================================================
# TEST 2: panic in then branch
# ==========================================================================
- three: Imu
+ three: Imm
if 1 != 1:
assert False
else:
@@ -28,9 +28,9 @@ def main():
# ==========================================================================
# TEST 3: Multiple mutable variables, panic in else
# ==========================================================================
- a: Imu
- b: Imu
- c: Imu
+ a: Imm
+ b: Imm
+ c: Imm
if 1 == 1:
a = 10
b = 20
@@ -44,7 +44,7 @@ def main():
# ==========================================================================
# TEST 4: Nested if with panic in inner else
# ==========================================================================
- x: Imu
+ x: Imm
if 1 == 1:
if 2 == 2:
x = 42
@@ -79,7 +79,7 @@ def main():
# ==========================================================================
# TEST 7: Chain of else-if with panic in final else
# ==========================================================================
- result: Imu
+ result: Imm
selector = 1
if selector == 0:
result = 100
@@ -94,7 +94,7 @@ def main():
# ==========================================================================
# TEST 8: Match with panic in one arm
# ==========================================================================
- matched: Imu
+ matched: Imm
tag = 1
match tag:
case 0:
@@ -108,7 +108,7 @@ def main():
# ==========================================================================
# TEST 9: Match where only one arm doesn't panic
# ==========================================================================
- only_valid: Imu
+ only_valid: Imm
tag2 = 2
match tag2:
case 0:
@@ -124,7 +124,7 @@ def main():
# ==========================================================================
# TEST 10: Panic in deeply nested structure
# ==========================================================================
- deep: Imu
+ deep: Imm
if 1 == 1:
if 1 == 1:
if 1 == 1:
@@ -151,7 +151,7 @@ def main():
# ==========================================================================
# TEST 12: Forward declared with = None panic in branch
# ==========================================================================
- fwd: Imu
+ fwd: Imm
cond = 1
if cond == 1:
fwd = 777
@@ -162,8 +162,8 @@ def main():
# ==========================================================================
# TEST 13: Both mutable and immutable forward decl with panic
# ==========================================================================
- imm: Imu
- mtbl: Imu
+ imm: Imm
+ mtbl: Imm
flag = 0
if flag == 0:
imm = 100
@@ -206,7 +206,7 @@ def main():
# ==========================================================================
# TEST 17: Nested match with panic
# ==========================================================================
- nested_match: Imu
+ nested_match: Imm
outer = 1
match outer:
case 0:
@@ -223,7 +223,7 @@ def main():
# ==========================================================================
# TEST 18: If inside match with panic
# ==========================================================================
- if_in_match: Imu
+ if_in_match: Imm
m18_sel = 0
match m18_sel:
case 0:
@@ -239,7 +239,7 @@ def main():
# ==========================================================================
# TEST 19: Match inside if with panic
# ==========================================================================
- match_in_if: Imu
+ match_in_if: Imm
cond19 = 1
if cond19 == 1:
tag19 = 1
@@ -255,7 +255,7 @@ def main():
# ==========================================================================
# TEST 20: Panic after partial assignment
# ==========================================================================
- partial: Imu
+ partial: Imm
check = 0
if check == 0:
partial_tmp: Mut = 1
@@ -288,7 +288,7 @@ def main():
# ==========================================================================
# TEST 23: Multiple levels - if/match/if with panics
# ==========================================================================
- multi_level: Imu
+ multi_level: Imm
c1 = 1
if c1 == 1:
s1 = 0
@@ -308,7 +308,7 @@ def main():
# ==========================================================================
# TEST 24: Panic in both outer branches but inner assigns
# ==========================================================================
- inner_assigns: Imu
+ inner_assigns: Imm
outer24 = 0
match outer24:
case 0:
@@ -324,9 +324,9 @@ def main():
# ==========================================================================
# TEST 25: Complex - multiple vars, nested, with panics
# ==========================================================================
- va: Imu
- vb: Imu
- vc: Imu
+ va: Imm
+ vb: Imm
+ vc: Imm
outer25 = 1
if outer25 == 1:
@@ -353,7 +353,7 @@ def main():
# Helper function for TEST 14
def test_early_return(flag):
- result: Imu
+ result: Imm
if flag == 1:
result = 10
else:
@@ -363,8 +363,8 @@ def test_early_return(flag):
# Helper function for TEST 15
def test_multi_return(flag):
- a: Imu
- b: Imu
+ a: Imm
+ b: Imm
if flag == 1:
a = 100
b = 200
@@ -374,9 +374,10 @@ def test_multi_return(flag):
# Helper function for TEST 22
-def func_with_mut_param(x: Mut, flag):
+def func_with_mut_param(x, flag):
+ y: Mut = x
if flag == 1:
- x = x * 10
+ y = y * 10
else:
assert False
- return x
+ return y
diff --git a/crates/lean_compiler/tests/test_data/program_15.py b/crates/lean_compiler/tests/test_data/program_15.py
index 6ea149c2..55433c26 100644
--- a/crates/lean_compiler/tests/test_data/program_15.py
+++ b/crates/lean_compiler/tests/test_data/program_15.py
@@ -1,6 +1,6 @@
from snark_lib import *
-ONE_EF_PTR = 1 # right after the (empty-public-input) zero-padded cell at memory[0]
+ONE_EF_PTR = 8 # right after the 8-cell public input region
def main():
diff --git a/crates/lean_compiler/tests/test_data/program_17.py b/crates/lean_compiler/tests/test_data/program_17.py
index d9274d42..092e2fc2 100644
--- a/crates/lean_compiler/tests/test_data/program_17.py
+++ b/crates/lean_compiler/tests/test_data/program_17.py
@@ -7,7 +7,7 @@ def main():
def func():
- a: Imu
+ a: Imm
if 0 == 0:
a = aux()
return a
diff --git a/crates/lean_compiler/tests/test_data/program_170.py b/crates/lean_compiler/tests/test_data/program_170.py
index 96c00012..62f5d2e4 100644
--- a/crates/lean_compiler/tests/test_data/program_170.py
+++ b/crates/lean_compiler/tests/test_data/program_170.py
@@ -15,7 +15,7 @@ def multi_return(a, b):
def multi_line_params(
a,
- b: Mut,
+ b,
c: Const,
):
return a + b + c
@@ -30,14 +30,14 @@ def main():
x = 5
y = 10
- z: Imu
+ z: Imm
if x + y == 15:
z = 1
else:
z = 0
assert z == 1
- w: Imu
+ w: Imm
if x + y * 2 == 25:
w = 100
else:
diff --git a/crates/lean_compiler/tests/test_data/program_171.py b/crates/lean_compiler/tests/test_data/program_171.py
index bb3c47c6..4af97391 100644
--- a/crates/lean_compiler/tests/test_data/program_171.py
+++ b/crates/lean_compiler/tests/test_data/program_171.py
@@ -1,7 +1,7 @@
from snark_lib import *
# Comprehensive test for inlining with mutable variables in branches
-# Tests: @inline functions, Mut/Imu variables, match, if/else, loops, nesting
+# Tests: @inline functions, Mut/Imm variables, match, if/else, loops, nesting
# ============================================================================
# Simple inline functions with mutable variables
@@ -118,7 +118,7 @@ def inline_with_if(x):
@inline
def inline_with_match(selector):
"""Inline function that itself contains match"""
- out: Imu
+ out: Imm
match selector:
case 0:
out = 1000
@@ -132,7 +132,7 @@ def inline_with_match(selector):
@inline
def inline_with_nested_branch(a, b):
"""Inline with nested if inside match"""
- res: Imu
+ res: Imm
match a:
case 0:
if b == 0:
@@ -321,7 +321,7 @@ def main():
# TEST 1: Basic inline in match arms (different inlined vars per arm)
# This was the original bug - each arm gets its own inlined variable names
# -------------------------------------------------------------------
- res1: Imu
+ res1: Imm
match 0:
case 0:
res1 = count_up(5)
@@ -329,7 +329,7 @@ def main():
res1 = count_up(10)
assert res1 == 5
- res2: Imu
+ res2: Imm
match 1:
case 0:
res2 = count_up(5)
@@ -340,7 +340,7 @@ def main():
# -------------------------------------------------------------------
# TEST 2: Different inline functions in different arms
# -------------------------------------------------------------------
- res3: Imu
+ res3: Imm
match 0:
case 0:
res3 = count_up(3)
@@ -350,7 +350,7 @@ def main():
res3 = double_count(3)
assert res3 == 3
- res4: Imu
+ res4: Imm
match 1:
case 0:
res4 = count_up(3)
@@ -360,7 +360,7 @@ def main():
res4 = double_count(3)
assert res4 == 3 # 0+1+2
- res5: Imu
+ res5: Imm
match 2:
case 0:
res5 = count_up(3)
@@ -392,7 +392,7 @@ def main():
# -------------------------------------------------------------------
# TEST 4: Multiple inlines in same arm
# -------------------------------------------------------------------
- multi: Imu
+ multi: Imm
match 0:
case 0:
a = count_up(3)
@@ -406,7 +406,7 @@ def main():
# -------------------------------------------------------------------
# TEST 5: Nested inline functions in match arms
# -------------------------------------------------------------------
- nested1: Imu
+ nested1: Imm
match 0:
case 0:
nested1 = outer_with_inner(4)
@@ -416,7 +416,7 @@ def main():
# = 0 + 0 + 1 + 3 = 4
assert nested1 == 4
- nested2: Imu
+ nested2: Imm
match 1:
case 0:
nested2 = outer_with_inner(4)
@@ -428,7 +428,7 @@ def main():
# -------------------------------------------------------------------
# TEST 6: Deep nesting in match
# -------------------------------------------------------------------
- deep1: Imu
+ deep1: Imm
match 0:
case 0:
deep1 = deep_nested(3)
@@ -442,14 +442,14 @@ def main():
# -------------------------------------------------------------------
# TEST 7: Inline in if/else branches
# -------------------------------------------------------------------
- if_res1: Imu
+ if_res1: Imm
if 1 == 1:
if_res1 = count_up(7)
else:
if_res1 = count_up(3)
assert if_res1 == 7
- if_res2: Imu
+ if_res2: Imm
if 1 == 0:
if_res2 = count_up(7)
else:
@@ -459,7 +459,7 @@ def main():
# -------------------------------------------------------------------
# TEST 8: Nested if/else with inlines
# -------------------------------------------------------------------
- nested_if: Imu
+ nested_if: Imm
if 1 == 1:
if 2 == 2:
nested_if = sum_range(0, 5)
@@ -472,7 +472,7 @@ def main():
# -------------------------------------------------------------------
# TEST 9: Match inside if with inlines
# -------------------------------------------------------------------
- mixed: Imu
+ mixed: Imm
if 1 == 1:
match 1:
case 0:
@@ -486,7 +486,7 @@ def main():
# -------------------------------------------------------------------
# TEST 10: If inside match with inlines
# -------------------------------------------------------------------
- mixed2: Imu
+ mixed2: Imm
match 0:
case 0:
if 1 == 1:
@@ -500,7 +500,7 @@ def main():
# -------------------------------------------------------------------
# TEST 11: Complex mutable variables in inline
# -------------------------------------------------------------------
- cx: Imu
+ cx: Imm
match 0:
case 0:
cx = complex_muts(4)
@@ -519,7 +519,7 @@ def main():
# TEST 12: Mix of Mut and immutable in branches with inlines
# -------------------------------------------------------------------
outer_mut: Mut = 10
- inner_imu: Imu
+ inner_imu: Imm
match 0:
case 0:
local_imm = with_immutable(3)
@@ -535,7 +535,7 @@ def main():
# -------------------------------------------------------------------
# TEST 13: Inline inside unroll loop inside match
# -------------------------------------------------------------------
- unroll_in_match: Imu
+ unroll_in_match: Imm
match 0:
case 0:
acc: Mut = 0
@@ -550,10 +550,10 @@ def main():
# -------------------------------------------------------------------
# TEST 14: Multiple match levels with different inlines at each
# -------------------------------------------------------------------
- multi_match: Imu
+ multi_match: Imm
match 1:
case 0:
- inner: Imu
+ inner: Imm
match 0:
case 0:
inner = count_up(2)
@@ -561,7 +561,7 @@ def main():
inner = count_up(3)
multi_match = inner
case 1:
- inner2: Imu
+ inner2: Imm
match 1:
case 0:
inner2 = sum_range(0, 2)
@@ -573,7 +573,7 @@ def main():
# -------------------------------------------------------------------
# TEST 15: Same inline function called multiple times in same arm
# -------------------------------------------------------------------
- same_fn: Imu
+ same_fn: Imm
match 0:
case 0:
r1 = count_up(3)
@@ -607,7 +607,7 @@ def main():
# -------------------------------------------------------------------
# TEST 17: Variables declared inside only some branches
# -------------------------------------------------------------------
- outside: Imu
+ outside: Imm
match 0:
case 0:
local_only_here = count_up(5)
@@ -622,7 +622,7 @@ def main():
# -------------------------------------------------------------------
# TEST 18: Very deeply nested structure
# -------------------------------------------------------------------
- very_deep: Imu
+ very_deep: Imm
if 1 == 1:
match 0:
case 0:
@@ -668,7 +668,7 @@ def main():
# -------------------------------------------------------------------
# TEST 20: Inline result used immediately in arithmetic in branch
# -------------------------------------------------------------------
- arith: Imu
+ arith: Imm
match 0:
case 0:
arith = count_up(3) * 10 + sum_range(0, 3) * 100
@@ -684,7 +684,7 @@ def main():
# -------------------------------------------------------------------
# TEST 21: Inline containing if/else in different match arms
# -------------------------------------------------------------------
- t21: Imu
+ t21: Imm
match 0:
case 0:
t21 = inline_with_if(0)
@@ -693,7 +693,7 @@ def main():
# inline_with_if(0): result=100, result=100+0=100
assert t21 == 100
- t21b: Imu
+ t21b: Imm
match 1:
case 0:
t21b = inline_with_if(0)
@@ -705,7 +705,7 @@ def main():
# -------------------------------------------------------------------
# TEST 22: Inline containing match in different branches
# -------------------------------------------------------------------
- t22: Imu
+ t22: Imm
match 0:
case 0:
t22 = inline_with_match(0)
@@ -715,7 +715,7 @@ def main():
t22 = inline_with_match(2)
assert t22 == 1000
- t22b: Imu
+ t22b: Imm
match 2:
case 0:
t22b = inline_with_match(0)
@@ -728,7 +728,7 @@ def main():
# -------------------------------------------------------------------
# TEST 23: Inline with nested branches called in nested branches
# -------------------------------------------------------------------
- t23: Imu
+ t23: Imm
match 0:
case 0:
if 1 == 1:
@@ -740,7 +740,7 @@ def main():
# inline_with_nested_branch(0, 1): a=0 -> if b==0 else -> 20
assert t23 == 20
- t23b: Imu
+ t23b: Imm
match 1:
case 0:
t23b = inline_with_nested_branch(0, 0)
@@ -752,8 +752,8 @@ def main():
# -------------------------------------------------------------------
# TEST 24: Multi-return inline in match arms
# -------------------------------------------------------------------
- t24a: Imu
- t24b: Imu
+ t24a: Imm
+ t24b: Imm
match 0:
case 0:
t24a, t24b = multi_return_inline(5)
@@ -763,8 +763,8 @@ def main():
assert t24a == 5
assert t24b == 110
- t24c: Imu
- t24d: Imu
+ t24c: Imm
+ t24d: Imm
match 1:
case 0:
t24c, t24d = multi_return_inline(5)
@@ -777,9 +777,9 @@ def main():
# -------------------------------------------------------------------
# TEST 25: Triple return inline in branches
# -------------------------------------------------------------------
- t25a: Imu
- t25b: Imu
- t25c: Imu
+ t25a: Imm
+ t25b: Imm
+ t25c: Imm
match 0:
case 0:
t25a, t25b, t25c = triple_return(10)
@@ -793,7 +793,7 @@ def main():
# -------------------------------------------------------------------
# TEST 26: 4-level deep inline nesting in match arms
# -------------------------------------------------------------------
- t26: Imu
+ t26: Imm
match 0:
case 0:
t26 = level_a(1)
@@ -809,7 +809,7 @@ def main():
# = (1+2) + 20 + 200 + 2000 = 2223
assert t26 == 2223
- t26b: Imu
+ t26b: Imm
match 3:
case 0:
t26b = level_a(5)
@@ -825,7 +825,7 @@ def main():
# -------------------------------------------------------------------
# TEST 27: Inline with Array in match arms
# -------------------------------------------------------------------
- t27: Imu
+ t27: Imm
match 0:
case 0:
t27 = inline_with_array(10)
@@ -834,7 +834,7 @@ def main():
# inline_with_array(10): 10+11+12+13 = 46
assert t27 == 46
- t27b: Imu
+ t27b: Imm
match 1:
case 0:
t27b = inline_with_array(10)
@@ -846,7 +846,7 @@ def main():
# -------------------------------------------------------------------
# TEST 28: Inline modifying array in branches
# -------------------------------------------------------------------
- t28: Imu
+ t28: Imm
match 0:
case 0:
t28 = inline_modify_array(1)
@@ -858,7 +858,7 @@ def main():
# -------------------------------------------------------------------
# TEST 29: Chained inline calls in match arms
# -------------------------------------------------------------------
- t29: Imu
+ t29: Imm
match 0:
case 0:
# chain_a(5)=7, chain_b(7)=28, chain_c(28)=48
@@ -867,7 +867,7 @@ def main():
t29 = chain_a(100)
assert t29 == 48
- t29b: Imu
+ t29b: Imm
match 1:
case 0:
t29b = chain_c(chain_b(chain_a(1)))
@@ -879,7 +879,7 @@ def main():
# -------------------------------------------------------------------
# TEST 30: Different chain patterns in different arms
# -------------------------------------------------------------------
- t30: Imu
+ t30: Imm
match 0:
case 0:
t30 = chain_a(chain_a(chain_a(0)))
@@ -890,7 +890,7 @@ def main():
# chain_a(0)=2, chain_a(2)=4, chain_a(4)=6
assert t30 == 6
- t30b: Imu
+ t30b: Imm
match 1:
case 0:
t30b = chain_a(chain_a(chain_a(0)))
@@ -904,7 +904,7 @@ def main():
# -------------------------------------------------------------------
# TEST 31: Stress test - many variables inline in match
# -------------------------------------------------------------------
- t31: Imu
+ t31: Imm
match 0:
case 0:
t31 = many_vars(0)
@@ -918,7 +918,7 @@ def main():
# -------------------------------------------------------------------
# TEST 32: Multiple multi-return inlines in same arm
# -------------------------------------------------------------------
- t32_sum: Imu
+ t32_sum: Imm
match 0:
case 0:
a1, b1 = multi_return_inline(3)
@@ -936,7 +936,7 @@ def main():
# -------------------------------------------------------------------
# TEST 33: 5-way match with all different inline types
# -------------------------------------------------------------------
- t33: Imu
+ t33: Imm
match 0:
case 0:
t33 = count_up(10)
@@ -950,7 +950,7 @@ def main():
t33 = inline_with_array(1)
assert t33 == 10
- t33b: Imu
+ t33b: Imm
match 4:
case 0:
t33b = count_up(10)
@@ -968,10 +968,10 @@ def main():
# -------------------------------------------------------------------
# TEST 34: Triple nested match with inlines at each level
# -------------------------------------------------------------------
- t34: Imu
+ t34: Imm
match 0:
case 0:
- inner1: Imu
+ inner1: Imm
match 1:
case 0:
tmp34a = count_up(2)
@@ -986,10 +986,10 @@ def main():
assert t34 == 1423
# Additional triple nesting test - without forward declaration inside innermost
- t34b: Imu
+ t34b: Imm
match 0:
case 0:
- mid1: Imu
+ mid1: Imm
match 0:
case 0:
# Use inline directly without forward declaration
@@ -1003,10 +1003,10 @@ def main():
assert t34b == 1105
# Test forward declaration with nested match and inline
- t34c: Imu
+ t34c: Imm
match 0:
case 0:
- val34c: Imu
+ val34c: Imm
match 0:
case 0:
val34c = sum_range(0, 5)
@@ -1041,12 +1041,12 @@ def main():
assert deep_mut == 210
# -------------------------------------------------------------------
- # TEST 36: Multiple forward-declared Imu assigned via inlines
+ # TEST 36: Multiple forward-declared Imm assigned via inlines
# -------------------------------------------------------------------
- fwd1: Imu
- fwd2: Imu
- fwd3: Imu
- fwd4: Imu
+ fwd1: Imm
+ fwd2: Imm
+ fwd3: Imm
+ fwd4: Imm
match 0:
case 0:
fwd1 = count_up(1)
@@ -1086,7 +1086,7 @@ def main():
# -------------------------------------------------------------------
# TEST 38: If-else-if chain with different inlines
# -------------------------------------------------------------------
- t38: Imu
+ t38: Imm
if 0 == 1:
t38 = count_up(100)
else:
@@ -1126,7 +1126,7 @@ def main():
# -------------------------------------------------------------------
# TEST 40: Inline returning mutable at different states
# -------------------------------------------------------------------
- t40: Imu
+ t40: Imm
match 0:
case 0:
# complex_muts returns computation of interdependent muts
@@ -1164,7 +1164,7 @@ def main():
# TEST 42: Deeply nested with mixed mutable tracking
# -------------------------------------------------------------------
outer_m: Mut = 100
- t42: Imu
+ t42: Imm
if 1 == 1:
outer_m = outer_m + 50
match 0:
@@ -1197,14 +1197,14 @@ def main():
# -------------------------------------------------------------------
# TEST 43: All arms have different nesting patterns
# -------------------------------------------------------------------
- t43: Imu
+ t43: Imm
match 0:
case 0:
# Flat
t43 = count_up(5)
case 1:
# One level nested
- if_inner: Imu
+ if_inner: Imm
if 1 == 1:
if_inner = sum_range(0, 10)
else:
@@ -1212,7 +1212,7 @@ def main():
t43 = if_inner
case 2:
# Two levels nested
- m_inner: Imu
+ m_inner: Imm
match 0:
case 0:
m_inner = level_a(1)
@@ -1221,7 +1221,7 @@ def main():
t43 = m_inner
case 3:
# Three levels nested
- deep_inner: Imu
+ deep_inner: Imm
if 1 == 1:
match 0:
case 0:
@@ -1271,7 +1271,7 @@ def main():
# -------------------------------------------------------------------
# TEST 45: Inline calling another inline that has internal branches
# -------------------------------------------------------------------
- t45: Imu
+ t45: Imm
match 0:
case 0:
# outer_with_inner calls inner_loop
diff --git a/crates/lean_compiler/tests/test_data/program_172.py b/crates/lean_compiler/tests/test_data/program_172.py
index da813966..ba54fc05 100644
--- a/crates/lean_compiler/tests/test_data/program_172.py
+++ b/crates/lean_compiler/tests/test_data/program_172.py
@@ -8,7 +8,7 @@ def helper_const(n: Const):
def main():
- # Test 1: Basic match_range - no forward declaration needed (auto-generated as Imu)
+ # Test 1: Basic match_range - no forward declaration needed (auto-generated as Imm)
x = 2
r1 = match_range(x, range(0, 4), lambda i: i * 100)
assert r1 == 200
diff --git a/crates/lean_compiler/tests/test_data/program_174.py b/crates/lean_compiler/tests/test_data/program_174.py
index 19be7961..94e1f4d1 100644
--- a/crates/lean_compiler/tests/test_data/program_174.py
+++ b/crates/lean_compiler/tests/test_data/program_174.py
@@ -40,7 +40,7 @@ def main():
def match_start_at_1(x):
- result: Imu
+ result: Imm
match x:
case 1:
result = 100
@@ -54,7 +54,7 @@ def match_start_at_1(x):
def match_start_at_5(x):
- result: Imu
+ result: Imm
match x:
case 5:
result = 50
@@ -68,7 +68,7 @@ def match_start_at_5(x):
def match_start_at_10(x):
- result: Imu
+ result: Imm
match x:
case 10:
result = 1000
@@ -95,7 +95,7 @@ def match_nonzero_mutable(x):
def nested_nonzero_match(outer, inner):
- result: Imu
+ result: Imm
match outer:
case 1:
match inner:
@@ -125,7 +125,7 @@ def nested_nonzero_match(outer, inner):
def nonzero_match_in_if(cond, x):
- result: Imu
+ result: Imm
if cond == 0:
result = 0
else:
diff --git a/crates/lean_compiler/tests/test_data/program_179.py b/crates/lean_compiler/tests/test_data/program_179.py
index 84d1f0b0..521d0af6 100644
--- a/crates/lean_compiler/tests/test_data/program_179.py
+++ b/crates/lean_compiler/tests/test_data/program_179.py
@@ -1,6 +1,6 @@
from snark_lib import *
-ONE_EF_PTR = 1 # right after the (empty-public-input) zero-padded cell at memory[0]
+ONE_EF_PTR = 8 # right after the 8-cell public input region
def main():
diff --git a/crates/lean_compiler/tests/test_data/program_43.py b/crates/lean_compiler/tests/test_data/program_43.py
index aed3a00b..d0199de4 100644
--- a/crates/lean_compiler/tests/test_data/program_43.py
+++ b/crates/lean_compiler/tests/test_data/program_43.py
@@ -7,7 +7,8 @@ def main():
return
-def increment_twice(x: Mut):
- x = x + 1
- x = x + 1
- return x
+def increment_twice(x):
+ y: Mut = x
+ y = y + 1
+ y = y + 1
+ return y
diff --git a/crates/lean_compiler/tests/test_data/program_57.py b/crates/lean_compiler/tests/test_data/program_57.py
index 3d4965c4..4a96412e 100644
--- a/crates/lean_compiler/tests/test_data/program_57.py
+++ b/crates/lean_compiler/tests/test_data/program_57.py
@@ -10,19 +10,22 @@ def main():
return
-def step1(n: Mut):
- n = n * 2
- n = n + 1
- return n
+def step1(n):
+ m: Mut = n
+ m = m * 2
+ m = m + 1
+ return m
-def step2(n: Mut):
- n = n * 3
- n = n + 2
- return n
+def step2(n):
+ m: Mut = n
+ m = m * 3
+ m = m + 2
+ return m
-def step3(n: Mut):
- n = n * 4
- n = n + 3
- return n
+def step3(n):
+ m: Mut = n
+ m = m * 4
+ m = m + 3
+ return m
diff --git a/crates/lean_compiler/tests/test_data/program_67.py b/crates/lean_compiler/tests/test_data/program_67.py
index ed80db93..5c91e661 100644
--- a/crates/lean_compiler/tests/test_data/program_67.py
+++ b/crates/lean_compiler/tests/test_data/program_67.py
@@ -2,7 +2,7 @@
def main():
- mut_a: Imu
+ mut_a: Imm
mut_a = 5
assert mut_a == 5
return
diff --git a/crates/lean_compiler/tests/test_data/program_68.py b/crates/lean_compiler/tests/test_data/program_68.py
index 13eb50c6..a4797324 100644
--- a/crates/lean_compiler/tests/test_data/program_68.py
+++ b/crates/lean_compiler/tests/test_data/program_68.py
@@ -10,10 +10,10 @@ def main():
def test_func(a, b):
x = 1
- mut_x_2: Imu
+ mut_x_2: Imm
match a:
case 0:
- mut_x_1: Imu
+ mut_x_1: Imm
mut_x_1 = x + 2
match b:
case 0:
diff --git a/crates/lean_compiler/tests/test_data/program_69.py b/crates/lean_compiler/tests/test_data/program_69.py
index 4c19fda8..ecc0051e 100644
--- a/crates/lean_compiler/tests/test_data/program_69.py
+++ b/crates/lean_compiler/tests/test_data/program_69.py
@@ -15,20 +15,20 @@ def main():
def compute(a, b, c):
base = 1000
- outer_val: Imu
- mid_val: Imu
- inner_val: Imu
+ outer_val: Imm
+ mid_val: Imm
+ inner_val: Imm
match a:
case 0:
outer_val = 5
- local_a: Imu
+ local_a: Imm
local_a = a + outer_val
match b:
case 0:
mid_val = 3
- local_b: Imu
+ local_b: Imm
local_b = local_a + mid_val
match c:
@@ -38,7 +38,7 @@ def compute(a, b, c):
inner_val = base + local_b + c
case 1:
mid_val = 7
- local_b: Imu
+ local_b: Imm
local_b = local_a + mid_val
match c:
@@ -48,13 +48,13 @@ def compute(a, b, c):
inner_val = base + local_b + c
case 1:
outer_val = 15
- local_a: Imu
+ local_a: Imm
local_a = a + outer_val
match b:
case 0:
mid_val = 20
- local_b: Imu
+ local_b: Imm
local_b = local_a + mid_val
match c:
@@ -64,7 +64,7 @@ def compute(a, b, c):
inner_val = base + local_b + c
case 1:
mid_val = 30
- local_b: Imu
+ local_b: Imm
local_b = local_a + mid_val
match c:
diff --git a/crates/lean_compiler/tests/test_data/program_83.py b/crates/lean_compiler/tests/test_data/program_83.py
index aa423c9f..36bd35d3 100644
--- a/crates/lean_compiler/tests/test_data/program_83.py
+++ b/crates/lean_compiler/tests/test_data/program_83.py
@@ -2,7 +2,7 @@
def main():
- x: Imu
+ x: Imm
cond = 1
if cond == 1:
x = 10
diff --git a/crates/lean_compiler/tests/test_data/program_99.py b/crates/lean_compiler/tests/test_data/program_99.py
index 40d61c71..66b2c119 100644
--- a/crates/lean_compiler/tests/test_data/program_99.py
+++ b/crates/lean_compiler/tests/test_data/program_99.py
@@ -7,7 +7,8 @@ def main():
return
-def accumulate(x: Mut):
+def accumulate(x):
+ acc: Mut = x
for i in unroll(0, 3):
- x = x + i
- return x
+ acc = acc + i
+ return acc
diff --git a/crates/lean_compiler/tests/test_data/soundness_2.py b/crates/lean_compiler/tests/test_data/soundness_2.py
index 3fa2867b..630d756a 100644
--- a/crates/lean_compiler/tests/test_data/soundness_2.py
+++ b/crates/lean_compiler/tests/test_data/soundness_2.py
@@ -12,7 +12,7 @@ def main():
offset = p[6]
total = p[7]
- computed: Imu
+ computed: Imm
match mode:
case 0:
computed = add_op(x, y)
@@ -24,7 +24,7 @@ def main():
computed = combined(x, y)
assert computed == expected
- adjusted: Imu
+ adjusted: Imm
if flag == 0:
adjusted = bump(secondary, 1)
elif flag == 1:
diff --git a/crates/lean_compiler/tests/test_data/soundness_5.py b/crates/lean_compiler/tests/test_data/soundness_5.py
index 3333ab2f..eaa31b7c 100644
--- a/crates/lean_compiler/tests/test_data/soundness_5.py
+++ b/crates/lean_compiler/tests/test_data/soundness_5.py
@@ -36,7 +36,7 @@ def main():
assert paired_sum(seed, n) == paired
- chosen: Imu
+ chosen: Imm
if flag == 1:
chosen = seed
else:
diff --git a/crates/lean_compiler/tests/test_performance.rs b/crates/lean_compiler/tests/test_performance.rs
index 723b9bd4..893883e2 100644
--- a/crates/lean_compiler/tests/test_performance.rs
+++ b/crates/lean_compiler/tests/test_performance.rs
@@ -1,3 +1,4 @@
+use backend::PrimeCharacteristicRing;
use lean_compiler::*;
use lean_vm::*;
@@ -9,7 +10,13 @@ fn test_data_dir() -> String {
/// Helper to get the number of cycles for a program file
fn get_cycle_count(path: &str) -> usize {
let bytecode = compile_program(&ProgramSource::Filepath(path.to_string()));
- let result = try_execute_bytecode(&bytecode, &[], &ExecutionWitness::default(), false).unwrap();
+ let result = try_execute_bytecode(
+ &bytecode,
+ &[F::ZERO; PUBLIC_INPUT_LEN],
+ &ExecutionWitness::default(),
+ false,
+ )
+ .unwrap();
result.pcs.len()
}
diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md
index 5ea8d46a..08d67c37 100644
--- a/crates/lean_compiler/zkDSL.md
+++ b/crates/lean_compiler/zkDSL.md
@@ -1,65 +1,102 @@
# zkDSL Language Reference
-Warning: still under construction (i.e. it's messy).
+The zkDSL is a Python-syntax language that compiles to leanVM bytecode (4 basic instructions and 2 special ones (precompile): poseidon / extension operations). For the underlying VM, and proving system, see [`minimal_zkVM.pdf`](../../minimal_zkVM.pdf).
-## Program Structure
+Source files use the `.py` extension. They are **not** currently runnable as
+real Python, but the syntax is kept Python-compatible so that one day they
+could be (TODO).
+## Dev experience
+
+To recycle python tooling/linting on zkDSL files (which import [`snark_lib`](snark_lib.py)), point your editor at the compiler crate. With VSCode (for instance in `leanMultisig/.vscode/settings.json`):
+
+```json
+{
+ "python.analysis.extraPaths": [
+ "./crates/lean_compiler"
+ ]
+}
```
-from snark_lib import * # Python compatibility (ignored by compiler)
-from dir.file import * # imports (optional, Python-style)
-NAME = value # constants (optional, uppercase by convention)
-def main(): # entry point (required)
+
+## Entrypoint
+
+Programs are organized as one or more `.py` files. The toplevel of each file is a
+sequence of:
+
+1. `from import *` statements (optional)
+2. Top-level constant declarations (optional)
+3. Function definitions
+
+Execution starts at `def main(): ...`.
+
+```python
+from snark_lib import * # only there to keep the Python linter happy; stripped by the zkDSL compiler
+from utils import * # import other file
+
+X = 42 # constants must come before functions
+# array constants (or arbitrary dimmensions: 1D, 2D, etc)
+ARR_1D = [1, 2, 3]
+ARR_2D = [[1, 2, 3], [], [10, 4]]
+ARR_3D = [[[1, 2, 3], [7, 8], [9]], [], [[10], [10, 4]]]
+
+def main(): # required entry point
...
-def helper(): # other functions (optional)
+
+def helper(): # other functions
...
```
-The `from snark_lib import *` line imports Python definitions for zkDSL primitives (Array, Mut, Const, etc.), allowing `.py` files to be executed as normal Python scripts for testing. The zkDSL compiler ignores this import line.
+## Imports
-To run zkDSL files as Python scripts, run from the file's directory with PYTHONPATH pointing to the lean_compiler crate (for snark_lib.py):
-```bash
-export PYTHONPATH=/path/to/repo/crates/lean_compiler
-cd crates/lean_compiler/tests/test_data
-python program_0.py
+```python
+from utils import * # imports utils.py (resolved from the import root)
+from dir.subdir.file import * # nested module
+from ..module import * # parent-directory import (relative to current file)
```
+Imports are wildcard-only (`import *`). Each module is loaded once even if imported
+multiple times; circular imports are rejected. Constants with the same
+name in two imported files cause a compile-time error.
+
## Constants
-Constants are declared at the top level (outside functions) using simple assignment. By convention, constant names are UPPERCASE.
+Constants live at the top of the file, outside any function.
-```
+```python
X = 42
ARR = [1, 2, 3]
NESTED = [[1, 2], [3]]
```
-### Multi-Dimensional Const Arrays
-
-Const arrays can be nested to any depth, and inner arrays can have different lengths (ragged arrays). All const array values are resolved at compile time.
+### Nested (multi-dimensional, possibly ragged) constant arrays
-```
-MATRIX = [[1, 2, 3], [4, 5], [6, 7, 8, 9]] # ragged 2D array
-DEEP = [[[1, 2], [3]], [[4, 5, 6]]] # 3D array
+```python
+MATRIX = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
+DEEP = [[[1, 2], [3]], [[4, 5, 6]]]
```
-**Accessing elements:** Use chained indexing with compile-time indices:
-```
-x = MATRIX[0][2] # x = 3
-y = DEEP[1][0][1] # y = 5
+Indexed access uses chained subscripts at compile time:
+
+```python
+x = MATRIX[0][2] # 3
+y = DEEP[1][0][1] # 5
```
-**Using `len()` on inner arrays:** The `len()` function can be applied to any level of a nested const array, including inner arrays accessed by index. This is particularly useful for iterating over ragged arrays where each row has a different length:
+`len()` works at every depth, including on a row addressed by a constant index:
-```
-len(MATRIX) # 3
-len(MATRIX[0]) # 3
-len(DEEP[0][0]) # 2
+```python
+len(MATRIX) # 3
+len(MATRIX[0]) # 3
+len(DEEP[0][0]) # 2
```
-**Important:** When using `len()` on an inner array with a variable index (e.g., `len(ARR[i])`), the index must be a compile-time constant. This works inside `unroll` loops because the loop variable becomes a compile-time constant during unrolling.
+When `len()` is applied with a variable index (`len(ARR[i])`), `i` must be a
+compile-time constant. `: Const` parameters always qualify (see [Functions]
+below), as do iterator variables of an `unroll` loop (see [For loops] below).
-**Example: Iterating over a ragged 2D array:**
-```
+Example: iterating a ragged 2D table:
+
+```python
MATRIX = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
def main():
@@ -67,17 +104,17 @@ def main():
for row in unroll(0, len(MATRIX)):
for col in unroll(0, len(MATRIX[row])):
total = total + MATRIX[row][col]
- assert total == 45 # 1+2+3+4+5+6+7+8+9
+ assert total == 45
return
```
## Functions
-```
-def add(a, b): # return count is inferred from return statements
+```python
+def add(a, b):
return a + b
-def swap(a, b): # multiple return values
+def swap(a, b):
return b, a
def main():
@@ -85,122 +122,116 @@ def main():
return
```
-The number of return values is automatically inferred from the `return` statements. All return statements in a function must return the same number of values.
+Every function must contain at least one `return`. The compiler infers the number
+of returned values from the `return` statements; all `return`s in a function must
+agree. A function that "returns nothing" uses a bare `return`.
-### Parameter Modifiers
-| Syntax | Meaning |
-| ---------- | --------------------------------------------------------- |
-| `x` | immutable parameter |
-| `x: Const` | compile-time value (enables `unroll` with dynamic bounds) |
-| `x: Mut` | mutable within function body only |
+### Parameter types
-**All parameters are pass-by-value.** The `: Mut` modifier allows reassignment within the function, but changes are not visible to the caller. Use return values to communicate results.
+| Syntax | Meaning |
+| ---------- | ------------------------------------ |
+| `x` | normal (immutable) runtime parameter |
+| `x: Const` | compile-time parameter |
-```
-def repeat(n: Const): # Const enables unroll
+```python
+def repeat(n: Const): # Const enables unroll(0, n)
sum: Mut = 0
for i in unroll(0, n):
sum = sum + i
return sum
-def double(x: Mut): # Mut allows local reassignment
- x = x * 2 # only affects local copy
- return x # must return to pass result back
+def double(x): # parameter is immutable; shadow with a local
+ y: Mut = x
+ y = y * 2
+ return y
```
-### Inline Functions
-Use the `@inline` decorator to mark functions for inlining at call sites:
-```
+### Inline functions
+
+`@inline` expands a function at every call site instead of generating a JUMP
+instruction to another part of the bytecode. Useful for performance (calling a function costs a few cycles).
+
+```python
@inline
def square(x):
return x * x
```
-**Note:** Inline functions cannot have `: Mut` parameters.
-**Note:** Inline functions support at most one `return`, and it must be at the
-top level of the body — never nested inside an `if`, loop, or `match`. Early or
-conditional returns are rejected by the compiler, because inlining expands each
-`return` into a plain assignment with no control flow. Use a regular (non-inline)
-function — with `: Const` parameters if you need compile-time specialization —
-when you need a conditional return.
+Constraints on inline functions (compiler limitations): Exactly one `return`, placed as the last statement of the body, not nested inside `if`, a loop, or `match`. Inlining rewrites the `return` into a plain assignment in place, so early or conditional returns cannot be expressed.
## Variables
| Declaration | Mutability | Notes |
| ------------- | ---------- | ---------------------------------------------- |
| `x = 10` | immutable | cannot be reassigned |
-| `x: Mut = 10` | mutable | can be reassigned |
-| `x: Imu` | immutable | forward declaration, assign exactly once later |
-| `x: Mut` | mutable | forward declaration for mutable variable |
+| `x: Mut = 10` | mutable | reassignable |
+| `x: Imm` | immutable | forward declaration; assign exactly once later |
+| `x: Mut` | mutable | forward declaration; reassignable later |
-### Forward Declarations
+### Forward declarations
-Use `x: Imu` when a variable must be assigned in different branches:
+Use `x: Imm` when you want an immutable binding but the value comes from a
+branch:
-```
-result: Imu # immutable: assign exactly once
+```python
+result: Imm
if cond == 1:
result = 10
else:
result = 20
-# result cannot be reassigned after this
+# result is now immutable
```
-Use `x: Mut` when you need the variable to be mutable after assignment:
+Use `x: Mut` when you want to keep mutating the variable after the branch:
-```
+```python
x: Mut
if cond == 1:
x = 10
else:
x = 20
-x = x + 1 # OK: x was declared as mutable
+x = x + 1 # OK: x is mutable
```
-### Tuple Assignments with Mutable Variables
+### Mutability inside tuple assignments
-When a function returns multiple values and some need to be mutable, use forward declarations:
+To make a single component of a tuple-return mutable, forward-declare it:
-```
-b: Mut # declare b as mutable
+```python
+b: Mut
a, b, c = some_function()
-# a and c are immutable, b is mutable
-b = b + 1 # OK
-# a = 5 # ERROR: a is immutable
+b = b + 1 # OK
+# a = 5 # ERROR: a is immutable
```
-This is useful when a function returns multiple values and only some need to be modified later.
-
-## Memory and Arrays
+## Memory and arrays
-```
-buffer = Array(16) # allocate 16 field elements
+```python
+buffer = Array(16) # allocate 16 field elements
buffer[0] = 42
-x = buffer[5]
+buffer[0] = 42 # Valid
+# buffer[0] = 41 # ERROR: conflicting write (read only memory)
+buffer[5] = 34
+x = buffer[5] # x = 34
-matrix = Array(64) # 2D via manual indexing
+matrix = Array(64) # 2D via manual indexing
matrix[row * 8 + col] = value
-ptr2 = ptr + 5 # pointer arithmetic
-ptr2[0] = 100 # same as ptr[5] = 100
+ptr2 = buffer + 5 # pointer arithmetic
+ptr2[0] = 100 # same as buffer[5] = 100
```
-**Memory is write-once.** Due to SSA constraints, each memory location can only hold one value. Writing to the same location multiple times is allowed, but all writes must produce the same value—otherwise a runtime error occurs.
+`Array(n)` returns a pointer to a freshly allocated block of `n` field
+elements. `n` may be a compile-time constant (more efficient, analogy: allocated on the stack) or a runtime
+value (less efficient, analogy: allocated on the heap). Memory is **write-once**: a cell may be
+written more than once only if all writes store the same value.
-```
-arr = Array(3)
-arr[0] = 10 # OK: first write
-arr[0] = 10 # OK: same value
-arr[0] = 20 # ERROR: different value at same location
-```
+## Control flow
-Use `mut` variables when you need mutability, the compiler cannot handle mutability on hand-written allocated memory ("Array(...)").
+### `if` / `elif` / `else`
-## Control Flow
-
-### If/Else
-```
+```python
if x == 0:
y = 1
elif x == 1:
@@ -208,30 +239,42 @@ elif x == 1:
else:
y = 3
```
-Comparison operators: `==`, `!=`
-### Match
-Patterns must be consecutive integers:
-```
+Comparison operators on conditions: `==`, `!=`, `<`, `<=`. There is **no** `>`
+or `>=` (flip the operands to get the same effect).
+
+### `match`
+
+Patterns must be a set of integers of the form [n, n+1, n + 2, ...]:
+
+```python
match value:
case 5:
result = 500
+ do_stuf()
case 6:
result = 600
+ do_other_stuf()
case 7:
result = 700
+ ...
```
-### match_range
+The matched value must lie inside the listed range; out-of-range values produce
+undefined behaviour: **It's the responsability of the program to ensure this** (no checks added by the compiler). Letting a prover-controlled value escape the range in a `range` is a critical vulnerability.
-Compile-time construct that expands into a match statement, useful for dispatching to functions with const parameters based on runtime values. Results are always immutable.
+### `match_range`
-```
+`match_range` enables to automatically generate a `match` with repeated arms.
+
+```python
result = match_range(n, range(1, 5), lambda i: compute(i))
```
-Expands to:
-```
-result: Imu # auto-generated forward declaration (always immutable)
+
+is expanded by the compiler to:
+
+```python
+result: Imm
match n:
case 1: result = compute(1)
case 2: result = compute(2)
@@ -239,323 +282,471 @@ match n:
case 4: result = compute(4)
```
-**Multiple continuous ranges** with different lambdas:
-```
+It's possible to chain several `(range, lambda)` pairs, provided the ranges are
+**contiguous** (the end of one is the start of the next):
+
+```python
result = match_range(n,
range(0, 1), lambda i: special_case(),
range(1, 8), lambda i: normal_case(i))
```
-Expands to a match where case 0 uses `special_case()` and cases 1-7 use `normal_case(i)`.
-Ranges must be continuous (end of one equals start of next).
+Multiple return values are supported via tuple unpacking. The bindings produced
+by `match_range` are always immutable. Forward-declare with `: Mut` (and then
+reassign) if you need them mutable later:
-**Multiple return values:**
-```
+```python
+a: Mut
a, b = match_range(n, range(0, 4), lambda i: two_values(i))
+a += 1
```
-**Common use case:** Dispatching runtime values to const-parameter functions:
-```
+Idiomatic use: enables to dispatch a runtime value to a const-parameter function.
+
+```python
def helper_const(n: Const):
- # function that requires compile-time n
return n * n
def compute(value):
- result = match_range(value, range(0, 10), lambda i: helper_const(i))
- return result
+ assert value < 10
+ return match_range(value, range(0, 10), lambda i: helper_const(i))
```
+Similar to `match`, range validity of the matched value is the responsibility of the program, not the compiler. Letting a prover-controlled value escape the range in a `match_range` is a critical vulnerability.
-**IMPORTANT:** For both `match` and `match_range`, the programmer must ensure the value is within the specified range. Out-of-range values cause undefined behavior. Use `debug_assert` to validate:
-```
-debug_assert(n < 10)
-debug_assert(0 < n)
-result = match_range(n, range(1, 10), lambda i: compute(i))
-```
+### For loops
-### For Loops
-```
-for i in range(0, 10): # standard loop
- ...
-for i in parallel_range(0, n): # iterations executed in parallel (see below)
- ...
-for i in unroll(0, 4): # unrolled at compile time
- ...
-```
-Use `unroll` when bounds are const or compile-time expansion is needed.
+Three loop forms, all written `for i in (start, end):`. The
+iterator visits `start, start + 1, ..., end - 1`.
-**`parallel_range`** executes iterations concurrently using rayon. The produced bytecode is identical to `range`. Constraints:
-- The loop body must be **iteration-independent**: no `Mut` variables carried
- across iterations. Each iteration may only write to its own frame and to
- external addresses that do not affect other iterations .
-- The memory footprint (i.e. total memory usage) must be the same across iterations
-- XMSS / Merkle hint consumption must be the same across iterations
+Restrictions shared by all three forms:
-**Mutable variables in non-unrolled loops:** Mutable variables can be modified inside non-unrolled loops. The compiler automatically transforms these into buffer-based implementations:
+- No `break` or `continue` (not in the grammar).
-```
+#### `range(a, b)`: runtime loop
+
+The general-purpose runtime loop. `a` and `b` may be runtime values. The
+compiler lowers the loop to a recursive function.
+
+```python
sum: Mut = 0
for i in range(1, 11):
sum += i
assert sum == 55
```
-Loops limitations:
-- no "continue" or "break" are supported yet
-- the "return" keyword is not supported inside the body of a normal (non-unrolled) loop (because under the hood normal loops are transformed into recursive functions)
+Mutable variables carried across iterations are supported transparently.
+
+*Under the hood: the compiler inserts a buffer array, stores the per-iteration value into it, and reads the final value back after the loop.*
+
+Restrictions: No `return` inside the body
+
+*Under the hood: because the loop is lowered to a recursive function.*
+
+#### `unroll(a, b)`: compile-time unrolling
+
+The loop is expanded at compile time: the body is duplicated once per iteration
+with `i` substituted by its concrete value. Both `a` and `b` must be
+compile-time constants.
+
+```python
+for i in unroll(0, 4):
+ buffer[i] = i * i
+```
+
+#### `parallel_range(a, b)` — parallel runtime loop
+
+**`parallel_range` compiles to exactly the same bytecode as `range`.** It
+differs only in the runner's scheduling policy: iterations are dispatched
+concurrently across worker threads rather than evaluated in sequence. The only advantage is faster witness generation.
+Iteration `a` is executed first, in isolation, to determine the per-iteration
+memory footprint; the remaining iterations are then evaluated in parallel
+without inter-iteration synchronization.
+
+```python
+for i in parallel_range(0, n):
+ process(i, inputs[i], outputs[i])
+```
+
+Because there is no synchronization, the loop body must be
+iteration-independent:
+
+- No `Mut` variables carried across iterations (each iteration writes only to
+ its own call frame and to addresses disjoint from every other iteration).
+- Identical memory footprint per iteration.
+- Identical hint consumption per iteration (witness hints, XMSS-specific
+ decomposition hints, Merkle hints, etc.).
+
+These constraints are **not** checked at compile time. Violating them produces
+silently wrong proofs.
+
+### Statements without effect are rejected
+
+Every line must either be a declaration, an assignment, a control-flow form, an
+assertion, a `return`, or a side-effecting call (`hint_witness`, precompile,
+`print`, or a function call). A bare expression like `x + 1` on its own line is
+a compile error.
## Expressions
### Arithmetic
-- `+`, `-`, `*`, `/` (field operations): allowed at runtime
-- `%` (modulo), `**` (exponentiation): only allowed at compile time
-### Compound Assignment
-Syntactic sugar for updating mutable variables:
-```
+`+`, `-`, `*`, `/` are field operations and work at runtime.
+
+`%` (modulo) and `**` (exponentiation) are **compile-time only** — both operands
+must be constants known at compile time.
+
+### Compound assignment
+
+```python
x: Mut = 10
-x += 5 # equivalent to: x = x + 5
-x -= 3 # equivalent to: x = x - 3
-x *= 2 # equivalent to: x = x * 2
-x /= 4 # equivalent to: x = x / 4
+x += 5 # x = x + 5
+x -= 3 # x = x - 3
+x *= 2 # x = x * 2
+x /= 4 # x = x / 4
```
-### Built-in Functions
-Only allowed at compile time:
+Only a single target is allowed on the LHS of a compound assignment.
-```
-log2_ceil(x) # ceiling of log2
-next_multiple_of(x, n) # smallest multiple of n >= x
-div_ceil(a, b) # ceiling division: (a + b - 1) // b
-div_floor(a, b) # floor division: a // b
+### Compile-time built-ins
+
+These functions are evaluated at compile time only — their arguments must be
+constants:
+
+```python
+log2_ceil(x) # ceil(log2(x))
+next_multiple_of(x, n) # smallest multiple of n that is >= x
+div_ceil(a, b) # (a + b - 1) // b
+div_floor(a, b) # a // b
saturating_sub(a, b) # max(0, a - b)
-len(array) # length of const array or vector
+len(array) # length of a constant array (any depth)
```
-## Assertions
+### `_` (the discard target)
+
+Inside a tuple-unpacking LHS, `_` discards the value at that position. The
+compiler rewrites each `_` to a fresh anonymous name so they don't collide.
+
+```python
+_, b = swap(a, b) # only keep b
+_ = compute() # discard a single return value
```
-# constraint in proof
+
+## Assertions
+
+The zkDSL provides two assertion forms with very different semantics:
+
+| Form | Enforced by | Use for |
+| -------------- | --------------------------------------- | ------------------------------------------------------------------- |
+| `assert` | The proof system | Invariants the verifier must check |
+| `debug_assert` | The prover only (at witness generation) | Sanity checks; preconditions the verifier does not need to re-check |
+
+### `assert`: proof-enforced constraint
+
+```python
assert x == y
assert x != y
-assert x < y
+assert x < y
assert x <= y
-# unconditional failure (panic)
+```
+
+The four supported comparison operators are `==`, `!=`, `<`, `<=` (no `>` or
+`>=`; flip the operands).
+
+
+### Range checks: `assert a < b` and `assert a <= b`
+
+**The program must ensure `b <= 2^16`.** The compiler does not check this
+(`b` may be a runtime value). Violating the bound is a critical soundness
+vulnerability.
+
+*Under the hood: the compiler proves `a < b` by emitting two DEREF instructions,
+which check that `a` and `b - 1 - a` are both valid memory addresses. An
+address is valid iff it is `< M`, where `M` is the memory size. To stay sound
+for every admissible memory size, the construction relies on the smallest one,
+`M_min = 2^16` (= `2^MIN_LOG_MEMORY_SIZE`), giving the bound `b <= 2^16`.*
+
+#### Explicit panic
+
+`assert False` is the unconditional failure form. It compiles to a Panic and
+accepts an optional message:
+
+```python
assert False
-assert False, "error message"
-# runtime check only (not constrained by the snark)
-debug_assert(x == y)
-debug_assert(x != y)
+assert False, "human-readable message"
+```
+
+### `debug_assert`: sanity checks at witness generation
+
+```python
debug_assert(x < y)
-debug_assert(x <= y)
```
+`debug_assert` accepts the same four comparison operators. It is evaluated by
+the prover at trace-generation time and does **not** emit any constraint, so
+the verifier never re-checks it. Use it for invariants the prover is expected
+to maintain but that the verifier can take for granted — typically the
+range-validity preconditions of `match` / `match_range` dispatches.
+
## Comments
-```
-# Single-line comment
+```python
+# single-line comment
"""
-Multi-line comment
-can span multiple lines
+block comment
"""
```
-## Imports
-
-```
-from utils import * # imports utils.py (relative to import root)
-from dir.subdir.file import * # imports dir/subdir/file.py
-```
+## Line continuation
-## Memory Layout
+As in Python:
-The runner places the program's memory as:
+- **Implicit** continuation inside `(...)` or `[...]`.
+- **Explicit** continuation with `\` at end of line.
-```
-[ public_input | preamble_memory | runtime ]
+```python
+result = function_call(arg1,
+ arg2,
+ arg3) # implicit continuation inside parens
+y = 1 + 2 + \
+ 3 + 4 # explicit continuation with backslash
```
-- `public_input` lives at `memory[0..public_input.len()]` (zero-padded to a power of two by the runner so it can be evaluated as a multilinear polynomial).
-- `preamble_memory` is a region the runner reserves but does not initialize. The guest program is responsible for writing any constants it needs (e.g. `ZERO_VEC_PTR`, `ONE_EF_PTR`, etc.) in this area.
+## Hints (prover-supplied data)
-Prover-supplied witness data is fetched on demand with `hint_witness("name", ptr)`, where the string literal
-names an entry in the witness's `hints: HashMap>>` map and
-`ptr` is a caller-allocated buffer. Each call writes the next unused `Vec`
-under that name (per-name running index) into the buffer at `ptr`. The guest
-is responsible for allocating `ptr` with enough room; the witness's length is
-trusted.
-`hint_witness`
+A hint is data the *prover* writes into memory without adding any constraint —
+the program must still constrain the written value if it wants the verifier to
+believe anything about it. There are two flavours of hint:
-```
-data_buf = Array(64)
-hint_witness("input_data", data_buf) # writes next `input_data` entry into data_buf
-n = data_buf[0]
-# ...
-```
+### `hint_witness("name", ptr)`
-### Built-in Hints
+Writes the next buffer queued under the label `name` into memory starting at
+`ptr`. The guest must allocate `ptr` large enough to hold the data; no length
+is checked at runtime.
-hints = prover-supplied values at runtime (without adding snark constraints). Like `hint_witness`, they are bare statements (no return value) — the caller allocates any destination memory and is responsible for constraining the written values.
+The buffer comes from the host (Rust side), not from the guest. Before
+running the program, the host fills `ExecutionWitness::hints` with one queue
+of buffers per label; each `hint_witness("name", ptr)` call pops the next
+buffer from `hints["name"]`.
-| Hint | Signature | Writes |
-| --------------------------------- | --------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
-| `hint_decompose_bits` | `(to_decompose, ptr, num_bits, endianness)` | `num_bits` field elements at `ptr` (the 0/1 bit decomposition of `to_decompose`); `endianness` is `0` for big-endian, `1` for little-endian |
-| `hint_less_than` | `(a, b, result_ptr)` | `1` at `result_ptr` if `a < b` else `0` |
-| `hint_log2_ceil` | `(n, result_ptr)` | `ceil(log2(n))` at `result_ptr` |
-| `hint_div_floor` | `(a, b, q_ptr, r_ptr)` | `floor(a/b)` at `q_ptr` and `a mod b` at `r_ptr` (requires `b != 0`) |
-| `hint_decompose_bits_xmss` | `(decomposed_ptr, remaining_ptr, to_decompose_ptr, num_to_decompose, chunk_size)` | XMSS-specific decomposition (see `crates/lean_vm/src/isa/hint.rs`) |
-| `hint_decompose_bits_merkle_whir` | `(decomposed_ptr, remaining_ptr, value, chunk_size)` | Merkle/WHIR-specific decomposition |
+`ExecutionWitness` lives in `crates/lean_vm/src/execution/runner.rs`:
-Hints only *suggest* a value; the guest must add appropriate constraints to bind that value to its specification.
+```rust
+pub struct ExecutionWitness {
+ ...
+ pub hints: HashMap>>,
+ ...
+}
+```
+Each map key is a label; the value is the **ordered list of buffers** the
+guest will consume under that label. The N-th `hint_witness("name", ptr)` call
+the guest executes pops the N-th `Vec` from `hints["name"]` and writes it
+at `ptr`.
-## Precompiles
+For example, the guest below issues three `hint_witness` calls — two against
+`"input_data"` and one against `"other_stuff"`:
-### poseidon16_compress
-Always in "compression" mode
-```
-poseidon16_compress(left, right, output)
+```python
+data_buf_1 = Array(64)
+hint_witness("input_data", data_buf_1)
+n = data_buf_1[0]
+
+data_buf_2 = Array(64)
+hint_witness("input_data", data_buf_2)
+m = data_buf_2[3]
+assert n == m + 8
+
+data_buf_3 = Array(10)
+hint_witness("other_stuff", data_buf_3)
+...
```
-- `left`: pointer to 8 field elements
-- `right`: pointer to 8 field elements
-- `res`: pointer to result (8 elements)
-### Extension Operations
+The matching Rust side must register two buffers under `"input_data"` (in
+the order the guest will read them) and one under `"other_stuff"`:
-Six built-in functions route through a single `extension_op` precompile table. Each combines an element-wise operation with an accumulation over `length` element pairs.
+```rust
+let mut hints: HashMap>> = HashMap::new();
+hints.insert(
+ "input_data".to_string(),
+ vec![
+ first_input_buffer, // consumed by the first hint_witness("input_data", ...)
+ second_input_buffer, // consumed by the second hint_witness("input_data", ...)
+ ],
+);
+hints.insert("other_stuff".to_string(), vec![other_buffer]);
+let witness = ExecutionWitness { hints, ..Default::default() };
```
-func(ptr_a, ptr_b, ptr_result) # length defaults to 1
-func(ptr_a, ptr_b, ptr_result, length) # explicit length (N elements)
-```
-
-**Operand types (suffix):**
-- `_ee`: both `ptr_a` and `ptr_b` point to extension field elements (5 consecutive field elements each, stride = DIM)
-- `_be`: `ptr_a` points to base field elements (stride 1), `ptr_b` points to extension field elements (stride DIM)
-`ptr_result` always points to a single extension field element (DIM=5 field elements).
+A missing label, or running out of buffers under a label, is a runner-side
+panic: each call requires its corresponding entry to exist.
-**Operations:**
+### Custom hints
-| Function | Element-wise | Accumulation |
-| ----------------------------------- | --------------------------------- | -------------------- |
-| `add_ee` / `add_be` | `e_i = a_i + b_i` | `result = sum(e_i)` |
-| `dot_product_ee` / `dot_product_be` | `e_i = a_i * b_i` | `result = sum(e_i)` |
-| `poly_eq_ee` / `poly_eq_be` | `e_i = a_i*b_i + (1-a_i)*(1-b_i)` | `result = prod(e_i)` |
+Custom hints are a fixed set of built-in calls the prover uses to compute
+values that would be expensive to derive in-circuit — bit
+decompositions, comparisons, integer division, etc. Each is invoked like an
+ordinary function and writes its result into a caller-supplied memory
+location.
-**Note:** `length` must be a compile-time constant. For runtime-known lengths, use `match_range` to dispatch (see example below).
+Like every hint, **the result is unconstrained**: the verifier checks
+nothing about the hinted value. The guest program must add its own
+constraints binding the hinted bits / quotient / remainder / boolean to the
+original input — otherwise a malicious prover can substitute any value. The
+typical pattern is "hint, then assert the relationship":
+```python
+# hint the bits...
+bits = Array(8)
+hint_decompose_bits(value, bits, 8)
+# ...then constrain them to actually equal `value`
+acc: Mut = 0
+for i in unroll(0, 8):
+ assert bits[i] * (bits[i] - 1) == 0 # boolean
+ acc = acc * 2 + bits[i]
+assert acc == value
```
-# Multiply two extension field elements (length=1, default)
-dot_product_ee(x, y, z) # z = x * y
-# Copy extension element (multiply by [1,0,0,0,0]).
-# `ONE_EF_PTR` is a guest-program constant that the program must materialize
-# in its preamble memory at startup; see `crates/rec_aggregation/zkdsl_implem/utils.py`
-# for an example (`build_preamble_memory`).
-dot_product_ee(src, ONE_EF_PTR, dst)
+The full list:
-# Dot product of N extension field elements
-dot_product_ee(coeffs, basis, result, N)
+| Hint | Arguments | Effect |
+| --------------------------------- | ------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------- |
+| `hint_decompose_bits` | `(value, ptr, n_bits)` | Writes `n_bits` big-endian 0/1 field elements at `ptr` (MSB at `ptr[0]`). Requires `n_bits <= 31`. |
+| `hint_decompose_bits_merkle_whir` | `(decomposed_ptr, value, chunk_size)` | Writes `24 / chunk_size` little-endian `chunk_size`-bit chunks of `value` at `decomposed_ptr` (`chunk_size` must divide 24). |
+| `hint_decompose_bits_xmss` | `(decomposed_ptr, to_decompose_ptr, num_to_decompose, chunk_size)` | For each of `num_to_decompose` values at `to_decompose_ptr[..]`, writes its `24 / chunk_size` little-endian chunks at `decomposed_ptr`. |
+| `hint_less_than` | `(a, b, result_ptr)` | `1` at `result_ptr` if `a < b` (canonical integer compare), else `0`. |
+| `hint_log2_ceil` | `(n, result_ptr)` | `ceil(log2(n))` at `result_ptr`. |
+| `hint_div_floor` | `(a, b, q_ptr, r_ptr)` | `floor(a / b)` at `q_ptr`, `a mod b` at `r_ptr` (requires `b != 0`). |
-# Dot product with base-field scalars
-dot_product_be(alpha_powers, coeffs, result, N)
+## Precompiles
-# Extension field addition: c = a + b
-add_ee(a, b, c)
+Precompiles are special instructions in the leanVM ISA, alongside the four
+basic ones (ADD, MUL, DEREF, JUMP). The zkDSL exposes them as built-in
+functions. There are two families: Poseidon hashing and extension-field
+operations.
-# Extension field subtraction via constraint: c = a - b <=> b + c = a
-add_ee(b, c, a)
+### Poseidon16 family
-# Equality polynomial: eq(a, b) = a*b + (1-a)*(1-b)
-poly_eq_ee(a, b, eq_result)
+The variants are as follows:
-# Multi-point equality polynomial: prod_{i=0}^{n-1} eq(a[i], b[i])
-poly_eq_ee(a, b, result, n)
+- **compress vs. permute** — `compress` applies the feed-forward addition
+ (`Poseidon(L || R) + L`); `permute` is the raw 16-cell permutation.
+- **full vs. half output** — `_half` constrains only the first 4 output cells
+ (the rest are unconstrained); useful when the consumer only cares about
+ half a digest.
+- **hardcoded-left** — `_hardcoded_left` reads the first 4 cells of the left
+ input from a compile-time address instead of from `m[L..L+4]`; the last 4
+ cells of the left input still come from memory.
-# Runtime-known length via match_range
-def dot_product_ee_dynamic(a, b, res, n):
- debug_assert(n <= 256)
- match_range(n, range(1, 257), lambda i: dot_product_ee(a, b, res, i))
-```
+Common arguments: `L`, `R` are 8-cell input buffers; `O` is the output
+buffer; `off` (where present) is a compile-time address.
-## Debugging
+| Function | Cells written to `O` | Notes |
+| ------------------------------------------------------- | -------------------- | ----------------------------------------- |
+| `poseidon16_compress(L, R, O)` | `O[0..8]` | `Poseidon(L \|\| R) + L` |
+| `poseidon16_compress_half(L, R, O)` | `O[0..4]` | `O[4..8]` is unconstrained |
+| `poseidon16_compress_hardcoded_left(L, R, O, off)` | `O[0..8]` | left = `m[off..off+4] \|\| m[L..L+4]` |
+| `poseidon16_compress_half_hardcoded_left(L, R, O, off)` | `O[0..4]` | half-output + hardcoded-left composition |
+| `poseidon16_permute(L, R, O)` | `O[0..16]` | raw Poseidon permutation, no feed-forward |
-```
-print(value)
-print(a, b, c)
-```
+### Extension-field operations
-## Example
+Six built-in functions, each reading two length-`n` vectors `a` and `b` and
+writing one extension-field element to `result`. `n` defaults to `1` and must
+be a compile-time constant when given.
+```python
+add_ee(a, b, result, n=1) # result = sum_i (a[i] + b[i])
+dot_product_ee(a, b, result, n=1) # result = sum_i a[i] * b[i]
+poly_eq_ee(a, b, result, n=1) # result = prod_i (a[i]*b[i] + (1-a[i])*(1-b[i]))
```
-SIZE = 8
-def main():
- arr = Array(SIZE)
- for i in unroll(0, SIZE):
- arr[i] = i * i
- sum = compute_sum(arr, SIZE)
- assert sum == 140
- return
+The `_ee` suffix means both `a` and `b` are vectors of *extension*-field
+elements (each occupying `DIM = 5` consecutive cells). The `_be` variants
+(`add_be`, `dot_product_be`, `poly_eq_be`) are identical except `a` is a
+vector of *base*-field elements (1 cell each); `b` and `result` are still
+extension-field.
-def compute_sum(ptr, n: Const):
- acc: Mut = 0
- for i in unroll(0, n):
- acc = acc + ptr[i]
- return acc
-```
+`result` always points to a single extension-field element (5 cells).
-## Line Continuation
+For a runtime `n`, dispatch through `match_range`:
+
+```python
+def dot_product_ee_dynamic(a, b, res, n):
+ debug_assert(n <= 256)
+ match_range(n, range(1, 257), lambda i: dot_product_ee(a, b, res, i))
+```
-Like Python, lines can be continued in two ways:
+Common idioms:
-### Implicit continuation (inside parentheses/brackets/braces)
+```python
+# Multiply two extension elements (n defaults to 1)
+dot_product_ee(x, y, z) # z = x * y
-Expressions inside `()`, `[]`, or `{}` can span multiple lines without any special syntax:
+# Copy an extension element by multiplying by 1
+# (ONE_EF_PTR is a constant materialized in the preamble)
+dot_product_ee(src, ONE_EF_PTR, dst)
+# Extension subtraction: write-once memory turns "c = a + b" into
+# the constraint "b + c = a", i.e. c = a - b
+add_ee(b, c, a) # c = a - b
```
-result = function_call(
- arg1,
- arg2,
- arg3
-)
-ARR = [
- 1,
- 2,
- 3,
-]
+## Debugging
+
+```python
+print(value)
+print(a, b, c)
```
-### Explicit continuation with backslash
+`print` flushes its output during execution; **a Rust-side panic mid-program drops
+buffered prints**. When you need a print to survive a panic, temporarily change
+the print hint in `lean_vm/src/isa/hint.rs (Self::Print)` to `eprint!` directly.
-Long lines can also be split using `\` at the end of a line:
+## Memory layout
-```
-x = very_long_function_name(arg1, \
- arg2, \
- arg3)
+The runner lays out memory as
-y = 1 + 2 + \
- 3 + 4
+```python
+[ public_input (PUBLIC_INPUT_LEN cells) | preamble_memory | runtime ]
```
-The `\` and following newline are replaced with a single space. Any whitespace after `\` and before the newline is ignored.
+- `public_input` is fixed at `PUBLIC_INPUT_LEN = DIGEST_LEN = 8` cells (a hash
+ digest), occupying `memory[0..8]`.
+- `preamble_memory` is a region of `witness.preamble_memory_len` cells the
+ runner reserves immediately after the public input but does **not**
+ initialize. The guest program is expected to fill this region with whatever
+ helper constants it relies on (e.g. a vector of zeros for
+ `dot_product_ee`-as-copy, an extension-field one for multiply-by-one tricks,
+ a vector of ones for batched accumulations, …) at the start of `main`. The
+ names and offsets of these constants are not enshrined within leanVM. See
+ `crates/rec_aggregation/zkdsl_implem/utils.py (build_preamble_memory)` for
+ a concrete example.
+- The runtime region holds the program's stack frames, working memory, and any
+ prover-supplied witness data, all governed by the write-once rule.
## Tips
-1. Use `unroll` for small, fixed-size loops
-2. Use `const` parameters when loop bounds depend on arguments
-3. Use `mut` sparingly - immutable is easier to verify
-4. Use `x: Imu` or `x: Mut` for forward-declaring variables that will be assigned in branches
-5. Match patterns must be consecutive integers (can start from any value)
+1. Prefer `unroll` over `range` for small, fixed-size loops.
+2. Reach for `: Const` parameters when the function body needs `unroll` over the
+ parameter.
+3. `if` / `elif` branches that assign to the same outer variable should
+ forward-declare it (`x: Imm` or `x: Mut`) before the branch.
+7. Function parameters are always immutable. To mutate a parameter's value
+ inside a function, introduce a local `: Mut` alias at the top of the body
+ (e.g. `y: Mut = x`).
-## Example: From high level syntactic sugar to minimal ISA, with read-only memory
+## Example
-Take the following program:
+Look at the recursive aggregation program (to aggregate XMSS) at its entrypoint [main.py](../rec_aggregation/zkdsl_implem/main.py).
-```
+## Compilation step-by-step: zkDSL -> ISA
+
+Starting program:
+
+```python
def main():
x: Mut = 0
y: Mut = 3
@@ -571,9 +762,10 @@ def main():
return
```
-First, we use buffers to handle mutable variables across (non-unrolled) loops.
+Step 1 — the compiler replaces mutable-across-loop variables with index buffers, since memory
+is write-once:
-```
+```python
def main():
x: Mut = 0
y: Mut = 3
@@ -602,10 +794,9 @@ def main():
return
```
-Then, use auxiliary variables to transform it into SSA form (Static Single-Assignment):
-
+Step 2 — SSA-rename all reassignments to fresh names:
-```
+```python
def main():
x = 0
y = 3
@@ -634,9 +825,9 @@ def main():
return
```
-Finally, transform the loop into a recursive function:
+Step 3 — lower the runtime loop to a recursive function:
-```
+```python
def main():
x = 0
y = 3
@@ -647,14 +838,14 @@ def main():
x_buff[0] = x2
y_buff = Array(size + 1)
y_buff[0] = y2
- loop(4, x_buff, y_buff)
+ loop_helper(4, x_buff, y_buff)
x3 = x_buff[size]
y3 = y_buff[size]
assert x3 == 35
assert y3 == 40
return
-def loop(i, x_buff, y_buff):
+def loop_helper(i, x_buff, y_buff):
if i == 6:
return
else:
@@ -668,20 +859,7 @@ def loop(i, x_buff, y_buff):
next_idx = buff_idx + 1
x_buff[next_idx] = x_body3
y_buff[next_idx] = y_body3
- loop(i + 1, x_buff, y_buff)
+ loop_helper(i + 1, x_buff, y_buff)
return
```
-## Dev experience
-
-If using VScode, add the following to your local settings `.vscode/settings.json` :
-
-```json
-{
- "python.analysis.extraPaths": [
- "./crates/lean_compiler"
- ],
-}
-```
-
-(you will get better linting for the zkDSL files starting with `from snark_lib import *`, since it will expose zkDSL special functions from `crates/lean_compiler/snark_lib.py`).
\ No newline at end of file
diff --git a/crates/lean_prover/src/lib.rs b/crates/lean_prover/src/lib.rs
index a9ee93df..c112c5a7 100644
--- a/crates/lean_prover/src/lib.rs
+++ b/crates/lean_prover/src/lib.rs
@@ -67,10 +67,6 @@ pub enum ProverError {
Runner(RunnerError),
UnknownMessage,
MultipleMessages,
- InvalidPublicInputSize {
- expected: usize,
- actual: usize,
- },
InvalidChildProof(ProofError),
LimitExceeded {
what: &'static str,
@@ -104,9 +100,6 @@ impl Display for ProverError {
Self::Runner(e) => write!(f, "{}", e),
Self::UnknownMessage => write!(f, "Unknown message, not part of the type2"),
Self::MultipleMessages => write!(f, "Multiple common messages in the type2"),
- Self::InvalidPublicInputSize { expected, actual } => {
- write!(f, "Invalid public input size: expected {}, actual {}", expected, actual)
- }
Self::InvalidChildProof(e) => write!(f, "Invalid child proof: {}", e),
Self::LimitExceeded { what, actual, max } => {
write!(f, "Too many {}: {} (max {})", what, actual, max)
diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs
index 1f985908..fb8fe1e3 100644
--- a/crates/lean_prover/src/prove_execution.rs
+++ b/crates/lean_prover/src/prove_execution.rs
@@ -19,7 +19,7 @@ pub struct ExecutionProof {
pub fn prove_execution(
bytecode: &Bytecode,
- public_input: &[F],
+ public_input: &[F; PUBLIC_INPUT_LEN],
witness: &ExecutionWitness,
whir_config: &WhirConfigBuilder,
vm_profiler: bool,
@@ -27,15 +27,8 @@ pub fn prove_execution(
check_rate(whir_config.starting_log_inv_rate)
.map_err(|err| panic!("{err}"))
.unwrap();
- if public_input.len() != PUBLIC_INPUT_LEN {
- return Err(ProverError::InvalidPublicInputSize {
- expected: PUBLIC_INPUT_LEN,
- actual: public_input.len(),
- });
- }
let ExecutionTrace {
traces,
- public_memory_size,
mut memory, // padded with zeros to next power of two
metadata,
} = info_span!("Witness generation").in_scope(|| -> Result<_, ProverError> {
@@ -232,8 +225,8 @@ pub fn prove_execution(
committed_statements.get_mut(table).unwrap().push(claim);
}
- let public_memory_random_point = MultilinearPoint(prover_state.sample_vec(log2_strict_usize(public_memory_size)));
- let public_memory_eval = (&memory[..public_memory_size]).evaluate(&public_memory_random_point);
+ let public_memory_random_point = MultilinearPoint(prover_state.sample_vec(log2_strict_usize(PUBLIC_INPUT_LEN)));
+ let public_memory_eval = (&memory[..PUBLIC_INPUT_LEN]).evaluate(&public_memory_random_point);
let previous_statements = vec![
SparseStatement::new(
diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs
index 91b3f76b..767a2756 100644
--- a/crates/lean_prover/src/test_zkvm.rs
+++ b/crates/lean_prover/src/test_zkvm.rs
@@ -97,7 +97,7 @@ fn all_precompiles_flags(loop_iters: usize) -> CompilationFlags {
}
}
-fn all_precompiles_witness() -> (Vec, ExecutionWitness) {
+fn all_precompiles_witness() -> ([F; PUBLIC_INPUT_LEN], ExecutionWitness) {
let mut rng = StdRng::seed_from_u64(0);
let mut scratch = F::zero_vec(8192);
@@ -194,7 +194,7 @@ fn all_precompiles_witness() -> (Vec, ExecutionWitness) {
.fold(EF::ONE, |acc, x| acc * x);
scratch[1300..][..DIMENSION].copy_from_slice(poly_eq_ee_result.as_basis_coefficients_slice());
- let mut public_input = vec![F::ZERO; PUBLIC_INPUT_LEN];
+ let mut public_input = [F::ZERO; PUBLIC_INPUT_LEN];
public_input[..4].copy_from_slice(&hardcoded_prefix);
let mut hints = std::collections::HashMap::new();
@@ -326,7 +326,7 @@ def fibonacci_const(a, b, n: Const):
);
}
-fn test_zk_vm_helper(program_str: &str, public_input: &[F]) {
+fn test_zk_vm_helper(program_str: &str, public_input: &[F; PUBLIC_INPUT_LEN]) {
test_zk_vm_helper_with_witness(
program_str,
public_input,
@@ -337,7 +337,7 @@ fn test_zk_vm_helper(program_str: &str, public_input: &[F]) {
fn test_zk_vm_helper_with_witness(
program_str: &str,
- public_input: &[F],
+ public_input: &[F; PUBLIC_INPUT_LEN],
witness: ExecutionWitness,
flags: CompilationFlags,
) {
diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs
index 45fc3057..5bee4a97 100644
--- a/crates/lean_prover/src/trace_gen.rs
+++ b/crates/lean_prover/src/trace_gen.rs
@@ -6,7 +6,6 @@ use utils::{ToUsize, get_poseidon_16_of_zero, transposed_par_iter_mut};
#[derive(Debug)]
pub struct ExecutionTrace {
pub traces: BTreeMap,
- pub public_memory_size: usize,
pub memory: Vec, // of length a multiple of public_memory_size
pub metadata: ExecutionMetadata,
}
@@ -171,7 +170,6 @@ pub fn get_execution_trace(
ExecutionTrace {
traces,
- public_memory_size: execution_result.public_memory_size,
memory: memory_padded,
metadata: execution_result.metadata,
}
diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs
index c1886acd..173d2a8d 100644
--- a/crates/lean_prover/src/verify_execution.rs
+++ b/crates/lean_prover/src/verify_execution.rs
@@ -14,7 +14,7 @@ pub struct ProofVerificationDetails {
pub fn verify_execution(
bytecode: &Bytecode,
- public_input: &[F],
+ public_input: &[F; PUBLIC_INPUT_LEN],
proof: Proof,
) -> Result<(ProofVerificationDetails, RawProof), ProofError> {
if bytecode.log_size() > MAX_BYTECODE_LOG_SIZE {
@@ -23,9 +23,6 @@ pub fn verify_execution(
max_log_size: MAX_BYTECODE_LOG_SIZE,
});
}
- if public_input.len() != PUBLIC_INPUT_LEN {
- return Err(ProofError::InvalidProof);
- }
let mut verifier_state =
VerifierState::::new(proof, get_poseidon16().clone(), fiat_shamir_domain_sep(bytecode))?;
verifier_state.observe_scalars(public_input);
@@ -58,8 +55,6 @@ pub fn verify_execution(
return Err(ProofError::InvalidProof);
}
- let public_memory = padd_with_zero_to_next_power_of_two(public_input);
-
if !(MIN_LOG_MEMORY_SIZE..=MAX_LOG_MEMORY_SIZE).contains(&log_memory) {
return Err(ProofError::InvalidProof);
}
@@ -175,9 +170,8 @@ pub fn verify_execution(
return Err(ProofError::InvalidProof);
}
- let public_memory_random_point =
- MultilinearPoint(verifier_state.sample_vec(log2_strict_usize(public_memory.len())));
- let public_memory_eval = public_memory.evaluate(&public_memory_random_point);
+ let public_memory_random_point = MultilinearPoint(verifier_state.sample_vec(log2_strict_usize(public_input.len())));
+ let public_memory_eval = public_input.evaluate(&public_memory_random_point);
let previous_statements = vec![
SparseStatement::new(
diff --git a/crates/lean_vm/src/diagnostics/exec_result.rs b/crates/lean_vm/src/diagnostics/exec_result.rs
index dcb1ae0c..2024fa08 100644
--- a/crates/lean_vm/src/diagnostics/exec_result.rs
+++ b/crates/lean_vm/src/diagnostics/exec_result.rs
@@ -72,7 +72,6 @@ impl ExecutionMetadata {
#[derive(Debug)]
pub struct ExecutionResult {
pub runtime_memory_size: usize,
- pub public_memory_size: usize,
pub memory: Memory,
pub pcs: Vec,
pub fps: Vec,
diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs
index 4cc5d721..50ff5429 100644
--- a/crates/lean_vm/src/execution/runner.rs
+++ b/crates/lean_vm/src/execution/runner.rs
@@ -1,6 +1,6 @@
//! VM execution runner
-use crate::core::{DIMENSION, F};
+use crate::core::{DIMENSION, F, PUBLIC_INPUT_LEN};
use crate::diagnostics::{ExecutionMetadata, ExecutionResult, RunnerError};
use crate::execution::memory::MemoryAccess;
use crate::execution::{ExecutionHistory, Memory};
@@ -10,7 +10,7 @@ use crate::isa::instruction::{InstructionContext, InstructionCounts};
use crate::{ALL_TABLES, CodeAddress, HintExecutionContext, MemOrConstant, N_TABLES, STARTING_PC, Table, TableTrace};
use backend::*;
use std::collections::{BTreeMap, BTreeSet, HashMap};
-use utils::{ToUsize, padd_with_zero_to_next_power_of_two};
+use utils::ToUsize;
use super::memory::SegmentMemory;
@@ -27,7 +27,7 @@ pub struct ExecutionWitness {
pub fn try_execute_bytecode(
bytecode: &Bytecode,
- public_input: &[F],
+ public_input: &[F; PUBLIC_INPUT_LEN],
witness: &ExecutionWitness,
profiling: bool,
) -> Result {
@@ -58,7 +58,7 @@ pub fn try_execute_bytecode(
pub fn execute_bytecode(
bytecode: &Bytecode,
- public_input: &[F],
+ public_input: &[F; PUBLIC_INPUT_LEN],
witness: &ExecutionWitness,
profiling: bool,
) -> ExecutionResult {
@@ -239,7 +239,7 @@ fn resolve_deref_hints(memory: &mut Memory, pending: &[(usize, usize)]) -> Resul
#[allow(clippy::too_many_arguments)]
fn execute_bytecode_helper(
bytecode: &Bytecode,
- public_input: &[F],
+ public_input: &[F; PUBLIC_INPUT_LEN],
witness: &ExecutionWitness,
std_out: &mut String,
instruction_history: &mut ExecutionHistory,
@@ -250,10 +250,9 @@ fn execute_bytecode_helper(
.iter()
.map(|(name, entries)| (name.clone(), NamedHintCursor::new(entries)))
.collect();
- let public_memory = padd_with_zero_to_next_power_of_two(public_input);
- let public_memory_size = public_memory.len();
+ let public_memory = public_input.to_vec();
let mut memory = Memory::new(public_memory);
- let mut fp = public_memory_size + witness.preamble_memory_len;
+ let mut fp = PUBLIC_INPUT_LEN + witness.preamble_memory_len;
fp = fp.next_multiple_of(DIMENSION);
let initial_ap = fp + bytecode.starting_frame_memory;
let mut pc = STARTING_PC;
@@ -327,7 +326,7 @@ fn execute_bytecode_helper(
} else {
None
};
- let runtime_memory_size = memory.0.len() - public_memory_size - witness.preamble_memory_len;
+ let runtime_memory_size = memory.0.len() - PUBLIC_INPUT_LEN - witness.preamble_memory_len;
let used_memory_cells = memory.0.par_iter().filter(|&&x| x.is_some()).count();
let metadata = ExecutionMetadata {
cycles: trace.pcs.len(),
@@ -335,7 +334,7 @@ fn execute_bytecode_helper(
n_poseidons: trace.tables[&Table::poseidon16()].columns[0].len(),
n_extension_ops: trace.tables[&Table::extension_op()].columns[0].len(),
bytecode_size: bytecode.code.len(),
- public_input_size: public_input.len(),
+ public_input_size: PUBLIC_INPUT_LEN,
runtime_memory: runtime_memory_size,
memory_usage_percent: used_memory_cells as f64 / memory.0.len() as f64 * 100.0,
stdout: std::mem::take(std_out),
@@ -343,7 +342,6 @@ fn execute_bytecode_helper(
};
Ok(ExecutionResult {
runtime_memory_size: no_vec_runtime_memory,
- public_memory_size,
memory,
pcs: trace.pcs,
fps: trace.fps,
diff --git a/crates/rec_aggregation/src/type_1_aggregation.rs b/crates/rec_aggregation/src/type_1_aggregation.rs
index da3455f7..2f9dd09d 100644
--- a/crates/rec_aggregation/src/type_1_aggregation.rs
+++ b/crates/rec_aggregation/src/type_1_aggregation.rs
@@ -288,7 +288,7 @@ pub(crate) fn aggregate_type_1_with_min_padding(
&reduced_claims.final_claim_flat(),
bytecode,
);
- let public_input = poseidon_compress_slice(&pub_input_data).to_vec();
+ let public_input = poseidon_compress_slice(&pub_input_data);
let mut claimed: HashSet = HashSet::new();
let mut dup_pub_keys: Vec = Vec::new();
diff --git a/crates/rec_aggregation/src/type_2_aggregation.rs b/crates/rec_aggregation/src/type_2_aggregation.rs
index bf5706d8..6e93f299 100644
--- a/crates/rec_aggregation/src/type_2_aggregation.rs
+++ b/crates/rec_aggregation/src/type_2_aggregation.rs
@@ -112,7 +112,7 @@ pub fn merge_many_type_1(
let digests: Vec<[F; DIGEST_LEN]> = verified_children.iter().map(|v| v.input_data_hash).collect();
let pub_input_data = build_type2_input_data(&digests, &reduced_claims.final_claim_flat());
- let public_input_digest = poseidon_compress_slice(&pub_input_data).to_vec();
+ let public_input_digest = poseidon_compress_slice(&pub_input_data);
let bytecode_value_hint_blobs: Vec> = verified_children
.iter()
diff --git a/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py b/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py
index 9e92af76..d2395529 100644
--- a/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py
+++ b/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py
@@ -177,8 +177,8 @@ def fs_receive_ef_inlined(fs, n):
def fs_receive_ef_by_log_dynamic(fs, log_n, min_value: Const, max_value: Const):
debug_assert(log_n < max_value)
debug_assert(min_value <= log_n)
- new_fs: Imu
- ef_ptr: Imu
+ new_fs: Imm
+ ef_ptr: Imm
new_fs, ef_ptr = match_range(log_n, range(min_value, max_value), lambda ln: fs_receive_ef(fs, 2**ln))
return new_fs, ef_ptr
diff --git a/crates/rec_aggregation/zkdsl_implem/hashing.py b/crates/rec_aggregation/zkdsl_implem/hashing.py
index 5146ee82..00b8a700 100644
--- a/crates/rec_aggregation/zkdsl_implem/hashing.py
+++ b/crates/rec_aggregation/zkdsl_implem/hashing.py
@@ -111,8 +111,8 @@ def euclidian_div_runtime(a, b):
# Requires:
# 1 <= b < 2^14
# floor(a / b) < 2^16 (so that q*b + r stays well below p)
- q: Imu
- r: Imu
+ q: Imm
+ r: Imm
hint_div_floor(a, b, q, r)
assert r < b
assert q < 2 ** 16
diff --git a/crates/rec_aggregation/zkdsl_implem/recursion.py b/crates/rec_aggregation/zkdsl_implem/recursion.py
index 3b949c44..af081d4c 100644
--- a/crates/rec_aggregation/zkdsl_implem/recursion.py
+++ b/crates/rec_aggregation/zkdsl_implem/recursion.py
@@ -610,7 +610,8 @@ def fingerprint_n(domsep, data_evals, n, logup_alphas_eq_poly):
return res
-def verify_gkr_quotient(fs: Mut, n_vars):
+def verify_gkr_quotient(prev_fs, n_vars):
+ fs: Mut = prev_fs
fs, nums = fs_receive_ef_inlined(fs, LOGUP_GKR_N_COEFFS_SENT)
fs, denoms = fs_receive_ef_inlined(fs, LOGUP_GKR_N_COEFFS_SENT)
@@ -653,13 +654,16 @@ def verify_gkr_quotient(fs: Mut, n_vars):
)
-def verify_gkr_quotient_step(fs: Mut, n_vars, point, claim_num, claim_den):
+def verify_gkr_quotient_step(prev_fs, n_vars, point, claim_num, claim_den):
+ fs: Mut = prev_fs
fs = fs_duplex(fs)
fs, alpha = fs_sample_ef(fs)
alpha_mul_claim_den = mul_extension_ret(alpha, claim_den)
num_plus_alpha_mul_claim_den = add_extension_ret(claim_num, alpha_mul_claim_den)
postponed_point = Array((n_vars + 1) * DIM)
- fs, postponed_value = sumcheck_verify_reversed_helper(fs, n_vars, num_plus_alpha_mul_claim_den, 3, postponed_point)
+ fs, postponed_value = sumcheck_verify_reversed_helper(
+ fs, n_vars, num_plus_alpha_mul_claim_den, 3, postponed_point
+ )
fs, inner_evals = fs_receive_ef_inlined(fs, 4)
a_num = inner_evals
b_num = inner_evals + DIM
@@ -705,7 +709,7 @@ def compute_total_gkr_n_vars(log_memory, log_bytecode_padded, tables_heights):
def evaluate_air_constraints(table_index, inner_evals, air_alpha_powers, logup_alphas_eq_poly):
- res: Imu
+ res: Imm
debug_assert(table_index < N_TABLES)
match table_index:
case 0:
diff --git a/crates/rec_aggregation/zkdsl_implem/utils.py b/crates/rec_aggregation/zkdsl_implem/utils.py
index b95eedfa..d4b2e983 100644
--- a/crates/rec_aggregation/zkdsl_implem/utils.py
+++ b/crates/rec_aggregation/zkdsl_implem/utils.py
@@ -236,7 +236,7 @@ def mle_of_01234567_etc(point, n):
@inline
def checked_less_than(a, b):
- res: Imu
+ res: Imm
hint_less_than(a, b, res)
assert res * (1 - res) == 0
if res == 1:
@@ -249,7 +249,7 @@ def checked_less_than(a, b):
@inline
def maximum(a, b):
is_a_less_than_b = checked_less_than(a, b)
- res: Imu
+ res: Imm
if is_a_less_than_b == 1:
res = b
else:
@@ -809,7 +809,7 @@ def _verify_log2_large(n, log2: Const):
def log2_ceil_runtime(n):
# requires: 2 < n <= 2^30
- log2: Imu
+ log2: Imm
hint_log2_ceil(n, log2)
assert log2 < 31
if two_exp(log2) != n:
diff --git a/crates/rec_aggregation/zkdsl_implem/whir.py b/crates/rec_aggregation/zkdsl_implem/whir.py
index 3124f253..d14a10ef 100644
--- a/crates/rec_aggregation/zkdsl_implem/whir.py
+++ b/crates/rec_aggregation/zkdsl_implem/whir.py
@@ -16,14 +16,17 @@
def whir_open(
- fs: Mut,
+ prev_fs,
n_vars,
initial_log_inv_rate,
- root: Mut,
+ prev_root,
ood_points_commit,
combination_randomness_powers_0,
- claimed_sum: Mut,
+ prev_claimed_sum,
):
+ fs: Mut = prev_fs
+ root: Mut = prev_root
+ claimed_sum: Mut = prev_claimed_sum
n_rounds, n_final_vars, num_queries, num_oods, query_grinding_bits, folding_grinding = get_whir_params(
n_vars, initial_log_inv_rate
)
@@ -39,7 +42,7 @@ def whir_open(
domain_sz: Mut = n_vars + initial_log_inv_rate
for r in range(0, n_rounds):
- is_first_round: Imu
+ is_first_round: Imm
if r == 0:
is_first_round = 1
else:
@@ -175,13 +178,15 @@ def whir_open(
return fs, folding_randomness_global, s, final_value, end_sum
-def sumcheck_verify(fs: Mut, n_steps, claimed_sum, degree: Const):
+def sumcheck_verify(fs, n_steps, claimed_sum, degree: Const):
challenges = Array(n_steps * DIM)
- fs, new_claimed_sum = sumcheck_verify_helper(fs, n_steps, claimed_sum, degree, challenges)
- return fs, challenges, new_claimed_sum
+ new_fs, new_claimed_sum = sumcheck_verify_helper(fs, n_steps, claimed_sum, degree, challenges)
+ return new_fs, challenges, new_claimed_sum
-def sumcheck_verify_helper(fs: Mut, n_steps, claimed_sum: Mut, degree: Const, challenges):
+def sumcheck_verify_helper(prev_fs, n_steps, prev_claimed_sum, degree: Const, challenges):
+ fs: Mut = prev_fs
+ claimed_sum: Mut = prev_claimed_sum
for sc_round in range(0, n_steps):
fs, poly = fs_receive_ef_inlined(fs, degree + 1)
polynomial_sum_at_0_and_1(poly, degree, claimed_sum)
@@ -192,10 +197,10 @@ def sumcheck_verify_helper(fs: Mut, n_steps, claimed_sum: Mut, degree: Const, ch
return fs, claimed_sum
-def sumcheck_verify_reversed(fs: Mut, n_steps, claimed_sum: Mut, degree: Const):
+def sumcheck_verify_reversed(fs, n_steps, claimed_sum, degree: Const):
challenges = Array(n_steps * DIM)
- fs, new_claimed_sum = sumcheck_verify_reversed_helper(fs, n_steps, claimed_sum, degree, challenges)
- return fs, challenges, new_claimed_sum
+ new_fs, final_claimed_sum = sumcheck_verify_reversed_helper(fs, n_steps, claimed_sum, degree, challenges)
+ return new_fs, challenges, final_claimed_sum
def sumcheck_verify_reversed_helper(fs, n_steps, claimed_sum, degree: Const, challenges):
@@ -208,7 +213,9 @@ def sumcheck_verify_reversed_helper(fs, n_steps, claimed_sum, degree: Const, cha
return new_fd, final_sum
-def sumcheck_verify_reversed_helper_const(fs: Mut, n_steps: Const, claimed_sum: Mut, degree: Const, challenges):
+def sumcheck_verify_reversed_helper_const(prev_fs, n_steps: Const, prev_claimed_sum, degree: Const, challenges):
+ fs: Mut = prev_fs
+ claimed_sum: Mut = prev_claimed_sum
for sc_round in unroll(0, n_steps):
fs, poly = fs_receive_ef_inlined(fs, degree + 1)
polynomial_sum_at_0_and_1(poly, degree, claimed_sum)
@@ -219,7 +226,9 @@ def sumcheck_verify_reversed_helper_const(fs: Mut, n_steps: Const, claimed_sum:
return fs, claimed_sum
-def sumcheck_verify_with_grinding(fs: Mut, n_steps, claimed_sum: Mut, degree: Const, folding_grinding_bits):
+def sumcheck_verify_with_grinding(prev_fs, n_steps, prev_claimed_sum, degree: Const, folding_grinding_bits):
+ fs: Mut = prev_fs
+ claimed_sum: Mut = prev_claimed_sum
challenges = Array(n_steps * DIM)
for sc_round in range(0, n_steps):
fs, poly = fs_receive_ef_inlined(fs, degree + 1)
@@ -285,7 +294,7 @@ def decompose_and_verify_merkle_batch_const(
def sample_stir_indexes_and_fold(
- fs: Mut,
+ prev_fs,
num_queries,
merkle_leaves_in_basefield,
folding_factor,
@@ -295,6 +304,7 @@ def sample_stir_indexes_and_fold(
folding_randomness,
query_grinding_bits,
):
+ fs: Mut = prev_fs
folded_domain_size = domain_size - folding_factor
fs = fs_grinding(fs, query_grinding_bits)
@@ -303,7 +313,7 @@ def sample_stir_indexes_and_fold(
merkle_leaves = Array(num_queries)
circle_values = Array(num_queries)
- n_chunks_per_answer: Imu
+ n_chunks_per_answer: Imm
# the number of chunk of 8 field elements per merkle leaf opened
if merkle_leaves_in_basefield == 1:
n_chunks_per_answer = two_pow_folding_factor
@@ -335,7 +345,7 @@ def sample_stir_indexes_and_fold(
def whir_round(
- fs: Mut,
+ prev_fs,
prev_root,
folding_factor,
two_pow_folding_factor,
@@ -347,6 +357,7 @@ def whir_round(
num_ood,
folding_grinding_bits,
):
+ fs: Mut = prev_fs
fs, folding_randomness, new_claimed_sum_a = sumcheck_verify_with_grinding(
fs, folding_factor, claimed_sum, 2, folding_grinding_bits
)
@@ -398,21 +409,24 @@ def polynomial_sum_at_0_and_1(coeffs, degree, dst):
return
-def parse_commitment(fs: Mut, num_ood):
- root: Imu
- ood_points: Imu
- ood_evals: Imu
+def parse_commitment(fs, num_ood):
+ root: Imm
+ ood_points: Imm
+ ood_evals: Imm
debug_assert(num_ood < 5)
debug_assert(num_ood != 0)
- fs, root, ood_points, ood_evals = match_range(num_ood, range(1, 5), lambda n: parse_whir_commitment_const(fs, n))
- return fs, root, ood_points, ood_evals
+ new_fs, root, ood_points, ood_evals = match_range(
+ num_ood, range(1, 5), lambda n: parse_whir_commitment_const(fs, n)
+ )
+ return new_fs, root, ood_points, ood_evals
-def parse_whir_commitment_const(fs: Mut, num_ood: Const):
- fs, root = fs_receive_chunks(fs, 1)
- fs, ood_points = fs_sample_many_ef(fs, num_ood)
- fs, ood_evals = fs_receive_ef_inlined(fs, num_ood)
- return fs, root, ood_points, ood_evals
+def parse_whir_commitment_const(fs, num_ood: Const):
+ new_fs: Mut
+ new_fs, root = fs_receive_chunks(fs, 1)
+ new_fs, ood_points = fs_sample_many_ef(new_fs, num_ood)
+ new_fs, ood_evals = fs_receive_ef_inlined(new_fs, num_ood)
+ return new_fs, root, ood_points, ood_evals
@inline
@@ -427,15 +441,15 @@ def get_whir_params(n_vars, log_inv_rate):
debug_assert(MIN_WHIR_LOG_INV_RATE <= log_inv_rate)
debug_assert(log_inv_rate <= MAX_WHIR_LOG_INV_RATE)
- num_queries: Imu
+ num_queries: Imm
num_queries = get_num_queries(log_inv_rate, n_vars)
- query_grinding_bits: Imu
+ query_grinding_bits: Imm
query_grinding_bits = get_query_grinding_bits(log_inv_rate, n_vars)
num_oods = get_num_oods(log_inv_rate, n_vars)
- folding_grinding: Imu
+ folding_grinding: Imm
folding_grinding = get_folding_grinding(log_inv_rate, n_vars)
return n_rounds, final_vars, num_queries, num_oods, query_grinding_bits, folding_grinding