Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Minimal hash-based zkVM, targeting recursion and aggregation of hash-based signa

<p align="center">
<a href="minimal_zkVM.pdf"><img src="https://img.shields.io/badge/Documentation-blue?style=for-the-badge&logo=data:image/svg%2bxml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCAyNCAyNCIgZmlsbD0id2hpdGUiPjxwYXRoIGQ9Ik0xNCAySDZjLTEuMSAwLTIgLjktMiAydjE2YzAgMS4xLjg5IDIgMS45OSAySDE4YzEuMSAwIDItLjkgMi0yVjhsLTYtNnpNOC41IDE0LjVoMS4yNWMuOTcgMCAxLjc1LS43OCAxLjc1LTEuNzVTMTAuNzIgMTEgOS43NSAxMUg3LjV2Nmgxdi0yLjV6bTAtMVYxMmgxLjI1Yy40MSAwIC43NS4zNC43NS43NXMtLjM0Ljc1LS43NS43NUg4LjV6bTUuNSAzLjVoMnYtMWgtMnYtMWgydi0xaC0ydi0xLjVjMC0uMjguMjItLjUuNS0uNUgxN3YtMWgtMmMtLjgzIDAtMS41LjY3LTEuNSAxLjVWMTd6TTEzIDlWMy41TDE4LjUgOUgxM3oiLz48L3N2Zz4=" alt="Documentation"></a>
<a href="crates/lean_compiler/zkDSL.md"><img src="https://img.shields.io/badge/zkDSL%20reference-7c3aed?style=for-the-badge&logo=markdown&logoColor=white" alt="zkDSL reference"></a>
<a href="crates/lean_prover/python-verifier/verifier.py"><img src="https://img.shields.io/badge/Python%20verifier-d97706?style=for-the-badge&logo=python&logoColor=white" alt="Python verifier"></a>
</p>

Expand Down
2 changes: 1 addition & 1 deletion crates/lean_compiler/snark_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Type annotations
Mut = Any
Const = Any
Imu = Any
Imm = Any


