From f55beda47f69a2c412e11516c050bac01579a61f Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 27 May 2026 15:11:11 +0400 Subject: [PATCH 01/13] wip --- .../src/parser/parsers/function.rs | 3 + crates/lean_compiler/zkDSL.md | 734 +++++++++++------- 2 files changed, 465 insertions(+), 272 deletions(-) diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 40256ac9..8fb1cee6 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -15,11 +15,14 @@ pub const RESERVED_FUNCTION_NAMES: &[&str] = &[ // Built-in functions "print", "Array", + "hint_witness", // Compile-time only functions "len", "log2_ceil", "next_multiple_of", "saturating_sub", + "div_ceil", + "div_floor", "range", "parallel_range", "match_range", diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 5ea8d46a..4246b1bc 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -1,31 +1,58 @@ # zkDSL Language Reference -Warning: still under construction (i.e. it's messy). +The zkDSL is a Python-syntax language that compiles to leanVM bytecode (4 instructions ++ 2 precompile tables). It is restricted enough that every `.py` source file also +runs as plain Python (using `crates/lean_compiler/snark_lib.py` as a stub library), +which lets you sanity-check programs with a regular interpreter before compiling. -## Program Structure +Programs are organized as one or more `.py` files. The toplevel of each file is a +sequence of: + +1. `from import *` statements (optional) +2. Top-level constant declarations (optional) +3. Function definitions + +Execution starts at `def main(): ...`. ``` -from snark_lib import * # Python compatibility (ignored by compiler) -from dir.file import * # imports (optional, Python-style) -NAME = value # constants (optional, uppercase by convention) -def main(): # entry point (required) +from snark_lib import * # Python compatibility shim, stripped by the compiler +from dir.file import * # other .py files in the import root +from ..parent_module import * # parent-directory imports + +X = 42 # constants must come before functions +ARR = [1, 2, 3] + +def main(): # required entry point ... -def helper(): # other functions (optional) + +def helper(): # other functions ... ``` -The `from snark_lib import *` line imports Python definitions for zkDSL primitives (Array, Mut, Const, etc.), allowing `.py` files to be executed as normal Python scripts for testing. The zkDSL compiler ignores this import line. +The compiler strips the `from snark_lib import *` line (and only that line) so the +same source is valid Python. To run a `.py` file under regular Python for testing: -To run zkDSL files as Python scripts, run from the file's directory with PYTHONPATH pointing to the lean_compiler crate (for snark_lib.py): ```bash export PYTHONPATH=/path/to/repo/crates/lean_compiler -cd crates/lean_compiler/tests/test_data -python program_0.py +python program.py +``` + +## Imports + +``` +from utils import * # imports utils.py (resolved from the import root) +from dir.subdir.file import * # nested module +from ..module import * # parent-directory import (relative to current file) ``` +Imports are wildcard-only (`import *`). Each module is loaded once even if imported +multiple times; circular imports are detected and rejected. Constants with the same +name in two imported files cause a compile-time error. + ## Constants -Constants are declared at the top level (outside functions) using simple assignment. By convention, constant names are UPPERCASE. +Constants live at the top of the file, outside any function. By convention they are +UPPERCASE. ``` X = 42 @@ -33,32 +60,34 @@ ARR = [1, 2, 3] NESTED = [[1, 2], [3]] ``` -### Multi-Dimensional Const Arrays - -Const arrays can be nested to any depth, and inner arrays can have different lengths (ragged arrays). All const array values are resolved at compile time. +### Nested (multi-dimensional, possibly ragged) constant arrays ``` -MATRIX = [[1, 2, 3], [4, 5], [6, 7, 8, 9]] # ragged 2D array -DEEP = [[[1, 2], [3]], [[4, 5, 6]]] # 3D array +MATRIX = [[1, 2, 3], [4, 5], [6, 7, 8, 9]] +DEEP = [[[1, 2], [3]], [[4, 5, 6]]] ``` -**Accessing elements:** Use chained indexing with compile-time indices: +Indexed access uses chained subscripts at compile time: + ``` -x = MATRIX[0][2] # x = 3 -y = DEEP[1][0][1] # y = 5 +x = MATRIX[0][2] # 3 +y = DEEP[1][0][1] # 5 ``` -**Using `len()` on inner arrays:** The `len()` function can be applied to any level of a nested const array, including inner arrays accessed by index. This is particularly useful for iterating over ragged arrays where each row has a different length: +`len()` works at every depth, including on a row addressed by a constant index: ``` -len(MATRIX) # 3 -len(MATRIX[0]) # 3 -len(DEEP[0][0]) # 2 +len(MATRIX) # 3 +len(MATRIX[0]) # 3 +len(DEEP[0][0]) # 2 ``` -**Important:** When using `len()` on an inner array with a variable index (e.g., `len(ARR[i])`), the index must be a compile-time constant. This works inside `unroll` loops because the loop variable becomes a compile-time constant during unrolling. +When `len()` is applied with a variable index (`len(ARR[i])`), `i` must be a +compile-time constant. `: Const` parameters always qualify (see [Functions] +below), as do iterator variables of an `unroll` loop (see [For loops] below) — +those are the two ways to get a value the compiler can substitute at expansion +time. Example: iterating a ragged 2D table: -**Example: Iterating over a ragged 2D array:** ``` MATRIX = [[1, 2, 3], [4, 5], [6, 7, 8, 9]] @@ -67,17 +96,17 @@ def main(): for row in unroll(0, len(MATRIX)): for col in unroll(0, len(MATRIX[row])): total = total + MATRIX[row][col] - assert total == 45 # 1+2+3+4+5+6+7+8+9 + assert total == 45 return ``` ## Functions ``` -def add(a, b): # return count is inferred from return statements +def add(a, b): return a + b -def swap(a, b): # multiple return values +def swap(a, b): return b, a def main(): @@ -85,69 +114,80 @@ def main(): return ``` -The number of return values is automatically inferred from the `return` statements. All return statements in a function must return the same number of values. +Every function must contain at least one `return`. The compiler infers the number +of returned values from the `return` statements; all `return`s in a function must +agree. A function that "returns nothing" uses a bare `return`. -### Parameter Modifiers +### Parameter modifiers -| Syntax | Meaning | -| ---------- | --------------------------------------------------------- | -| `x` | immutable parameter | -| `x: Const` | compile-time value (enables `unroll` with dynamic bounds) | -| `x: Mut` | mutable within function body only | +| Syntax | Meaning | +| ---------- | --------------------------------------------------------------------------------- | +| `x` | normal (immutable) parameter | +| `x: Const` | compile-time-known value; enables `unroll`/array sizes that depend on the param | +| `x: Mut` | locally mutable parameter (reassignable inside the function — caller is unaffected) | -**All parameters are pass-by-value.** The `: Mut` modifier allows reassignment within the function, but changes are not visible to the caller. Use return values to communicate results. +All parameters are pass-by-value. Use return values to propagate results — there +are no out-parameters. ``` -def repeat(n: Const): # Const enables unroll +def repeat(n: Const): # Const enables unroll(0, n) sum: Mut = 0 for i in unroll(0, n): sum = sum + i return sum -def double(x: Mut): # Mut allows local reassignment - x = x * 2 # only affects local copy - return x # must return to pass result back +def double(x: Mut): # Mut: only the local copy is reassignable + x = x * 2 + return x ``` -### Inline Functions -Use the `@inline` decorator to mark functions for inlining at call sites: +### Inline functions + +`@inline` expands a function at every call site instead of generating a call +instruction. Useful for small helpers, and for cases where the body must "see" the +caller's `: Const` context. + ``` @inline def square(x): return x * x ``` -**Note:** Inline functions cannot have `: Mut` parameters. -**Note:** Inline functions support at most one `return`, and it must be at the -top level of the body — never nested inside an `if`, loop, or `match`. Early or -conditional returns are rejected by the compiler, because inlining expands each -`return` into a plain assignment with no control flow. Use a regular (non-inline) -function — with `: Const` parameters if you need compile-time specialization — -when you need a conditional return. +Constraints on inline functions: + +- No `: Mut` parameters allowed. +- Exactly one `return`, placed at the top level of the body — not nested inside + `if`, a loop, or `match`. Inlining rewrites the `return` into a plain + assignment, so early or conditional returns cannot be expressed. + +If you need conditional returns, use a normal (non-`@inline`) function. Combine +it with `: Const` parameters when you need compile-time specialization at the +call site. ## Variables | Declaration | Mutability | Notes | | ------------- | ---------- | ---------------------------------------------- | | `x = 10` | immutable | cannot be reassigned | -| `x: Mut = 10` | mutable | can be reassigned | -| `x: Imu` | immutable | forward declaration, assign exactly once later | -| `x: Mut` | mutable | forward declaration for mutable variable | +| `x: Mut = 10` | mutable | reassignable | +| `x: Imu` | immutable | forward declaration; assign exactly once later | +| `x: Mut` | mutable | forward declaration; reassignable later | -### Forward Declarations +### Forward declarations -Use `x: Imu` when a variable must be assigned in different branches: +Use `x: Imu` when you want an immutable binding but the value comes from a +branch: ``` -result: Imu # immutable: assign exactly once +result: Imu if cond == 1: result = 10 else: result = 20 -# result cannot be reassigned after this +# result is now immutable ``` -Use `x: Mut` when you need the variable to be mutable after assignment: +Use `x: Mut` when you want to keep mutating the variable after the branch: ``` x: Mut @@ -155,51 +195,55 @@ if cond == 1: x = 10 else: x = 20 -x = x + 1 # OK: x was declared as mutable +x = x + 1 # OK: x is mutable ``` -### Tuple Assignments with Mutable Variables +### Mutability inside tuple assignments -When a function returns multiple values and some need to be mutable, use forward declarations: +To make a single component of a tuple-return mutable, forward-declare it: ``` -b: Mut # declare b as mutable +b: Mut a, b, c = some_function() -# a and c are immutable, b is mutable -b = b + 1 # OK -# a = 5 # ERROR: a is immutable +b = b + 1 # OK +# a = 5 # ERROR: a is immutable ``` -This is useful when a function returns multiple values and only some need to be modified later. - -## Memory and Arrays +## Memory and arrays ``` -buffer = Array(16) # allocate 16 field elements +buffer = Array(16) # allocate 16 field elements buffer[0] = 42 x = buffer[5] -matrix = Array(64) # 2D via manual indexing +matrix = Array(64) # 2D via manual indexing matrix[row * 8 + col] = value -ptr2 = ptr + 5 # pointer arithmetic -ptr2[0] = 100 # same as ptr[5] = 100 +ptr2 = buffer + 5 # pointer arithmetic +ptr2[0] = 100 # same as buffer[5] = 100 ``` -**Memory is write-once.** Due to SSA constraints, each memory location can only hold one value. Writing to the same location multiple times is allowed, but all writes must produce the same value—otherwise a runtime error occurs. +`Array(n)` returns a pointer to a freshly allocated block of `n` field +elements. `n` may be a compile-time constant (the common case) or a runtime +value; the runner handles both. Memory is **write-once**: a cell may be +written more than once only if all writes store the same value. The second +write of a different value is a runtime error at the point of the write. ``` arr = Array(3) -arr[0] = 10 # OK: first write -arr[0] = 10 # OK: same value -arr[0] = 20 # ERROR: different value at same location +arr[0] = 10 +arr[0] = 10 # OK: same value +arr[0] = 20 # ERROR: conflicting write ``` -Use `mut` variables when you need mutability, the compiler cannot handle mutability on hand-written allocated memory ("Array(...)"). +`Array` cells are not implicitly mutable — if you need a running accumulator, +use `x: Mut` for the variable and only commit final values to memory. Pointer +arithmetic (`ptr + offset`) is the way to address into sub-regions. + +## Control flow -## Control Flow +### `if` / `elif` / `else` -### If/Else ``` if x == 0: y = 1 @@ -208,10 +252,14 @@ elif x == 1: else: y = 3 ``` -Comparison operators: `==`, `!=` -### Match -Patterns must be consecutive integers: +Comparison operators on conditions: `==`, `!=`, `<`, `<=`. There is **no** `>` +or `>=` — flip the operands to get the same effect. + +### `match` + +Patterns must be a contiguous run of integers: + ``` match value: case 5: @@ -222,16 +270,24 @@ match value: result = 700 ``` -### match_range +The matched value must lie inside the listed range; out-of-range values produce +undefined behaviour. Use a `debug_assert` (or `assert`, if you want it to be +enforced by the proof) to guard the input. + +### `match_range` -Compile-time construct that expands into a match statement, useful for dispatching to functions with const parameters based on runtime values. Results are always immutable. +`match_range` is the workhorse for *dispatching a runtime value to a const- +parameter function*. It is a compile-time construct that expands into a +forward-declared variable plus a `match` over a contiguous range of integers. ``` result = match_range(n, range(1, 5), lambda i: compute(i)) ``` -Expands to: + +expands to + ``` -result: Imu # auto-generated forward declaration (always immutable) +result: Imu match n: case 1: result = compute(1) case 2: result = compute(2) @@ -239,58 +295,71 @@ match n: case 4: result = compute(4) ``` -**Multiple continuous ranges** with different lambdas: +You can chain several `(range, lambda)` pairs, provided the ranges are +**contiguous** (the end of one is the start of the next): + ``` -result = match_range(n, - range(0, 1), lambda i: special_case(), - range(1, 8), lambda i: normal_case(i)) +result = match_range( + n, + range(0, 1), lambda i: special_case(), + range(1, 8), lambda i: normal_case(i), +) ``` -Expands to a match where case 0 uses `special_case()` and cases 1-7 use `normal_case(i)`. -Ranges must be continuous (end of one equals start of next). +Multiple return values are supported via tuple unpacking. The bindings produced +by `match_range` are always immutable — forward-declare with `: Mut` (and then +reassign) if you need them mutable later: -**Multiple return values:** ``` a, b = match_range(n, range(0, 4), lambda i: two_values(i)) ``` -**Common use case:** Dispatching runtime values to const-parameter functions: +Idiomatic use — dispatching a runtime length to a function that requires a +compile-time length: + ``` def helper_const(n: Const): - # function that requires compile-time n return n * n def compute(value): - result = match_range(value, range(0, 10), lambda i: helper_const(i)) - return result + debug_assert(value < 10) + return match_range(value, range(0, 10), lambda i: helper_const(i)) ``` -**IMPORTANT:** For both `match` and `match_range`, the programmer must ensure the value is within the specified range. Out-of-range values cause undefined behavior. Use `debug_assert` to validate: -``` -debug_assert(n < 10) -debug_assert(0 < n) -result = match_range(n, range(1, 10), lambda i: compute(i)) -``` +**Range validity is the caller's job.** A `match_range` whose input falls +outside any listed range is undefined behaviour at runtime — always pair it +with a `debug_assert` (or `assert`, if you want the proof to enforce it) on the +dispatched value. Skipping this guard is by far the most common source of +silent bugs in zkDSL. -### For Loops -``` -for i in range(0, 10): # standard loop - ... -for i in parallel_range(0, n): # iterations executed in parallel (see below) - ... -for i in unroll(0, 4): # unrolled at compile time - ... -``` -Use `unroll` when bounds are const or compile-time expansion is needed. +### For loops + +Three loop forms, all written `for i in (start, end):`. Bounds and +behaviour: + +| Loop form | When | +| ---------------------------- | ------------------------------------------------------------------------- | +| `for i in range(a, b):` | Runtime loop. Compiled into a recursive function (no `break`/`continue`). | +| `for i in unroll(a, b):` | Compile-time expansion; `a` and `b` must both be compile-time constants. | +| `for i in parallel_range(a, b):` | Runtime loop; iterations are executed in parallel by the runner via rayon. | + +`parallel_range` requires the loop body to be iteration-independent. The +runner executes the first iteration sequentially to learn its memory footprint, +then runs the rest of the iterations concurrently — so anything cross-iteration +must hold a-priori, since there is no synchronization: + +- No `Mut` variables carried across iterations (each iteration writes only to + its own call frame and to addresses disjoint from every other iteration). +- Identical memory footprint per iteration. +- Identical hint consumption per iteration (witness hints, XMSS-specific + decomposition hints, Merkle hints, etc.). -**`parallel_range`** executes iterations concurrently using rayon. The produced bytecode is identical to `range`. Constraints: -- The loop body must be **iteration-independent**: no `Mut` variables carried - across iterations. Each iteration may only write to its own frame and to - external addresses that do not affect other iterations . -- The memory footprint (i.e. total memory usage) must be the same across iterations -- XMSS / Merkle hint consumption must be the same across iterations +These constraints are **not** checked at compile time. Violating them produces +silently wrong proofs. -**Mutable variables in non-unrolled loops:** Mutable variables can be modified inside non-unrolled loops. The compiler automatically transforms these into buffer-based implementations: +Mutable variables inside non-unrolled loops are supported transparently — the +compiler inserts a buffer array, stores per-iteration values into it, and reads +the final value back after the loop: ``` sum: Mut = 0 @@ -299,143 +368,239 @@ for i in range(1, 11): assert sum == 55 ``` -Loops limitations: -- no "continue" or "break" are supported yet -- the "return" keyword is not supported inside the body of a normal (non-unrolled) loop (because under the hood normal loops are transformed into recursive functions) +Loop limitations (current): + +- No `break` or `continue` (these forms are not in the grammar). +- No `return` inside the body of a non-unrolled loop (because such loops are + lowered to recursive functions). The compiler emits "Function return inside + a loop is not currently supported" if you try. + +### Statements without effect are rejected + +Every line must either be a declaration, an assignment, a control-flow form, an +assertion, a `return`, or a side-effecting call (`hint_witness`, precompile, +`print`, or a function call). A bare expression like `x + 1` on its own line is +a compile error. ## Expressions ### Arithmetic -- `+`, `-`, `*`, `/` (field operations): allowed at runtime -- `%` (modulo), `**` (exponentiation): only allowed at compile time -### Compound Assignment -Syntactic sugar for updating mutable variables: +`+`, `-`, `*`, `/` are field operations and work at runtime. + +`%` (modulo) and `**` (exponentiation) are **compile-time only** — both operands +must be constants known at compile time. + +### Compound assignment + ``` x: Mut = 10 -x += 5 # equivalent to: x = x + 5 -x -= 3 # equivalent to: x = x - 3 -x *= 2 # equivalent to: x = x * 2 -x /= 4 # equivalent to: x = x / 4 +x += 5 # x = x + 5 +x -= 3 # x = x - 3 +x *= 2 # x = x * 2 +x /= 4 # x = x / 4 ``` -### Built-in Functions -Only allowed at compile time: +Only a single target is allowed on the LHS of a compound assignment. + +### Compile-time built-ins + +These functions are evaluated at compile time only — their arguments must be +constants: ``` -log2_ceil(x) # ceiling of log2 -next_multiple_of(x, n) # smallest multiple of n >= x -div_ceil(a, b) # ceiling division: (a + b - 1) // b -div_floor(a, b) # floor division: a // b +log2_ceil(x) # ceil(log2(x)) +next_multiple_of(x, n) # smallest multiple of n that is >= x +div_ceil(a, b) # (a + b - 1) // b +div_floor(a, b) # a // b saturating_sub(a, b) # max(0, a - b) -len(array) # length of const array or vector +len(array) # length of a constant array (any depth) +``` + +### Reserved names + +These identifiers cannot be redefined as user functions, because the parser or +compiler intercepts calls to them: + +- Built-ins: `print`, `Array`, `len`, `hint_witness` +- Compile-time math: `log2_ceil`, `next_multiple_of`, `saturating_sub`, + `div_ceil`, `div_floor` +- Loop / control-flow forms: `range`, `parallel_range`, `match_range` +- Custom hints: every `hint_*` name (see [Hints] below) +- Poseidon16 precompiles: `poseidon16_compress`, `poseidon16_compress_half`, + `poseidon16_compress_hardcoded_left`, + `poseidon16_compress_half_hardcoded_left`, `poseidon16_permute` +- Extension-op precompiles: `add_ee`, `add_be`, `dot_product_ee`, + `dot_product_be`, `poly_eq_ee`, `poly_eq_be` + +### `_` (the discard target) + +Inside a tuple-unpacking LHS, `_` discards the value at that position. The +compiler rewrites each `_` to a fresh anonymous name so they don't collide. + +``` +_, b = swap(a, b) # only keep b +_ = compute() # discard a single return value ``` ## Assertions ``` -# constraint in proof +# Snark constraint (enforced by the proof) assert x == y assert x != y -assert x < y +assert x < y assert x <= y -# unconditional failure (panic) + +# Unconditional failure (compiles to a Panic) assert False -assert False, "error message" -# runtime check only (not constrained by the snark) +assert False, "human-readable message" + +# Runtime-only check; not part of the constraint system debug_assert(x == y) debug_assert(x != y) -debug_assert(x < y) +debug_assert(x < y) debug_assert(x <= y) ``` +`debug_assert` is for invariants the prover must respect but that the verifier +doesn't need to re-check — typically range-validity preconditions for `match` / +`match_range` dispatches. + +### Range checks: `assert a < b` and `assert a <= b` + +A signed inequality is implemented using DEREF (memory-access soundness on a +read-only memory of size `<= 2^MIN_LOG_MEMORY_SIZE`). The compiler automatically +emits the necessary helper hints, but **the right-hand side `b` must fit in +`2^16` (MIN_LOG_MEMORY_SIZE bits)** for the constraint to be sound. Compare +against larger constants by decomposing the value into bits first. + ## Comments ``` -# Single-line comment +# single-line comment """ -Multi-line comment -can span multiple lines +block comment """ ``` -## Imports +Both forms are stripped before the grammar runs. There is no docstring concept — +a `"""..."""` block is purely a comment. -``` -from utils import * # imports utils.py (relative to import root) -from dir.subdir.file import * # imports dir/subdir/file.py -``` +## Line continuation -## Memory Layout +As in Python: -The runner places the program's memory as: +- **Implicit** continuation inside `(...)`, `[...]`, or `{...}`. +- **Explicit** continuation with `\` at end of line. ``` -[ public_input | preamble_memory | runtime ] +result = function_call( + arg1, + arg2, + arg3, +) + +ARR = [ + 1, + 2, + 3, +] + +x = very_long_function_name(arg1, \ + arg2, \ + arg3) ``` -- `public_input` lives at `memory[0..public_input.len()]` (zero-padded to a power of two by the runner so it can be evaluated as a multilinear polynomial). -- `preamble_memory` is a region the runner reserves but does not initialize. The guest program is responsible for writing any constants it needs (e.g. `ZERO_VEC_PTR`, `ONE_EF_PTR`, etc.) in this area. +## Hints (prover-supplied data) + +A hint is data the *prover* writes into memory without adding any constraint — +the program must still constrain the written value if it wants the verifier to +believe anything about it. There are two flavours of hint: + +### `hint_witness("name", ptr)` -Prover-supplied witness data is fetched on demand with `hint_witness("name", ptr)`, where the string literal -names an entry in the witness's `hints: HashMap>>` map and -`ptr` is a caller-allocated buffer. Each call writes the next unused `Vec` -under that name (per-name running index) into the buffer at `ptr`. The guest -is responsible for allocating `ptr` with enough room; the witness's length is -trusted. -`hint_witness` +Pulls the next chunk of witness data registered under the string label `name`, +and writes it into the buffer at `ptr`. Witness data lives in the +`ExecutionWitness::hints: HashMap>>` map (each name has a +list of byte-buffers, consumed in order). The guest is responsible for +allocating `ptr` large enough; the length is implicit and trusted. ``` data_buf = Array(64) -hint_witness("input_data", data_buf) # writes next `input_data` entry into data_buf +hint_witness("input_data", data_buf) n = data_buf[0] -# ... ``` -### Built-in Hints +### Custom hints -hints = prover-supplied values at runtime (without adding snark constraints). Like `hint_witness`, they are bare statements (no return value) — the caller allocates any destination memory and is responsible for constraining the written values. +Each hint has a fixed argument count and writes its result(s) into caller-provided +buffers. The hint *suggests* a value — your program must add the constraints +that bind the value to its specification. -| Hint | Signature | Writes | -| --------------------------------- | --------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | -| `hint_decompose_bits` | `(to_decompose, ptr, num_bits, endianness)` | `num_bits` field elements at `ptr` (the 0/1 bit decomposition of `to_decompose`); `endianness` is `0` for big-endian, `1` for little-endian | -| `hint_less_than` | `(a, b, result_ptr)` | `1` at `result_ptr` if `a < b` else `0` | -| `hint_log2_ceil` | `(n, result_ptr)` | `ceil(log2(n))` at `result_ptr` | -| `hint_div_floor` | `(a, b, q_ptr, r_ptr)` | `floor(a/b)` at `q_ptr` and `a mod b` at `r_ptr` (requires `b != 0`) | -| `hint_decompose_bits_xmss` | `(decomposed_ptr, remaining_ptr, to_decompose_ptr, num_to_decompose, chunk_size)` | XMSS-specific decomposition (see `crates/lean_vm/src/isa/hint.rs`) | -| `hint_decompose_bits_merkle_whir` | `(decomposed_ptr, remaining_ptr, value, chunk_size)` | Merkle/WHIR-specific decomposition | +| Hint | Arguments | Effect | +| --------------------------------- | --------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- | +| `hint_decompose_bits` | `(value, ptr, n_bits)` | Writes `n_bits` big-endian 0/1 field elements at `ptr` (MSB at `ptr[0]`). Requires `n_bits <= 31`. | +| `hint_decompose_bits_merkle_whir` | `(decomposed_ptr, value, chunk_size)` | Writes `24 / chunk_size` little-endian `chunk_size`-bit chunks of `value` at `decomposed_ptr` (`chunk_size` must divide 24). | +| `hint_decompose_bits_xmss` | `(decomposed_ptr, to_decompose_ptr, num_to_decompose, chunk_size)` | For each of `num_to_decompose` values at `to_decompose_ptr[..]`, writes its `24 / chunk_size` little-endian chunks at `decomposed_ptr`. | +| `hint_less_than` | `(a, b, result_ptr)` | `1` at `result_ptr` if `a < b` (canonical integer compare), else `0`. | +| `hint_log2_ceil` | `(n, result_ptr)` | `ceil(log2(n))` at `result_ptr`. | +| `hint_div_floor` | `(a, b, q_ptr, r_ptr)` | `floor(a / b)` at `q_ptr`, `a mod b` at `r_ptr` (requires `b != 0`). | -Hints only *suggest* a value; the guest must add appropriate constraints to bind that value to its specification. +## Precompiles +### Poseidon16 family -## Precompiles +leanVM has one Poseidon2 width-16 precompile table; the zkDSL exposes five +specializations that all hit the same table. -### poseidon16_compress -Always in "compression" mode ``` poseidon16_compress(left, right, output) ``` -- `left`: pointer to 8 field elements -- `right`: pointer to 8 field elements -- `res`: pointer to result (8 elements) -### Extension Operations +Standard compression: writes the 8-cell compressed output of `Poseidon2(left || right) + left` +to `m[output..output+8]`. `left` and `right` are 8-cell buffers; `output` is an +8-cell destination. + +``` +poseidon16_compress_half(left, right, output) +``` + +Same as `poseidon16_compress`, but only the first 4 output cells are +constrained — `output[4..8]` is unconstrained. Useful when the consumer only +cares about half of the digest. + +``` +poseidon16_compress_hardcoded_left(left, right, output, offset) +``` + +Like `poseidon16_compress`, except the first 4 cells of the *left* input are +read from the **compile-time** address `offset` instead of `m[left..left+4]`. +The remaining 4 cells of the left input still come from `m[left..left+4]`. Used +e.g. for XMSS Merkle hashing where one half of the input is the public parameter +(stored at a fixed address). -Six built-in functions route through a single `extension_op` precompile table. Each combines an element-wise operation with an accumulation over `length` element pairs. +``` +poseidon16_compress_half_hardcoded_left(left, right, output, offset) +``` + +Composition of `_compress_half` and `_compress_hardcoded_left`: hardcoded left +prefix at `offset`, only the first 4 output cells constrained. ``` -func(ptr_a, ptr_b, ptr_result) # length defaults to 1 -func(ptr_a, ptr_b, ptr_result, length) # explicit length (N elements) +poseidon16_permute(left, right, output) ``` -**Operand types (suffix):** -- `_ee`: both `ptr_a` and `ptr_b` point to extension field elements (5 consecutive field elements each, stride = DIM) -- `_be`: `ptr_a` points to base field elements (stride 1), `ptr_b` points to extension field elements (stride DIM) +Raw Poseidon2 permutation (no feed-forward addition). Writes the full 16 output +cells to `m[output..output+16]` in natural order. Used for the Fiat-Shamir +sponge. -`ptr_result` always points to a single extension field element (DIM=5 field elements). +### Extension field operations -**Operations:** +Six built-in functions all route through one `extension_op` precompile table. +Each combines a fixed element-wise operation with an accumulation over `length` +element pairs: | Function | Element-wise | Accumulation | | ----------------------------------- | --------------------------------- | -------------------- | @@ -443,40 +608,50 @@ func(ptr_a, ptr_b, ptr_result, length) # explicit length (N elements) | `dot_product_ee` / `dot_product_be` | `e_i = a_i * b_i` | `result = sum(e_i)` | | `poly_eq_ee` / `poly_eq_be` | `e_i = a_i*b_i + (1-a_i)*(1-b_i)` | `result = prod(e_i)` | -**Note:** `length` must be a compile-time constant. For runtime-known lengths, use `match_range` to dispatch (see example below). +``` +func(ptr_a, ptr_b, ptr_result) # length defaults to 1 +func(ptr_a, ptr_b, ptr_result, length) # explicit length (N element pairs) +``` + +**Operand suffix:** + +- `_ee`: both `ptr_a` and `ptr_b` point to *extension* field elements (5 base-field + cells each, stride `DIM = 5`). +- `_be`: `ptr_a` points to *base* field elements (stride 1); `ptr_b` points to + *extension* field elements (stride `DIM = 5`). + +`ptr_result` always points to a single extension-field element (5 cells). +**`length` must be a compile-time constant.** For a runtime length, dispatch +through `match_range`: + +``` +def dot_product_ee_dynamic(a, b, res, n): + debug_assert(n <= 256) + match_range(n, range(1, 257), lambda i: dot_product_ee(a, b, res, i)) ``` -# Multiply two extension field elements (length=1, default) -dot_product_ee(x, y, z) # z = x * y -# Copy extension element (multiply by [1,0,0,0,0]). -# `ONE_EF_PTR` is a guest-program constant that the program must materialize -# in its preamble memory at startup; see `crates/rec_aggregation/zkdsl_implem/utils.py` -# for an example (`build_preamble_memory`). +Common idioms: + +``` +# Multiply two extension elements (length defaults to 1) +dot_product_ee(x, y, z) # z = x * y + +# Copy an extension element by multiplying by [1, 0, 0, 0, 0] +# ONE_EF_PTR is a guest-program constant that you materialize in the preamble dot_product_ee(src, ONE_EF_PTR, dst) -# Dot product of N extension field elements +# Dot products dot_product_ee(coeffs, basis, result, N) - -# Dot product with base-field scalars dot_product_be(alpha_powers, coeffs, result, N) -# Extension field addition: c = a + b -add_ee(a, b, c) - -# Extension field subtraction via constraint: c = a - b <=> b + c = a -add_ee(b, c, a) +# Extension addition / subtraction +add_ee(a, b, c) # c = a + b +add_ee(b, c, a) # c = a - b, expressed as a constraint (b + c = a) # Equality polynomial: eq(a, b) = a*b + (1-a)*(1-b) poly_eq_ee(a, b, eq_result) - -# Multi-point equality polynomial: prod_{i=0}^{n-1} eq(a[i], b[i]) -poly_eq_ee(a, b, result, n) - -# Runtime-known length via match_range -def dot_product_ee_dynamic(a, b, res, n): - debug_assert(n <= 256) - match_range(n, range(1, 257), lambda i: dot_product_ee(a, b, res, i)) +poly_eq_ee(a, b, result, n) # multi-point eq: prod_i eq(a[i], b[i]) ``` ## Debugging @@ -486,7 +661,60 @@ print(value) print(a, b, c) ``` -## Example +`print` flushes its output during execution; **a Rust-side panic mid-program drops +buffered prints**. When you need a print to survive a panic, temporarily change +the print hint in `lean_vm/src/isa/hint.rs (Self::Print)` to `eprint!` directly. + +## Memory layout + +The runner lays out memory as + +``` +[ public_input (zero-padded) | preamble_memory | runtime ] +``` + +- `public_input` lives at `memory[0..public_input.len()]` and is zero-padded to + the next power of two by the runner, so it can be evaluated as a multilinear + polynomial. +- `preamble_memory` is a region of `witness.preamble_memory_len` cells the + runner reserves but does **not** initialize. The guest program is expected + to fill this region with whatever helper constants it relies on (e.g. a + vector of zeros for `dot_product_ee`-as-copy, an extension-field one for + multiply-by-one tricks, a vector of ones for batched accumulations, …) at + the start of `main`. The names and offsets of these constants are not part + of the VM contract — each program defines its own. See + `crates/rec_aggregation/zkdsl_implem/utils.py (build_preamble_memory)` for + a concrete example. +- The runtime region holds the program's stack frames, working memory, and any + prover-supplied witness data, all governed by the write-once rule. + +## Tips and gotchas + +1. Prefer `unroll` over `range` for small, fixed-size loops — no buffer + bookkeeping, no recursive-function overhead. +2. Reach for `: Const` parameters when the function body needs `unroll` over the + parameter, or when array sizes depend on it. +3. `if` / `elif` branches that assign to the same outer variable should + forward-declare it (`x: Imu` or `x: Mut`) before the branch. +4. **`match`** / **`match_range`** dispatch is undefined for out-of-range + values — always pair it with a `debug_assert` (or `assert`) on the value. +5. `match` patterns must be contiguous integers; if you need gaps, restructure + into an `if` chain or pad with an empty arm. +6. `assert a < b` and `assert a <= b` are range-checked under the assumption + that `b <= 2^MIN_LOG_MEMORY_SIZE = 2^16`. Larger comparisons must be done + with explicit bit decomposition (`hint_decompose_bits` + manual checks). +7. Inline functions cannot have `: Mut` parameters and cannot return + conditionally — use a regular function for those cases. +8. `parallel_range` requires per-iteration determinism in memory and hints; a + single divergent iteration breaks proving. +9. **A variable that's assigned inside an `if` nested in an `unroll` loop may + silently fail to remain in scope after the loop.** When you're dispatching + over per-iteration compile-time constants, prefer a flat top-level + `if`/`elif` chain (one branch per iteration value) over `unroll` + nested + `if`. This affects compile-time dispatch only; runtime `if` inside `range` + loops is unaffected. + +## A simple example ``` SIZE = 8 @@ -506,54 +734,13 @@ def compute_sum(ptr, n: Const): return acc ``` -## Line Continuation - -Like Python, lines can be continued in two ways: +## Worked example: sugar -> ISA -### Implicit continuation (inside parentheses/brackets/braces) +This shows how the front-end normalizes a small program with mutable variables in +a runtime loop down to a form close to the ISA. The compiler does this +automatically; you don't have to write the intermediate forms. -Expressions inside `()`, `[]`, or `{}` can span multiple lines without any special syntax: - -``` -result = function_call( - arg1, - arg2, - arg3 -) - -ARR = [ - 1, - 2, - 3, -] -``` - -### Explicit continuation with backslash - -Long lines can also be split using `\` at the end of a line: - -``` -x = very_long_function_name(arg1, \ - arg2, \ - arg3) - -y = 1 + 2 + \ - 3 + 4 -``` - -The `\` and following newline are replaced with a single space. Any whitespace after `\` and before the newline is ignored. - -## Tips - -1. Use `unroll` for small, fixed-size loops -2. Use `const` parameters when loop bounds depend on arguments -3. Use `mut` sparingly - immutable is easier to verify -4. Use `x: Imu` or `x: Mut` for forward-declaring variables that will be assigned in branches -5. Match patterns must be consecutive integers (can start from any value) - -## Example: From high level syntactic sugar to minimal ISA, with read-only memory - -Take the following program: +Starting program: ``` def main(): @@ -571,7 +758,8 @@ def main(): return ``` -First, we use buffers to handle mutable variables across (non-unrolled) loops. +Step 1 — replace mutable-across-loop variables with index buffers, since memory +is write-once: ``` def main(): @@ -602,8 +790,7 @@ def main(): return ``` -Then, use auxiliary variables to transform it into SSA form (Static Single-Assignment): - +Step 2 — SSA-rename all reassignments to fresh names: ``` def main(): @@ -634,7 +821,7 @@ def main(): return ``` -Finally, transform the loop into a recursive function: +Step 3 — lower the runtime loop to a recursive function: ``` def main(): @@ -674,14 +861,17 @@ def loop(i, x_buff, y_buff): ## Dev experience -If using VScode, add the following to your local settings `.vscode/settings.json` : +For Python tooling/linting on zkDSL files (which import `snark_lib` at the top), +point your editor at the compiler crate. With VSCode: ```json -{ - "python.analysis.extraPaths": [ - "./crates/lean_compiler" - ], +{ + "python.analysis.extraPaths": [ + "./crates/lean_compiler" + ] } ``` -(you will get better linting for the zkDSL files starting with `from snark_lib import *`, since it will expose zkDSL special functions from `crates/lean_compiler/snark_lib.py`). \ No newline at end of file +This makes the stubs in `crates/lean_compiler/snark_lib.py` visible to your +language server, so completion / type-checks light up correctly inside `.py` +zkDSL sources. From 47c01af71afe8c11713800ae8ebdbbb78b261c4e Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 27 May 2026 15:28:39 +0400 Subject: [PATCH 02/13] wip --- crates/lean_compiler/zkDSL.md | 178 ++++++++++++++++------------------ 1 file changed, 84 insertions(+), 94 deletions(-) diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 4246b1bc..3a447f96 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -14,7 +14,7 @@ sequence of: Execution starts at `def main(): ...`. -``` +```python from snark_lib import * # Python compatibility shim, stripped by the compiler from dir.file import * # other .py files in the import root from ..parent_module import * # parent-directory imports @@ -39,7 +39,7 @@ python program.py ## Imports -``` +```python from utils import * # imports utils.py (resolved from the import root) from dir.subdir.file import * # nested module from ..module import * # parent-directory import (relative to current file) @@ -54,7 +54,7 @@ name in two imported files cause a compile-time error. Constants live at the top of the file, outside any function. By convention they are UPPERCASE. -``` +```python X = 42 ARR = [1, 2, 3] NESTED = [[1, 2], [3]] @@ -62,21 +62,21 @@ NESTED = [[1, 2], [3]] ### Nested (multi-dimensional, possibly ragged) constant arrays -``` +```python MATRIX = [[1, 2, 3], [4, 5], [6, 7, 8, 9]] DEEP = [[[1, 2], [3]], [[4, 5, 6]]] ``` Indexed access uses chained subscripts at compile time: -``` +```python x = MATRIX[0][2] # 3 y = DEEP[1][0][1] # 5 ``` `len()` works at every depth, including on a row addressed by a constant index: -``` +```python len(MATRIX) # 3 len(MATRIX[0]) # 3 len(DEEP[0][0]) # 2 @@ -88,7 +88,7 @@ below), as do iterator variables of an `unroll` loop (see [For loops] below) — those are the two ways to get a value the compiler can substitute at expansion time. Example: iterating a ragged 2D table: -``` +```python MATRIX = [[1, 2, 3], [4, 5], [6, 7, 8, 9]] def main(): @@ -102,7 +102,7 @@ def main(): ## Functions -``` +```python def add(a, b): return a + b @@ -120,16 +120,16 @@ agree. A function that "returns nothing" uses a bare `return`. ### Parameter modifiers -| Syntax | Meaning | -| ---------- | --------------------------------------------------------------------------------- | -| `x` | normal (immutable) parameter | -| `x: Const` | compile-time-known value; enables `unroll`/array sizes that depend on the param | +| Syntax | Meaning | +| ---------- | ----------------------------------------------------------------------------------- | +| `x` | normal (immutable) parameter | +| `x: Const` | compile-time-known value; enables `unroll`/array sizes that depend on the param | | `x: Mut` | locally mutable parameter (reassignable inside the function — caller is unaffected) | All parameters are pass-by-value. Use return values to propagate results — there are no out-parameters. -``` +```python def repeat(n: Const): # Const enables unroll(0, n) sum: Mut = 0 for i in unroll(0, n): @@ -147,7 +147,7 @@ def double(x: Mut): # Mut: only the local copy is reassignable instruction. Useful for small helpers, and for cases where the body must "see" the caller's `: Const` context. -``` +```python @inline def square(x): return x * x @@ -178,7 +178,7 @@ call site. Use `x: Imu` when you want an immutable binding but the value comes from a branch: -``` +```python result: Imu if cond == 1: result = 10 @@ -189,7 +189,7 @@ else: Use `x: Mut` when you want to keep mutating the variable after the branch: -``` +```python x: Mut if cond == 1: x = 10 @@ -202,7 +202,7 @@ x = x + 1 # OK: x is mutable To make a single component of a tuple-return mutable, forward-declare it: -``` +```python b: Mut a, b, c = some_function() b = b + 1 # OK @@ -211,7 +211,7 @@ b = b + 1 # OK ## Memory and arrays -``` +```python buffer = Array(16) # allocate 16 field elements buffer[0] = 42 x = buffer[5] @@ -229,7 +229,7 @@ value; the runner handles both. Memory is **write-once**: a cell may be written more than once only if all writes store the same value. The second write of a different value is a runtime error at the point of the write. -``` +```python arr = Array(3) arr[0] = 10 arr[0] = 10 # OK: same value @@ -244,7 +244,7 @@ arithmetic (`ptr + offset`) is the way to address into sub-regions. ### `if` / `elif` / `else` -``` +```python if x == 0: y = 1 elif x == 1: @@ -260,14 +260,11 @@ or `>=` — flip the operands to get the same effect. Patterns must be a contiguous run of integers: -``` +```python match value: - case 5: - result = 500 - case 6: - result = 600 - case 7: - result = 700 + case 5: result = 500 + case 6: result = 600 + case 7: result = 700 ``` The matched value must lie inside the listed range; out-of-range values produce @@ -280,13 +277,13 @@ enforced by the proof) to guard the input. parameter function*. It is a compile-time construct that expands into a forward-declared variable plus a `match` over a contiguous range of integers. -``` +```python result = match_range(n, range(1, 5), lambda i: compute(i)) ``` expands to -``` +```python result: Imu match n: case 1: result = compute(1) @@ -298,26 +295,24 @@ match n: You can chain several `(range, lambda)` pairs, provided the ranges are **contiguous** (the end of one is the start of the next): -``` -result = match_range( - n, - range(0, 1), lambda i: special_case(), - range(1, 8), lambda i: normal_case(i), -) +```python +result = match_range(n, + range(0, 1), lambda i: special_case(), + range(1, 8), lambda i: normal_case(i)) ``` Multiple return values are supported via tuple unpacking. The bindings produced by `match_range` are always immutable — forward-declare with `: Mut` (and then reassign) if you need them mutable later: -``` +```python a, b = match_range(n, range(0, 4), lambda i: two_values(i)) ``` Idiomatic use — dispatching a runtime length to a function that requires a compile-time length: -``` +```python def helper_const(n: Const): return n * n @@ -337,10 +332,10 @@ silent bugs in zkDSL. Three loop forms, all written `for i in (start, end):`. Bounds and behaviour: -| Loop form | When | -| ---------------------------- | ------------------------------------------------------------------------- | -| `for i in range(a, b):` | Runtime loop. Compiled into a recursive function (no `break`/`continue`). | -| `for i in unroll(a, b):` | Compile-time expansion; `a` and `b` must both be compile-time constants. | +| Loop form | When | +| -------------------------------- | -------------------------------------------------------------------------- | +| `for i in range(a, b):` | Runtime loop. Compiled into a recursive function (no `break`/`continue`). | +| `for i in unroll(a, b):` | Compile-time expansion; `a` and `b` must both be compile-time constants. | | `for i in parallel_range(a, b):` | Runtime loop; iterations are executed in parallel by the runner via rayon. | `parallel_range` requires the loop body to be iteration-independent. The @@ -361,7 +356,7 @@ Mutable variables inside non-unrolled loops are supported transparently — the compiler inserts a buffer array, stores per-iteration values into it, and reads the final value back after the loop: -``` +```python sum: Mut = 0 for i in range(1, 11): sum += i @@ -393,7 +388,7 @@ must be constants known at compile time. ### Compound assignment -``` +```python x: Mut = 10 x += 5 # x = x + 5 x -= 3 # x = x - 3 @@ -408,7 +403,7 @@ Only a single target is allowed on the LHS of a compound assignment. These functions are evaluated at compile time only — their arguments must be constants: -``` +```python log2_ceil(x) # ceil(log2(x)) next_multiple_of(x, n) # smallest multiple of n that is >= x div_ceil(a, b) # (a + b - 1) // b @@ -438,29 +433,34 @@ compiler intercepts calls to them: Inside a tuple-unpacking LHS, `_` discards the value at that position. The compiler rewrites each `_` to a fresh anonymous name so they don't collide. -``` +```python _, b = swap(a, b) # only keep b _ = compute() # discard a single return value ``` ## Assertions -``` -# Snark constraint (enforced by the proof) +Snark constraints (enforced by the proof): + +```python assert x == y assert x != y assert x < y assert x <= y +``` + +Unconditional failure (compiles to a Panic): -# Unconditional failure (compiles to a Panic) +```python assert False assert False, "human-readable message" +``` + +Runtime-only checks; not part of the constraint system. Same four comparison +operators (`==`, `!=`, `<`, `<=`): -# Runtime-only check; not part of the constraint system -debug_assert(x == y) -debug_assert(x != y) -debug_assert(x < y) -debug_assert(x <= y) +```python +debug_assert(x < y) ``` `debug_assert` is for invariants the prover must respect but that the verifier @@ -477,7 +477,7 @@ against larger constants by decomposing the value into bits first. ## Comments -``` +```python # single-line comment """ @@ -495,22 +495,12 @@ As in Python: - **Implicit** continuation inside `(...)`, `[...]`, or `{...}`. - **Explicit** continuation with `\` at end of line. -``` -result = function_call( - arg1, - arg2, - arg3, -) - -ARR = [ - 1, - 2, - 3, -] - -x = very_long_function_name(arg1, \ - arg2, \ - arg3) +```python +result = function_call(arg1, + arg2, + arg3) # implicit continuation inside parens +y = 1 + 2 + \ + 3 + 4 # explicit continuation with backslash ``` ## Hints (prover-supplied data) @@ -527,7 +517,7 @@ and writes it into the buffer at `ptr`. Witness data lives in the list of byte-buffers, consumed in order). The guest is responsible for allocating `ptr` large enough; the length is implicit and trusted. -``` +```python data_buf = Array(64) hint_witness("input_data", data_buf) n = data_buf[0] @@ -539,14 +529,14 @@ Each hint has a fixed argument count and writes its result(s) into caller-provid buffers. The hint *suggests* a value — your program must add the constraints that bind the value to its specification. -| Hint | Arguments | Effect | -| --------------------------------- | --------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- | -| `hint_decompose_bits` | `(value, ptr, n_bits)` | Writes `n_bits` big-endian 0/1 field elements at `ptr` (MSB at `ptr[0]`). Requires `n_bits <= 31`. | -| `hint_decompose_bits_merkle_whir` | `(decomposed_ptr, value, chunk_size)` | Writes `24 / chunk_size` little-endian `chunk_size`-bit chunks of `value` at `decomposed_ptr` (`chunk_size` must divide 24). | -| `hint_decompose_bits_xmss` | `(decomposed_ptr, to_decompose_ptr, num_to_decompose, chunk_size)` | For each of `num_to_decompose` values at `to_decompose_ptr[..]`, writes its `24 / chunk_size` little-endian chunks at `decomposed_ptr`. | -| `hint_less_than` | `(a, b, result_ptr)` | `1` at `result_ptr` if `a < b` (canonical integer compare), else `0`. | -| `hint_log2_ceil` | `(n, result_ptr)` | `ceil(log2(n))` at `result_ptr`. | -| `hint_div_floor` | `(a, b, q_ptr, r_ptr)` | `floor(a / b)` at `q_ptr`, `a mod b` at `r_ptr` (requires `b != 0`). | +| Hint | Arguments | Effect | +| --------------------------------- | ------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------- | +| `hint_decompose_bits` | `(value, ptr, n_bits)` | Writes `n_bits` big-endian 0/1 field elements at `ptr` (MSB at `ptr[0]`). Requires `n_bits <= 31`. | +| `hint_decompose_bits_merkle_whir` | `(decomposed_ptr, value, chunk_size)` | Writes `24 / chunk_size` little-endian `chunk_size`-bit chunks of `value` at `decomposed_ptr` (`chunk_size` must divide 24). | +| `hint_decompose_bits_xmss` | `(decomposed_ptr, to_decompose_ptr, num_to_decompose, chunk_size)` | For each of `num_to_decompose` values at `to_decompose_ptr[..]`, writes its `24 / chunk_size` little-endian chunks at `decomposed_ptr`. | +| `hint_less_than` | `(a, b, result_ptr)` | `1` at `result_ptr` if `a < b` (canonical integer compare), else `0`. | +| `hint_log2_ceil` | `(n, result_ptr)` | `ceil(log2(n))` at `result_ptr`. | +| `hint_div_floor` | `(a, b, q_ptr, r_ptr)` | `floor(a / b)` at `q_ptr`, `a mod b` at `r_ptr` (requires `b != 0`). | ## Precompiles @@ -555,7 +545,7 @@ that bind the value to its specification. leanVM has one Poseidon2 width-16 precompile table; the zkDSL exposes five specializations that all hit the same table. -``` +```python poseidon16_compress(left, right, output) ``` @@ -563,7 +553,7 @@ Standard compression: writes the 8-cell compressed output of `Poseidon2(left || to `m[output..output+8]`. `left` and `right` are 8-cell buffers; `output` is an 8-cell destination. -``` +```python poseidon16_compress_half(left, right, output) ``` @@ -571,7 +561,7 @@ Same as `poseidon16_compress`, but only the first 4 output cells are constrained — `output[4..8]` is unconstrained. Useful when the consumer only cares about half of the digest. -``` +```python poseidon16_compress_hardcoded_left(left, right, output, offset) ``` @@ -581,14 +571,14 @@ The remaining 4 cells of the left input still come from `m[left..left+4]`. Used e.g. for XMSS Merkle hashing where one half of the input is the public parameter (stored at a fixed address). -``` +```python poseidon16_compress_half_hardcoded_left(left, right, output, offset) ``` Composition of `_compress_half` and `_compress_hardcoded_left`: hardcoded left prefix at `offset`, only the first 4 output cells constrained. -``` +```python poseidon16_permute(left, right, output) ``` @@ -608,7 +598,7 @@ element pairs: | `dot_product_ee` / `dot_product_be` | `e_i = a_i * b_i` | `result = sum(e_i)` | | `poly_eq_ee` / `poly_eq_be` | `e_i = a_i*b_i + (1-a_i)*(1-b_i)` | `result = prod(e_i)` | -``` +```python func(ptr_a, ptr_b, ptr_result) # length defaults to 1 func(ptr_a, ptr_b, ptr_result, length) # explicit length (N element pairs) ``` @@ -625,7 +615,7 @@ func(ptr_a, ptr_b, ptr_result, length) # explicit length (N element pairs) **`length` must be a compile-time constant.** For a runtime length, dispatch through `match_range`: -``` +```python def dot_product_ee_dynamic(a, b, res, n): debug_assert(n <= 256) match_range(n, range(1, 257), lambda i: dot_product_ee(a, b, res, i)) @@ -633,7 +623,7 @@ def dot_product_ee_dynamic(a, b, res, n): Common idioms: -``` +```python # Multiply two extension elements (length defaults to 1) dot_product_ee(x, y, z) # z = x * y @@ -656,7 +646,7 @@ poly_eq_ee(a, b, result, n) # multi-point eq: prod_i eq(a[i], b[i]) ## Debugging -``` +```python print(value) print(a, b, c) ``` @@ -669,7 +659,7 @@ the print hint in `lean_vm/src/isa/hint.rs (Self::Print)` to `eprint!` directly. The runner lays out memory as -``` +```python [ public_input (zero-padded) | preamble_memory | runtime ] ``` @@ -716,7 +706,7 @@ The runner lays out memory as ## A simple example -``` +```python SIZE = 8 def main(): @@ -742,7 +732,7 @@ automatically; you don't have to write the intermediate forms. Starting program: -``` +```python def main(): x: Mut = 0 y: Mut = 3 @@ -761,7 +751,7 @@ def main(): Step 1 — replace mutable-across-loop variables with index buffers, since memory is write-once: -``` +```python def main(): x: Mut = 0 y: Mut = 3 @@ -792,7 +782,7 @@ def main(): Step 2 — SSA-rename all reassignments to fresh names: -``` +```python def main(): x = 0 y = 3 @@ -823,7 +813,7 @@ def main(): Step 3 — lower the runtime loop to a recursive function: -``` +```python def main(): x = 0 y = 3 From 814955cb46a5348fcf45dc0f16779340e55241c8 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 27 May 2026 15:50:50 +0400 Subject: [PATCH 03/13] w --- crates/lean_compiler/zkDSL.md | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 3a447f96..b4452acd 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -1,9 +1,10 @@ # zkDSL Language Reference -The zkDSL is a Python-syntax language that compiles to leanVM bytecode (4 instructions -+ 2 precompile tables). It is restricted enough that every `.py` source file also -runs as plain Python (using `crates/lean_compiler/snark_lib.py` as a stub library), -which lets you sanity-check programs with a regular interpreter before compiling. +The zkDSL is a Python-syntax language that compiles to leanVM bytecode (4 basic instructions and 2 special ones (precompile): poseidon / extension operations). + +Source files use the `.py` extension. They are **not** currently runnable as +real Python, but the syntax is kept Python-compatible so that one day they +could be (TODO). Programs are organized as one or more `.py` files. The toplevel of each file is a sequence of: @@ -15,12 +16,14 @@ sequence of: Execution starts at `def main(): ...`. ```python -from snark_lib import * # Python compatibility shim, stripped by the compiler -from dir.file import * # other .py files in the import root -from ..parent_module import * # parent-directory imports +from snark_lib import * # only there to keep the Python linter happy; stripped by the zkDSL compiler +from utils import * # import other file X = 42 # constants must come before functions -ARR = [1, 2, 3] +# array constants (or arbitrary dimmensions: 1D, 2D, etc) +ARR_1D = [1, 2, 3] +ARR_2D = [[1, 2, 3], [], [10, 4]] +ARR_3D = [[[1, 2, 3], [7, 8], [9]], [], [[10], [10, 4]]] def main(): # required entry point ... @@ -29,14 +32,6 @@ def helper(): # other functions ... ``` -The compiler strips the `from snark_lib import *` line (and only that line) so the -same source is valid Python. To run a `.py` file under regular Python for testing: - -```bash -export PYTHONPATH=/path/to/repo/crates/lean_compiler -python program.py -``` - ## Imports ```python @@ -46,13 +41,12 @@ from ..module import * # parent-directory import (relative to current ``` Imports are wildcard-only (`import *`). Each module is loaded once even if imported -multiple times; circular imports are detected and rejected. Constants with the same +multiple times; circular imports are rejected. Constants with the same name in two imported files cause a compile-time error. ## Constants -Constants live at the top of the file, outside any function. By convention they are -UPPERCASE. +Constants live at the top of the file, outside any function. ```python X = 42 @@ -123,7 +117,7 @@ agree. A function that "returns nothing" uses a bare `return`. | Syntax | Meaning | | ---------- | ----------------------------------------------------------------------------------- | | `x` | normal (immutable) parameter | -| `x: Const` | compile-time-known value; enables `unroll`/array sizes that depend on the param | +| `x: Const` | compile-time-known value| | `x: Mut` | locally mutable parameter (reassignable inside the function — caller is unaffected) | All parameters are pass-by-value. Use return values to propagate results — there From 93a8d8315acdc2470d541bbcf62db143aa6cb978 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 27 May 2026 16:59:10 +0400 Subject: [PATCH 04/13] zkDSL: remove 'mutable' function arguments --- .../lean_compiler/src/a_simplify_lang/mod.rs | 16 +---- crates/lean_compiler/src/lang.rs | 6 -- .../src/parser/parsers/function.rs | 22 +++---- .../lean_compiler/tests/test_data/error_7.py | 2 +- .../tests/test_data/program_111.py | 7 ++- .../tests/test_data/program_144.py | 7 ++- .../tests/test_data/program_170.py | 2 +- .../tests/test_data/program_43.py | 9 +-- .../tests/test_data/program_57.py | 27 +++++---- .../tests/test_data/program_99.py | 7 ++- crates/lean_compiler/zkDSL.md | 26 +++++---- .../rec_aggregation/zkdsl_implem/recursion.py | 10 +++- crates/rec_aggregation/zkdsl_implem/whir.py | 58 ++++++++++++------- 13 files changed, 106 insertions(+), 93 deletions(-) diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs index 9894ac16..96c1e657 100644 --- a/crates/lean_compiler/src/a_simplify_lang/mod.rs +++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs @@ -310,21 +310,14 @@ pub fn simplify_program(mut program: Program) -> Result { let mut array_manager = ArrayManager::default(); let mut mut_tracker = MutableVarTracker::default(); - // Register mutable arguments and capture their initial versioned names - // BEFORE simplifying the body + // All arguments are immutable; record them as assigned to detect illegal reassignment. let arguments: Vec = func .arguments .iter() .map(|arg| { assert!(!arg.is_const); - if arg.is_mutable { - mut_tracker.register_mutable(&arg.name); - // Capture the initial versioned name (version 0) - mut_tracker.current_name(&arg.name) - } else { - mut_tracker.assigned.insert(arg.name.clone()); - arg.name.clone() - } + mut_tracker.assigned.insert(arg.name.clone()); + arg.name.clone() }) .collect(); @@ -376,9 +369,6 @@ fn compile_time_transform_in_program( .collect(); for func in inlined_functions.values() { - if func.has_mutable_arguments() { - return Err("Inlined functions with mutable arguments are not supported yet".to_string()); - } if func.has_const_arguments() { return Err(format!( "Inlined function should not have \"Const\" arguments (function \"{}\")", diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index 938aa024..4ceaead1 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -32,7 +32,6 @@ impl Program { pub struct FunctionArg { pub name: Var, pub is_const: bool, - pub is_mutable: bool, } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] @@ -48,9 +47,6 @@ impl Function { pub fn has_const_arguments(&self) -> bool { self.arguments.iter().any(|arg| arg.is_const) } - pub fn has_mutable_arguments(&self) -> bool { - self.arguments.iter().any(|arg| arg.is_mutable) - } } pub type Var = String; @@ -902,8 +898,6 @@ impl Display for Function { .map(|arg| { if arg.is_const { format!("const {}", arg.name) - } else if arg.is_mutable { - format!("mut {}", arg.name) } else { arg.name.to_string() } diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 8fb1cee6..11f2a6ba 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -142,22 +142,24 @@ impl Parse for ParameterParser { let mut inner = pair.into_inner(); let name = next_inner_pair(&mut inner, "parameter name")?.as_str().to_string(); - // Check for optional type annotation (: Const or : Mut) - let (is_const, is_mutable) = if let Some(annotation) = inner.next() { + // Check for optional type annotation (: Const). ': Mut' parameters are forbidden. + let is_const = if let Some(annotation) = inner.next() { match annotation.as_str().trim() { - ": Const" => (true, false), - ": Mut" => (false, true), + ": Const" => true, + ": Mut" => { + return Err(SemanticError::new(format!( + "Parameter '{name}' cannot be declared ': Mut'. Mutable parameters are not allowed; \ + introduce a local '{name}_mut: Mut = {name}' instead." + )) + .into()); + } other => return Err(SemanticError::new(format!("Invalid parameter annotation: {other}")).into()), } } else { - (false, false) + false }; - Ok(FunctionArg { - name, - is_const, - is_mutable, - }) + Ok(FunctionArg { name, is_const }) } } diff --git a/crates/lean_compiler/tests/test_data/error_7.py b/crates/lean_compiler/tests/test_data/error_7.py index 94c262a6..e4435908 100644 --- a/crates/lean_compiler/tests/test_data/error_7.py +++ b/crates/lean_compiler/tests/test_data/error_7.py @@ -1,7 +1,7 @@ from snark_lib import * -# Error: inline functions with parameters: Mut are not supported +# Error: function parameters cannot be declared ': Mut' def main(): return diff --git a/crates/lean_compiler/tests/test_data/program_111.py b/crates/lean_compiler/tests/test_data/program_111.py index e497afdb..e13d3e32 100644 --- a/crates/lean_compiler/tests/test_data/program_111.py +++ b/crates/lean_compiler/tests/test_data/program_111.py @@ -78,10 +78,11 @@ def chain_compute(x, y): a2, b2, c2 = step_compute(a1, b1) return a2, b2, c1 + c2 -def nested_mut_params(base: Mut): +def nested_mut_params(base): + acc: Mut = base for i in unroll(0, 3): - base = base + i * 2 - return base + acc = acc + i * 2 + return acc def state_machine_step(current_state, phase): result: Imu diff --git a/crates/lean_compiler/tests/test_data/program_144.py b/crates/lean_compiler/tests/test_data/program_144.py index f6b60dff..b9dcd456 100644 --- a/crates/lean_compiler/tests/test_data/program_144.py +++ b/crates/lean_compiler/tests/test_data/program_144.py @@ -374,9 +374,10 @@ def test_multi_return(flag): # Helper function for TEST 22 -def func_with_mut_param(x: Mut, flag): +def func_with_mut_param(x, flag): + y: Mut = x if flag == 1: - x = x * 10 + y = y * 10 else: assert False - return x + return y diff --git a/crates/lean_compiler/tests/test_data/program_170.py b/crates/lean_compiler/tests/test_data/program_170.py index 96c00012..f5652fe7 100644 --- a/crates/lean_compiler/tests/test_data/program_170.py +++ b/crates/lean_compiler/tests/test_data/program_170.py @@ -15,7 +15,7 @@ def multi_return(a, b): def multi_line_params( a, - b: Mut, + b, c: Const, ): return a + b + c diff --git a/crates/lean_compiler/tests/test_data/program_43.py b/crates/lean_compiler/tests/test_data/program_43.py index aed3a00b..d0199de4 100644 --- a/crates/lean_compiler/tests/test_data/program_43.py +++ b/crates/lean_compiler/tests/test_data/program_43.py @@ -7,7 +7,8 @@ def main(): return -def increment_twice(x: Mut): - x = x + 1 - x = x + 1 - return x +def increment_twice(x): + y: Mut = x + y = y + 1 + y = y + 1 + return y diff --git a/crates/lean_compiler/tests/test_data/program_57.py b/crates/lean_compiler/tests/test_data/program_57.py index 3d4965c4..4a96412e 100644 --- a/crates/lean_compiler/tests/test_data/program_57.py +++ b/crates/lean_compiler/tests/test_data/program_57.py @@ -10,19 +10,22 @@ def main(): return -def step1(n: Mut): - n = n * 2 - n = n + 1 - return n +def step1(n): + m: Mut = n + m = m * 2 + m = m + 1 + return m -def step2(n: Mut): - n = n * 3 - n = n + 2 - return n +def step2(n): + m: Mut = n + m = m * 3 + m = m + 2 + return m -def step3(n: Mut): - n = n * 4 - n = n + 3 - return n +def step3(n): + m: Mut = n + m = m * 4 + m = m + 3 + return m diff --git a/crates/lean_compiler/tests/test_data/program_99.py b/crates/lean_compiler/tests/test_data/program_99.py index 40d61c71..66b2c119 100644 --- a/crates/lean_compiler/tests/test_data/program_99.py +++ b/crates/lean_compiler/tests/test_data/program_99.py @@ -7,7 +7,8 @@ def main(): return -def accumulate(x: Mut): +def accumulate(x): + acc: Mut = x for i in unroll(0, 3): - x = x + i - return x + acc = acc + i + return acc diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index b4452acd..3cb2126e 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -114,14 +114,14 @@ agree. A function that "returns nothing" uses a bare `return`. ### Parameter modifiers -| Syntax | Meaning | -| ---------- | ----------------------------------------------------------------------------------- | -| `x` | normal (immutable) parameter | -| `x: Const` | compile-time-known value| -| `x: Mut` | locally mutable parameter (reassignable inside the function — caller is unaffected) | +| Syntax | Meaning | +| ---------- | ------------------------ | +| `x` | normal (immutable) parameter | +| `x: Const` | compile-time-known value | -All parameters are pass-by-value. Use return values to propagate results — there -are no out-parameters. +All parameters are pass-by-value and immutable inside the function — use return +values to propagate results (there are no out-parameters). If you need a locally +mutable copy of a parameter, introduce a `: Mut` local at the top of the body: ```python def repeat(n: Const): # Const enables unroll(0, n) @@ -130,9 +130,10 @@ def repeat(n: Const): # Const enables unroll(0, n) sum = sum + i return sum -def double(x: Mut): # Mut: only the local copy is reassignable - x = x * 2 - return x +def double(x): # parameter is immutable; shadow with a local + y: Mut = x + y = y * 2 + return y ``` ### Inline functions @@ -149,7 +150,6 @@ def square(x): Constraints on inline functions: -- No `: Mut` parameters allowed. - Exactly one `return`, placed at the top level of the body — not nested inside `if`, a loop, or `match`. Inlining rewrites the `return` into a plain assignment, so early or conditional returns cannot be expressed. @@ -687,7 +687,9 @@ The runner lays out memory as 6. `assert a < b` and `assert a <= b` are range-checked under the assumption that `b <= 2^MIN_LOG_MEMORY_SIZE = 2^16`. Larger comparisons must be done with explicit bit decomposition (`hint_decompose_bits` + manual checks). -7. Inline functions cannot have `: Mut` parameters and cannot return +7. Function parameters are always immutable. To mutate a parameter's value + inside a function, introduce a local `: Mut` alias at the top of the body + (e.g. `y: Mut = x`). Inline functions additionally cannot return conditionally — use a regular function for those cases. 8. `parallel_range` requires per-iteration determinism in memory and hints; a single divergent iteration breaks proving. diff --git a/crates/rec_aggregation/zkdsl_implem/recursion.py b/crates/rec_aggregation/zkdsl_implem/recursion.py index 3b949c44..487eddb3 100644 --- a/crates/rec_aggregation/zkdsl_implem/recursion.py +++ b/crates/rec_aggregation/zkdsl_implem/recursion.py @@ -610,7 +610,8 @@ def fingerprint_n(domsep, data_evals, n, logup_alphas_eq_poly): return res -def verify_gkr_quotient(fs: Mut, n_vars): +def verify_gkr_quotient(prev_fs, n_vars): + fs: Mut = prev_fs fs, nums = fs_receive_ef_inlined(fs, LOGUP_GKR_N_COEFFS_SENT) fs, denoms = fs_receive_ef_inlined(fs, LOGUP_GKR_N_COEFFS_SENT) @@ -653,13 +654,16 @@ def verify_gkr_quotient(fs: Mut, n_vars): ) -def verify_gkr_quotient_step(fs: Mut, n_vars, point, claim_num, claim_den): +def verify_gkr_quotient_step(prev_fs, n_vars, point, claim_num, claim_den): + fs: Mut = prev_fs fs = fs_duplex(fs) fs, alpha = fs_sample_ef(fs) alpha_mul_claim_den = mul_extension_ret(alpha, claim_den) num_plus_alpha_mul_claim_den = add_extension_ret(claim_num, alpha_mul_claim_den) postponed_point = Array((n_vars + 1) * DIM) - fs, postponed_value = sumcheck_verify_reversed_helper(fs, n_vars, num_plus_alpha_mul_claim_den, 3, postponed_point) + fs, postponed_value = sumcheck_verify_reversed_helper( + fs, n_vars, num_plus_alpha_mul_claim_den, 3, postponed_point + ) fs, inner_evals = fs_receive_ef_inlined(fs, 4) a_num = inner_evals b_num = inner_evals + DIM diff --git a/crates/rec_aggregation/zkdsl_implem/whir.py b/crates/rec_aggregation/zkdsl_implem/whir.py index 3124f253..a3c84cb8 100644 --- a/crates/rec_aggregation/zkdsl_implem/whir.py +++ b/crates/rec_aggregation/zkdsl_implem/whir.py @@ -16,14 +16,17 @@ def whir_open( - fs: Mut, + prev_fs, n_vars, initial_log_inv_rate, - root: Mut, + prev_root, ood_points_commit, combination_randomness_powers_0, - claimed_sum: Mut, + prev_claimed_sum, ): + fs: Mut = prev_fs + root: Mut = prev_root + claimed_sum: Mut = prev_claimed_sum n_rounds, n_final_vars, num_queries, num_oods, query_grinding_bits, folding_grinding = get_whir_params( n_vars, initial_log_inv_rate ) @@ -175,13 +178,15 @@ def whir_open( return fs, folding_randomness_global, s, final_value, end_sum -def sumcheck_verify(fs: Mut, n_steps, claimed_sum, degree: Const): +def sumcheck_verify(fs, n_steps, claimed_sum, degree: Const): challenges = Array(n_steps * DIM) - fs, new_claimed_sum = sumcheck_verify_helper(fs, n_steps, claimed_sum, degree, challenges) - return fs, challenges, new_claimed_sum + new_fs, new_claimed_sum = sumcheck_verify_helper(fs, n_steps, claimed_sum, degree, challenges) + return new_fs, challenges, new_claimed_sum -def sumcheck_verify_helper(fs: Mut, n_steps, claimed_sum: Mut, degree: Const, challenges): +def sumcheck_verify_helper(prev_fs, n_steps, prev_claimed_sum, degree: Const, challenges): + fs: Mut = prev_fs + claimed_sum: Mut = prev_claimed_sum for sc_round in range(0, n_steps): fs, poly = fs_receive_ef_inlined(fs, degree + 1) polynomial_sum_at_0_and_1(poly, degree, claimed_sum) @@ -192,10 +197,10 @@ def sumcheck_verify_helper(fs: Mut, n_steps, claimed_sum: Mut, degree: Const, ch return fs, claimed_sum -def sumcheck_verify_reversed(fs: Mut, n_steps, claimed_sum: Mut, degree: Const): +def sumcheck_verify_reversed(fs, n_steps, claimed_sum, degree: Const): challenges = Array(n_steps * DIM) - fs, new_claimed_sum = sumcheck_verify_reversed_helper(fs, n_steps, claimed_sum, degree, challenges) - return fs, challenges, new_claimed_sum + new_fs, final_claimed_sum = sumcheck_verify_reversed_helper(fs, n_steps, claimed_sum, degree, challenges) + return new_fs, challenges, final_claimed_sum def sumcheck_verify_reversed_helper(fs, n_steps, claimed_sum, degree: Const, challenges): @@ -208,7 +213,9 @@ def sumcheck_verify_reversed_helper(fs, n_steps, claimed_sum, degree: Const, cha return new_fd, final_sum -def sumcheck_verify_reversed_helper_const(fs: Mut, n_steps: Const, claimed_sum: Mut, degree: Const, challenges): +def sumcheck_verify_reversed_helper_const(prev_fs, n_steps: Const, prev_claimed_sum, degree: Const, challenges): + fs: Mut = prev_fs + claimed_sum: Mut = prev_claimed_sum for sc_round in unroll(0, n_steps): fs, poly = fs_receive_ef_inlined(fs, degree + 1) polynomial_sum_at_0_and_1(poly, degree, claimed_sum) @@ -219,7 +226,9 @@ def sumcheck_verify_reversed_helper_const(fs: Mut, n_steps: Const, claimed_sum: return fs, claimed_sum -def sumcheck_verify_with_grinding(fs: Mut, n_steps, claimed_sum: Mut, degree: Const, folding_grinding_bits): +def sumcheck_verify_with_grinding(prev_fs, n_steps, prev_claimed_sum, degree: Const, folding_grinding_bits): + fs: Mut = prev_fs + claimed_sum: Mut = prev_claimed_sum challenges = Array(n_steps * DIM) for sc_round in range(0, n_steps): fs, poly = fs_receive_ef_inlined(fs, degree + 1) @@ -285,7 +294,7 @@ def decompose_and_verify_merkle_batch_const( def sample_stir_indexes_and_fold( - fs: Mut, + prev_fs, num_queries, merkle_leaves_in_basefield, folding_factor, @@ -295,6 +304,7 @@ def sample_stir_indexes_and_fold( folding_randomness, query_grinding_bits, ): + fs: Mut = prev_fs folded_domain_size = domain_size - folding_factor fs = fs_grinding(fs, query_grinding_bits) @@ -335,7 +345,7 @@ def sample_stir_indexes_and_fold( def whir_round( - fs: Mut, + prev_fs, prev_root, folding_factor, two_pow_folding_factor, @@ -347,6 +357,7 @@ def whir_round( num_ood, folding_grinding_bits, ): + fs: Mut = prev_fs fs, folding_randomness, new_claimed_sum_a = sumcheck_verify_with_grinding( fs, folding_factor, claimed_sum, 2, folding_grinding_bits ) @@ -398,21 +409,24 @@ def polynomial_sum_at_0_and_1(coeffs, degree, dst): return -def parse_commitment(fs: Mut, num_ood): +def parse_commitment(fs, num_ood): root: Imu ood_points: Imu ood_evals: Imu debug_assert(num_ood < 5) debug_assert(num_ood != 0) - fs, root, ood_points, ood_evals = match_range(num_ood, range(1, 5), lambda n: parse_whir_commitment_const(fs, n)) - return fs, root, ood_points, ood_evals + new_fs, root, ood_points, ood_evals = match_range( + num_ood, range(1, 5), lambda n: parse_whir_commitment_const(fs, n) + ) + return new_fs, root, ood_points, ood_evals -def parse_whir_commitment_const(fs: Mut, num_ood: Const): - fs, root = fs_receive_chunks(fs, 1) - fs, ood_points = fs_sample_many_ef(fs, num_ood) - fs, ood_evals = fs_receive_ef_inlined(fs, num_ood) - return fs, root, ood_points, ood_evals +def parse_whir_commitment_const(fs, num_ood: Const): + new_fs: Mut + new_fs, root = fs_receive_chunks(fs, 1) + new_fs, ood_points = fs_sample_many_ef(new_fs, num_ood) + new_fs, ood_evals = fs_receive_ef_inlined(new_fs, num_ood) + return new_fs, root, ood_points, ood_evals @inline From 6101d1d0c5e096c4be882df4999f9df9f586368c Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 27 May 2026 17:39:00 +0400 Subject: [PATCH 05/13] w --- crates/lean_compiler/zkDSL.md | 32 ++++++++++---------------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 3cb2126e..30eb9f00 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -78,9 +78,9 @@ len(DEEP[0][0]) # 2 When `len()` is applied with a variable index (`len(ARR[i])`), `i` must be a compile-time constant. `: Const` parameters always qualify (see [Functions] -below), as do iterator variables of an `unroll` loop (see [For loops] below) — -those are the two ways to get a value the compiler can substitute at expansion -time. Example: iterating a ragged 2D table: +below), as do iterator variables of an `unroll` loop (see [For loops] below). + +Example: iterating a ragged 2D table: ```python MATRIX = [[1, 2, 3], [4, 5], [6, 7, 8, 9]] @@ -112,16 +112,13 @@ Every function must contain at least one `return`. The compiler infers the numbe of returned values from the `return` statements; all `return`s in a function must agree. A function that "returns nothing" uses a bare `return`. -### Parameter modifiers + +### Parameter types | Syntax | Meaning | | ---------- | ------------------------ | -| `x` | normal (immutable) parameter | -| `x: Const` | compile-time-known value | - -All parameters are pass-by-value and immutable inside the function — use return -values to propagate results (there are no out-parameters). If you need a locally -mutable copy of a parameter, introduce a `: Mut` local at the top of the body: +| `x` | normal (immutable) runtime parameter | +| `x: Const` | compile-time parameter | ```python def repeat(n: Const): # Const enables unroll(0, n) @@ -138,9 +135,8 @@ def double(x): # parameter is immutable; shadow with a local ### Inline functions -`@inline` expands a function at every call site instead of generating a call -instruction. Useful for small helpers, and for cases where the body must "see" the -caller's `: Const` context. +`@inline` expands a function at every call site instead of generating a JUMP +instruction to another part of the bytecode. Useful for performance (calling a function costs a few cycles). ```python @inline @@ -148,15 +144,7 @@ def square(x): return x * x ``` -Constraints on inline functions: - -- Exactly one `return`, placed at the top level of the body — not nested inside - `if`, a loop, or `match`. Inlining rewrites the `return` into a plain - assignment, so early or conditional returns cannot be expressed. - -If you need conditional returns, use a normal (non-`@inline`) function. Combine -it with `: Const` parameters when you need compile-time specialization at the -call site. +Constraints on inline functions (compiler limitations): Exactly one `return`, placed as the last statement of the body, not nested inside `if`, a loop, or `match`. Inlining rewrites the `return` into a plain assignment in place, so early or conditional returns cannot be expressed. ## Variables From 895bfb14840e02890b6a4e30d1e0a706c12d150f Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 27 May 2026 18:24:48 +0400 Subject: [PATCH 06/13] zkDSL: rename "Imu" -> "Imm" --- crates/lean_compiler/snark_lib.py | 2 +- crates/lean_compiler/src/grammar.pest | 4 +- crates/lean_compiler/src/lang.rs | 2 +- .../src/parser/parsers/statement.rs | 4 +- .../lean_compiler/tests/test_data/error_13.py | 2 +- .../tests/test_data/program_100.py | 4 +- .../tests/test_data/program_109.py | 4 +- .../tests/test_data/program_110.py | 2 +- .../tests/test_data/program_111.py | 2 +- .../tests/test_data/program_112.py | 18 +-- .../tests/test_data/program_114.py | 6 +- .../tests/test_data/program_143.py | 14 +- .../tests/test_data/program_144.py | 50 +++--- .../tests/test_data/program_17.py | 2 +- .../tests/test_data/program_170.py | 4 +- .../tests/test_data/program_171.py | 144 +++++++++--------- .../tests/test_data/program_172.py | 2 +- .../tests/test_data/program_174.py | 10 +- .../tests/test_data/program_67.py | 2 +- .../tests/test_data/program_68.py | 4 +- .../tests/test_data/program_69.py | 18 +-- .../tests/test_data/program_83.py | 2 +- .../tests/test_data/soundness_2.py | 4 +- .../tests/test_data/soundness_5.py | 2 +- crates/lean_compiler/zkDSL.md | 33 ++-- .../zkdsl_implem/fiat_shamir.py | 4 +- .../rec_aggregation/zkdsl_implem/hashing.py | 4 +- .../rec_aggregation/zkdsl_implem/recursion.py | 2 +- crates/rec_aggregation/zkdsl_implem/utils.py | 6 +- crates/rec_aggregation/zkdsl_implem/whir.py | 16 +- 30 files changed, 182 insertions(+), 191 deletions(-) diff --git a/crates/lean_compiler/snark_lib.py b/crates/lean_compiler/snark_lib.py index f11c8138..bec59983 100644 --- a/crates/lean_compiler/snark_lib.py +++ b/crates/lean_compiler/snark_lib.py @@ -6,7 +6,7 @@ # Type annotations Mut = Any Const = Any -Imu = Any +Imm = Any # @inline decorator (does nothing in Python execution) diff --git a/crates/lean_compiler/src/grammar.pest b/crates/lean_compiler/src/grammar.pest index 16b041a6..6640d2de 100644 --- a/crates/lean_compiler/src/grammar.pest +++ b/crates/lean_compiler/src/grammar.pest @@ -49,9 +49,9 @@ return_statement = { "return" ~ (("(" ~ tuple_expression ~ ")") | tuple_expressi mut_keyword = @{ "mut" ~ !(ASCII_ALPHANUMERIC | "_") } mut_annotation = { ":" ~ "Mut" } -im_annotation = { ":" ~ "Imu" } +im_annotation = { ":" ~ "Imm" } -// Forward declaration: x: Imu or x: Mut (not followed by =) +// Forward declaration: x: Imm or x: Mut (not followed by =) forward_declaration = { identifier ~ (im_annotation | mut_annotation) ~ !("=") } // General assignment: LHS is optional, RHS is any expression diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index 4ceaead1..46be4269 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -662,7 +662,7 @@ impl Line { if *is_mutable { format!("{var}: Mut") } else { - format!("{var}: Imu") + format!("{var}: Imm") } } Self::Statement { targets, value, .. } => { diff --git a/crates/lean_compiler/src/parser/parsers/statement.rs b/crates/lean_compiler/src/parser/parsers/statement.rs index 5e683514..c79edc67 100644 --- a/crates/lean_compiler/src/parser/parsers/statement.rs +++ b/crates/lean_compiler/src/parser/parsers/statement.rs @@ -284,7 +284,7 @@ impl Parse for AssertParser { } } -/// Parser for forward declarations: `x: Imu` or `x: Mut` +/// Parser for forward declarations: `x: Imm` or `x: Mut` pub struct ForwardDeclarationParser; impl Parse for ForwardDeclarationParser { @@ -294,7 +294,7 @@ impl Parse for ForwardDeclarationParser { // Parse variable name let var = next_inner_pair(&mut inner, "variable name")?.as_str().to_string(); - // Check for : Mut or : Imu annotation + // Check for : Mut or : Imm annotation let annotation = next_inner_pair(&mut inner, "type annotation")?; let is_mutable = annotation.as_rule() == Rule::mut_annotation; diff --git a/crates/lean_compiler/tests/test_data/error_13.py b/crates/lean_compiler/tests/test_data/error_13.py index 0450a5aa..2224ac22 100644 --- a/crates/lean_compiler/tests/test_data/error_13.py +++ b/crates/lean_compiler/tests/test_data/error_13.py @@ -2,7 +2,7 @@ def main(): - a: Imu + a: Imm a = 0 a = a + 1 if a == 1: diff --git a/crates/lean_compiler/tests/test_data/program_100.py b/crates/lean_compiler/tests/test_data/program_100.py index 5eeac3b6..a81d0480 100644 --- a/crates/lean_compiler/tests/test_data/program_100.py +++ b/crates/lean_compiler/tests/test_data/program_100.py @@ -2,8 +2,8 @@ def main(): - x: Imu - y: Imu + x: Imm + y: Imm cond = 1 if cond == 1: diff --git a/crates/lean_compiler/tests/test_data/program_109.py b/crates/lean_compiler/tests/test_data/program_109.py index b04f2608..77e86702 100644 --- a/crates/lean_compiler/tests/test_data/program_109.py +++ b/crates/lean_compiler/tests/test_data/program_109.py @@ -9,10 +9,10 @@ def main(): def test_func(a, b): x = 1 - mut_x_2: Imu + mut_x_2: Imm match a: case 0: - mut_x_1: Imu + mut_x_1: Imm mut_x_1 = x + 2 match b: case 0: diff --git a/crates/lean_compiler/tests/test_data/program_110.py b/crates/lean_compiler/tests/test_data/program_110.py index f863f968..df47ee5c 100644 --- a/crates/lean_compiler/tests/test_data/program_110.py +++ b/crates/lean_compiler/tests/test_data/program_110.py @@ -79,7 +79,7 @@ def main(): result = complex_compute(3, 4, 5) assert result == 47 - fwd_val: Imu + fwd_val: Imm cond = 1 if cond == 0: fwd_val = 100 diff --git a/crates/lean_compiler/tests/test_data/program_111.py b/crates/lean_compiler/tests/test_data/program_111.py index e13d3e32..9634df0f 100644 --- a/crates/lean_compiler/tests/test_data/program_111.py +++ b/crates/lean_compiler/tests/test_data/program_111.py @@ -85,7 +85,7 @@ def nested_mut_params(base): return acc def state_machine_step(current_state, phase): - result: Imu + result: Imm if phase == 0: if current_state == 0: result = 1 diff --git a/crates/lean_compiler/tests/test_data/program_112.py b/crates/lean_compiler/tests/test_data/program_112.py index 7cd0fe3b..51a9de2e 100644 --- a/crates/lean_compiler/tests/test_data/program_112.py +++ b/crates/lean_compiler/tests/test_data/program_112.py @@ -2,7 +2,7 @@ def main(): - result1: Imu + result1: Imm outer_sel = 1 match outer_sel: case 0: @@ -18,8 +18,8 @@ def main(): result1 = 456 assert result1 == 456 - counter: Imu - flag: Imu + counter: Imm + flag: Imm phase = 1 if phase == 0: @@ -40,8 +40,8 @@ def main(): assert counter2 == 15 assert flag2 == 400 - x: Imu - y: Imu + x: Imm + y: Imm init_sel = 0 if init_sel == 0: @@ -61,7 +61,7 @@ def main(): assert x2 == 220 assert y2 == 20 - outcome: Imu + outcome: Imm selector = 4 match selector: case 0: @@ -78,9 +78,9 @@ def main(): outcome = compute_outcome(5, 25) assert outcome == 84 - p: Imu - q: Imu - r: Imu + p: Imm + q: Imm + r: Imm s1 = 1 if s1 == 1: diff --git a/crates/lean_compiler/tests/test_data/program_114.py b/crates/lean_compiler/tests/test_data/program_114.py index 1d0b0b95..c92f7044 100644 --- a/crates/lean_compiler/tests/test_data/program_114.py +++ b/crates/lean_compiler/tests/test_data/program_114.py @@ -31,8 +31,8 @@ def main(): result4 = complex_nested_compute(2, 1, 3) assert result4 == 280 - fwd_x: Imu - fwd_y: Imu + fwd_x: Imm + fwd_y: Imm mode = 2 if mode == 0: @@ -90,7 +90,7 @@ def sum_array_func(arr, n: Const): def complex_nested_compute(outer, inner, depth): - result: Imu + result: Imm if outer == 0: result = 100 diff --git a/crates/lean_compiler/tests/test_data/program_143.py b/crates/lean_compiler/tests/test_data/program_143.py index 08be3b5d..1e16c510 100644 --- a/crates/lean_compiler/tests/test_data/program_143.py +++ b/crates/lean_compiler/tests/test_data/program_143.py @@ -71,7 +71,7 @@ def main(): assert sum == 10 # Test 13: Inline functions in if condition (comparison) - result13: Imu + result13: Imm if incr(incr(0)) == 2: result13 = 100 else: @@ -79,7 +79,7 @@ def main(): assert result13 == 100 # Test 14: Nested inline calls in both sides of if condition - result14: Imu + result14: Imm if double(3) == triple(2): result14 = 1 else: @@ -88,7 +88,7 @@ def main(): assert result14 == 1 # Test 15: Inline calls inside if/else branches - result15: Imu + result15: Imm if 1 == 1: result15 = incr(incr(incr(10))) else: @@ -96,7 +96,7 @@ def main(): assert result15 == 13 # Test 16: Multiple nested inline calls in if condition - result16: Imu + result16: Imm if incr(double(incr(1))) == 5: # incr(1) = 2, double(2) = 4, incr(4) = 5 result16 = 200 @@ -105,7 +105,7 @@ def main(): assert result16 == 200 # Test 17: Inline call with != comparison - result17: Imu + result17: Imm if incr(5) != 5: result17 = 300 else: @@ -152,7 +152,7 @@ def main(): assert sum23 == 21 # Test 24: Chained else-if with inline conditions - result24: Imu + result24: Imm x24 = 5 if incr(x24) == 4: result24 = 1 @@ -194,7 +194,7 @@ def double(x): @inline def triple(x): y: Mut = x - two: Imu + two: Imm match y - x + 1: case 0: assert False diff --git a/crates/lean_compiler/tests/test_data/program_144.py b/crates/lean_compiler/tests/test_data/program_144.py index b9dcd456..9ceb190f 100644 --- a/crates/lean_compiler/tests/test_data/program_144.py +++ b/crates/lean_compiler/tests/test_data/program_144.py @@ -8,7 +8,7 @@ def main(): # ========================================================================== # TEST 1: Basic - panic in else branch (the original bug case) # ========================================================================== - two: Imu + two: Imm if 1 == 1: two = 2 else: @@ -18,7 +18,7 @@ def main(): # ========================================================================== # TEST 2: panic in then branch # ========================================================================== - three: Imu + three: Imm if 1 != 1: assert False else: @@ -28,9 +28,9 @@ def main(): # ========================================================================== # TEST 3: Multiple mutable variables, panic in else # ========================================================================== - a: Imu - b: Imu - c: Imu + a: Imm + b: Imm + c: Imm if 1 == 1: a = 10 b = 20 @@ -44,7 +44,7 @@ def main(): # ========================================================================== # TEST 4: Nested if with panic in inner else # ========================================================================== - x: Imu + x: Imm if 1 == 1: if 2 == 2: x = 42 @@ -79,7 +79,7 @@ def main(): # ========================================================================== # TEST 7: Chain of else-if with panic in final else # ========================================================================== - result: Imu + result: Imm selector = 1 if selector == 0: result = 100 @@ -94,7 +94,7 @@ def main(): # ========================================================================== # TEST 8: Match with panic in one arm # ========================================================================== - matched: Imu + matched: Imm tag = 1 match tag: case 0: @@ -108,7 +108,7 @@ def main(): # ========================================================================== # TEST 9: Match where only one arm doesn't panic # ========================================================================== - only_valid: Imu + only_valid: Imm tag2 = 2 match tag2: case 0: @@ -124,7 +124,7 @@ def main(): # ========================================================================== # TEST 10: Panic in deeply nested structure # ========================================================================== - deep: Imu + deep: Imm if 1 == 1: if 1 == 1: if 1 == 1: @@ -151,7 +151,7 @@ def main(): # ========================================================================== # TEST 12: Forward declared with = None panic in branch # ========================================================================== - fwd: Imu + fwd: Imm cond = 1 if cond == 1: fwd = 777 @@ -162,8 +162,8 @@ def main(): # ========================================================================== # TEST 13: Both mutable and immutable forward decl with panic # ========================================================================== - imm: Imu - mtbl: Imu + imm: Imm + mtbl: Imm flag = 0 if flag == 0: imm = 100 @@ -206,7 +206,7 @@ def main(): # ========================================================================== # TEST 17: Nested match with panic # ========================================================================== - nested_match: Imu + nested_match: Imm outer = 1 match outer: case 0: @@ -223,7 +223,7 @@ def main(): # ========================================================================== # TEST 18: If inside match with panic # ========================================================================== - if_in_match: Imu + if_in_match: Imm m18_sel = 0 match m18_sel: case 0: @@ -239,7 +239,7 @@ def main(): # ========================================================================== # TEST 19: Match inside if with panic # ========================================================================== - match_in_if: Imu + match_in_if: Imm cond19 = 1 if cond19 == 1: tag19 = 1 @@ -255,7 +255,7 @@ def main(): # ========================================================================== # TEST 20: Panic after partial assignment # ========================================================================== - partial: Imu + partial: Imm check = 0 if check == 0: partial_tmp: Mut = 1 @@ -288,7 +288,7 @@ def main(): # ========================================================================== # TEST 23: Multiple levels - if/match/if with panics # ========================================================================== - multi_level: Imu + multi_level: Imm c1 = 1 if c1 == 1: s1 = 0 @@ -308,7 +308,7 @@ def main(): # ========================================================================== # TEST 24: Panic in both outer branches but inner assigns # ========================================================================== - inner_assigns: Imu + inner_assigns: Imm outer24 = 0 match outer24: case 0: @@ -324,9 +324,9 @@ def main(): # ========================================================================== # TEST 25: Complex - multiple vars, nested, with panics # ========================================================================== - va: Imu - vb: Imu - vc: Imu + va: Imm + vb: Imm + vc: Imm outer25 = 1 if outer25 == 1: @@ -353,7 +353,7 @@ def main(): # Helper function for TEST 14 def test_early_return(flag): - result: Imu + result: Imm if flag == 1: result = 10 else: @@ -363,8 +363,8 @@ def test_early_return(flag): # Helper function for TEST 15 def test_multi_return(flag): - a: Imu - b: Imu + a: Imm + b: Imm if flag == 1: a = 100 b = 200 diff --git a/crates/lean_compiler/tests/test_data/program_17.py b/crates/lean_compiler/tests/test_data/program_17.py index d9274d42..092e2fc2 100644 --- a/crates/lean_compiler/tests/test_data/program_17.py +++ b/crates/lean_compiler/tests/test_data/program_17.py @@ -7,7 +7,7 @@ def main(): def func(): - a: Imu + a: Imm if 0 == 0: a = aux() return a diff --git a/crates/lean_compiler/tests/test_data/program_170.py b/crates/lean_compiler/tests/test_data/program_170.py index f5652fe7..62f5d2e4 100644 --- a/crates/lean_compiler/tests/test_data/program_170.py +++ b/crates/lean_compiler/tests/test_data/program_170.py @@ -30,14 +30,14 @@ def main(): x = 5 y = 10 - z: Imu + z: Imm if x + y == 15: z = 1 else: z = 0 assert z == 1 - w: Imu + w: Imm if x + y * 2 == 25: w = 100 else: diff --git a/crates/lean_compiler/tests/test_data/program_171.py b/crates/lean_compiler/tests/test_data/program_171.py index bb3c47c6..4af97391 100644 --- a/crates/lean_compiler/tests/test_data/program_171.py +++ b/crates/lean_compiler/tests/test_data/program_171.py @@ -1,7 +1,7 @@ from snark_lib import * # Comprehensive test for inlining with mutable variables in branches -# Tests: @inline functions, Mut/Imu variables, match, if/else, loops, nesting +# Tests: @inline functions, Mut/Imm variables, match, if/else, loops, nesting # ============================================================================ # Simple inline functions with mutable variables @@ -118,7 +118,7 @@ def inline_with_if(x): @inline def inline_with_match(selector): """Inline function that itself contains match""" - out: Imu + out: Imm match selector: case 0: out = 1000 @@ -132,7 +132,7 @@ def inline_with_match(selector): @inline def inline_with_nested_branch(a, b): """Inline with nested if inside match""" - res: Imu + res: Imm match a: case 0: if b == 0: @@ -321,7 +321,7 @@ def main(): # TEST 1: Basic inline in match arms (different inlined vars per arm) # This was the original bug - each arm gets its own inlined variable names # ------------------------------------------------------------------- - res1: Imu + res1: Imm match 0: case 0: res1 = count_up(5) @@ -329,7 +329,7 @@ def main(): res1 = count_up(10) assert res1 == 5 - res2: Imu + res2: Imm match 1: case 0: res2 = count_up(5) @@ -340,7 +340,7 @@ def main(): # ------------------------------------------------------------------- # TEST 2: Different inline functions in different arms # ------------------------------------------------------------------- - res3: Imu + res3: Imm match 0: case 0: res3 = count_up(3) @@ -350,7 +350,7 @@ def main(): res3 = double_count(3) assert res3 == 3 - res4: Imu + res4: Imm match 1: case 0: res4 = count_up(3) @@ -360,7 +360,7 @@ def main(): res4 = double_count(3) assert res4 == 3 # 0+1+2 - res5: Imu + res5: Imm match 2: case 0: res5 = count_up(3) @@ -392,7 +392,7 @@ def main(): # ------------------------------------------------------------------- # TEST 4: Multiple inlines in same arm # ------------------------------------------------------------------- - multi: Imu + multi: Imm match 0: case 0: a = count_up(3) @@ -406,7 +406,7 @@ def main(): # ------------------------------------------------------------------- # TEST 5: Nested inline functions in match arms # ------------------------------------------------------------------- - nested1: Imu + nested1: Imm match 0: case 0: nested1 = outer_with_inner(4) @@ -416,7 +416,7 @@ def main(): # = 0 + 0 + 1 + 3 = 4 assert nested1 == 4 - nested2: Imu + nested2: Imm match 1: case 0: nested2 = outer_with_inner(4) @@ -428,7 +428,7 @@ def main(): # ------------------------------------------------------------------- # TEST 6: Deep nesting in match # ------------------------------------------------------------------- - deep1: Imu + deep1: Imm match 0: case 0: deep1 = deep_nested(3) @@ -442,14 +442,14 @@ def main(): # ------------------------------------------------------------------- # TEST 7: Inline in if/else branches # ------------------------------------------------------------------- - if_res1: Imu + if_res1: Imm if 1 == 1: if_res1 = count_up(7) else: if_res1 = count_up(3) assert if_res1 == 7 - if_res2: Imu + if_res2: Imm if 1 == 0: if_res2 = count_up(7) else: @@ -459,7 +459,7 @@ def main(): # ------------------------------------------------------------------- # TEST 8: Nested if/else with inlines # ------------------------------------------------------------------- - nested_if: Imu + nested_if: Imm if 1 == 1: if 2 == 2: nested_if = sum_range(0, 5) @@ -472,7 +472,7 @@ def main(): # ------------------------------------------------------------------- # TEST 9: Match inside if with inlines # ------------------------------------------------------------------- - mixed: Imu + mixed: Imm if 1 == 1: match 1: case 0: @@ -486,7 +486,7 @@ def main(): # ------------------------------------------------------------------- # TEST 10: If inside match with inlines # ------------------------------------------------------------------- - mixed2: Imu + mixed2: Imm match 0: case 0: if 1 == 1: @@ -500,7 +500,7 @@ def main(): # ------------------------------------------------------------------- # TEST 11: Complex mutable variables in inline # ------------------------------------------------------------------- - cx: Imu + cx: Imm match 0: case 0: cx = complex_muts(4) @@ -519,7 +519,7 @@ def main(): # TEST 12: Mix of Mut and immutable in branches with inlines # ------------------------------------------------------------------- outer_mut: Mut = 10 - inner_imu: Imu + inner_imu: Imm match 0: case 0: local_imm = with_immutable(3) @@ -535,7 +535,7 @@ def main(): # ------------------------------------------------------------------- # TEST 13: Inline inside unroll loop inside match # ------------------------------------------------------------------- - unroll_in_match: Imu + unroll_in_match: Imm match 0: case 0: acc: Mut = 0 @@ -550,10 +550,10 @@ def main(): # ------------------------------------------------------------------- # TEST 14: Multiple match levels with different inlines at each # ------------------------------------------------------------------- - multi_match: Imu + multi_match: Imm match 1: case 0: - inner: Imu + inner: Imm match 0: case 0: inner = count_up(2) @@ -561,7 +561,7 @@ def main(): inner = count_up(3) multi_match = inner case 1: - inner2: Imu + inner2: Imm match 1: case 0: inner2 = sum_range(0, 2) @@ -573,7 +573,7 @@ def main(): # ------------------------------------------------------------------- # TEST 15: Same inline function called multiple times in same arm # ------------------------------------------------------------------- - same_fn: Imu + same_fn: Imm match 0: case 0: r1 = count_up(3) @@ -607,7 +607,7 @@ def main(): # ------------------------------------------------------------------- # TEST 17: Variables declared inside only some branches # ------------------------------------------------------------------- - outside: Imu + outside: Imm match 0: case 0: local_only_here = count_up(5) @@ -622,7 +622,7 @@ def main(): # ------------------------------------------------------------------- # TEST 18: Very deeply nested structure # ------------------------------------------------------------------- - very_deep: Imu + very_deep: Imm if 1 == 1: match 0: case 0: @@ -668,7 +668,7 @@ def main(): # ------------------------------------------------------------------- # TEST 20: Inline result used immediately in arithmetic in branch # ------------------------------------------------------------------- - arith: Imu + arith: Imm match 0: case 0: arith = count_up(3) * 10 + sum_range(0, 3) * 100 @@ -684,7 +684,7 @@ def main(): # ------------------------------------------------------------------- # TEST 21: Inline containing if/else in different match arms # ------------------------------------------------------------------- - t21: Imu + t21: Imm match 0: case 0: t21 = inline_with_if(0) @@ -693,7 +693,7 @@ def main(): # inline_with_if(0): result=100, result=100+0=100 assert t21 == 100 - t21b: Imu + t21b: Imm match 1: case 0: t21b = inline_with_if(0) @@ -705,7 +705,7 @@ def main(): # ------------------------------------------------------------------- # TEST 22: Inline containing match in different branches # ------------------------------------------------------------------- - t22: Imu + t22: Imm match 0: case 0: t22 = inline_with_match(0) @@ -715,7 +715,7 @@ def main(): t22 = inline_with_match(2) assert t22 == 1000 - t22b: Imu + t22b: Imm match 2: case 0: t22b = inline_with_match(0) @@ -728,7 +728,7 @@ def main(): # ------------------------------------------------------------------- # TEST 23: Inline with nested branches called in nested branches # ------------------------------------------------------------------- - t23: Imu + t23: Imm match 0: case 0: if 1 == 1: @@ -740,7 +740,7 @@ def main(): # inline_with_nested_branch(0, 1): a=0 -> if b==0 else -> 20 assert t23 == 20 - t23b: Imu + t23b: Imm match 1: case 0: t23b = inline_with_nested_branch(0, 0) @@ -752,8 +752,8 @@ def main(): # ------------------------------------------------------------------- # TEST 24: Multi-return inline in match arms # ------------------------------------------------------------------- - t24a: Imu - t24b: Imu + t24a: Imm + t24b: Imm match 0: case 0: t24a, t24b = multi_return_inline(5) @@ -763,8 +763,8 @@ def main(): assert t24a == 5 assert t24b == 110 - t24c: Imu - t24d: Imu + t24c: Imm + t24d: Imm match 1: case 0: t24c, t24d = multi_return_inline(5) @@ -777,9 +777,9 @@ def main(): # ------------------------------------------------------------------- # TEST 25: Triple return inline in branches # ------------------------------------------------------------------- - t25a: Imu - t25b: Imu - t25c: Imu + t25a: Imm + t25b: Imm + t25c: Imm match 0: case 0: t25a, t25b, t25c = triple_return(10) @@ -793,7 +793,7 @@ def main(): # ------------------------------------------------------------------- # TEST 26: 4-level deep inline nesting in match arms # ------------------------------------------------------------------- - t26: Imu + t26: Imm match 0: case 0: t26 = level_a(1) @@ -809,7 +809,7 @@ def main(): # = (1+2) + 20 + 200 + 2000 = 2223 assert t26 == 2223 - t26b: Imu + t26b: Imm match 3: case 0: t26b = level_a(5) @@ -825,7 +825,7 @@ def main(): # ------------------------------------------------------------------- # TEST 27: Inline with Array in match arms # ------------------------------------------------------------------- - t27: Imu + t27: Imm match 0: case 0: t27 = inline_with_array(10) @@ -834,7 +834,7 @@ def main(): # inline_with_array(10): 10+11+12+13 = 46 assert t27 == 46 - t27b: Imu + t27b: Imm match 1: case 0: t27b = inline_with_array(10) @@ -846,7 +846,7 @@ def main(): # ------------------------------------------------------------------- # TEST 28: Inline modifying array in branches # ------------------------------------------------------------------- - t28: Imu + t28: Imm match 0: case 0: t28 = inline_modify_array(1) @@ -858,7 +858,7 @@ def main(): # ------------------------------------------------------------------- # TEST 29: Chained inline calls in match arms # ------------------------------------------------------------------- - t29: Imu + t29: Imm match 0: case 0: # chain_a(5)=7, chain_b(7)=28, chain_c(28)=48 @@ -867,7 +867,7 @@ def main(): t29 = chain_a(100) assert t29 == 48 - t29b: Imu + t29b: Imm match 1: case 0: t29b = chain_c(chain_b(chain_a(1))) @@ -879,7 +879,7 @@ def main(): # ------------------------------------------------------------------- # TEST 30: Different chain patterns in different arms # ------------------------------------------------------------------- - t30: Imu + t30: Imm match 0: case 0: t30 = chain_a(chain_a(chain_a(0))) @@ -890,7 +890,7 @@ def main(): # chain_a(0)=2, chain_a(2)=4, chain_a(4)=6 assert t30 == 6 - t30b: Imu + t30b: Imm match 1: case 0: t30b = chain_a(chain_a(chain_a(0))) @@ -904,7 +904,7 @@ def main(): # ------------------------------------------------------------------- # TEST 31: Stress test - many variables inline in match # ------------------------------------------------------------------- - t31: Imu + t31: Imm match 0: case 0: t31 = many_vars(0) @@ -918,7 +918,7 @@ def main(): # ------------------------------------------------------------------- # TEST 32: Multiple multi-return inlines in same arm # ------------------------------------------------------------------- - t32_sum: Imu + t32_sum: Imm match 0: case 0: a1, b1 = multi_return_inline(3) @@ -936,7 +936,7 @@ def main(): # ------------------------------------------------------------------- # TEST 33: 5-way match with all different inline types # ------------------------------------------------------------------- - t33: Imu + t33: Imm match 0: case 0: t33 = count_up(10) @@ -950,7 +950,7 @@ def main(): t33 = inline_with_array(1) assert t33 == 10 - t33b: Imu + t33b: Imm match 4: case 0: t33b = count_up(10) @@ -968,10 +968,10 @@ def main(): # ------------------------------------------------------------------- # TEST 34: Triple nested match with inlines at each level # ------------------------------------------------------------------- - t34: Imu + t34: Imm match 0: case 0: - inner1: Imu + inner1: Imm match 1: case 0: tmp34a = count_up(2) @@ -986,10 +986,10 @@ def main(): assert t34 == 1423 # Additional triple nesting test - without forward declaration inside innermost - t34b: Imu + t34b: Imm match 0: case 0: - mid1: Imu + mid1: Imm match 0: case 0: # Use inline directly without forward declaration @@ -1003,10 +1003,10 @@ def main(): assert t34b == 1105 # Test forward declaration with nested match and inline - t34c: Imu + t34c: Imm match 0: case 0: - val34c: Imu + val34c: Imm match 0: case 0: val34c = sum_range(0, 5) @@ -1041,12 +1041,12 @@ def main(): assert deep_mut == 210 # ------------------------------------------------------------------- - # TEST 36: Multiple forward-declared Imu assigned via inlines + # TEST 36: Multiple forward-declared Imm assigned via inlines # ------------------------------------------------------------------- - fwd1: Imu - fwd2: Imu - fwd3: Imu - fwd4: Imu + fwd1: Imm + fwd2: Imm + fwd3: Imm + fwd4: Imm match 0: case 0: fwd1 = count_up(1) @@ -1086,7 +1086,7 @@ def main(): # ------------------------------------------------------------------- # TEST 38: If-else-if chain with different inlines # ------------------------------------------------------------------- - t38: Imu + t38: Imm if 0 == 1: t38 = count_up(100) else: @@ -1126,7 +1126,7 @@ def main(): # ------------------------------------------------------------------- # TEST 40: Inline returning mutable at different states # ------------------------------------------------------------------- - t40: Imu + t40: Imm match 0: case 0: # complex_muts returns computation of interdependent muts @@ -1164,7 +1164,7 @@ def main(): # TEST 42: Deeply nested with mixed mutable tracking # ------------------------------------------------------------------- outer_m: Mut = 100 - t42: Imu + t42: Imm if 1 == 1: outer_m = outer_m + 50 match 0: @@ -1197,14 +1197,14 @@ def main(): # ------------------------------------------------------------------- # TEST 43: All arms have different nesting patterns # ------------------------------------------------------------------- - t43: Imu + t43: Imm match 0: case 0: # Flat t43 = count_up(5) case 1: # One level nested - if_inner: Imu + if_inner: Imm if 1 == 1: if_inner = sum_range(0, 10) else: @@ -1212,7 +1212,7 @@ def main(): t43 = if_inner case 2: # Two levels nested - m_inner: Imu + m_inner: Imm match 0: case 0: m_inner = level_a(1) @@ -1221,7 +1221,7 @@ def main(): t43 = m_inner case 3: # Three levels nested - deep_inner: Imu + deep_inner: Imm if 1 == 1: match 0: case 0: @@ -1271,7 +1271,7 @@ def main(): # ------------------------------------------------------------------- # TEST 45: Inline calling another inline that has internal branches # ------------------------------------------------------------------- - t45: Imu + t45: Imm match 0: case 0: # outer_with_inner calls inner_loop diff --git a/crates/lean_compiler/tests/test_data/program_172.py b/crates/lean_compiler/tests/test_data/program_172.py index da813966..ba54fc05 100644 --- a/crates/lean_compiler/tests/test_data/program_172.py +++ b/crates/lean_compiler/tests/test_data/program_172.py @@ -8,7 +8,7 @@ def helper_const(n: Const): def main(): - # Test 1: Basic match_range - no forward declaration needed (auto-generated as Imu) + # Test 1: Basic match_range - no forward declaration needed (auto-generated as Imm) x = 2 r1 = match_range(x, range(0, 4), lambda i: i * 100) assert r1 == 200 diff --git a/crates/lean_compiler/tests/test_data/program_174.py b/crates/lean_compiler/tests/test_data/program_174.py index 19be7961..94e1f4d1 100644 --- a/crates/lean_compiler/tests/test_data/program_174.py +++ b/crates/lean_compiler/tests/test_data/program_174.py @@ -40,7 +40,7 @@ def main(): def match_start_at_1(x): - result: Imu + result: Imm match x: case 1: result = 100 @@ -54,7 +54,7 @@ def match_start_at_1(x): def match_start_at_5(x): - result: Imu + result: Imm match x: case 5: result = 50 @@ -68,7 +68,7 @@ def match_start_at_5(x): def match_start_at_10(x): - result: Imu + result: Imm match x: case 10: result = 1000 @@ -95,7 +95,7 @@ def match_nonzero_mutable(x): def nested_nonzero_match(outer, inner): - result: Imu + result: Imm match outer: case 1: match inner: @@ -125,7 +125,7 @@ def nested_nonzero_match(outer, inner): def nonzero_match_in_if(cond, x): - result: Imu + result: Imm if cond == 0: result = 0 else: diff --git a/crates/lean_compiler/tests/test_data/program_67.py b/crates/lean_compiler/tests/test_data/program_67.py index ed80db93..5c91e661 100644 --- a/crates/lean_compiler/tests/test_data/program_67.py +++ b/crates/lean_compiler/tests/test_data/program_67.py @@ -2,7 +2,7 @@ def main(): - mut_a: Imu + mut_a: Imm mut_a = 5 assert mut_a == 5 return diff --git a/crates/lean_compiler/tests/test_data/program_68.py b/crates/lean_compiler/tests/test_data/program_68.py index 13eb50c6..a4797324 100644 --- a/crates/lean_compiler/tests/test_data/program_68.py +++ b/crates/lean_compiler/tests/test_data/program_68.py @@ -10,10 +10,10 @@ def main(): def test_func(a, b): x = 1 - mut_x_2: Imu + mut_x_2: Imm match a: case 0: - mut_x_1: Imu + mut_x_1: Imm mut_x_1 = x + 2 match b: case 0: diff --git a/crates/lean_compiler/tests/test_data/program_69.py b/crates/lean_compiler/tests/test_data/program_69.py index 4c19fda8..ecc0051e 100644 --- a/crates/lean_compiler/tests/test_data/program_69.py +++ b/crates/lean_compiler/tests/test_data/program_69.py @@ -15,20 +15,20 @@ def main(): def compute(a, b, c): base = 1000 - outer_val: Imu - mid_val: Imu - inner_val: Imu + outer_val: Imm + mid_val: Imm + inner_val: Imm match a: case 0: outer_val = 5 - local_a: Imu + local_a: Imm local_a = a + outer_val match b: case 0: mid_val = 3 - local_b: Imu + local_b: Imm local_b = local_a + mid_val match c: @@ -38,7 +38,7 @@ def compute(a, b, c): inner_val = base + local_b + c case 1: mid_val = 7 - local_b: Imu + local_b: Imm local_b = local_a + mid_val match c: @@ -48,13 +48,13 @@ def compute(a, b, c): inner_val = base + local_b + c case 1: outer_val = 15 - local_a: Imu + local_a: Imm local_a = a + outer_val match b: case 0: mid_val = 20 - local_b: Imu + local_b: Imm local_b = local_a + mid_val match c: @@ -64,7 +64,7 @@ def compute(a, b, c): inner_val = base + local_b + c case 1: mid_val = 30 - local_b: Imu + local_b: Imm local_b = local_a + mid_val match c: diff --git a/crates/lean_compiler/tests/test_data/program_83.py b/crates/lean_compiler/tests/test_data/program_83.py index aa423c9f..36bd35d3 100644 --- a/crates/lean_compiler/tests/test_data/program_83.py +++ b/crates/lean_compiler/tests/test_data/program_83.py @@ -2,7 +2,7 @@ def main(): - x: Imu + x: Imm cond = 1 if cond == 1: x = 10 diff --git a/crates/lean_compiler/tests/test_data/soundness_2.py b/crates/lean_compiler/tests/test_data/soundness_2.py index 3fa2867b..630d756a 100644 --- a/crates/lean_compiler/tests/test_data/soundness_2.py +++ b/crates/lean_compiler/tests/test_data/soundness_2.py @@ -12,7 +12,7 @@ def main(): offset = p[6] total = p[7] - computed: Imu + computed: Imm match mode: case 0: computed = add_op(x, y) @@ -24,7 +24,7 @@ def main(): computed = combined(x, y) assert computed == expected - adjusted: Imu + adjusted: Imm if flag == 0: adjusted = bump(secondary, 1) elif flag == 1: diff --git a/crates/lean_compiler/tests/test_data/soundness_5.py b/crates/lean_compiler/tests/test_data/soundness_5.py index 3333ab2f..eaa31b7c 100644 --- a/crates/lean_compiler/tests/test_data/soundness_5.py +++ b/crates/lean_compiler/tests/test_data/soundness_5.py @@ -36,7 +36,7 @@ def main(): assert paired_sum(seed, n) == paired - chosen: Imu + chosen: Imm if flag == 1: chosen = seed else: diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 30eb9f00..8fd7e6f0 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -152,16 +152,16 @@ Constraints on inline functions (compiler limitations): Exactly one `return`, pl | ------------- | ---------- | ---------------------------------------------- | | `x = 10` | immutable | cannot be reassigned | | `x: Mut = 10` | mutable | reassignable | -| `x: Imu` | immutable | forward declaration; assign exactly once later | +| `x: Imm` | immutable | forward declaration; assign exactly once later | | `x: Mut` | mutable | forward declaration; reassignable later | ### Forward declarations -Use `x: Imu` when you want an immutable binding but the value comes from a +Use `x: Imm` when you want an immutable binding but the value comes from a branch: ```python -result: Imu +result: Imm if cond == 1: result = 10 else: @@ -196,7 +196,10 @@ b = b + 1 # OK ```python buffer = Array(16) # allocate 16 field elements buffer[0] = 42 -x = buffer[5] +buffer[0] = 42 # Valid +# buffer[0] = 41 # ERROR: conflicting write (read only memory) +buffer[5] = 34 +x = buffer[5] # x = 34 matrix = Array(64) # 2D via manual indexing matrix[row * 8 + col] = value @@ -206,21 +209,9 @@ ptr2[0] = 100 # same as buffer[5] = 100 ``` `Array(n)` returns a pointer to a freshly allocated block of `n` field -elements. `n` may be a compile-time constant (the common case) or a runtime -value; the runner handles both. Memory is **write-once**: a cell may be -written more than once only if all writes store the same value. The second -write of a different value is a runtime error at the point of the write. - -```python -arr = Array(3) -arr[0] = 10 -arr[0] = 10 # OK: same value -arr[0] = 20 # ERROR: conflicting write -``` - -`Array` cells are not implicitly mutable — if you need a running accumulator, -use `x: Mut` for the variable and only commit final values to memory. Pointer -arithmetic (`ptr + offset`) is the way to address into sub-regions. +elements. `n` may be a compile-time constant (more efficient, analogy: allocated on the stack) or a runtime +value (less efficient, analogy: allocated on the heap). Memory is **write-once**: a cell may be +written more than once only if all writes store the same value. ## Control flow @@ -266,7 +257,7 @@ result = match_range(n, range(1, 5), lambda i: compute(i)) expands to ```python -result: Imu +result: Imm match n: case 1: result = compute(1) case 2: result = compute(2) @@ -667,7 +658,7 @@ The runner lays out memory as 2. Reach for `: Const` parameters when the function body needs `unroll` over the parameter, or when array sizes depend on it. 3. `if` / `elif` branches that assign to the same outer variable should - forward-declare it (`x: Imu` or `x: Mut`) before the branch. + forward-declare it (`x: Imm` or `x: Mut`) before the branch. 4. **`match`** / **`match_range`** dispatch is undefined for out-of-range values — always pair it with a `debug_assert` (or `assert`) on the value. 5. `match` patterns must be contiguous integers; if you need gaps, restructure diff --git a/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py b/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py index e7155545..27895a4e 100644 --- a/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py +++ b/crates/rec_aggregation/zkdsl_implem/fiat_shamir.py @@ -177,8 +177,8 @@ def fs_receive_ef_inlined(fs, n): def fs_receive_ef_by_log_dynamic(fs, log_n, min_value: Const, max_value: Const): debug_assert(log_n < max_value) debug_assert(min_value <= log_n) - new_fs: Imu - ef_ptr: Imu + new_fs: Imm + ef_ptr: Imm new_fs, ef_ptr = match_range(log_n, range(min_value, max_value), lambda ln: fs_receive_ef(fs, 2**ln)) return new_fs, ef_ptr diff --git a/crates/rec_aggregation/zkdsl_implem/hashing.py b/crates/rec_aggregation/zkdsl_implem/hashing.py index 5146ee82..00b8a700 100644 --- a/crates/rec_aggregation/zkdsl_implem/hashing.py +++ b/crates/rec_aggregation/zkdsl_implem/hashing.py @@ -111,8 +111,8 @@ def euclidian_div_runtime(a, b): # Requires: # 1 <= b < 2^14 # floor(a / b) < 2^16 (so that q*b + r stays well below p) - q: Imu - r: Imu + q: Imm + r: Imm hint_div_floor(a, b, q, r) assert r < b assert q < 2 ** 16 diff --git a/crates/rec_aggregation/zkdsl_implem/recursion.py b/crates/rec_aggregation/zkdsl_implem/recursion.py index 487eddb3..af081d4c 100644 --- a/crates/rec_aggregation/zkdsl_implem/recursion.py +++ b/crates/rec_aggregation/zkdsl_implem/recursion.py @@ -709,7 +709,7 @@ def compute_total_gkr_n_vars(log_memory, log_bytecode_padded, tables_heights): def evaluate_air_constraints(table_index, inner_evals, air_alpha_powers, logup_alphas_eq_poly): - res: Imu + res: Imm debug_assert(table_index < N_TABLES) match table_index: case 0: diff --git a/crates/rec_aggregation/zkdsl_implem/utils.py b/crates/rec_aggregation/zkdsl_implem/utils.py index b95eedfa..d4b2e983 100644 --- a/crates/rec_aggregation/zkdsl_implem/utils.py +++ b/crates/rec_aggregation/zkdsl_implem/utils.py @@ -236,7 +236,7 @@ def mle_of_01234567_etc(point, n): @inline def checked_less_than(a, b): - res: Imu + res: Imm hint_less_than(a, b, res) assert res * (1 - res) == 0 if res == 1: @@ -249,7 +249,7 @@ def checked_less_than(a, b): @inline def maximum(a, b): is_a_less_than_b = checked_less_than(a, b) - res: Imu + res: Imm if is_a_less_than_b == 1: res = b else: @@ -809,7 +809,7 @@ def _verify_log2_large(n, log2: Const): def log2_ceil_runtime(n): # requires: 2 < n <= 2^30 - log2: Imu + log2: Imm hint_log2_ceil(n, log2) assert log2 < 31 if two_exp(log2) != n: diff --git a/crates/rec_aggregation/zkdsl_implem/whir.py b/crates/rec_aggregation/zkdsl_implem/whir.py index a3c84cb8..d14a10ef 100644 --- a/crates/rec_aggregation/zkdsl_implem/whir.py +++ b/crates/rec_aggregation/zkdsl_implem/whir.py @@ -42,7 +42,7 @@ def whir_open( domain_sz: Mut = n_vars + initial_log_inv_rate for r in range(0, n_rounds): - is_first_round: Imu + is_first_round: Imm if r == 0: is_first_round = 1 else: @@ -313,7 +313,7 @@ def sample_stir_indexes_and_fold( merkle_leaves = Array(num_queries) circle_values = Array(num_queries) - n_chunks_per_answer: Imu + n_chunks_per_answer: Imm # the number of chunk of 8 field elements per merkle leaf opened if merkle_leaves_in_basefield == 1: n_chunks_per_answer = two_pow_folding_factor @@ -410,9 +410,9 @@ def polynomial_sum_at_0_and_1(coeffs, degree, dst): def parse_commitment(fs, num_ood): - root: Imu - ood_points: Imu - ood_evals: Imu + root: Imm + ood_points: Imm + ood_evals: Imm debug_assert(num_ood < 5) debug_assert(num_ood != 0) new_fs, root, ood_points, ood_evals = match_range( @@ -441,15 +441,15 @@ def get_whir_params(n_vars, log_inv_rate): debug_assert(MIN_WHIR_LOG_INV_RATE <= log_inv_rate) debug_assert(log_inv_rate <= MAX_WHIR_LOG_INV_RATE) - num_queries: Imu + num_queries: Imm num_queries = get_num_queries(log_inv_rate, n_vars) - query_grinding_bits: Imu + query_grinding_bits: Imm query_grinding_bits = get_query_grinding_bits(log_inv_rate, n_vars) num_oods = get_num_oods(log_inv_rate, n_vars) - folding_grinding: Imu + folding_grinding: Imm folding_grinding = get_folding_grinding(log_inv_rate, n_vars) return n_rounds, final_vars, num_queries, num_oods, query_grinding_bits, folding_grinding From 88c33e31dc42163bf6d80d27728dd0e8ceae9322 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 28 May 2026 01:10:58 +0400 Subject: [PATCH 07/13] wip --- crates/lean_compiler/zkDSL.md | 46 +++++++++++++++++------------------ 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 8fd7e6f0..ca153384 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -227,34 +227,37 @@ else: ``` Comparison operators on conditions: `==`, `!=`, `<`, `<=`. There is **no** `>` -or `>=` — flip the operands to get the same effect. +or `>=` (flip the operands to get the same effect). ### `match` -Patterns must be a contiguous run of integers: +Patterns must be a set of integers of the form [n, n+1, n + 2, ...]: ```python match value: - case 5: result = 500 - case 6: result = 600 - case 7: result = 700 + case 5: + result = 500 + do_stuf() + case 6: + result = 600 + do_other_stuf() + case 7: + result = 700 + ... ``` The matched value must lie inside the listed range; out-of-range values produce -undefined behaviour. Use a `debug_assert` (or `assert`, if you want it to be -enforced by the proof) to guard the input. +undefined behaviour: **It's the responsability of the program to ensure this** (no checks added by the compiler). Letting a prover-controlled value escape the range in a `range` is a critical vulnerability. ### `match_range` -`match_range` is the workhorse for *dispatching a runtime value to a const- -parameter function*. It is a compile-time construct that expands into a -forward-declared variable plus a `match` over a contiguous range of integers. +`match_range` enables to automatically generate a `match` with repeated arms. ```python result = match_range(n, range(1, 5), lambda i: compute(i)) ``` -expands to +is expanded by the compiler to: ```python result: Imm @@ -265,7 +268,7 @@ match n: case 4: result = compute(4) ``` -You can chain several `(range, lambda)` pairs, provided the ranges are +It's possible to chain several `(range, lambda)` pairs, provided the ranges are **contiguous** (the end of one is the start of the next): ```python @@ -275,35 +278,30 @@ result = match_range(n, ``` Multiple return values are supported via tuple unpacking. The bindings produced -by `match_range` are always immutable — forward-declare with `: Mut` (and then +by `match_range` are always immutable. Forward-declare with `: Mut` (and then reassign) if you need them mutable later: ```python +a: Mut a, b = match_range(n, range(0, 4), lambda i: two_values(i)) +a += 1 ``` -Idiomatic use — dispatching a runtime length to a function that requires a -compile-time length: +Idiomatic use: enables to dispatch a runtime value to a const-parameter function. ```python def helper_const(n: Const): return n * n def compute(value): - debug_assert(value < 10) + assert value < 10 return match_range(value, range(0, 10), lambda i: helper_const(i)) ``` - -**Range validity is the caller's job.** A `match_range` whose input falls -outside any listed range is undefined behaviour at runtime — always pair it -with a `debug_assert` (or `assert`, if you want the proof to enforce it) on the -dispatched value. Skipping this guard is by far the most common source of -silent bugs in zkDSL. +Similar to `match`, range validity of the matched value is the responsibility of the program, not the compiler. Letting a prover-controlled value escape the range in a `match_range` is a critical vulnerability. ### For loops -Three loop forms, all written `for i in (start, end):`. Bounds and -behaviour: +Three loop forms, all written `for i in (start, end):`. Ranging from `start`, `start + 1`, ..., up to `end - 1`. | Loop form | When | | -------------------------------- | -------------------------------------------------------------------------- | From 0e91d490c18c0642eae3c90300f3ca294c3147c2 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 28 May 2026 01:45:23 +0400 Subject: [PATCH 08/13] loops --- crates/lean_compiler/zkDSL.md | 136 ++++++++++++++++++++-------------- 1 file changed, 80 insertions(+), 56 deletions(-) diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index ca153384..63ab82f1 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -301,31 +301,17 @@ Similar to `match`, range validity of the matched value is the responsibility of ### For loops -Three loop forms, all written `for i in (start, end):`. Ranging from `start`, `start + 1`, ..., up to `end - 1`. +Three loop forms, all written `for i in (start, end):`. The +iterator visits `start, start + 1, ..., end - 1`. -| Loop form | When | -| -------------------------------- | -------------------------------------------------------------------------- | -| `for i in range(a, b):` | Runtime loop. Compiled into a recursive function (no `break`/`continue`). | -| `for i in unroll(a, b):` | Compile-time expansion; `a` and `b` must both be compile-time constants. | -| `for i in parallel_range(a, b):` | Runtime loop; iterations are executed in parallel by the runner via rayon. | +Restrictions shared by all three forms: -`parallel_range` requires the loop body to be iteration-independent. The -runner executes the first iteration sequentially to learn its memory footprint, -then runs the rest of the iterations concurrently — so anything cross-iteration -must hold a-priori, since there is no synchronization: +- No `break` or `continue` (not in the grammar). -- No `Mut` variables carried across iterations (each iteration writes only to - its own call frame and to addresses disjoint from every other iteration). -- Identical memory footprint per iteration. -- Identical hint consumption per iteration (witness hints, XMSS-specific - decomposition hints, Merkle hints, etc.). +#### `range(a, b)`: runtime loop -These constraints are **not** checked at compile time. Violating them produces -silently wrong proofs. - -Mutable variables inside non-unrolled loops are supported transparently — the -compiler inserts a buffer array, stores per-iteration values into it, and reads -the final value back after the loop: +The general-purpose runtime loop. `a` and `b` may be runtime values. The +compiler lowers the loop to a recursive function. ```python sum: Mut = 0 @@ -334,12 +320,50 @@ for i in range(1, 11): assert sum == 55 ``` -Loop limitations (current): +Mutable variables carried across iterations are supported transparently. + +*Under the hood: the compiler inserts a buffer array, stores the per-iteration value into it, and reads the final value back after the loop.* + +Restrictions: No `return` inside the body + +*Under the hood: because the loop is lowered to a recursive function.* + +#### `unroll(a, b)`: compile-time unrolling + +The loop is expanded at compile time: the body is duplicated once per iteration +with `i` substituted by its concrete value. Both `a` and `b` must be +compile-time constants. + +```python +for i in unroll(0, 4): + buffer[i] = i * i +``` + +#### `parallel_range(a, b)` — parallel runtime loop -- No `break` or `continue` (these forms are not in the grammar). -- No `return` inside the body of a non-unrolled loop (because such loops are - lowered to recursive functions). The compiler emits "Function return inside - a loop is not currently supported" if you try. +**`parallel_range` compiles to exactly the same bytecode as `range`.** It +differs only in the runner's scheduling policy: iterations are dispatched +concurrently across worker threads rather than evaluated in sequence. The only advantage is faster witness generation. +Iteration `a` is executed first, in isolation, to determine the per-iteration +memory footprint; the remaining iterations are then evaluated in parallel +without inter-iteration synchronization. + +```python +for i in parallel_range(0, n): + process(i, inputs[i], outputs[i]) +``` + +Because there is no synchronization, the loop body must be +iteration-independent: + +- No `Mut` variables carried across iterations (each iteration writes only to + its own call frame and to addresses disjoint from every other iteration). +- Identical memory footprint per iteration. +- Identical hint consumption per iteration (witness hints, XMSS-specific + decomposition hints, Merkle hints, etc.). + +These constraints are **not** checked at compile time. Violating them produces +silently wrong proofs. ### Statements without effect are rejected @@ -383,21 +407,6 @@ saturating_sub(a, b) # max(0, a - b) len(array) # length of a constant array (any depth) ``` -### Reserved names - -These identifiers cannot be redefined as user functions, because the parser or -compiler intercepts calls to them: - -- Built-ins: `print`, `Array`, `len`, `hint_witness` -- Compile-time math: `log2_ceil`, `next_multiple_of`, `saturating_sub`, - `div_ceil`, `div_floor` -- Loop / control-flow forms: `range`, `parallel_range`, `match_range` -- Custom hints: every `hint_*` name (see [Hints] below) -- Poseidon16 precompiles: `poseidon16_compress`, `poseidon16_compress_half`, - `poseidon16_compress_hardcoded_left`, - `poseidon16_compress_half_hardcoded_left`, `poseidon16_permute` -- Extension-op precompiles: `add_ee`, `add_be`, `dot_product_ee`, - `dot_product_be`, `poly_eq_ee`, `poly_eq_be` ### `_` (the discard target) @@ -411,7 +420,14 @@ _ = compute() # discard a single return value ## Assertions -Snark constraints (enforced by the proof): +The zkDSL provides two assertion forms with very different semantics: + +| Form | Enforced by | Use for | +| -------------- | ------------------- | -------------------------------------------------------- | +| `assert` | The proof system | Invariants the verifier must check | +| `debug_assert` | The prover only (at witness generation) | Sanity checks; preconditions the verifier does not need to re-check | + +### `assert`: proof-enforced constraint ```python assert x == y @@ -420,31 +436,39 @@ assert x < y assert x <= y ``` -Unconditional failure (compiles to a Panic): +The four supported comparison operators are `==`, `!=`, `<`, `<=` (no `>` or +`>=`; flip the operands). + + +### Range checks: `assert a < b` and `assert a <= b` + +Inequalities are proved via DEREF, which relies on the soundness of memory +accesses into a read-only memory of size `<= 2^MIN_LOG_MEMORY_SIZE` (currently +`MIN_LOG_MEMORY_SIZE = 16`). For the constraint to be sound, the right-hand +side `b` must therefore satisfy `b <= 2^16`. To compare against a larger +constant, first decompose the value into bits and assert the bound piecewise. + +#### Explicit panic + +`assert False` is the unconditional failure form. It compiles to a Panic and +accepts an optional message: ```python assert False assert False, "human-readable message" ``` -Runtime-only checks; not part of the constraint system. Same four comparison -operators (`==`, `!=`, `<`, `<=`): +### `debug_assert`: sanity checks at witness generation ```python debug_assert(x < y) ``` -`debug_assert` is for invariants the prover must respect but that the verifier -doesn't need to re-check — typically range-validity preconditions for `match` / -`match_range` dispatches. - -### Range checks: `assert a < b` and `assert a <= b` - -A signed inequality is implemented using DEREF (memory-access soundness on a -read-only memory of size `<= 2^MIN_LOG_MEMORY_SIZE`). The compiler automatically -emits the necessary helper hints, but **the right-hand side `b` must fit in -`2^16` (MIN_LOG_MEMORY_SIZE bits)** for the constraint to be sound. Compare -against larger constants by decomposing the value into bits first. +`debug_assert` accepts the same four comparison operators. It is evaluated by +the prover at trace-generation time and does **not** emit any constraint, so +the verifier never re-checks it. Use it for invariants the prover is expected +to maintain but that the verifier can take for granted — typically the +range-validity preconditions of `match` / `match_range` dispatches. ## Comments From fa65f9dac8a44ec106158617415ece0ffd383f3a Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 28 May 2026 02:01:30 +0400 Subject: [PATCH 09/13] wip --- crates/lean_compiler/zkDSL.md | 115 +++++++++++++++++++++++++++------- 1 file changed, 94 insertions(+), 21 deletions(-) diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 63ab82f1..25019172 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -442,11 +442,15 @@ The four supported comparison operators are `==`, `!=`, `<`, `<=` (no `>` or ### Range checks: `assert a < b` and `assert a <= b` -Inequalities are proved via DEREF, which relies on the soundness of memory -accesses into a read-only memory of size `<= 2^MIN_LOG_MEMORY_SIZE` (currently -`MIN_LOG_MEMORY_SIZE = 16`). For the constraint to be sound, the right-hand -side `b` must therefore satisfy `b <= 2^16`. To compare against a larger -constant, first decompose the value into bits and assert the bound piecewise. +**The program must ensure `b <= 2^16`.** The compiler does not check this +(`b` may be a runtime value). Violating the bound is a critical soundness +vulnerability. + +*Under the hood: the compiler proves `a < b` by emitting two DEREF instructions, +which check that `a` and `b - 1 - a` are both valid memory addresses. An +address is valid iff it is `< M`, where `M` is the memory size. To stay sound +for every admissible memory size, the construction relies on the smallest one, +`M_min = 2^16` (= `2^MIN_LOG_MEMORY_SIZE`), giving the bound `b <= 2^16`.* #### Explicit panic @@ -480,14 +484,11 @@ block comment """ ``` -Both forms are stripped before the grammar runs. There is no docstring concept — -a `"""..."""` block is purely a comment. - ## Line continuation As in Python: -- **Implicit** continuation inside `(...)`, `[...]`, or `{...}`. +- **Implicit** continuation inside `(...)` or `[...]`. - **Explicit** continuation with `\` at end of line. ```python @@ -506,23 +507,95 @@ believe anything about it. There are two flavours of hint: ### `hint_witness("name", ptr)` -Pulls the next chunk of witness data registered under the string label `name`, -and writes it into the buffer at `ptr`. Witness data lives in the -`ExecutionWitness::hints: HashMap>>` map (each name has a -list of byte-buffers, consumed in order). The guest is responsible for -allocating `ptr` large enough; the length is implicit and trusted. +Writes the next buffer queued under the label `name` into memory starting at +`ptr`. The guest must allocate `ptr` large enough to hold the data; no length +is checked at runtime. + +The buffer comes from the host (Rust side), not from the guest. Before +running the program, the host fills `ExecutionWitness::hints` with one queue +of buffers per label; each `hint_witness("name", ptr)` call pops the next +buffer from `hints["name"]`. + +`ExecutionWitness` lives in `crates/lean_vm/src/execution/runner.rs`: + +```rust +pub struct ExecutionWitness { + ... + pub hints: HashMap>>, + ... +} +``` + +Each map key is a label; the value is the **ordered list of buffers** the +guest will consume under that label. The N-th `hint_witness("name", ptr)` call +the guest executes pops the N-th `Vec` from `hints["name"]` and writes it +at `ptr`. + +For example, the guest below issues three `hint_witness` calls — two against +`"input_data"` and one against `"other_stuff"`: ```python -data_buf = Array(64) -hint_witness("input_data", data_buf) -n = data_buf[0] +data_buf_1 = Array(64) +hint_witness("input_data", data_buf_1) +n = data_buf_1[0] + +data_buf_2 = Array(64) +hint_witness("input_data", data_buf_2) +m = data_buf_2[3] +assert n == m + 8 + +data_buf_3 = Array(10) +hint_witness("other_stuff", data_buf_3) +... +``` + +The matching Rust side must register two buffers under `"input_data"` (in +the order the guest will read them) and one under `"other_stuff"`: + +```rust +let mut hints: HashMap>> = HashMap::new(); +hints.insert( + "input_data".to_string(), + vec![ + first_input_buffer, // consumed by the first hint_witness("input_data", ...) + second_input_buffer, // consumed by the second hint_witness("input_data", ...) + ], +); +hints.insert("other_stuff".to_string(), vec![other_buffer]); + +let witness = ExecutionWitness { hints, ..Default::default() }; ``` +A missing label, or running out of buffers under a label, is a runner-side +panic: each call requires its corresponding entry to exist. + ### Custom hints -Each hint has a fixed argument count and writes its result(s) into caller-provided -buffers. The hint *suggests* a value — your program must add the constraints -that bind the value to its specification. +Custom hints are a fixed set of built-in calls the prover uses to compute +values that would be expensive to derive in-circuit — bit +decompositions, comparisons, integer division, etc. Each is invoked like an +ordinary function and writes its result into a caller-supplied memory +location. + +Like every hint, **the result is unconstrained**: the verifier checks +nothing about the hinted value. The guest program must add its own +constraints binding the hinted bits / quotient / remainder / boolean to the +original input — otherwise a malicious prover can substitute any value. The +typical pattern is "hint, then assert the relationship": + +```python +# hint the bits... +bits = Array(8) +hint_decompose_bits(value, bits, 8) +# ...then constrain them to actually equal `value` +acc: Mut = 0 +for i in unroll(0, 8): + assert bits[i] * (bits[i] - 1) == 0 # boolean + acc = acc * 2 + bits[i] +assert acc == value +``` + +The full list: | Hint | Arguments | Effect | | --------------------------------- | ------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------- | @@ -537,7 +610,7 @@ that bind the value to its specification. ### Poseidon16 family -leanVM has one Poseidon2 width-16 precompile table; the zkDSL exposes five +leanVM has one Poseidon width-16 precompile table; the zkDSL exposes five specializations that all hit the same table. ```python From 00136c781a304ffa868a77f246383a776f786b0c Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 28 May 2026 02:40:16 +0400 Subject: [PATCH 10/13] wip --- crates/lean_compiler/zkDSL.md | 118 ++++++++++++++-------------------- 1 file changed, 50 insertions(+), 68 deletions(-) diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 25019172..12d3007b 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -115,10 +115,10 @@ agree. A function that "returns nothing" uses a bare `return`. ### Parameter types -| Syntax | Meaning | -| ---------- | ------------------------ | +| Syntax | Meaning | +| ---------- | ------------------------------------ | | `x` | normal (immutable) runtime parameter | -| `x: Const` | compile-time parameter | +| `x: Const` | compile-time parameter | ```python def repeat(n: Const): # Const enables unroll(0, n) @@ -422,9 +422,9 @@ _ = compute() # discard a single return value The zkDSL provides two assertion forms with very different semantics: -| Form | Enforced by | Use for | -| -------------- | ------------------- | -------------------------------------------------------- | -| `assert` | The proof system | Invariants the verifier must check | +| Form | Enforced by | Use for | +| -------------- | --------------------------------------- | ------------------------------------------------------------------- | +| `assert` | The proof system | Invariants the verifier must check | | `debug_assert` | The prover only (at witness generation) | Sanity checks; preconditions the verifier does not need to re-check | ### `assert`: proof-enforced constraint @@ -608,80 +608,61 @@ The full list: ## Precompiles -### Poseidon16 family - -leanVM has one Poseidon width-16 precompile table; the zkDSL exposes five -specializations that all hit the same table. - -```python -poseidon16_compress(left, right, output) -``` - -Standard compression: writes the 8-cell compressed output of `Poseidon2(left || right) + left` -to `m[output..output+8]`. `left` and `right` are 8-cell buffers; `output` is an -8-cell destination. - -```python -poseidon16_compress_half(left, right, output) -``` - -Same as `poseidon16_compress`, but only the first 4 output cells are -constrained — `output[4..8]` is unconstrained. Useful when the consumer only -cares about half of the digest. +Precompiles are special instructions in the leanVM ISA, alongside the four +basic ones (ADD, MUL, DEREF, JUMP). The zkDSL exposes them as built-in +functions. There are two families: Poseidon hashing and extension-field +operations. -```python -poseidon16_compress_hardcoded_left(left, right, output, offset) -``` +### Poseidon16 family -Like `poseidon16_compress`, except the first 4 cells of the *left* input are -read from the **compile-time** address `offset` instead of `m[left..left+4]`. -The remaining 4 cells of the left input still come from `m[left..left+4]`. Used -e.g. for XMSS Merkle hashing where one half of the input is the public parameter -(stored at a fixed address). +The variants are as follows: -```python -poseidon16_compress_half_hardcoded_left(left, right, output, offset) -``` +- **compress vs. permute** — `compress` applies the feed-forward addition + (`Poseidon(L || R) + L`); `permute` is the raw 16-cell permutation. +- **full vs. half output** — `_half` constrains only the first 4 output cells + (the rest are unconstrained); useful when the consumer only cares about + half a digest. +- **hardcoded-left** — `_hardcoded_left` reads the first 4 cells of the left + input from a compile-time address instead of from `m[L..L+4]`; the last 4 + cells of the left input still come from memory. -Composition of `_compress_half` and `_compress_hardcoded_left`: hardcoded left -prefix at `offset`, only the first 4 output cells constrained. +Common arguments: `L`, `R` are 8-cell input buffers; `O` is the output +buffer; `off` (where present) is a compile-time address. -```python -poseidon16_permute(left, right, output) -``` +| Function | Cells written to `O` | Notes | +| ------------------------------------------------------- | -------------------- | ----------------------------------------- | +| `poseidon16_compress(L, R, O)` | `O[0..8]` | `Poseidon(L \|\| R) + L` | +| `poseidon16_compress_half(L, R, O)` | `O[0..4]` | `O[4..8]` is unconstrained | +| `poseidon16_compress_hardcoded_left(L, R, O, off)` | `O[0..8]` | left = `m[off..off+4] \|\| m[L..L+4]` | +| `poseidon16_compress_half_hardcoded_left(L, R, O, off)` | `O[0..4]` | half-output + hardcoded-left composition | +| `poseidon16_permute(L, R, O)` | `O[0..16]` | raw Poseidon permutation, no feed-forward | -Raw Poseidon2 permutation (no feed-forward addition). Writes the full 16 output -cells to `m[output..output+16]` in natural order. Used for the Fiat-Shamir -sponge. +### Extension-field operations -### Extension field operations +Six built-in functions, each combines a fixed element-wise operation with an +accumulation over `length` element pairs: -Six built-in functions all route through one `extension_op` precompile table. -Each combines a fixed element-wise operation with an accumulation over `length` -element pairs: +| Function | Element-wise `e_i` | Result | +| ----------------------------------- | -------------------------------- | ----------- | +| `add_ee` / `add_be` | `a_i + b_i` | `sum(e_i)` | +| `dot_product_ee` / `dot_product_be` | `a_i * b_i` | `sum(e_i)` | +| `poly_eq_ee` / `poly_eq_be` | `a_i * b_i + (1 - a_i)(1 - b_i)` | `prod(e_i)` | -| Function | Element-wise | Accumulation | -| ----------------------------------- | --------------------------------- | -------------------- | -| `add_ee` / `add_be` | `e_i = a_i + b_i` | `result = sum(e_i)` | -| `dot_product_ee` / `dot_product_be` | `e_i = a_i * b_i` | `result = sum(e_i)` | -| `poly_eq_ee` / `poly_eq_be` | `e_i = a_i*b_i + (1-a_i)*(1-b_i)` | `result = prod(e_i)` | +Signature (the same for all six): ```python func(ptr_a, ptr_b, ptr_result) # length defaults to 1 -func(ptr_a, ptr_b, ptr_result, length) # explicit length (N element pairs) +func(ptr_a, ptr_b, ptr_result, length) # length must be a compile-time constant ``` -**Operand suffix:** - -- `_ee`: both `ptr_a` and `ptr_b` point to *extension* field elements (5 base-field - cells each, stride `DIM = 5`). -- `_be`: `ptr_a` points to *base* field elements (stride 1); `ptr_b` points to - *extension* field elements (stride `DIM = 5`). - -`ptr_result` always points to a single extension-field element (5 cells). +- `ptr_result` points to a single extension-field element (5 cells). +- The `_ee` / `_be` suffix selects the layout of the operands: + - `_ee`: both `ptr_a` and `ptr_b` point to extension elements (stride + `DIM = 5`). + - `_be`: `ptr_a` points to base-field elements (stride 1); `ptr_b` points + to extension elements (stride `DIM = 5`). -**`length` must be a compile-time constant.** For a runtime length, dispatch -through `match_range`: +For a runtime length, dispatch through `match_range`: ```python def dot_product_ee_dynamic(a, b, res, n): @@ -696,16 +677,17 @@ Common idioms: dot_product_ee(x, y, z) # z = x * y # Copy an extension element by multiplying by [1, 0, 0, 0, 0] -# ONE_EF_PTR is a guest-program constant that you materialize in the preamble +# (ONE_EF_PTR is a constant materialized in the preamble) dot_product_ee(src, ONE_EF_PTR, dst) # Dot products dot_product_ee(coeffs, basis, result, N) dot_product_be(alpha_powers, coeffs, result, N) -# Extension addition / subtraction +# Extension addition / subtraction (the second form uses write-once memory +# to turn an addition into the subtraction constraint b + c = a) add_ee(a, b, c) # c = a + b -add_ee(b, c, a) # c = a - b, expressed as a constraint (b + c = a) +add_ee(b, c, a) # c = a - b # Equality polynomial: eq(a, b) = a*b + (1-a)*(1-b) poly_eq_ee(a, b, eq_result) From 1dd713761ed50e9ef5fb9fc178dd069472038d4c Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 28 May 2026 03:17:30 +0400 Subject: [PATCH 11/13] hardcode public memory of 8 everywhere --- crates/lean_compiler/src/lib.rs | 8 +- crates/lean_compiler/tests/test_compiler.rs | 85 +++++++++++-------- .../tests/test_data/program_15.py | 2 +- .../tests/test_data/program_179.py | 2 +- .../lean_compiler/tests/test_performance.rs | 9 +- crates/lean_compiler/zkDSL.md | 74 +++++++--------- crates/lean_prover/src/lib.rs | 4 - crates/lean_prover/src/prove_execution.rs | 13 +-- crates/lean_prover/src/test_zkvm.rs | 8 +- crates/lean_prover/src/trace_gen.rs | 2 - crates/lean_prover/src/verify_execution.rs | 12 +-- crates/lean_vm/src/diagnostics/exec_result.rs | 1 - crates/lean_vm/src/execution/runner.rs | 20 ++--- .../rec_aggregation/src/type_1_aggregation.rs | 2 +- .../rec_aggregation/src/type_2_aggregation.rs | 2 +- 15 files changed, 115 insertions(+), 129 deletions(-) diff --git a/crates/lean_compiler/src/lib.rs b/crates/lean_compiler/src/lib.rs index 590d58de..d99d1c2e 100644 --- a/crates/lean_compiler/src/lib.rs +++ b/crates/lean_compiler/src/lib.rs @@ -150,7 +150,11 @@ pub fn compile_program(input: &ProgramSource) -> Bytecode { try_compile_program(input).unwrap() } -pub fn try_compile_and_run(input: &ProgramSource, public_input: &[F], profiler: bool) -> Result { +pub fn try_compile_and_run( + input: &ProgramSource, + public_input: &[F; PUBLIC_INPUT_LEN], + profiler: bool, +) -> Result { let bytecode = try_compile_program(input)?; let witness = ExecutionWitness::default(); let result = try_execute_bytecode(&bytecode, public_input, &witness, profiler)?; @@ -158,7 +162,7 @@ pub fn try_compile_and_run(input: &ProgramSource, public_input: &[F], profiler: Ok(result.metadata.display()) } -pub fn compile_and_run(input: &ProgramSource, public_input: &[F], profiler: bool) { +pub fn compile_and_run(input: &ProgramSource, public_input: &[F; PUBLIC_INPUT_LEN], profiler: bool) { let summary = try_compile_and_run(input, public_input, profiler).unwrap(); println!("{summary}"); } diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 04b3f70e..cc8ac473 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -4,27 +4,6 @@ use backend::{BasedVectorSpace, PrimeCharacteristicRing}; use lean_compiler::*; use lean_vm::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; -use utils::poseidon16_compress; - -#[test] -fn test_poseidon() { - let program = r#" -def main(): - a = 0 - b = a + 8 - c = Array(8) - poseidon16_compress(a, b, c) - - for i in range(0, 8): - cc = c[i] - print(cc) - return - "#; - let public_input: [F; 16] = (0..16).map(F::new).collect::>().try_into().unwrap(); - compile_and_run(&ProgramSource::Raw(program.to_string()), &public_input, false); - - let _ = dbg!(poseidon16_compress(public_input)); -} #[test] fn test_div_extension_field() { @@ -32,13 +11,16 @@ fn test_div_extension_field() { DIM = 5 def main(): - n = 0 - d = n + DIM - q = n + 2 * DIM + nd = Array(2 * DIM) + hint_witness("nd", nd) + n = nd + d = nd + DIM + expected_q = Array(DIM) + hint_witness("q", expected_q) computed_q_1 = div_ext_1(n, d) computed_q_2 = div_ext_2(n, d) - assert_eq_ext(computed_q_2, q) - assert_eq_ext(computed_q_1, q) + assert_eq_ext(computed_q_2, expected_q) + assert_eq_ext(computed_q_1, expected_q) return def assert_eq_ext(x, y): @@ -61,12 +43,19 @@ def div_ext_2(n, d): let n: EF = rng.random(); let d: EF = rng.random(); let q = n / d; - let mut public_input = vec![]; - public_input.extend(n.as_basis_coefficients_slice()); - public_input.extend(d.as_basis_coefficients_slice()); - public_input.extend(q.as_basis_coefficients_slice()); - public_input.resize(16, F::ZERO); - compile_and_run(&ProgramSource::Raw(program.to_string()), &public_input, false); + let mut nd_buf: Vec = Vec::new(); + nd_buf.extend(n.as_basis_coefficients_slice()); + nd_buf.extend(d.as_basis_coefficients_slice()); + let q_buf: Vec = q.as_basis_coefficients_slice().to_vec(); + let mut hints = std::collections::HashMap::new(); + hints.insert("nd".to_string(), vec![nd_buf]); + hints.insert("q".to_string(), vec![q_buf]); + let witness = ExecutionWitness { + hints, + ..ExecutionWitness::default() + }; + let bytecode = compile_program(&ProgramSource::Raw(program.to_string())); + try_execute_bytecode(&bytecode, &[F::ZERO; PUBLIC_INPUT_LEN], &witness, false).unwrap(); } fn test_data_dir() -> String { @@ -134,7 +123,7 @@ fn test_all_programs() { Ok(b) => b, Err(err) => panic!("Program {} failed to compile: {:?}", path, err), }; - if let Err(err) = try_execute_bytecode(&bytecode, &[], &witness, false) { + if let Err(err) = try_execute_bytecode(&bytecode, &[F::ZERO; PUBLIC_INPUT_LEN], &witness, false) { panic!("Program {} failed with error: {:?}", path, err); } } @@ -176,7 +165,13 @@ def func(a, b): return "#; let bytecode = compile_program(&ProgramSource::Raw(program.to_string())); - let n_cycles = execute_bytecode(&bytecode, &[], &ExecutionWitness::default(), false).n_cycles(); + let n_cycles = execute_bytecode( + &bytecode, + &[F::ZERO; PUBLIC_INPUT_LEN], + &ExecutionWitness::default(), + false, + ) + .n_cycles(); assert!(n_cycles < 1100); } @@ -205,10 +200,20 @@ def factorial(n): let compiled_parallel = compile_program(&ProgramSource::Raw(program.replace("loop", "parallel_range"))); let time_sequential = Instant::now(); - let exec_seq = execute_bytecode(&compiled_sequencial, &[], &ExecutionWitness::default(), false); + let exec_seq = execute_bytecode( + &compiled_sequencial, + &[F::ZERO; PUBLIC_INPUT_LEN], + &ExecutionWitness::default(), + false, + ); let duration_sequential = time_sequential.elapsed(); let time_parallel = Instant::now(); - let exec_par = execute_bytecode(&compiled_parallel, &[], &ExecutionWitness::default(), false); + let exec_par = execute_bytecode( + &compiled_parallel, + &[F::ZERO; PUBLIC_INPUT_LEN], + &ExecutionWitness::default(), + false, + ); let duration_parallel = time_parallel.elapsed(); assert_eq!(exec_seq.metadata.stdout, exec_par.metadata.stdout); @@ -249,7 +254,13 @@ fn test_soundness_suite() { ("soundness_5", &[3, 4, 7, 19, 49, 28, 1, 3], &[(0, 4), (1, 5), (2, 8), (3, 20), (4, 50), (5, 29), (6, 0), (6, 2), (7, 4)]), ]; - let to_input = |v: &[u32]| v.iter().copied().map(F::new).collect::>(); + let to_input = |v: &[u32]| -> [F; PUBLIC_INPUT_LEN] { + let mut out = [F::ZERO; PUBLIC_INPUT_LEN]; + for (slot, &x) in out.iter_mut().zip(v) { + *slot = F::new(x); + } + out + }; for &(name, valid, perturbations) in cases { let path = format!("{}/{}.py", test_data_dir(), name); diff --git a/crates/lean_compiler/tests/test_data/program_15.py b/crates/lean_compiler/tests/test_data/program_15.py index 6ea149c2..55433c26 100644 --- a/crates/lean_compiler/tests/test_data/program_15.py +++ b/crates/lean_compiler/tests/test_data/program_15.py @@ -1,6 +1,6 @@ from snark_lib import * -ONE_EF_PTR = 1 # right after the (empty-public-input) zero-padded cell at memory[0] +ONE_EF_PTR = 8 # right after the 8-cell public input region def main(): diff --git a/crates/lean_compiler/tests/test_data/program_179.py b/crates/lean_compiler/tests/test_data/program_179.py index 84d1f0b0..521d0af6 100644 --- a/crates/lean_compiler/tests/test_data/program_179.py +++ b/crates/lean_compiler/tests/test_data/program_179.py @@ -1,6 +1,6 @@ from snark_lib import * -ONE_EF_PTR = 1 # right after the (empty-public-input) zero-padded cell at memory[0] +ONE_EF_PTR = 8 # right after the 8-cell public input region def main(): diff --git a/crates/lean_compiler/tests/test_performance.rs b/crates/lean_compiler/tests/test_performance.rs index 723b9bd4..893883e2 100644 --- a/crates/lean_compiler/tests/test_performance.rs +++ b/crates/lean_compiler/tests/test_performance.rs @@ -1,3 +1,4 @@ +use backend::PrimeCharacteristicRing; use lean_compiler::*; use lean_vm::*; @@ -9,7 +10,13 @@ fn test_data_dir() -> String { /// Helper to get the number of cycles for a program file fn get_cycle_count(path: &str) -> usize { let bytecode = compile_program(&ProgramSource::Filepath(path.to_string())); - let result = try_execute_bytecode(&bytecode, &[], &ExecutionWitness::default(), false).unwrap(); + let result = try_execute_bytecode( + &bytecode, + &[F::ZERO; PUBLIC_INPUT_LEN], + &ExecutionWitness::default(), + false, + ) + .unwrap(); result.pcs.len() } diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 12d3007b..0dc57386 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -639,30 +639,25 @@ buffer; `off` (where present) is a compile-time address. ### Extension-field operations -Six built-in functions, each combines a fixed element-wise operation with an -accumulation over `length` element pairs: - -| Function | Element-wise `e_i` | Result | -| ----------------------------------- | -------------------------------- | ----------- | -| `add_ee` / `add_be` | `a_i + b_i` | `sum(e_i)` | -| `dot_product_ee` / `dot_product_be` | `a_i * b_i` | `sum(e_i)` | -| `poly_eq_ee` / `poly_eq_be` | `a_i * b_i + (1 - a_i)(1 - b_i)` | `prod(e_i)` | - -Signature (the same for all six): +Six built-in functions, each reading two length-`n` vectors `a` and `b` and +writing one extension-field element to `result`. `n` defaults to `1` and must +be a compile-time constant when given. ```python -func(ptr_a, ptr_b, ptr_result) # length defaults to 1 -func(ptr_a, ptr_b, ptr_result, length) # length must be a compile-time constant +add_ee(a, b, result, n=1) # result = sum_i (a[i] + b[i]) +dot_product_ee(a, b, result, n=1) # result = sum_i a[i] * b[i] +poly_eq_ee(a, b, result, n=1) # result = prod_i (a[i]*b[i] + (1-a[i])*(1-b[i])) ``` -- `ptr_result` points to a single extension-field element (5 cells). -- The `_ee` / `_be` suffix selects the layout of the operands: - - `_ee`: both `ptr_a` and `ptr_b` point to extension elements (stride - `DIM = 5`). - - `_be`: `ptr_a` points to base-field elements (stride 1); `ptr_b` points - to extension elements (stride `DIM = 5`). +The `_ee` suffix means both `a` and `b` are vectors of *extension*-field +elements (each occupying `DIM = 5` consecutive cells). The `_be` variants +(`add_be`, `dot_product_be`, `poly_eq_be`) are identical except `a` is a +vector of *base*-field elements (1 cell each); `b` and `result` are still +extension-field. + +`result` always points to a single extension-field element (5 cells). -For a runtime length, dispatch through `match_range`: +For a runtime `n`, dispatch through `match_range`: ```python def dot_product_ee_dynamic(a, b, res, n): @@ -673,25 +668,16 @@ def dot_product_ee_dynamic(a, b, res, n): Common idioms: ```python -# Multiply two extension elements (length defaults to 1) -dot_product_ee(x, y, z) # z = x * y +# Multiply two extension elements (n defaults to 1) +dot_product_ee(x, y, z) # z = x * y -# Copy an extension element by multiplying by [1, 0, 0, 0, 0] +# Copy an extension element by multiplying by 1 # (ONE_EF_PTR is a constant materialized in the preamble) dot_product_ee(src, ONE_EF_PTR, dst) -# Dot products -dot_product_ee(coeffs, basis, result, N) -dot_product_be(alpha_powers, coeffs, result, N) - -# Extension addition / subtraction (the second form uses write-once memory -# to turn an addition into the subtraction constraint b + c = a) -add_ee(a, b, c) # c = a + b -add_ee(b, c, a) # c = a - b - -# Equality polynomial: eq(a, b) = a*b + (1-a)*(1-b) -poly_eq_ee(a, b, eq_result) -poly_eq_ee(a, b, result, n) # multi-point eq: prod_i eq(a[i], b[i]) +# Extension subtraction: write-once memory turns "c = a + b" into +# the constraint "b + c = a", i.e. c = a - b +add_ee(b, c, a) # c = a - b ``` ## Debugging @@ -710,19 +696,19 @@ the print hint in `lean_vm/src/isa/hint.rs (Self::Print)` to `eprint!` directly. The runner lays out memory as ```python -[ public_input (zero-padded) | preamble_memory | runtime ] +[ public_input (PUBLIC_INPUT_LEN cells) | preamble_memory | runtime ] ``` -- `public_input` lives at `memory[0..public_input.len()]` and is zero-padded to - the next power of two by the runner, so it can be evaluated as a multilinear - polynomial. +- `public_input` is fixed at `PUBLIC_INPUT_LEN = DIGEST_LEN = 8` cells (a hash + digest), occupying `memory[0..8]`. - `preamble_memory` is a region of `witness.preamble_memory_len` cells the - runner reserves but does **not** initialize. The guest program is expected - to fill this region with whatever helper constants it relies on (e.g. a - vector of zeros for `dot_product_ee`-as-copy, an extension-field one for - multiply-by-one tricks, a vector of ones for batched accumulations, …) at - the start of `main`. The names and offsets of these constants are not part - of the VM contract — each program defines its own. See + runner reserves immediately after the public input but does **not** + initialize. The guest program is expected to fill this region with whatever + helper constants it relies on (e.g. a vector of zeros for + `dot_product_ee`-as-copy, an extension-field one for multiply-by-one tricks, + a vector of ones for batched accumulations, …) at the start of `main`. The + names and offsets of these constants are not part of the VM contract — each + program defines its own. See `crates/rec_aggregation/zkdsl_implem/utils.py (build_preamble_memory)` for a concrete example. - The runtime region holds the program's stack frames, working memory, and any diff --git a/crates/lean_prover/src/lib.rs b/crates/lean_prover/src/lib.rs index d5daff25..82d61f39 100644 --- a/crates/lean_prover/src/lib.rs +++ b/crates/lean_prover/src/lib.rs @@ -67,7 +67,6 @@ pub enum ProverError { Runner(RunnerError), UnknownMessage, MultipleMessages, - InvalidPublicInputSize { expected: usize, actual: usize }, } impl From for ProverError { @@ -89,9 +88,6 @@ impl Display for ProverError { Self::Runner(e) => write!(f, "{}", e), Self::UnknownMessage => write!(f, "Unknown message, not part of the type2"), Self::MultipleMessages => write!(f, "Multiple common messages in the type2"), - Self::InvalidPublicInputSize { expected, actual } => { - write!(f, "Invalid public input size: expected {}, actual {}", expected, actual) - } } } } diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index 1f985908..fb8fe1e3 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -19,7 +19,7 @@ pub struct ExecutionProof { pub fn prove_execution( bytecode: &Bytecode, - public_input: &[F], + public_input: &[F; PUBLIC_INPUT_LEN], witness: &ExecutionWitness, whir_config: &WhirConfigBuilder, vm_profiler: bool, @@ -27,15 +27,8 @@ pub fn prove_execution( check_rate(whir_config.starting_log_inv_rate) .map_err(|err| panic!("{err}")) .unwrap(); - if public_input.len() != PUBLIC_INPUT_LEN { - return Err(ProverError::InvalidPublicInputSize { - expected: PUBLIC_INPUT_LEN, - actual: public_input.len(), - }); - } let ExecutionTrace { traces, - public_memory_size, mut memory, // padded with zeros to next power of two metadata, } = info_span!("Witness generation").in_scope(|| -> Result<_, ProverError> { @@ -232,8 +225,8 @@ pub fn prove_execution( committed_statements.get_mut(table).unwrap().push(claim); } - let public_memory_random_point = MultilinearPoint(prover_state.sample_vec(log2_strict_usize(public_memory_size))); - let public_memory_eval = (&memory[..public_memory_size]).evaluate(&public_memory_random_point); + let public_memory_random_point = MultilinearPoint(prover_state.sample_vec(log2_strict_usize(PUBLIC_INPUT_LEN))); + let public_memory_eval = (&memory[..PUBLIC_INPUT_LEN]).evaluate(&public_memory_random_point); let previous_statements = vec![ SparseStatement::new( diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index 91b3f76b..767a2756 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -97,7 +97,7 @@ fn all_precompiles_flags(loop_iters: usize) -> CompilationFlags { } } -fn all_precompiles_witness() -> (Vec, ExecutionWitness) { +fn all_precompiles_witness() -> ([F; PUBLIC_INPUT_LEN], ExecutionWitness) { let mut rng = StdRng::seed_from_u64(0); let mut scratch = F::zero_vec(8192); @@ -194,7 +194,7 @@ fn all_precompiles_witness() -> (Vec, ExecutionWitness) { .fold(EF::ONE, |acc, x| acc * x); scratch[1300..][..DIMENSION].copy_from_slice(poly_eq_ee_result.as_basis_coefficients_slice()); - let mut public_input = vec![F::ZERO; PUBLIC_INPUT_LEN]; + let mut public_input = [F::ZERO; PUBLIC_INPUT_LEN]; public_input[..4].copy_from_slice(&hardcoded_prefix); let mut hints = std::collections::HashMap::new(); @@ -326,7 +326,7 @@ def fibonacci_const(a, b, n: Const): ); } -fn test_zk_vm_helper(program_str: &str, public_input: &[F]) { +fn test_zk_vm_helper(program_str: &str, public_input: &[F; PUBLIC_INPUT_LEN]) { test_zk_vm_helper_with_witness( program_str, public_input, @@ -337,7 +337,7 @@ fn test_zk_vm_helper(program_str: &str, public_input: &[F]) { fn test_zk_vm_helper_with_witness( program_str: &str, - public_input: &[F], + public_input: &[F; PUBLIC_INPUT_LEN], witness: ExecutionWitness, flags: CompilationFlags, ) { diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index 45fc3057..5bee4a97 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -6,7 +6,6 @@ use utils::{ToUsize, get_poseidon_16_of_zero, transposed_par_iter_mut}; #[derive(Debug)] pub struct ExecutionTrace { pub traces: BTreeMap, - pub public_memory_size: usize, pub memory: Vec, // of length a multiple of public_memory_size pub metadata: ExecutionMetadata, } @@ -171,7 +170,6 @@ pub fn get_execution_trace( ExecutionTrace { traces, - public_memory_size: execution_result.public_memory_size, memory: memory_padded, metadata: execution_result.metadata, } diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index c1886acd..173d2a8d 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -14,7 +14,7 @@ pub struct ProofVerificationDetails { pub fn verify_execution( bytecode: &Bytecode, - public_input: &[F], + public_input: &[F; PUBLIC_INPUT_LEN], proof: Proof, ) -> Result<(ProofVerificationDetails, RawProof), ProofError> { if bytecode.log_size() > MAX_BYTECODE_LOG_SIZE { @@ -23,9 +23,6 @@ pub fn verify_execution( max_log_size: MAX_BYTECODE_LOG_SIZE, }); } - if public_input.len() != PUBLIC_INPUT_LEN { - return Err(ProofError::InvalidProof); - } let mut verifier_state = VerifierState::::new(proof, get_poseidon16().clone(), fiat_shamir_domain_sep(bytecode))?; verifier_state.observe_scalars(public_input); @@ -58,8 +55,6 @@ pub fn verify_execution( return Err(ProofError::InvalidProof); } - let public_memory = padd_with_zero_to_next_power_of_two(public_input); - if !(MIN_LOG_MEMORY_SIZE..=MAX_LOG_MEMORY_SIZE).contains(&log_memory) { return Err(ProofError::InvalidProof); } @@ -175,9 +170,8 @@ pub fn verify_execution( return Err(ProofError::InvalidProof); } - let public_memory_random_point = - MultilinearPoint(verifier_state.sample_vec(log2_strict_usize(public_memory.len()))); - let public_memory_eval = public_memory.evaluate(&public_memory_random_point); + let public_memory_random_point = MultilinearPoint(verifier_state.sample_vec(log2_strict_usize(public_input.len()))); + let public_memory_eval = public_input.evaluate(&public_memory_random_point); let previous_statements = vec![ SparseStatement::new( diff --git a/crates/lean_vm/src/diagnostics/exec_result.rs b/crates/lean_vm/src/diagnostics/exec_result.rs index dcb1ae0c..2024fa08 100644 --- a/crates/lean_vm/src/diagnostics/exec_result.rs +++ b/crates/lean_vm/src/diagnostics/exec_result.rs @@ -72,7 +72,6 @@ impl ExecutionMetadata { #[derive(Debug)] pub struct ExecutionResult { pub runtime_memory_size: usize, - pub public_memory_size: usize, pub memory: Memory, pub pcs: Vec, pub fps: Vec, diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index 6124e41a..844585e4 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -1,6 +1,6 @@ //! VM execution runner -use crate::core::{DIMENSION, F}; +use crate::core::{DIMENSION, F, PUBLIC_INPUT_LEN}; use crate::diagnostics::{ExecutionMetadata, ExecutionResult, RunnerError}; use crate::execution::memory::MemoryAccess; use crate::execution::{ExecutionHistory, Memory}; @@ -10,7 +10,7 @@ use crate::isa::instruction::{InstructionContext, InstructionCounts}; use crate::{ALL_TABLES, CodeAddress, HintExecutionContext, MemOrConstant, N_TABLES, STARTING_PC, Table, TableTrace}; use backend::*; use std::collections::{BTreeMap, BTreeSet, HashMap}; -use utils::{ToUsize, padd_with_zero_to_next_power_of_two}; +use utils::ToUsize; use super::memory::SegmentMemory; @@ -27,7 +27,7 @@ pub struct ExecutionWitness { pub fn try_execute_bytecode( bytecode: &Bytecode, - public_input: &[F], + public_input: &[F; PUBLIC_INPUT_LEN], witness: &ExecutionWitness, profiling: bool, ) -> Result { @@ -58,7 +58,7 @@ pub fn try_execute_bytecode( pub fn execute_bytecode( bytecode: &Bytecode, - public_input: &[F], + public_input: &[F; PUBLIC_INPUT_LEN], witness: &ExecutionWitness, profiling: bool, ) -> ExecutionResult { @@ -237,7 +237,7 @@ fn resolve_deref_hints(memory: &mut Memory, pending: &[(usize, usize)]) { #[allow(clippy::too_many_arguments)] fn execute_bytecode_helper( bytecode: &Bytecode, - public_input: &[F], + public_input: &[F; PUBLIC_INPUT_LEN], witness: &ExecutionWitness, std_out: &mut String, instruction_history: &mut ExecutionHistory, @@ -248,10 +248,9 @@ fn execute_bytecode_helper( .iter() .map(|(name, entries)| (name.clone(), NamedHintCursor::new(entries))) .collect(); - let public_memory = padd_with_zero_to_next_power_of_two(public_input); - let public_memory_size = public_memory.len(); + let public_memory = public_input.to_vec(); let mut memory = Memory::new(public_memory); - let mut fp = public_memory_size + witness.preamble_memory_len; + let mut fp = PUBLIC_INPUT_LEN + witness.preamble_memory_len; fp = fp.next_multiple_of(DIMENSION); let initial_ap = fp + bytecode.starting_frame_memory; let mut pc = STARTING_PC; @@ -325,7 +324,7 @@ fn execute_bytecode_helper( } else { None }; - let runtime_memory_size = memory.0.len() - public_memory_size - witness.preamble_memory_len; + let runtime_memory_size = memory.0.len() - PUBLIC_INPUT_LEN - witness.preamble_memory_len; let used_memory_cells = memory.0.par_iter().filter(|&&x| x.is_some()).count(); let metadata = ExecutionMetadata { cycles: trace.pcs.len(), @@ -333,7 +332,7 @@ fn execute_bytecode_helper( n_poseidons: trace.tables[&Table::poseidon16()].columns[0].len(), n_extension_ops: trace.tables[&Table::extension_op()].columns[0].len(), bytecode_size: bytecode.code.len(), - public_input_size: public_input.len(), + public_input_size: PUBLIC_INPUT_LEN, runtime_memory: runtime_memory_size, memory_usage_percent: used_memory_cells as f64 / memory.0.len() as f64 * 100.0, stdout: std::mem::take(std_out), @@ -341,7 +340,6 @@ fn execute_bytecode_helper( }; Ok(ExecutionResult { runtime_memory_size: no_vec_runtime_memory, - public_memory_size, memory, pcs: trace.pcs, fps: trace.fps, diff --git a/crates/rec_aggregation/src/type_1_aggregation.rs b/crates/rec_aggregation/src/type_1_aggregation.rs index fa27d019..bdf29d02 100644 --- a/crates/rec_aggregation/src/type_1_aggregation.rs +++ b/crates/rec_aggregation/src/type_1_aggregation.rs @@ -273,7 +273,7 @@ pub(crate) fn aggregate_type_1_with_min_padding( &reduced_claims.final_claim_flat(), bytecode, ); - let public_input = poseidon_compress_slice(&pub_input_data).to_vec(); + let public_input = poseidon_compress_slice(&pub_input_data); let mut claimed: HashSet = HashSet::new(); let mut dup_pub_keys: Vec = Vec::new(); diff --git a/crates/rec_aggregation/src/type_2_aggregation.rs b/crates/rec_aggregation/src/type_2_aggregation.rs index fbb3449c..5910c863 100644 --- a/crates/rec_aggregation/src/type_2_aggregation.rs +++ b/crates/rec_aggregation/src/type_2_aggregation.rs @@ -110,7 +110,7 @@ pub fn merge_many_type_1( let digests: Vec<[F; DIGEST_LEN]> = verified_children.iter().map(|v| v.input_data_hash).collect(); let pub_input_data = build_type2_input_data(&digests, &reduced_claims.final_claim_flat()); - let public_input_digest = poseidon_compress_slice(&pub_input_data).to_vec(); + let public_input_digest = poseidon_compress_slice(&pub_input_data); let bytecode_value_hint_blobs: Vec> = verified_children .iter() From 587ef3e2ece25f584f774c97055aa052f6058fa6 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 28 May 2026 03:35:52 +0400 Subject: [PATCH 12/13] w --- README.md | 1 + crates/lean_compiler/zkDSL.md | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d8a40144..cce64ba4 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ Minimal hash-based zkVM, targeting recursion and aggregation of hash-based signa

Documentation + zkDSL reference Python verifier

diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 0dc57386..e95e9e10 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -1,6 +1,6 @@ # zkDSL Language Reference -The zkDSL is a Python-syntax language that compiles to leanVM bytecode (4 basic instructions and 2 special ones (precompile): poseidon / extension operations). +The zkDSL is a Python-syntax language that compiles to leanVM bytecode (4 basic instructions and 2 special ones (precompile): poseidon / extension operations). For the underlying VM, and proving system, see [`minimal_zkVM.pdf`](../../minimal_zkVM.pdf). Source files use the `.py` extension. They are **not** currently runnable as real Python, but the syntax is kept Python-compatible so that one day they From d815925724df45791e62413edc2726c53b3553f5 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 28 May 2026 03:54:49 +0400 Subject: [PATCH 13/13] wip --- crates/lean_compiler/zkDSL.md | 92 ++++++++++------------------------- 1 file changed, 26 insertions(+), 66 deletions(-) diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index e95e9e10..08d67c37 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -6,6 +6,20 @@ Source files use the `.py` extension. They are **not** currently runnable as real Python, but the syntax is kept Python-compatible so that one day they could be (TODO). +## Dev experience + +To recycle python tooling/linting on zkDSL files (which import [`snark_lib`](snark_lib.py)), point your editor at the compiler crate. With VSCode (for instance in `leanMultisig/.vscode/settings.json`): + +```json +{ + "python.analysis.extraPaths": [ + "./crates/lean_compiler" + ] +} +``` + +## Entrypoint + Programs are organized as one or more `.py` files. The toplevel of each file is a sequence of: @@ -707,66 +721,28 @@ The runner lays out memory as helper constants it relies on (e.g. a vector of zeros for `dot_product_ee`-as-copy, an extension-field one for multiply-by-one tricks, a vector of ones for batched accumulations, …) at the start of `main`. The - names and offsets of these constants are not part of the VM contract — each - program defines its own. See + names and offsets of these constants are not enshrined within leanVM. See `crates/rec_aggregation/zkdsl_implem/utils.py (build_preamble_memory)` for a concrete example. - The runtime region holds the program's stack frames, working memory, and any prover-supplied witness data, all governed by the write-once rule. -## Tips and gotchas +## Tips -1. Prefer `unroll` over `range` for small, fixed-size loops — no buffer - bookkeeping, no recursive-function overhead. +1. Prefer `unroll` over `range` for small, fixed-size loops. 2. Reach for `: Const` parameters when the function body needs `unroll` over the - parameter, or when array sizes depend on it. + parameter. 3. `if` / `elif` branches that assign to the same outer variable should forward-declare it (`x: Imm` or `x: Mut`) before the branch. -4. **`match`** / **`match_range`** dispatch is undefined for out-of-range - values — always pair it with a `debug_assert` (or `assert`) on the value. -5. `match` patterns must be contiguous integers; if you need gaps, restructure - into an `if` chain or pad with an empty arm. -6. `assert a < b` and `assert a <= b` are range-checked under the assumption - that `b <= 2^MIN_LOG_MEMORY_SIZE = 2^16`. Larger comparisons must be done - with explicit bit decomposition (`hint_decompose_bits` + manual checks). 7. Function parameters are always immutable. To mutate a parameter's value inside a function, introduce a local `: Mut` alias at the top of the body - (e.g. `y: Mut = x`). Inline functions additionally cannot return - conditionally — use a regular function for those cases. -8. `parallel_range` requires per-iteration determinism in memory and hints; a - single divergent iteration breaks proving. -9. **A variable that's assigned inside an `if` nested in an `unroll` loop may - silently fail to remain in scope after the loop.** When you're dispatching - over per-iteration compile-time constants, prefer a flat top-level - `if`/`elif` chain (one branch per iteration value) over `unroll` + nested - `if`. This affects compile-time dispatch only; runtime `if` inside `range` - loops is unaffected. + (e.g. `y: Mut = x`). -## A simple example +## Example -```python -SIZE = 8 +Look at the recursive aggregation program (to aggregate XMSS) at its entrypoint [main.py](../rec_aggregation/zkdsl_implem/main.py). -def main(): - arr = Array(SIZE) - for i in unroll(0, SIZE): - arr[i] = i * i - sum = compute_sum(arr, SIZE) - assert sum == 140 - return - -def compute_sum(ptr, n: Const): - acc: Mut = 0 - for i in unroll(0, n): - acc = acc + ptr[i] - return acc -``` - -## Worked example: sugar -> ISA - -This shows how the front-end normalizes a small program with mutable variables in -a runtime loop down to a form close to the ISA. The compiler does this -automatically; you don't have to write the intermediate forms. +## Compilation step-by-step: zkDSL -> ISA Starting program: @@ -786,7 +762,7 @@ def main(): return ``` -Step 1 — replace mutable-across-loop variables with index buffers, since memory +Step 1 — the compiler replaces mutable-across-loop variables with index buffers, since memory is write-once: ```python @@ -862,14 +838,14 @@ def main(): x_buff[0] = x2 y_buff = Array(size + 1) y_buff[0] = y2 - loop(4, x_buff, y_buff) + loop_helper(4, x_buff, y_buff) x3 = x_buff[size] y3 = y_buff[size] assert x3 == 35 assert y3 == 40 return -def loop(i, x_buff, y_buff): +def loop_helper(i, x_buff, y_buff): if i == 6: return else: @@ -883,23 +859,7 @@ def loop(i, x_buff, y_buff): next_idx = buff_idx + 1 x_buff[next_idx] = x_body3 y_buff[next_idx] = y_body3 - loop(i + 1, x_buff, y_buff) + loop_helper(i + 1, x_buff, y_buff) return ``` -## Dev experience - -For Python tooling/linting on zkDSL files (which import `snark_lib` at the top), -point your editor at the compiler crate. With VSCode: - -```json -{ - "python.analysis.extraPaths": [ - "./crates/lean_compiler" - ] -} -``` - -This makes the stubs in `crates/lean_compiler/snark_lib.py` visible to your -language server, so completion / type-checks light up correctly inside `.py` -zkDSL sources.