# @inline decorator (does nothing in Python execution)
Expand Down
16 changes: 3 additions & 13 deletions crates/lean_compiler/src/a_simplify_lang/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,21 +332,14 @@ pub fn simplify_program(mut program: Program) -> Result<SimpleProgram, String> {
let mut array_manager = ArrayManager::default();
let mut mut_tracker = MutableVarTracker::default();

// Register mutable arguments and capture their initial versioned names
// BEFORE simplifying the body
// All arguments are immutable; record them as assigned to detect illegal reassignment.
let arguments: Vec<Var> = func
.arguments
.iter()
.map(|arg| {
assert!(!arg.is_const);
if arg.is_mutable {
mut_tracker.register_mutable(&arg.name);
// Capture the initial versioned name (version 0)
mut_tracker.current_name(&arg.name)
} else {
mut_tracker.assigned.insert(arg.name.clone());
arg.name.clone()
}
mut_tracker.assigned.insert(arg.name.clone());
arg.name.clone()
})
.collect();

Expand Down Expand Up @@ -398,9 +391,6 @@ fn compile_time_transform_in_program(
.collect();

for func in inlined_functions.values() {
if func.has_mutable_arguments() {
return Err("Inlined functions with mutable arguments are not supported yet".to_string());
}
if func.has_const_arguments() {
return Err(format!(
"Inlined function should not have \"Const\" arguments (function \"{}\")",
Expand Down
4 changes: 2 additions & 2 deletions crates/lean_compiler/src/grammar.pest
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ return_statement = { "return" ~ (("(" ~ tuple_expression ~ ")") | tuple_expressi

mut_keyword = @{ "mut" ~ !(ASCII_ALPHANUMERIC | "_") }
mut_annotation = { ":" ~ "Mut" }
im_annotation = { ":" ~ "Imu" }
im_annotation = { ":" ~ "Imm" }

// Forward declaration: x: Imu or x: Mut (not followed by =)
// Forward declaration: x: Imm or x: Mut (not followed by =)
forward_declaration = { identifier ~ (im_annotation | mut_annotation) ~ !("=") }

// General assignment: LHS is optional, RHS is any expression
Expand Down
8 changes: 1 addition & 7 deletions crates/lean_compiler/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ impl Program {
pub struct FunctionArg {
pub name: Var,
pub is_const: bool,
pub is_mutable: bool,
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
Expand All @@ -48,9 +47,6 @@ impl Function {
pub fn has_const_arguments(&self) -> bool {
self.arguments.iter().any(|arg| arg.is_const)
}
pub fn has_mutable_arguments(&self) -> bool {
self.arguments.iter().any(|arg| arg.is_mutable)
}
}

pub type Var = String;
Expand Down Expand Up @@ -663,7 +659,7 @@ impl Line {
if *is_mutable {
format!("{var}: Mut")
} else {
format!("{var}: Imu")
format!("{var}: Imm")
}
}
Self::Statement { targets, value, .. } => {
Expand Down Expand Up @@ -899,8 +895,6 @@ impl Display for Function {
.map(|arg| {
if arg.is_const {
format!("const {}", arg.name)
} else if arg.is_mutable {
format!("mut {}", arg.name)
} else {
arg.name.to_string()
}
Expand Down
8 changes: 6 additions & 2 deletions crates/lean_compiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,19 @@ pub fn compile_program(input: &ProgramSource) -> Bytecode {
try_compile_program(input).unwrap()
}

pub fn try_compile_and_run(input: &ProgramSource, public_input: &[F], profiler: bool) -> Result<String, Error> {
pub fn try_compile_and_run(
input: &ProgramSource,
public_input: &[F; PUBLIC_INPUT_LEN],
profiler: bool,
) -> Result<String, Error> {
let bytecode = try_compile_program(input)?;
let witness = ExecutionWitness::default();
let result = try_execute_bytecode(&bytecode, public_input, &witness, profiler)?;
println!("{}", result.metadata.display());
Ok(result.metadata.display())
}

pub fn compile_and_run(input: &ProgramSource, public_input: &[F], profiler: bool) {
pub fn compile_and_run(input: &ProgramSource, public_input: &[F; PUBLIC_INPUT_LEN], profiler: bool) {
let summary = try_compile_and_run(input, public_input, profiler).unwrap();
println!("{summary}");
}
25 changes: 15 additions & 10 deletions crates/lean_compiler/src/parser/parsers/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ pub const RESERVED_FUNCTION_NAMES: &[&str] = &[
// Built-in functions
"print",
"Array",
"hint_witness",
// Compile-time only functions
"len",
"log2_ceil",
"next_multiple_of",
"saturating_sub",
"div_ceil",
"div_floor",
"range",
"parallel_range",
"match_range",
Expand Down Expand Up @@ -155,22 +158,24 @@ impl Parse<FunctionArg> for ParameterParser {
.into());
}

// Check for optional type annotation (: Const or : Mut)
let (is_const, is_mutable) = if let Some(annotation) = inner.next() {
// Check for optional type annotation (: Const). ': Mut' parameters are forbidden.
let is_const = if let Some(annotation) = inner.next() {
match annotation.as_str().trim() {
": Const" => (true, false),
": Mut" => (false, true),
": Const" => true,
": Mut" => {
return Err(SemanticError::new(format!(
"Parameter '{name}' cannot be declared ': Mut'. Mutable parameters are not allowed; \
introduce a local '{name}_mut: Mut = {name}' instead."
))
.into());
}
other => return Err(SemanticError::new(format!("Invalid parameter annotation: {other}")).into()),
}
} else {
(false, false)
false
};

Ok(FunctionArg {
name,
is_const,
is_mutable,
})
Ok(FunctionArg { name, is_const })
}
}

Expand Down
4 changes: 2 additions & 2 deletions crates/lean_compiler/src/parser/parsers/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ impl<const DEBUG: bool> Parse<Line> for AssertParser<DEBUG> {
}
}

/// Parser for forward declarations: `x: Imu` or `x: Mut`
/// Parser for forward declarations: `x: Imm` or `x: Mut`
pub struct ForwardDeclarationParser;

impl Parse<Line> for ForwardDeclarationParser {
Expand All @@ -297,7 +297,7 @@ impl Parse<Line> for ForwardDeclarationParser {
// Parse variable name
let var = next_inner_pair(&mut inner, "variable name")?.as_str().to_string();

// Check for : Mut or : Imu annotation
// Check for : Mut or : Imm annotation
let annotation = next_inner_pair(&mut inner, "type annotation")?;
let is_mutable = annotation.as_rule() == Rule::mut_annotation;

Expand Down
85 changes: 48 additions & 37 deletions crates/lean_compiler/tests/test_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,23 @@ use backend::{BasedVectorSpace, PrimeCharacteristicRing};
use lean_compiler::*;
use lean_vm::*;
use rand::{RngExt, SeedableRng, rngs::StdRng};
use utils::poseidon16_compress;

#[test]
fn test_poseidon() {
let program = r#"
def main():
a = 0
b = a + 8
c = Array(8)
poseidon16_compress(a, b, c)

for i in range(0, 8):
cc = c[i]
print(cc)
return
"#;
let public_input: [F; 16] = (0..16).map(F::new).collect::<Vec<F>>().try_into().unwrap();
compile_and_run(&ProgramSource::Raw(program.to_string()), &public_input, false);

let _ = dbg!(poseidon16_compress(public_input));
}

#[test]
fn test_div_extension_field() {
let program = r#"
DIM = 5

def main():
n = 0
d = n + DIM
q = n + 2 * DIM
nd = Array(2 * DIM)
hint_witness("nd", nd)
n = nd
d = nd + DIM
expected_q = Array(DIM)
hint_witness("q", expected_q)
computed_q_1 = div_ext_1(n, d)
computed_q_2 = div_ext_2(n, d)
assert_eq_ext(computed_q_2, q)
assert_eq_ext(computed_q_1, q)
assert_eq_ext(computed_q_2, expected_q)
assert_eq_ext(computed_q_1, expected_q)
return

def assert_eq_ext(x, y):
Expand All @@ -61,12 +43,19 @@ def div_ext_2(n, d):
let n: EF = rng.random();
let d: EF = rng.random();
let q = n / d;
let mut public_input = vec![];
public_input.extend(n.as_basis_coefficients_slice());
public_input.extend(d.as_basis_coefficients_slice());
public_input.extend(q.as_basis_coefficients_slice());
public_input.resize(16, F::ZERO);
compile_and_run(&ProgramSource::Raw(program.to_string()), &public_input, false);
let mut nd_buf: Vec<F> = Vec::new();
nd_buf.extend(n.as_basis_coefficients_slice());
nd_buf.extend(d.as_basis_coefficients_slice());
let q_buf: Vec<F> = q.as_basis_coefficients_slice().to_vec();
let mut hints = std::collections::HashMap::new();
hints.insert("nd".to_string(), vec![nd_buf]);
hints.insert("q".to_string(), vec![q_buf]);
let witness = ExecutionWitness {
hints,
..ExecutionWitness::default()
};
let bytecode = compile_program(&ProgramSource::Raw(program.to_string()));
try_execute_bytecode(&bytecode, &[F::ZERO; PUBLIC_INPUT_LEN], &witness, false).unwrap();
}

fn test_data_dir() -> String {
Expand Down Expand Up @@ -134,7 +123,7 @@ fn test_all_programs() {
Ok(b) => b,
Err(err) => panic!("Program {} failed to compile: {:?}", path, err),
};
if let Err(err) = try_execute_bytecode(&bytecode, &[], &witness, false) {
if let Err(err) = try_execute_bytecode(&bytecode, &[F::ZERO; PUBLIC_INPUT_LEN], &witness, false) {
panic!("Program {} failed with error: {:?}", path, err);
}
}
Expand Down Expand Up @@ -176,7 +165,13 @@ def func(a, b):
return
"#;
let bytecode = compile_program(&ProgramSource::Raw(program.to_string()));
let n_cycles = execute_bytecode(&bytecode, &[], &ExecutionWitness::default(), false).n_cycles();
let n_cycles = execute_bytecode(
&bytecode,
&[F::ZERO; PUBLIC_INPUT_LEN],
&ExecutionWitness::default(),
false,
)
.n_cycles();
assert!(n_cycles < 1100);
}

Expand Down Expand Up @@ -205,10 +200,20 @@ def factorial(n):
let compiled_parallel = compile_program(&ProgramSource::Raw(program.replace("loop", "parallel_range")));

let time_sequential = Instant::now();
let exec_seq = execute_bytecode(&compiled_sequencial, &[], &ExecutionWitness::default(), false);
let exec_seq = execute_bytecode(
&compiled_sequencial,
&[F::ZERO; PUBLIC_INPUT_LEN],
&ExecutionWitness::default(),
false,
);
let duration_sequential = time_sequential.elapsed();
let time_parallel = Instant::now();
let exec_par = execute_bytecode(&compiled_parallel, &[], &ExecutionWitness::default(), false);
let exec_par = execute_bytecode(
&compiled_parallel,
&[F::ZERO; PUBLIC_INPUT_LEN],
&ExecutionWitness::default(),
false,
);
let duration_parallel = time_parallel.elapsed();

assert_eq!(exec_seq.metadata.stdout, exec_par.metadata.stdout);
Expand Down Expand Up @@ -249,7 +254,13 @@ fn test_soundness_suite() {
("soundness_5", &[3, 4, 7, 19, 49, 28, 1, 3], &[(0, 4), (1, 5), (2, 8), (3, 20), (4, 50), (5, 29), (6, 0), (6, 2), (7, 4)]),
];

let to_input = |v: &[u32]| v.iter().copied().map(F::new).collect::<Vec<_>>();
let to_input = |v: &[u32]| -> [F; PUBLIC_INPUT_LEN] {
let mut out = [F::ZERO; PUBLIC_INPUT_LEN];
for (slot, &x) in out.iter_mut().zip(v) {
*slot = F::new(x);
}
out
};

for &(name, valid, perturbations) in cases {
let path = format!("{}/{}.py", test_data_dir(), name);
Expand Down
2 changes: 1 addition & 1 deletion crates/lean_compiler/tests/test_data/error_13.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def main():
a: Imu
a: Imm
a = 0
a = a + 1
if a == 1:
Expand Down
2 changes: 1 addition & 1 deletion crates/lean_compiler/tests/test_data/error_7.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from snark_lib import *


# Error: inline functions with parameters: Mut are not supported
# Error: function parameters cannot be declared ': Mut'
def main():
return

Expand Down
4 changes: 2 additions & 2 deletions crates/lean_compiler/tests/test_data/program_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@


def main():
x: Imu
y: Imu
x: Imm
y: Imm

cond = 1
if cond == 1:
Expand Down
4 changes: 2 additions & 2 deletions crates/lean_compiler/tests/test_data/program_109.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ def main():
def test_func(a, b):
x = 1

mut_x_2: Imu
mut_x_2: Imm
match a:
case 0:
mut_x_1: Imu
mut_x_1: Imm
mut_x_1 = x + 2
match b:
case 0:
Expand Down
2 changes: 1 addition & 1 deletion crates/lean_compiler/tests/test_data/program_110.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def main():
result = complex_compute(3, 4, 5)
assert result == 47

fwd_val: Imu
fwd_val: Imm
cond = 1
if cond == 0:
fwd_val = 100
Expand Down
9 changes: 5 additions & 4 deletions crates/lean_compiler/tests/test_data/program_111.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,14 @@ def chain_compute(x, y):
a2, b2, c2 = step_compute(a1, b1)
return a2, b2, c1 + c2

def nested_mut_params(base: Mut):
def nested_mut_params(base):
acc: Mut = base
for i in unroll(0, 3):
base = base + i * 2
return base
acc = acc + i * 2
return acc

def state_machine_step(current_state, phase):
result: Imu
result: Imm
if phase == 0:
if current_state == 0:
result = 1
Expand Down
Loading
Loading