From acd87ab9f8b6323be79e069e4a24176ce9b6f1cb Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 25 May 2026 17:09:52 +0400 Subject: [PATCH 1/3] remove DynArray from recursion.py --- crates/rec_aggregation/src/compilation.rs | 14 ++ .../rec_aggregation/zkdsl_implem/recursion.py | 141 +++++++++--------- 2 files changed, 83 insertions(+), 72 deletions(-) diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index 5db63d36..11d552a5 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -275,6 +275,7 @@ fn build_replacements(log_inner_bytecode: usize, bytecode_zero_eval: F) -> BTree let mut n_air_columns = vec![]; let mut n_air_shift_columns = vec![]; let mut n_air_constraints = vec![]; + let mut one_buses_all_cols = vec![]; for table in ALL_TABLES { let mut table_domseps = vec![]; let mut table_data_cols = vec![]; @@ -322,12 +323,25 @@ fn build_replacements(log_inner_bytecode: usize, bytecode_zero_eval: F) -> BTree one_buses_data_offsets.push(format!("[{}]", table_data_offsets.join(", "))); one_buses_new_cols.push(format!("[{}]", table_new_cols.join(", "))); + let mut sorted_seen: Vec = seen_cols.iter().copied().collect(); + sorted_seen.sort(); + one_buses_all_cols.push(format!( + "[{}]", + sorted_seen.iter().map(usize::to_string).collect::>().join(", ") + )); + num_cols_air.push(table.n_columns().to_string()); air_degrees.push(table.degree_air().to_string()); n_air_columns.push(table.n_columns().to_string()); n_air_shift_columns.push(table.n_shift_columns().to_string()); n_air_constraints.push(table.n_constraints().to_string()); } + let max_num_cols_air = ALL_TABLES.iter().map(|t| t.n_columns()).max().unwrap(); + replacements.insert("MAX_NUM_COLS_AIR_PLACEHOLDER".to_string(), max_num_cols_air.to_string()); + replacements.insert( + "ONE_BUSES_ALL_COLS_PLACEHOLDER".to_string(), + format!("[{}]", one_buses_all_cols.join(", ")), + ); replacements.insert( "ONE_BUSES_DOMSEPS_PLACEHOLDER".to_string(), format!("[{}]", one_buses_domseps.join(", ")), diff --git a/crates/rec_aggregation/zkdsl_implem/recursion.py b/crates/rec_aggregation/zkdsl_implem/recursion.py index bcd33598..a65ca749 100644 --- a/crates/rec_aggregation/zkdsl_implem/recursion.py +++ b/crates/rec_aggregation/zkdsl_implem/recursion.py @@ -25,6 +25,8 @@ ONE_BUSES_NEW_COLS = ONE_BUSES_NEW_COLS_PLACEHOLDER # [[[_; n_new]; num_buses]; N_TABLES] NUM_COLS_AIR = NUM_COLS_AIR_PLACEHOLDER +MAX_NUM_COLS_AIR = MAX_NUM_COLS_AIR_PLACEHOLDER # max(NUM_COLS_AIR[t] for t in 0..N_TABLES) +ONE_BUSES_ALL_COLS = ONE_BUSES_ALL_COLS_PLACEHOLDER # [[col, ...], _; N_TABLES] — sorted union of cols across all Multiplicity::One buses per table AIR_DEGREES = AIR_DEGREES_PLACEHOLDER # [_; N_TABLES] MAX_AIR_FULL_DEGREE = MAX_AIR_FULL_DEGREE_PLACEHOLDER @@ -221,18 +223,10 @@ def recursion(inner_public_memory, bytecode_hash_domsep): # Per-table data accumulators (indexed by table_index). bus_numerators_values = Array(N_TABLES * DIM) bus_denominators_values = Array(N_TABLES * DIM) - pcs_points = DynArray([]) - pcs_values = DynArray([]) - pcs_values_shift = DynArray([]) - for table_index in unroll(0, N_TABLES): - pcs_points.push(DynArray([])) - pcs_values.push(DynArray([])) - pcs_values[table_index].push(DynArray([])) - pcs_values_shift.push(DynArray([])) - pcs_values_shift[table_index].push(DynArray([])) - for _ in unroll(0, NUM_COLS_AIR[table_index]): - pcs_values[table_index][0].push(DynArray([])) - pcs_values_shift[table_index][0].push(DynArray([])) + pcs_inner_points = Array(N_TABLES) + pcs_vals_logup = Array(N_TABLES * MAX_NUM_COLS_AIR) + pcs_vals_air = Array(N_TABLES * MAX_NUM_COLS_AIR) + pcs_shifts_air = Array(N_TABLES * MAX_NUM_COLS_AIR) for table_index in unroll(0, N_TABLES): log_n_rows = table_log_heights[table_index] @@ -240,7 +234,7 @@ def recursion(inner_public_memory, bytecode_hash_domsep): offset: Mut = gkr_table_base_offset[table_index] inner_point = point_gkr + (n_vars_logup_gkr - log_n_rows) * DIM - pcs_points[table_index].push(inner_point) + pcs_inner_points[table_index] = inner_point # Bus (data flow between tables — Multiplicity::Column) prefix = multilinear_location_prefix(offset / n_rows, n_vars_logup_gkr - log_n_rows, point_gkr) @@ -270,14 +264,13 @@ def recursion(inner_public_memory, bytecode_hash_domsep): for i in unroll(0, n_new): new_col = ONE_BUSES_NEW_COLS[table_index][one_bus_idx][i] - debug_assert(len(pcs_values[table_index][0][new_col]) == 0) - pcs_values[table_index][0][new_col].push(new_evals + i * DIM) + pcs_vals_logup[table_index * MAX_NUM_COLS_AIR + new_col] = new_evals + i * DIM data_evals = Array(n_data * DIM) for i in unroll(0, n_data): data_col = ONE_BUSES_DATA_COLS[table_index][one_bus_idx][i] data_ofs = ONE_BUSES_DATA_OFFSETS[table_index][one_bus_idx][i] - src = pcs_values[table_index][0][data_col][0] + src = pcs_vals_logup[table_index * MAX_NUM_COLS_AIR + data_col] if data_ofs == 0: copy_5(src, data_evals + i * DIM) if data_ofs != 0: @@ -336,7 +329,6 @@ def recursion(inner_public_memory, bytecode_hash_domsep): check_sum: Mut = ZERO_VEC_PTR for table_index in unroll(0, N_TABLES): log_n_rows = table_log_heights[table_index] - total_num_cols = NUM_COLS_AIR[table_index] n_flat_columns = N_AIR_COLUMNS[table_index] n_shift_columns = N_AIR_SHIFT_COLUMNS[table_index] alpha_offset = AIR_ALPHA_OFFSETS[table_index] @@ -347,7 +339,7 @@ def recursion(inner_public_memory, bytecode_hash_domsep): table_index, inner_evals, air_alpha_powers + alpha_offset * DIM, logup_alphas_eq_poly ) - bus_point = pcs_points[table_index][0] + bus_point = pcs_inner_points[table_index] eq_val = poly_eq_extension_dynamic_ret(bus_point, all_challenges, log_n_rows) k_t = product_first_n(all_challenges + log_n_rows * DIM, n_max - log_n_rows) @@ -355,19 +347,13 @@ def recursion(inner_public_memory, bytecode_hash_domsep): contribution = mul_extension_ret(k_t, mul_extension_ret(eq_val, air_constraints_eval)) check_sum = add_extension_ret(check_sum, contribution) - pcs_points[table_index].push(all_challenges) - pcs_values[table_index].push(DynArray([])) - pcs_values_shift[table_index].push(DynArray([])) - last_index = len(pcs_values[table_index]) - 1 - for _ in unroll(0, total_num_cols): - pcs_values[table_index][last_index].push(DynArray([])) - pcs_values_shift[table_index][last_index].push(DynArray([])) + # AIR block (i=1): all flat cols 0..n_flat_columns populated; shifts 0..n_shift_columns populated. for i in unroll(0, n_flat_columns): - pcs_values[table_index][last_index][i].push(inner_evals + i * DIM) + pcs_vals_air[table_index * MAX_NUM_COLS_AIR + i] = inner_evals + i * DIM if n_shift_columns != 0: evals_shift = inner_evals + n_flat_columns * DIM for i in unroll(0, n_shift_columns): - pcs_values_shift[table_index][last_index][i].push(evals_shift + i * DIM) + pcs_shifts_air[table_index * MAX_NUM_COLS_AIR + i] = evals_shift + i * DIM # verify that the AIR-batched sumcheck is valid copy_5(check_sum, batched_air_final_value) @@ -403,25 +389,29 @@ def recursion(inner_public_memory, bytecode_hash_domsep): curr_randomness += DIM whir_sum = add_extension_ret(mul_extension_ret(embed_in_ef(ENDING_PC), curr_randomness), whir_sum) curr_randomness += DIM - debug_assert(len(pcs_points[table_index]) == len(pcs_values[table_index])) - for i in unroll(0, len(pcs_values[table_index])): - # next_mle-weighted (shift) values come first - for j in unroll(0, len(pcs_values_shift[table_index][i])): - if len(pcs_values_shift[table_index][i][j]) == 1: - whir_sum = add_extension_ret( - mul_extension_ret(pcs_values_shift[table_index][i][j][0], curr_randomness), - whir_sum, - ) - curr_randomness += DIM - # eq-weighted (up) values - for j in unroll(0, len(pcs_values[table_index][i])): - debug_assert(len(pcs_values[table_index][i][j]) < 2) - if len(pcs_values[table_index][i][j]) == 1: - whir_sum = add_extension_ret( - mul_extension_ret(pcs_values[table_index][i][j][0], curr_randomness), - whir_sum, - ) - curr_randomness += DIM + + # LOGUP + for k in unroll(0, len(ONE_BUSES_ALL_COLS[table_index])): + col = ONE_BUSES_ALL_COLS[table_index][k] + whir_sum = add_extension_ret( + mul_extension_ret(pcs_vals_logup[table_index * MAX_NUM_COLS_AIR + col], curr_randomness), + whir_sum, + ) + curr_randomness += DIM + + # AIR + for j in unroll(0, N_AIR_SHIFT_COLUMNS[table_index]): + whir_sum = add_extension_ret( + mul_extension_ret(pcs_shifts_air[table_index * MAX_NUM_COLS_AIR + j], curr_randomness), + whir_sum, + ) + curr_randomness += DIM + for j in unroll(0, N_AIR_COLUMNS[table_index]): + whir_sum = add_extension_ret( + mul_extension_ret(pcs_vals_air[table_index * MAX_NUM_COLS_AIR + j], curr_randomness), + whir_sum, + ) + curr_randomness += DIM folding_randomness_global: Mut s: Mut @@ -523,32 +513,39 @@ def recursion(inner_public_memory, bytecode_hash_domsep): folding_randomness_global, total_num_cols, ) - for i in unroll(0, len(pcs_points[table_index])): - point = pcs_points[table_index][i] - inner_folding = folding_randomness_global + (stacked_n_vars - log_n_rows) * DIM - n_shift_columns = N_AIR_SHIFT_COLUMNS[table_index] - - # next_mle (shift) values - if n_shift_columns != 0: - next_factor = next_mle(point, inner_folding, log_n_rows) - for j in unroll(0, total_num_cols): - if len(pcs_values_shift[table_index][i][j]) == 1: - prefix = column_prefixes + j * DIM - s = add_extension_ret( - s, - mul_extension_ret(mul_extension_ret(curr_randomness, prefix), next_factor), - ) - curr_randomness += DIM - # eq (flat) values - eq_factor = poly_eq_extension_dynamic_ret(point, inner_folding, log_n_rows) - for j in unroll(0, total_num_cols): - if len(pcs_values[table_index][i][j]) == 1: - prefix = column_prefixes + j * DIM - s = add_extension_ret( - s, - mul_extension_ret(mul_extension_ret(curr_randomness, prefix), eq_factor), - ) - curr_randomness += DIM + inner_folding = folding_randomness_global + (stacked_n_vars - log_n_rows) * DIM + n_shift_columns = N_AIR_SHIFT_COLUMNS[table_index] + + # LOGUP + point_logup = pcs_inner_points[table_index] + eq_factor_logup = poly_eq_extension_dynamic_ret(point_logup, inner_folding, log_n_rows) + for k in unroll(0, len(ONE_BUSES_ALL_COLS[table_index])): + col = ONE_BUSES_ALL_COLS[table_index][k] + prefix = column_prefixes + col * DIM + s = add_extension_ret( + s, + mul_extension_ret(mul_extension_ret(curr_randomness, prefix), eq_factor_logup), + ) + curr_randomness += DIM + + # AIR + if n_shift_columns != 0: + next_factor = next_mle(all_challenges, inner_folding, log_n_rows) + for j in unroll(0, n_shift_columns): + prefix = column_prefixes + j * DIM + s = add_extension_ret( + s, + mul_extension_ret(mul_extension_ret(curr_randomness, prefix), next_factor), + ) + curr_randomness += DIM + eq_factor_air = poly_eq_extension_dynamic_ret(all_challenges, inner_folding, log_n_rows) + for j in unroll(0, N_AIR_COLUMNS[table_index]): + prefix = column_prefixes + j * DIM + s = add_extension_ret( + s, + mul_extension_ret(mul_extension_ret(curr_randomness, prefix), eq_factor_air), + ) + curr_randomness += DIM copy_5(mul_extension_ret(s, final_value), end_sum) From 808e164f556d0a4ff0f104030b6ca4fb58491ba2 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 25 May 2026 17:23:32 +0400 Subject: [PATCH 2/3] remove DynArray from the compiler --- crates/lean_compiler/snark_lib.py | 18 - .../lean_compiler/src/a_simplify_lang/mod.rs | 661 +----------------- crates/lean_compiler/src/grammar.pest | 17 - crates/lean_compiler/src/lang.rs | 127 +--- .../src/parser/parsers/expression.rs | 39 +- .../src/parser/parsers/function.rs | 2 - .../src/parser/parsers/statement.rs | 112 +-- crates/lean_compiler/tests/test_compiler.rs | 2 +- .../lean_compiler/tests/test_data/error_17.py | 12 - .../lean_compiler/tests/test_data/error_18.py | 9 - .../lean_compiler/tests/test_data/error_19.py | 12 - .../lean_compiler/tests/test_data/error_20.py | 9 - .../lean_compiler/tests/test_data/error_21.py | 8 - .../lean_compiler/tests/test_data/error_22.py | 8 - .../lean_compiler/tests/test_data/error_23.py | 9 - .../tests/test_data/program_11.py | 13 - .../tests/test_data/program_146.py | 11 - .../tests/test_data/program_147.py | 14 - .../tests/test_data/program_148.py | 14 - .../tests/test_data/program_149.py | 16 - .../tests/test_data/program_150.py | 15 - .../tests/test_data/program_151.py | 14 - .../tests/test_data/program_152.py | 17 - .../tests/test_data/program_153.py | 28 - .../tests/test_data/program_154.py | 12 - .../tests/test_data/program_155.py | 17 - .../tests/test_data/program_156.py | 20 - .../tests/test_data/program_157.py | 31 - .../tests/test_data/program_158.py | 33 - .../tests/test_data/program_159.py | 21 - .../tests/test_data/program_160.py | 15 - .../tests/test_data/program_161.py | 14 - .../tests/test_data/program_162.py | 14 - .../tests/test_data/program_163.py | 16 - .../tests/test_data/program_164.py | 18 - .../tests/test_data/program_165.py | 98 --- .../tests/test_data/program_166.py | 51 -- .../tests/test_data/program_168.py | 86 --- .../tests/test_data/program_170.py | 5 - crates/lean_compiler/zkDSL.md | 67 +- 40 files changed, 42 insertions(+), 1663 deletions(-) delete mode 100644 crates/lean_compiler/tests/test_data/error_17.py delete mode 100644 crates/lean_compiler/tests/test_data/error_18.py delete mode 100644 crates/lean_compiler/tests/test_data/error_19.py delete mode 100644 crates/lean_compiler/tests/test_data/error_20.py delete mode 100644 crates/lean_compiler/tests/test_data/error_21.py delete mode 100644 crates/lean_compiler/tests/test_data/error_22.py delete mode 100644 crates/lean_compiler/tests/test_data/error_23.py delete mode 100644 crates/lean_compiler/tests/test_data/program_11.py delete mode 100644 crates/lean_compiler/tests/test_data/program_146.py delete mode 100644 crates/lean_compiler/tests/test_data/program_147.py delete mode 100644 crates/lean_compiler/tests/test_data/program_148.py delete mode 100644 crates/lean_compiler/tests/test_data/program_149.py delete mode 100644 crates/lean_compiler/tests/test_data/program_150.py delete mode 100644 crates/lean_compiler/tests/test_data/program_151.py delete mode 100644 crates/lean_compiler/tests/test_data/program_152.py delete mode 100644 crates/lean_compiler/tests/test_data/program_153.py delete mode 100644 crates/lean_compiler/tests/test_data/program_154.py delete mode 100644 crates/lean_compiler/tests/test_data/program_155.py delete mode 100644 crates/lean_compiler/tests/test_data/program_156.py delete mode 100644 crates/lean_compiler/tests/test_data/program_157.py delete mode 100644 crates/lean_compiler/tests/test_data/program_158.py delete mode 100644 crates/lean_compiler/tests/test_data/program_159.py delete mode 100644 crates/lean_compiler/tests/test_data/program_160.py delete mode 100644 crates/lean_compiler/tests/test_data/program_161.py delete mode 100644 crates/lean_compiler/tests/test_data/program_162.py delete mode 100644 crates/lean_compiler/tests/test_data/program_163.py delete mode 100644 crates/lean_compiler/tests/test_data/program_164.py delete mode 100644 crates/lean_compiler/tests/test_data/program_165.py delete mode 100644 crates/lean_compiler/tests/test_data/program_166.py delete mode 100644 crates/lean_compiler/tests/test_data/program_168.py diff --git a/crates/lean_compiler/snark_lib.py b/crates/lean_compiler/snark_lib.py index 53f119e6..f11c8138 100644 --- a/crates/lean_compiler/snark_lib.py +++ b/crates/lean_compiler/snark_lib.py @@ -45,24 +45,6 @@ def __len__(self): return -# DynArray - dynamic array with push/pop (compile-time construct) -class DynArray: - def __init__(self, initial: list): - self._data = list(initial) - - def __getitem__(self, idx): - return self._data[idx] - - def __len__(self): - return len(self._data) - - def push(self, value): - self._data.append(value) - - def pop(self): - self._data.pop() - - def poseidon16_compress(left, right, output): _ = left, right, output diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs index 4e1bf388..9894ac16 100644 --- a/crates/lean_compiler/src/a_simplify_lang/mod.rs +++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs @@ -285,8 +285,6 @@ pub fn simplify_program(mut program: Program) -> Result { // Remove all inlined functions (they've been inlined) program.functions.retain(|_, func| !func.inlined); - validate_program_vectors(&program)?; - // Remove all const functions - they should all have been specialized by now let const_func_names: Vec<_> = program .functions @@ -311,7 +309,6 @@ pub fn simplify_program(mut program: Program) -> Result { for (name, func) in &program.functions { let mut array_manager = ArrayManager::default(); let mut mut_tracker = MutableVarTracker::default(); - let mut vec_tracker = VectorTracker::default(); // Register mutable arguments and capture their initial versioned names // BEFORE simplifying the body @@ -335,7 +332,6 @@ pub fn simplify_program(mut program: Program) -> Result { counters: &mut counters, array_manager: &mut array_manager, mut_tracker: &mut mut_tracker, - vec_tracker: &mut vec_tracker, }; let simplified_instructions = simplify_lines( &ctx, @@ -364,153 +360,6 @@ pub fn simplify_program(mut program: Program) -> Result { Ok(simple_program) } -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum TreeVec { - Scalar(S), - Vector(Vec>), -} - -pub type VectorLenValue = TreeVec<()>; -pub type VectorValue = TreeVec; - -#[derive(Debug, Clone, Default)] -pub struct TreeVecTracker { - vectors: BTreeMap>, -} - -pub type VectorLenTracker = TreeVecTracker<()>; -type VectorTracker = TreeVecTracker; - -impl TreeVecTracker { - fn register(&mut self, var: &Var, value: TreeVec) { - self.vectors.insert(var.clone(), value); - } - - fn is_vector(&self, var: &Var) -> bool { - self.vectors.contains_key(var) - } - - pub fn get(&self, var: &Var) -> Option<&TreeVec> { - self.vectors.get(var) - } - - fn get_mut(&mut self, var: &Var) -> Option<&mut TreeVec> { - self.vectors.get_mut(var) - } -} - -impl TreeVec { - pub fn push(&mut self, elem: Self) { - match self { - Self::Vector(v) => v.push(elem), - _ => panic!("push on scalar"), - } - } - - pub fn pop(&mut self) -> Option { - match self { - Self::Vector(v) => v.pop(), - _ => panic!("pop on scalar"), - } - } - - pub fn len(&self) -> usize { - match self { - Self::Vector(v) => v.len(), - _ => panic!("len on scalar"), - } - } - - pub fn is_vector(&self) -> bool { - matches!(self, Self::Vector(_)) - } - - fn get(&self, i: usize) -> Option<&Self> { - match self { - Self::Vector(v) => v.get(i), - _ => None, - } - } - - fn get_mut(&mut self, i: usize) -> Option<&mut Self> { - match self { - Self::Vector(v) => v.get_mut(i), - _ => None, - } - } - - pub fn navigate(&self, idx: &[usize]) -> Option<&Self> { - idx.iter().try_fold(self, |v, &i| v.get(i)) - } - - pub fn navigate_mut(&mut self, idx: &[usize]) -> Option<&mut Self> { - idx.iter().try_fold(self, |v, &i| v.get_mut(i)) - } -} - -fn scalar_indices(indices: &[Expression]) -> Option> { - indices - .iter() - .map(|idx| idx.as_scalar().map(|f| f.to_usize())) - .collect() -} - -fn simplify_const_indices( - ctx: &SimplifyContext<'_>, - state: &mut SimplifyState<'_>, - const_malloc: &mut ConstMalloc, - indices: &[Expression], - lines: &mut Vec, - op: &str, - location: SourceLocation, -) -> Result, String> { - indices - .iter() - .map(|idx| { - let simplified = simplify_expr(ctx, state, const_malloc, idx, lines)?; - let const_val = simplified - .as_constant() - .ok_or_else(|| format!("{op} index must be a compile-time constant, at {location}"))?; - let val = const_val - .naive_eval() - .ok_or_else(|| format!("{op} index must be evaluable at compile time, at {location}"))?; - Ok(val.to_usize()) - }) - .collect() -} - -/// Navigate to the (sub-)vector to act upon and verify it is a vector. -fn navigate_vector_target_mut<'a>( - tracker: &'a mut VectorTracker, - vector: &Var, - indices: &[usize], - op: &str, - location: SourceLocation, -) -> Result<&'a mut VectorValue, String> { - let root = tracker - .get_mut(vector) - .ok_or_else(|| format!("{op} called on non-vector variable '{vector}', at {location}"))?; - let target = root - .navigate_mut(indices) - .ok_or_else(|| format!("{op} target index out of bounds, at {location}"))?; - if !target.is_vector() { - return Err(format!("{op} target must be a vector, not a scalar, at {location}")); - } - Ok(target) -} - -fn build_vector_len_value(elements: &[VecLiteral]) -> VectorLenValue { - VectorLenValue::Vector( - elements - .iter() - .map(|elem| match elem { - VecLiteral::Vec(inner) => build_vector_len_value(inner), - VecLiteral::Expr(_) => VectorLenValue::Scalar(()), - }) - .collect(), - ) -} - fn compile_time_transform_in_program( program: &mut Program, unroll_counter: &mut Counter, @@ -587,7 +436,6 @@ fn compile_time_transform_in_lines( inline_counter: &mut Counter, parent_const_var_exprs: &BTreeMap, ) -> Result<(), String> { - let mut vector_len_tracker = VectorLenTracker::default(); let mut const_var_exprs: BTreeMap = parent_const_var_exprs.clone(); // used to simplify expressions containing variables with known constant values let mut i = 0; @@ -601,8 +449,7 @@ fn compile_time_transform_in_lines( value, location, } = line - && let Some(expanded) = - try_expand_match_range(value, targets, *location, const_arrays, &vector_len_tracker)? + && let Some(expanded) = try_expand_match_range(value, targets, *location, const_arrays)? { lines.splice(i..=i, expanded); continue; @@ -610,7 +457,7 @@ fn compile_time_transform_in_lines( for expr in line.expressions_mut() { substitute_const_vars_in_expr(expr, &const_var_exprs); - compile_time_transform_in_expr(expr, const_arrays, &vector_len_tracker); + compile_time_transform_in_expr(expr, const_arrays); } // Extract nested calls to functions requiring preprocessing (inlined or const-arg) @@ -689,54 +536,6 @@ fn compile_time_transform_in_lines( } } - Line::VecDeclaration { var, elements, .. } => { - vector_len_tracker.register(var, build_vector_len_value(elements)); - } - - Line::Push { - vector, - indices, - element, - .. - } => { - let const_indices = - scalar_indices(indices).ok_or_else(|| "push with non-constant indices".to_string())?; - let new_element = match element { - VecLiteral::Vec(inner) => build_vector_len_value(inner), - VecLiteral::Expr(_) => VectorLenValue::Scalar(()), - }; - let target = vector_len_tracker - .get_mut(vector) - .ok_or_else(|| "pushing to undeclared vector".to_string())? - .navigate_mut(&const_indices) - .ok_or_else(|| "push target index out of bounds".to_string())?; - if !target.is_vector() { - return Err("push target is not a vector".to_string()); - } - target.push(new_element); - } - - Line::Pop { - vector, - indices, - location, - } => { - let const_indices = scalar_indices(indices) - .ok_or_else(|| format!("line {}: pop with non-constant indices", location))?; - let target = vector_len_tracker - .get_mut(vector) - .ok_or_else(|| format!("line {}: pop on undeclared vector '{}'", location, vector))? - .navigate_mut(&const_indices) - .ok_or_else(|| format!("line {}: pop target index out of bounds", location))?; - if !target.is_vector() { - return Err(format!("line {}: pop target is not a vector", location)); - } - if target.len() == 0 { - return Err(format!("line {}: pop on empty vector", location)); - } - target.pop(); - } - Line::IfCondition { condition, then_branch, @@ -820,7 +619,6 @@ fn try_expand_match_range( targets: &[AssignmentTarget], location: SourceLocation, const_arrays: &BTreeMap, - vector_len: &VectorLenTracker, ) -> Result>, String> { let Expression::FunctionCall { function_name, args, .. @@ -881,11 +679,11 @@ fn try_expand_match_range( return Err("match_range: expected range(start, end)".into()); } let start = ra[0] - .compile_time_eval(const_arrays, vector_len) + .compile_time_eval(const_arrays) .ok_or(format!("match_range: range start must be constant (at {location})"))? .to_usize(); let end = ra[1] - .compile_time_eval(const_arrays, vector_len) + .compile_time_eval(const_arrays) .ok_or(format!("match_range: range end must be constant (at {location})"))? .to_usize(); @@ -1097,19 +895,15 @@ fn extract_preprocessed_calls( } } -fn compile_time_transform_in_expr( - expr: &mut Expression, - const_arrays: &BTreeMap, - vector_len_tracker: &VectorLenTracker, -) -> bool { +fn compile_time_transform_in_expr(expr: &mut Expression, const_arrays: &BTreeMap) -> bool { if expr.is_scalar() { return false; } let mut changed = false; for inner_expr in expr.inner_exprs_mut() { - changed |= compile_time_transform_in_expr(inner_expr, const_arrays, vector_len_tracker); + changed |= compile_time_transform_in_expr(inner_expr, const_arrays); } - if let Some(scalar) = expr.compile_time_eval(const_arrays, vector_len_tracker) { + if let Some(scalar) = expr.compile_time_eval(const_arrays) { *expr = Expression::scalar(scalar); changed = true; } @@ -1553,167 +1347,8 @@ fn check_block_scoping(block: &[Line], ctx: &mut Context) { } } Line::Panic { .. } | Line::LocationReport { .. } => {} - Line::VecDeclaration { var, elements, .. } => { - // Check expressions in vec elements - check_vec_literal_scoping(elements, ctx); - // Add the vector variable to scope - ctx.add_var(var); - } - Line::Push { - vector, - indices, - element, - .. - } => { - // Check the vector variable is in scope - assert!(ctx.defines(vector), "Vector variable '{}' not in scope", vector); - // Check indices are in scope - for idx in indices { - check_expr_scoping(idx, ctx); - } - // Check the pushed element - check_vec_literal_element_scoping(element, ctx); - } - Line::Pop { vector, indices, .. } => { - // Check the vector variable is in scope - assert!(ctx.defines(vector), "Vector variable '{}' not in scope", vector); - // Check indices are in scope - for idx in indices { - check_expr_scoping(idx, ctx); - } - } - } - } -} - -fn validate_program_vectors(program: &Program) -> Result<(), String> { - let inlined_functions = program.inlined_function_names(); - for f in program.functions.values() { - validate_vectors(&f.body, &BTreeSet::new(), &inlined_functions, None)?; - } - Ok(()) -} - -fn validate_vectors( - lines: &[Line], - outer: &BTreeSet, - inlined: &BTreeSet, - restrict: Option, -) -> Result<(), String> { - let mut local: BTreeSet = BTreeSet::new(); - macro_rules! all { - () => { - outer.union(&local).cloned().collect::>() - }; - } - - for line in lines { - match line { - Line::VecDeclaration { var, elements, .. } => { - local.insert(var.clone()); - validate_vec_lit(elements, &all!(), inlined)?; - } - Line::Push { - vector, - element, - location, - .. - } => { - if restrict.is_some() && outer.contains(vector) { - return Err(format!("line {}: push to outer-scope vector '{}'", location, vector)); - } - validate_vec_lit(std::slice::from_ref(element), &all!(), inlined)?; - if !local.contains(vector) && !outer.contains(vector) { - return Err(format!("line {}: unknown vector '{}'", location, vector)); - } - } - Line::Pop { vector, location, .. } => { - if restrict.is_some() && outer.contains(vector) { - return Err(format!("line {}: pop from outer-scope vector '{}'", location, vector)); - } - if !local.contains(vector) && !outer.contains(vector) { - return Err(format!("line {}: unknown vector '{}'", location, vector)); - } - } - Line::Statement { value, .. } => { - check_vec_in_call(value, &all!(), inlined)?; - } - Line::IfCondition { - then_branch, - else_branch, - location, - .. - } => { - validate_vectors(then_branch, &all!(), inlined, Some(*location))?; - validate_vectors(else_branch, &all!(), inlined, Some(*location))?; - } - Line::ForLoop { - body, - loop_kind, - location, - .. - } => { - validate_vectors( - body, - &all!(), - inlined, - if loop_kind.is_unroll() { None } else { Some(*location) }, - )?; - } - Line::Match { arms, location, .. } => { - for (_, arm) in arms { - validate_vectors(arm, &all!(), inlined, Some(*location))?; - } - } - _ => {} } } - Ok(()) -} - -fn validate_vec_lit(elems: &[VecLiteral], vecs: &BTreeSet, inlined: &BTreeSet) -> Result<(), String> { - for e in elems { - match e { - VecLiteral::Expr(expr) => check_vec_in_call(expr, vecs, inlined)?, - VecLiteral::Vec(inner) => validate_vec_lit(inner, vecs, inlined)?, - } - } - Ok(()) -} - -fn check_vec_in_call(expr: &Expression, vecs: &BTreeSet, inlined: &BTreeSet) -> Result<(), String> { - if let Expression::FunctionCall { - function_name, - args, - location, - } = expr - && !inlined.contains(function_name) - { - for arg in args { - if let Expression::Value(SimpleExpr::Memory(VarOrConstMallocAccess::Var(v))) = arg - && vecs.contains(v) - { - return Err(format!( - "line {}: vector '{}' passed to function '{}'", - location, v, function_name - )); - } - } - } - Ok(()) -} - -fn check_vec_literal_scoping(elements: &[VecLiteral], ctx: &Context) { - for elem in elements { - check_vec_literal_element_scoping(elem, ctx); - } -} - -fn check_vec_literal_element_scoping(elem: &VecLiteral, ctx: &Context) { - match elem { - VecLiteral::Expr(expr) => check_expr_scoping(expr, ctx), - VecLiteral::Vec(inner) => check_vec_literal_scoping(inner, ctx), - } } fn check_expr_scoping(expr: &Expression, ctx: &Context) { @@ -1772,7 +1407,6 @@ struct SimplifyState<'a> { counters: &'a mut Counters, array_manager: &'a mut ArrayManager, mut_tracker: &'a mut MutableVarTracker, - vec_tracker: &'a mut VectorTracker, } #[derive(Debug, Clone, Default)] @@ -1954,50 +1588,6 @@ impl ArrayManager { } } -fn build_vector_value( - ctx: &SimplifyContext<'_>, - state: &mut SimplifyState<'_>, - const_malloc: &mut ConstMalloc, - elements: &[VecLiteral], - lines: &mut Vec, - location: SourceLocation, -) -> Result { - let mut vec_elements = Vec::new(); - - for elem in elements { - vec_elements.push(build_vector_value_from_element( - ctx, - state, - const_malloc, - elem, - lines, - location, - )?); - } - - Ok(VectorValue::Vector(vec_elements)) -} - -fn build_vector_value_from_element( - ctx: &SimplifyContext<'_>, - state: &mut SimplifyState<'_>, - const_malloc: &mut ConstMalloc, - element: &VecLiteral, - lines: &mut Vec, - location: SourceLocation, -) -> Result { - match element { - VecLiteral::Vec(inner) => build_vector_value(ctx, state, const_malloc, inner, lines, location), - VecLiteral::Expr(expr) => { - // Scalar expression - create auxiliary variable and emit assignment - let aux_var = state.counters.aux_var(); - let simplified_value = simplify_expr(ctx, state, const_malloc, expr, lines)?; - lines.push(SimpleLine::equality(aux_var.clone(), simplified_value)); - Ok(VectorValue::Scalar(aux_var)) - } - } -} - #[allow(clippy::too_many_arguments)] fn simplify_lines( ctx: &SimplifyContext<'_>, @@ -2380,45 +1970,25 @@ fn simplify_lines( res.push(SimpleLine::equality(target_var, simplified_val)); } Expression::ArrayAccess { array, index } => { - // Check if array is a vector (needs to be handled before simplifying indices) - let versioned_array = array.as_var().map(|n| state.mut_tracker.current_name(n)); - if let Some(versioned_array) = &versioned_array - && state.vec_tracker.is_vector(versioned_array) - { - // Use simplify_expr which handles vectors correctly - let simplified_val = simplify_expr( - ctx, - state, - const_malloc, - &Expression::ArrayAccess { - array: array.clone(), - index: index.clone(), - }, - &mut res, - )?; - let target_var = get_target_var_name(state, var, *is_mutable)?; - res.push(SimpleLine::equality(target_var, simplified_val)); - } else { - // Pre-simplify indices before version update - let simplified_index = index - .iter() - .map(|idx| simplify_expr(ctx, state, const_malloc, idx, &mut res)) - .collect::, _>>()?; - let target_var = get_target_var_name(state, var, *is_mutable)?; - if state.mut_tracker.is_ssa_reassignment(var) { - res.push(SimpleLine::ForwardDeclaration { - var: target_var.clone(), - }); - } - handle_array_assignment( - state, - const_malloc, - &mut res, - array, - &simplified_index, - ArrayAccessType::VarIsAssigned(target_var), - ); + // Pre-simplify indices before version update + let simplified_index = index + .iter() + .map(|idx| simplify_expr(ctx, state, const_malloc, idx, &mut res)) + .collect::, _>>()?; + let target_var = get_target_var_name(state, var, *is_mutable)?; + if state.mut_tracker.is_ssa_reassignment(var) { + res.push(SimpleLine::ForwardDeclaration { + var: target_var.clone(), + }); } + handle_array_assignment( + state, + const_malloc, + &mut res, + array, + &simplified_index, + ArrayAccessType::VarIsAssigned(target_var), + ); } Expression::MathExpr(operation, args) => { let args_simplified = args @@ -2650,16 +2220,13 @@ fn simplify_lines( // Snapshot state before processing branches let mut_tracker_snapshot = state.mut_tracker.clone(); - let vec_tracker_snapshot = state.vec_tracker.clone(); let mut array_manager_then = state.array_manager.clone(); let mut mut_tracker_then = state.mut_tracker.clone(); - let mut vec_tracker_then = state.vec_tracker.clone(); let mut state_then = SimplifyState { counters: state.counters, array_manager: &mut array_manager_then, mut_tracker: &mut mut_tracker_then, - vec_tracker: &mut vec_tracker_then, }; let then_branch_simplified = simplify_lines( ctx, @@ -2677,13 +2244,11 @@ fn simplify_lines( // Restore state for else branch let mut mut_tracker_else = mut_tracker_snapshot.clone(); - let mut vec_tracker_else = vec_tracker_snapshot.clone(); let mut state_else = SimplifyState { counters: state.counters, array_manager: &mut array_manager_else, mut_tracker: &mut mut_tracker_else, - vec_tracker: &mut vec_tracker_else, }; let else_branch_simplified = simplify_lines( ctx, @@ -2852,41 +2417,6 @@ fn simplify_lines( Line::LocationReport { location } => { res.push(SimpleLine::LocationReport { location: *location }); } - Line::VecDeclaration { - var, - elements, - location, - } => { - let vector_value = build_vector_value(ctx, state, const_malloc, elements, &mut res, *location)?; - state.vec_tracker.register(var, vector_value); - // No SimpleLine for the variable itself - vector metadata is compile-time only - } - Line::Push { - vector, - indices, - element, - location, - } => { - let const_indices = - simplify_const_indices(ctx, state, const_malloc, indices, &mut res, "push", *location)?; - let new_element = - build_vector_value_from_element(ctx, state, const_malloc, element, &mut res, *location)?; - let target = navigate_vector_target_mut(state.vec_tracker, vector, &const_indices, "push", *location)?; - target.push(new_element); - } - Line::Pop { - vector, - indices, - location, - } => { - let const_indices = - simplify_const_indices(ctx, state, const_malloc, indices, &mut res, "pop", *location)?; - let target = navigate_vector_target_mut(state.vec_tracker, vector, &const_indices, "pop", *location)?; - if target.len() == 0 { - return Err(format!("pop on empty vector, at {}", location)); - } - target.pop(); - } } } @@ -2934,39 +2464,6 @@ fn simplify_expr( let versioned_array = array_var_name.map(|n| state.mut_tracker.current_name(n)); - // Check for compile-time vector access - if let Some(versioned) = &versioned_array - && state.vec_tracker.is_vector(versioned) - { - // Vector access - indices must all be compile-time constant - // First, simplify all indices (this may mutate state) - let mut const_indices: Vec = Vec::new(); - for idx in index { - let simplified = simplify_expr(ctx, state, const_malloc, idx, lines)?; - let SimpleExpr::Constant(const_expr) = simplified else { - return Err("Vector index must be compile-time constant".to_string()); - }; - let val = const_expr - .naive_eval() - .ok_or_else(|| "Cannot evaluate vector index".to_string())? - .to_usize(); - const_indices.push(val); - } - - // Now we can borrow vec_tracker again - let vector_value = state.vec_tracker.get(versioned).unwrap(); - - // Navigate to the element - let element = vector_value - .navigate(&const_indices) - .ok_or_else(|| format!("Vector index out of bounds: {:?}", const_indices))?; - - return match element { - VectorValue::Scalar(var) => Ok(var.clone().into()), - VectorValue::Vector(_) => Err("Cannot use nested vector as expression value".to_string()), - }; - } - assert_eq!(index.len(), 1); let index = index[0].clone(); @@ -3060,40 +2557,7 @@ fn simplify_expr( Ok(VarOrConstMallocAccess::Var(result_var).into()) } - Expression::Len { array, indices } => { - // Check for compile-time vector len() - let versioned_array = state.mut_tracker.current_name(array); - if state.vec_tracker.is_vector(&versioned_array) { - // Evaluate indices at compile time - first simplify to avoid borrow issues - let mut const_indices: Vec = Vec::new(); - for idx in indices { - let simplified = simplify_expr(ctx, state, const_malloc, idx, lines)?; - let SimpleExpr::Constant(const_expr) = simplified else { - return Err("Vector len() index must be compile-time constant".to_string()); - }; - let val = const_expr - .naive_eval() - .ok_or_else(|| "Cannot evaluate len() index".to_string())? - .to_usize(); - const_indices.push(val); - } - - // Now we can borrow vec_tracker again - let vector_value = state.vec_tracker.get(&versioned_array).unwrap(); - - // Navigate and get length - let target = if const_indices.is_empty() { - vector_value - } else { - vector_value - .navigate(&const_indices) - .ok_or_else(|| "len() index out of bounds".to_string())? - }; - - return Ok(SimpleExpr::Constant(ConstExpression::from(target.len()))); - } - - // Fall through to const array handling (should be unreachable for vectors) + Expression::Len { .. } => { unreachable!("len() should have been resolved at parse time for const arrays") } Expression::Lambda { .. } => Err("Lambda expressions can only be used as arguments to match_range".to_string()), @@ -3218,76 +2682,12 @@ pub fn find_variable_usage( on_new_expr(end, &internal_vars, &mut external_vars); } Line::Panic { .. } | Line::LocationReport { .. } => {} - Line::VecDeclaration { var, elements, .. } => { - // Process expressions in vec elements - process_vec_elements_usage(elements, &internal_vars, &mut external_vars, const_arrays); - // Add the vector variable to internal vars - internal_vars.insert(var.clone()); - } - Line::Push { - vector, - indices, - element, - .. - } => { - // The vector variable is used - if !internal_vars.contains(vector) { - external_vars.insert(vector.clone()); - } - // Process index expressions - for idx in indices { - on_new_expr(idx, &internal_vars, &mut external_vars); - } - // Process the pushed element - process_vec_element_usage(element, &internal_vars, &mut external_vars, const_arrays); - } - Line::Pop { vector, indices, .. } => { - // The vector variable is used - if !internal_vars.contains(vector) { - external_vars.insert(vector.clone()); - } - // Process index expressions - for idx in indices { - on_new_expr(idx, &internal_vars, &mut external_vars); - } - } } } (internal_vars, external_vars) } -fn process_vec_elements_usage( - elements: &[VecLiteral], - internal_vars: &BTreeSet, - external_vars: &mut BTreeSet, - const_arrays: &BTreeMap, -) { - for elem in elements { - process_vec_element_usage(elem, internal_vars, external_vars, const_arrays); - } -} - -fn process_vec_element_usage( - elem: &VecLiteral, - internal_vars: &BTreeSet, - external_vars: &mut BTreeSet, - const_arrays: &BTreeMap, -) { - match elem { - VecLiteral::Expr(expr) => { - for var in vars_in_expression(expr, const_arrays) { - if !internal_vars.contains(&var) && !const_arrays.contains_key(&var) { - external_vars.insert(var); - } - } - } - VecLiteral::Vec(inner) => { - process_vec_elements_usage(inner, internal_vars, external_vars, const_arrays); - } - } -} - enum VarTransform { ReplaceWithExpr(SimpleExpr), Rename(String), @@ -3373,15 +2773,6 @@ fn transform_vars_in_lines(lines: &mut [Line], transform: &impl Fn(&Var) -> VarT Line::ForLoop { iterator, .. } => { transform(iterator).apply_to_var(iterator); } - Line::VecDeclaration { var, .. } => { - transform(var).apply_to_var(var); - } - Line::Push { vector, .. } => { - transform(vector).apply_to_var(vector); - } - Line::Pop { vector, .. } => { - transform(vector).apply_to_var(vector); - } _ => {} } } diff --git a/crates/lean_compiler/src/grammar.pest b/crates/lean_compiler/src/grammar.pest index 0ee1c499..16b041a6 100644 --- a/crates/lean_compiler/src/grammar.pest +++ b/crates/lean_compiler/src/grammar.pest @@ -41,23 +41,9 @@ simple_statement = { return_statement | assert_statement | debug_assert_statement | - vec_declaration | - push_statement | - pop_statement | assignment } -// Vector declaration: var = vec![...] (vectors are implicitly mutable for push) -vec_declaration = { identifier ~ "=" ~ vec_literal } - -// Push statement: vec_var.push(element) or vec_var[i][j].push(element) -push_statement = { push_target ~ "." ~ "push" ~ "(" ~ vec_element ~ ")" } -push_target = { identifier ~ ("[" ~ expression ~ "]")* } - -// Pop statement: vec_var.pop() or vec_var[i][j].pop() -pop_statement = { pop_target ~ "." ~ "pop" ~ "(" ~ ")" } -pop_target = { identifier ~ ("[" ~ expression ~ "]")* } - return_statement = { "return" ~ (("(" ~ tuple_expression ~ ")") | tuple_expression)? } @@ -131,9 +117,6 @@ primary = { // Lambda expression: lambda param: body lambda_expr = { "lambda" ~ identifier ~ ":" ~ expression } -// DynArray literal: DynArray([elem1, elem2, ...]) - compile-time dynamic arrays -vec_literal = { "DynArray" ~ "(" ~ "[" ~ (vec_element ~ ("," ~ vec_element)* ~ ","?)? ~ "]" ~ ")" } -vec_element = { vec_literal | expression } function_call_expr = { identifier ~ "(" ~ tuple_expression? ~ ")" } hint_witness_expr = { "hint_witness" ~ "(" ~ string_literal ~ "," ~ expression ~ ")" } log2_ceil_expr = { "log2_ceil" ~ "(" ~ expression ~ ")" } diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index 5accb007..938aa024 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -4,7 +4,7 @@ use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Display, Formatter}; use utils::ToUsize; -use crate::a_simplify_lang::{VarOrConstMallocAccess, VectorLenTracker}; +use crate::a_simplify_lang::VarOrConstMallocAccess; use crate::{F, parser::ConstArrayValue}; pub use lean_vm::{FileId, FunctionName, SourceLocation}; @@ -366,27 +366,16 @@ impl From for Expression { } impl Expression { - pub fn compile_time_eval( - &self, - const_arrays: &BTreeMap, - vector_len: &VectorLenTracker, - ) -> Option { + pub fn compile_time_eval(&self, const_arrays: &BTreeMap) -> Option { // Handle Len specially since it needs const_arrays if let Self::Len { array, indices } = self { let idx = indices .iter() - .map(|e| e.compile_time_eval(const_arrays, vector_len)) + .map(|e| e.compile_time_eval(const_arrays)) .collect::>>()?; - if let Some(arr) = const_arrays.get(array) { - let target = arr.navigate(&idx)?; - return Some(F::from_usize(target.len())); - } - if let Some(arr) = vector_len.get(array) { - let usize_idx: Vec = idx.iter().map(|f| f.to_usize()).collect(); - let target = arr.navigate(&usize_idx)?; - return Some(F::from_usize(target.len())); - } - return None; + let arr = const_arrays.get(array)?; + let target = arr.navigate(&idx)?; + return Some(F::from_usize(target.len())); } self.eval_with( &|value: &SimpleExpr| value.as_constant()?.naive_eval(), @@ -510,31 +499,6 @@ impl AssignmentTarget { } } -/// A compile-time dynamic array literal: DynArray(elem1, elem2, ...) -/// Elements can be expressions or nested DynArray literals. -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum VecLiteral { - /// A scalar expression element - Expr(Expression), - /// A nested vector literal - Vec(Vec), -} - -impl VecLiteral { - pub fn all_exprs_mut_in_slice(arr: &mut [Self]) -> Vec<&mut Expression> { - let mut exprs = Vec::new(); - for elem in arr { - match elem { - Self::Expr(expr) => exprs.push(expr), - Self::Vec(nested) => { - exprs.extend(Self::all_exprs_mut_in_slice(nested)); - } - } - } - exprs - } -} - #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum LoopKind { Range, @@ -597,25 +561,6 @@ pub enum Line { LocationReport { location: SourceLocation, }, - /// Compile-time dynamic array declaration: var = DynArray(...) - VecDeclaration { - var: Var, - elements: Vec, - location: SourceLocation, - }, - /// Compile-time vector push: push(vec_var, element) or push(vec_var[i][j], element) - Push { - vector: Var, - indices: Vec, - element: VecLiteral, - location: SourceLocation, - }, - /// Compile-time vector pop: vec_var.pop() or vec_var[i][j].pop() - Pop { - vector: Var, - indices: Vec, - location: SourceLocation, - }, } /// A context specifying which variables are in scope. @@ -802,33 +747,6 @@ impl Line { Some(msg) => format!("assert False, \"{msg}\""), None => "assert False".to_string(), }, - Self::VecDeclaration { var, elements, .. } => { - format!("{var} = DynArray({})", elements.len()) - } - Self::Push { - vector, - indices, - element, - .. - } => { - format!( - "{}[{}].push({})", - vector, - indices.iter().map(|i| format!("{i}")).collect::>().join("]["), - element - ) - } - Self::Pop { vector, indices, .. } => { - if indices.is_empty() { - format!("{}.pop()", vector) - } else { - format!( - "{}[{}].pop()", - vector, - indices.iter().map(|i| format!("{i}")).collect::>().join("][") - ) - } - } }; format!("{spaces}{line_str}") } @@ -847,10 +765,7 @@ impl Line { | Self::Assert { .. } | Self::FunctionRet { .. } | Self::Panic { .. } - | Self::LocationReport { .. } - | Self::VecDeclaration { .. } - | Self::Push { .. } - | Self::Pop { .. } => vec![], + | Self::LocationReport { .. } => vec![], } } @@ -868,10 +783,7 @@ impl Line { | Self::Assert { .. } | Self::FunctionRet { .. } | Self::Panic { .. } - | Self::LocationReport { .. } - | Self::VecDeclaration { .. } - | Self::Push { .. } - | Self::Pop { .. } => vec![], + | Self::LocationReport { .. } => vec![], } } @@ -895,13 +807,6 @@ impl Line { vec![start, end] } Self::FunctionRet { return_data } => return_data.iter_mut().collect(), - Self::Push { indices, element, .. } => { - let mut exprs = indices.iter_mut().collect::>(); - exprs.extend(VecLiteral::all_exprs_mut_in_slice(std::slice::from_mut(element))); - exprs - } - Self::Pop { indices, .. } => indices.iter_mut().collect(), - Self::VecDeclaration { elements, .. } => VecLiteral::all_exprs_mut_in_slice(elements), Self::ForwardDeclaration { .. } | Self::Panic { .. } | Self::LocationReport { .. } => vec![], } } @@ -925,22 +830,6 @@ impl Display for ConstantValue { } } -impl Display for VecLiteral { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::Expr(expr) => write!(f, "{expr}"), - Self::Vec(elements) => { - let elements_str = elements - .iter() - .map(|elem| format!("{elem}")) - .collect::>() - .join(", "); - write!(f, "DynArray([{elements_str}])") - } - } - } -} - impl Display for SimpleExpr { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { diff --git a/crates/lean_compiler/src/parser/parsers/expression.rs b/crates/lean_compiler/src/parser/parsers/expression.rs index fe43f289..004834ca 100644 --- a/crates/lean_compiler/src/parser/parsers/expression.rs +++ b/crates/lean_compiler/src/parser/parsers/expression.rs @@ -5,7 +5,7 @@ use super::literal::{VarOrConstantParser, evaluate_const_expr}; use super::{ConstArrayValue, Parse, ParseContext, next_inner_pair}; use crate::lang::MathOperation; use crate::{ - lang::{ConstExpression, ConstantValue, Expression, SimpleExpr, VecLiteral}, + lang::{ConstExpression, ConstantValue, Expression, SimpleExpr}, parser::{ error::{ParseResult, SemanticError}, grammar::{ParsePair, Rule}, @@ -221,40 +221,3 @@ impl Parse for LenParser { }) } } - -/// Parser for vec![...] literals (compile-time vectors) -/// Parses into the VecLiteral enum (separate from Expression) -pub struct VecLiteralParser; - -impl Parse for VecLiteralParser { - fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - // vec_literal = { "vec!" ~ "[" ~ (vec_element ~ ("," ~ vec_element)*)? ~ "]" } - // vec_element = { vec_literal | expression } - let elements: Vec = pair - .into_inner() - .map(|elem_pair| VecElementParser.parse(elem_pair, ctx)) - .collect::, _>>()?; - - Ok(VecLiteral::Vec(elements)) - } -} - -/// Parser for vec element (either a nested vec_literal or an expression) -pub struct VecElementParser; - -impl Parse for VecElementParser { - fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - match pair.as_rule() { - Rule::vec_element => { - // vec_element contains either vec_literal or expression - let inner = next_inner_pair(&mut pair.into_inner(), "vec element")?; - match inner.as_rule() { - Rule::vec_literal => VecLiteralParser.parse(inner, ctx), - _ => Ok(VecLiteral::Expr(ExpressionParser.parse(inner, ctx)?)), - } - } - Rule::vec_literal => VecLiteralParser.parse(pair, ctx), - _ => Ok(VecLiteral::Expr(ExpressionParser.parse(pair, ctx)?)), - } - } -} diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 576e5af6..40256ac9 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -15,8 +15,6 @@ pub const RESERVED_FUNCTION_NAMES: &[&str] = &[ // Built-in functions "print", "Array", - "DynArray", - "push", // Compile-time vector push // Compile-time only functions "len", "log2_ceil", diff --git a/crates/lean_compiler/src/parser/parsers/statement.rs b/crates/lean_compiler/src/parser/parsers/statement.rs index 4f001693..5e683514 100644 --- a/crates/lean_compiler/src/parser/parsers/statement.rs +++ b/crates/lean_compiler/src/parser/parsers/statement.rs @@ -1,13 +1,13 @@ use lean_vm::{Boolean, BooleanExpr}; use utils::ToUsize; -use super::expression::{ExpressionParser, VecElementParser, VecLiteralParser}; +use super::expression::ExpressionParser; use super::function::{AssignmentParser, TupleExpressionParser}; use super::literal::ConstExprParser; use super::{Parse, ParseContext, next_inner_pair, push_statement_with_location}; use crate::{ SourceLineNumber, - lang::{Expression, Line, LoopKind, SourceLocation, VecLiteral}, + lang::{Expression, Line, LoopKind, SourceLocation}, parser::{ error::{ParseResult, SemanticError}, grammar::{ParsePair, Rule}, @@ -36,9 +36,6 @@ impl Parse for StatementParser { Rule::return_statement => ReturnStatementParser.parse(simple_inner, ctx), Rule::assert_statement => AssertParser::.parse(simple_inner, ctx), Rule::debug_assert_statement => AssertParser::.parse(simple_inner, ctx), - Rule::vec_declaration => VecDeclarationParser.parse(simple_inner, ctx), - Rule::push_statement => PushStatementParser.parse(simple_inner, ctx), - Rule::pop_statement => PopStatementParser.parse(simple_inner, ctx), _ => Err(SemanticError::new("Unknown simple statement").into()), } } @@ -287,111 +284,6 @@ impl Parse for AssertParser { } } -/// Parser for vector declarations: `var = vec![...]` (vectors are implicitly mutable for push) -pub struct VecDeclarationParser; - -impl Parse for VecDeclarationParser { - fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - let line_number = pair.line_col().0; - let mut inner = pair.into_inner(); - - // Parse variable name - let var = next_inner_pair(&mut inner, "variable name")?.as_str().to_string(); - - // Parse the vec_literal - let vec_literal_pair = next_inner_pair(&mut inner, "vec literal")?; - let vec_literal = VecLiteralParser.parse(vec_literal_pair, ctx)?; - - // Extract elements from the VecLiteral::Vec - let elements = match vec_literal { - VecLiteral::Vec(elems) => elems, - VecLiteral::Expr(_) => { - return Err(SemanticError::new("Expected vec literal, got expression").into()); - } - }; - - Ok(Line::VecDeclaration { - var, - elements, - location: SourceLocation { - file_id: ctx.current_file_id, - line_number, - }, - }) - } -} - -/// Parser for push statements: `vec_var.push(element);` or `vec_var[i][j].push(element);` -pub struct PushStatementParser; - -impl Parse for PushStatementParser { - fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - let line_number = pair.line_col().0; - let mut inner = pair.into_inner(); - - // Parse the push_target (identifier with optional indices) - let push_target = next_inner_pair(&mut inner, "push target")?; - let mut target_inner = push_target.into_inner(); - - // First element is the vector variable name - let vector = next_inner_pair(&mut target_inner, "vector variable")? - .as_str() - .to_string(); - - // Remaining elements are index expressions - let indices: Vec = target_inner - .map(|idx_pair| ExpressionParser.parse(idx_pair, ctx)) - .collect::, _>>()?; - - // Parse the element to push (vec_element can be vec_literal or expression) - let element_pair = next_inner_pair(&mut inner, "push element")?; - let element = VecElementParser.parse(element_pair, ctx)?; - - Ok(Line::Push { - vector, - indices, - element, - location: SourceLocation { - file_id: ctx.current_file_id, - line_number, - }, - }) - } -} - -/// Parser for pop statements: `vec_var.pop();` or `vec_var[i][j].pop();` -pub struct PopStatementParser; - -impl Parse for PopStatementParser { - fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - let line_number = pair.line_col().0; - let mut inner = pair.into_inner(); - - // Parse the pop_target (identifier with optional indices) - let pop_target = next_inner_pair(&mut inner, "pop target")?; - let mut target_inner = pop_target.into_inner(); - - // First element is the vector variable name - let vector = next_inner_pair(&mut target_inner, "vector variable")? - .as_str() - .to_string(); - - // Remaining elements are index expressions - let indices: Vec = target_inner - .map(|idx_pair| ExpressionParser.parse(idx_pair, ctx)) - .collect::, _>>()?; - - Ok(Line::Pop { - vector, - indices, - location: SourceLocation { - file_id: ctx.current_file_id, - line_number, - }, - }) - } -} - /// Parser for forward declarations: `x: Imu` or `x: Mut` pub struct ForwardDeclarationParser; diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 962d333c..17e13f7a 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -124,7 +124,7 @@ fn test_all_programs() { println!("Found {} test programs", paths.len()); // Reserve a 5-cell preamble for the programs that materialize a local - // ONE_EF_PTR (program_15, program_166, program_179). + // ONE_EF_PTR (program_15, program_179). let witness = ExecutionWitness { preamble_memory_len: 5, ..ExecutionWitness::default() diff --git a/crates/lean_compiler/tests/test_data/error_17.py b/crates/lean_compiler/tests/test_data/error_17.py deleted file mode 100644 index 0c0e8381..00000000 --- a/crates/lean_compiler/tests/test_data/error_17.py +++ /dev/null @@ -1,12 +0,0 @@ -from snark_lib import * - - -# Error: push to outer-scope vector inside else branch -def main(): - v = DynArray([1, 2, 3]) # Vector in outer scope - x = 5 - if x == 5: - x = 6 - else: - v.push(4) # Error: cannot push to outer-scope vector inside if/else - return diff --git a/crates/lean_compiler/tests/test_data/error_18.py b/crates/lean_compiler/tests/test_data/error_18.py deleted file mode 100644 index 86866ef7..00000000 --- a/crates/lean_compiler/tests/test_data/error_18.py +++ /dev/null @@ -1,9 +0,0 @@ -from snark_lib import * - - -# Error: push to outer-scope vector inside non-unrolled loop -def main(): - v = DynArray([1, 2, 3]) # Vector in outer scope - for i in range(0, 5): - v.push(i) # Error: cannot push to outer-scope vector inside non-unrolled loop - return diff --git a/crates/lean_compiler/tests/test_data/error_19.py b/crates/lean_compiler/tests/test_data/error_19.py deleted file mode 100644 index 774778da..00000000 --- a/crates/lean_compiler/tests/test_data/error_19.py +++ /dev/null @@ -1,12 +0,0 @@ -from snark_lib import * - - -# Error: vector passed to non-inlined function -def main(): - v = DynArray([1, 2, 3]) - process(v) # Error: vectors cannot be passed to non-inlined functions - return - - -def process(x): - return diff --git a/crates/lean_compiler/tests/test_data/error_20.py b/crates/lean_compiler/tests/test_data/error_20.py deleted file mode 100644 index 9f089b33..00000000 --- a/crates/lean_compiler/tests/test_data/error_20.py +++ /dev/null @@ -1,9 +0,0 @@ -from snark_lib import * - - -# timing matters -def main(): - v = DynArray([]) - print(v[0]) - v.push(10) - return diff --git a/crates/lean_compiler/tests/test_data/error_21.py b/crates/lean_compiler/tests/test_data/error_21.py deleted file mode 100644 index c3ccc449..00000000 --- a/crates/lean_compiler/tests/test_data/error_21.py +++ /dev/null @@ -1,8 +0,0 @@ -from snark_lib import * - - -# Error test: pop on empty vector -def main(): - v = DynArray([]) - v.pop() - return diff --git a/crates/lean_compiler/tests/test_data/error_22.py b/crates/lean_compiler/tests/test_data/error_22.py deleted file mode 100644 index b478bcff..00000000 --- a/crates/lean_compiler/tests/test_data/error_22.py +++ /dev/null @@ -1,8 +0,0 @@ -from snark_lib import * - - -# Error test: pop on empty nested vector -def main(): - v = DynArray([DynArray([])]) - v[0].pop() - return diff --git a/crates/lean_compiler/tests/test_data/error_23.py b/crates/lean_compiler/tests/test_data/error_23.py deleted file mode 100644 index 38945b13..00000000 --- a/crates/lean_compiler/tests/test_data/error_23.py +++ /dev/null @@ -1,9 +0,0 @@ -from snark_lib import * - - -# Error test: pop from outer-scope vector in non-unroll loop -def main(): - v = DynArray([1, 2, 3]) - for i in range(0, 2): - v.pop() - return diff --git a/crates/lean_compiler/tests/test_data/program_11.py b/crates/lean_compiler/tests/test_data/program_11.py deleted file mode 100644 index e0b4f20b..00000000 --- a/crates/lean_compiler/tests/test_data/program_11.py +++ /dev/null @@ -1,13 +0,0 @@ -from snark_lib import * - -ARR = [0, 1, 2, 3, 4] - - -def main(): - vector = DynArray([]) - for i in unroll(0, len(ARR)): - k = ARR[i] - vector.push(DynArray([])) - vector[k].push(k) - assert vector[k][0] == k - return diff --git a/crates/lean_compiler/tests/test_data/program_146.py b/crates/lean_compiler/tests/test_data/program_146.py deleted file mode 100644 index a30d01dc..00000000 --- a/crates/lean_compiler/tests/test_data/program_146.py +++ /dev/null @@ -1,11 +0,0 @@ -from snark_lib import * - - -# Test basic DynArray([]) creation and indexing -def main(): - v = DynArray([1, 2, 3]) - assert v[0] == 1 - assert v[1] == 2 - assert v[2] == 3 - assert len(v) == 3 - return diff --git a/crates/lean_compiler/tests/test_data/program_147.py b/crates/lean_compiler/tests/test_data/program_147.py deleted file mode 100644 index 4eaf9f56..00000000 --- a/crates/lean_compiler/tests/test_data/program_147.py +++ /dev/null @@ -1,14 +0,0 @@ -from snark_lib import * - - -# Test .push() on vectors -def main(): - v = DynArray([1, 2, 3]) - assert len(v) == 3 - v.push(4) - assert len(v) == 4 - assert v[3] == 4 - v.push(5) - assert len(v) == 5 - assert v[4] == 5 - return diff --git a/crates/lean_compiler/tests/test_data/program_148.py b/crates/lean_compiler/tests/test_data/program_148.py deleted file mode 100644 index 1312850e..00000000 --- a/crates/lean_compiler/tests/test_data/program_148.py +++ /dev/null @@ -1,14 +0,0 @@ -from snark_lib import * - - -# Test nested vectors -def main(): - v = DynArray([DynArray([1, 2]), DynArray([3, 4, 5])]) - assert len(v) == 2 - assert len(v[0]) == 2 - assert len(v[1]) == 3 - assert v[0][0] == 1 - assert v[0][1] == 2 - assert v[1][0] == 3 - assert v[1][2] == 5 - return diff --git a/crates/lean_compiler/tests/test_data/program_149.py b/crates/lean_compiler/tests/test_data/program_149.py deleted file mode 100644 index 1831734e..00000000 --- a/crates/lean_compiler/tests/test_data/program_149.py +++ /dev/null @@ -1,16 +0,0 @@ -from snark_lib import * - - -# Test push with nested vectors -def main(): - v = DynArray([DynArray([1, 2])]) - assert len(v) == 1 - v.push(DynArray([3, 4])) - assert len(v) == 2 - assert v[1][0] == 3 - assert v[1][1] == 4 - v.push(DynArray([5, 6, 7])) - assert len(v) == 3 - assert len(v[2]) == 3 - assert v[2][2] == 7 - return diff --git a/crates/lean_compiler/tests/test_data/program_150.py b/crates/lean_compiler/tests/test_data/program_150.py deleted file mode 100644 index cd05f6f9..00000000 --- a/crates/lean_compiler/tests/test_data/program_150.py +++ /dev/null @@ -1,15 +0,0 @@ -from snark_lib import * - - -# Test vectors with unrolled loops -def main(): - v = DynArray([]) - for i in unroll(0, 5): - v.push(i * 2) - assert len(v) == 5 - assert v[0] == 0 - assert v[1] == 2 - assert v[2] == 4 - assert v[3] == 6 - assert v[4] == 8 - return diff --git a/crates/lean_compiler/tests/test_data/program_151.py b/crates/lean_compiler/tests/test_data/program_151.py deleted file mode 100644 index f9bbd26c..00000000 --- a/crates/lean_compiler/tests/test_data/program_151.py +++ /dev/null @@ -1,14 +0,0 @@ -from snark_lib import * - - -# Test vectors with expression elements -def main(): - x = 10 - y = 20 - v = DynArray([x, y, x + y, x * 2]) - assert len(v) == 4 - assert v[0] == 10 - assert v[1] == 20 - assert v[2] == 30 - assert v[3] == 20 - return diff --git a/crates/lean_compiler/tests/test_data/program_152.py b/crates/lean_compiler/tests/test_data/program_152.py deleted file mode 100644 index eb9c9fd9..00000000 --- a/crates/lean_compiler/tests/test_data/program_152.py +++ /dev/null @@ -1,17 +0,0 @@ -from snark_lib import * - - -# Test vectors with nested unrolled loops -def main(): - v = DynArray([]) - for i in unroll(0, 3): - for j in unroll(0, 2): - v.push(i * 10 + j) - assert len(v) == 6 - assert v[0] == 0 # i=0, j=0 - assert v[1] == 1 # i=0, j=1 - assert v[2] == 10 # i=1, j=0 - assert v[3] == 11 # i=1, j=1 - assert v[4] == 20 # i=2, j=0 - assert v[5] == 21 # i=2, j=1 - return diff --git a/crates/lean_compiler/tests/test_data/program_153.py b/crates/lean_compiler/tests/test_data/program_153.py deleted file mode 100644 index 664402cc..00000000 --- a/crates/lean_compiler/tests/test_data/program_153.py +++ /dev/null @@ -1,28 +0,0 @@ -from snark_lib import * - - -# Test pushing nested vectors in unrolled loop -def main(): - v = DynArray([]) - for i in unroll(0, 3): - v.push(DynArray([i, i + 1, i + 2])) - assert len(v) == 3 - assert len(v[0]) == 3 - assert len(v[1]) == 3 - assert len(v[2]) == 3 - - # v[0] = [0, 1, 2] - assert v[0][0] == 0 - assert v[0][1] == 1 - assert v[0][2] == 2 - - # v[1] = [1, 2, 3] - assert v[1][0] == 1 - assert v[1][1] == 2 - assert v[1][2] == 3 - - # v[2] = [2, 3, 4] - assert v[2][0] == 2 - assert v[2][1] == 3 - assert v[2][2] == 4 - return diff --git a/crates/lean_compiler/tests/test_data/program_154.py b/crates/lean_compiler/tests/test_data/program_154.py deleted file mode 100644 index 8cb0c7fc..00000000 --- a/crates/lean_compiler/tests/test_data/program_154.py +++ /dev/null @@ -1,12 +0,0 @@ -from snark_lib import * - - -# Test accessing vector elements inside unrolled loop -def main(): - v = DynArray([10, 20, 30, 40, 50]) - - sum: Mut = 0 - for i in unroll(0, 5): - sum = sum + v[i] - assert sum == 150 - return diff --git a/crates/lean_compiler/tests/test_data/program_155.py b/crates/lean_compiler/tests/test_data/program_155.py deleted file mode 100644 index 709840d2..00000000 --- a/crates/lean_compiler/tests/test_data/program_155.py +++ /dev/null @@ -1,17 +0,0 @@ -from snark_lib import * - - -# Test building vector and then reading in separate unrolled loops -def main(): - # Build a vector of squares - squares = DynArray([]) - for i in unroll(0, 6): - squares.push(i * i) - - # Verify in a separate loop - for i in unroll(0, 6): - assert squares[i] == i * i - - # Also check len - assert len(squares) == 6 - return diff --git a/crates/lean_compiler/tests/test_data/program_156.py b/crates/lean_compiler/tests/test_data/program_156.py deleted file mode 100644 index af062869..00000000 --- a/crates/lean_compiler/tests/test_data/program_156.py +++ /dev/null @@ -1,20 +0,0 @@ -from snark_lib import * - - -# Test vector with expression using loop variable -def main(): - # Build Fibonacci-like sequence using vector - fib = DynArray([1, 1]) - for i in unroll(2, 8): - fib.push(fib[i - 1] + fib[i - 2]) - - assert len(fib) == 8 - assert fib[0] == 1 - assert fib[1] == 1 - assert fib[2] == 2 - assert fib[3] == 3 - assert fib[4] == 5 - assert fib[5] == 8 - assert fib[6] == 13 - assert fib[7] == 21 - return diff --git a/crates/lean_compiler/tests/test_data/program_157.py b/crates/lean_compiler/tests/test_data/program_157.py deleted file mode 100644 index 2181399f..00000000 --- a/crates/lean_compiler/tests/test_data/program_157.py +++ /dev/null @@ -1,31 +0,0 @@ -from snark_lib import * - - -# Test pushing to nested vectors with indices -def main(): - # Create a vector of empty vectors - v = DynArray([DynArray([]), DynArray([]), DynArray([])]) - - # Push to nested vectors using indices - v[0].push(10) - v[0].push(20) - v[1].push(30) - v[2].push(40) - v[2].push(50) - v[2].push(60) - - # Verify structure - assert len(v) == 3 - assert len(v[0]) == 2 - assert len(v[1]) == 1 - assert len(v[2]) == 3 - - # Verify values - assert v[0][0] == 10 - assert v[0][1] == 20 - assert v[1][0] == 30 - assert v[2][0] == 40 - assert v[2][1] == 50 - assert v[2][2] == 60 - - return diff --git a/crates/lean_compiler/tests/test_data/program_158.py b/crates/lean_compiler/tests/test_data/program_158.py deleted file mode 100644 index 8f56540c..00000000 --- a/crates/lean_compiler/tests/test_data/program_158.py +++ /dev/null @@ -1,33 +0,0 @@ -from snark_lib import * -# Test pushing to nested vectors in unrolled loops - -ARR = [3, 4, 5] - - -def main(): - # Create a 3-element vector of empty vectors - rows = DynArray([DynArray([]), DynArray([]), DynArray([])]) - - # Fill each row with its index repeated - for i in unroll(0, 3): - for j in unroll(0, 3): - rows[i].push(i + len(ARR) - 1 + ARR[0] - 3 + j - 2) - - # Verify structure - assert len(rows) == 3 - assert len(rows[0]) == 3 - assert len(rows[1]) == 3 - assert len(rows[2]) == 3 - - # Verify values: rows[i][j] == i + j - assert rows[0][0] == 0 - assert rows[0][1] == 1 - assert rows[0][2] == 2 - assert rows[1][0] == 1 - assert rows[1][1] == 2 - assert rows[1][len(ARR) - 1 + ARR[0] - 3] == 3 - assert rows[2][0] == 2 - assert rows[2][1] == 3 - assert rows[2][2] == 4 - - return diff --git a/crates/lean_compiler/tests/test_data/program_159.py b/crates/lean_compiler/tests/test_data/program_159.py deleted file mode 100644 index 1d86b9d7..00000000 --- a/crates/lean_compiler/tests/test_data/program_159.py +++ /dev/null @@ -1,21 +0,0 @@ -from snark_lib import * - - -# Test: local vectors inside if/else branches are allowed -def main(): - x = 5 - if x == 5: - # Local vector in then branch - allowed - v = DynArray([1, 2, 3]) - v.push(4) - assert v[3] == 4 - assert len(v) == 4 - w = DynArray([]) - w.push(100) - assert w[0] == 100 - else: - # Different local vector in else branch - allowed (no clash, different control flow) - w = DynArray([10, 20]) - w.push(30) - assert w[2] == 30 - return diff --git a/crates/lean_compiler/tests/test_data/program_160.py b/crates/lean_compiler/tests/test_data/program_160.py deleted file mode 100644 index 1659f99d..00000000 --- a/crates/lean_compiler/tests/test_data/program_160.py +++ /dev/null @@ -1,15 +0,0 @@ -from snark_lib import * - - -# Test: local vectors inside non-unrolled loops are allowed -# This just tests that local vector creation and push works inside a loop -def main(): - for i in range(0, 3): - # Local vector created fresh each iteration - allowed - v = DynArray([1, 2, 3]) - v.push(4) - # Use the vector within the same iteration - assert v[0] == 1 - assert v[3] == 4 - assert len(v) == 4 - return diff --git a/crates/lean_compiler/tests/test_data/program_161.py b/crates/lean_compiler/tests/test_data/program_161.py deleted file mode 100644 index da7b127c..00000000 --- a/crates/lean_compiler/tests/test_data/program_161.py +++ /dev/null @@ -1,14 +0,0 @@ -from snark_lib import * - -# Test: compile-time true condition allows push to outer-scope vector in then branch -FLAG = 1 - - -def main(): - v = DynArray([1, 2, 3]) - if FLAG == 1: - v.push(4) # OK: condition is compile-time true, branch is inlined - else: - v.push(5) - assert v[3] == 4 - return diff --git a/crates/lean_compiler/tests/test_data/program_162.py b/crates/lean_compiler/tests/test_data/program_162.py deleted file mode 100644 index 74268a82..00000000 --- a/crates/lean_compiler/tests/test_data/program_162.py +++ /dev/null @@ -1,14 +0,0 @@ -from snark_lib import * - -# Test: compile-time false condition allows push to outer-scope vector in else branch -FLAG = 0 - - -def main(): - v = DynArray([1, 2, 3]) - if FLAG == 1: - v.push(4) - else: - v.push(5) # OK: condition is compile-time false, else branch is inlined - assert v[3] == 5 - return diff --git a/crates/lean_compiler/tests/test_data/program_163.py b/crates/lean_compiler/tests/test_data/program_163.py deleted file mode 100644 index 507acf7d..00000000 --- a/crates/lean_compiler/tests/test_data/program_163.py +++ /dev/null @@ -1,16 +0,0 @@ -from snark_lib import * - -# Test: compile-time condition using const array access -ARR = [0, 1, 2] - - -def main(): - v = DynArray([]) - v.push(10) - if ARR[1] == 1: - v.push(20) # OK: ARR[1] == 1 is compile-time true - else: - v.push(30) - assert v[0] == 10 - assert v[1] == 20 - return diff --git a/crates/lean_compiler/tests/test_data/program_164.py b/crates/lean_compiler/tests/test_data/program_164.py deleted file mode 100644 index 140275e8..00000000 --- a/crates/lean_compiler/tests/test_data/program_164.py +++ /dev/null @@ -1,18 +0,0 @@ -from snark_lib import * - -# Test: nested compile-time conditions -A = 1 -B = 2 - - -def main(): - v = DynArray([]) - if A == 1: - if B == 2: - v.push(100) # OK: both conditions are compile-time true - else: - v.push(200) - else: - v.push(300) - assert v[0] == 100 - return diff --git a/crates/lean_compiler/tests/test_data/program_165.py b/crates/lean_compiler/tests/test_data/program_165.py deleted file mode 100644 index ef5bc8d7..00000000 --- a/crates/lean_compiler/tests/test_data/program_165.py +++ /dev/null @@ -1,98 +0,0 @@ -from snark_lib import * -# Comprehensive test: nested unrolled loops, vectors with pushes in various scopes - - -def main(): - # === PART 1: Basic nested loops over 2D vector === - - outer = DynArray([DynArray([1, 2]), DynArray([10, 20, 30])]) - - total: Mut = 0 - for i in unroll(0, len(outer)): - row_sum: Mut = 0 - for j in unroll(0, len(outer[i])): - row_sum = row_sum + outer[i][j] - total = total + row_sum - assert total == 63 - - # === PART 2: Push new row to outer, iterate again === - - outer.push(DynArray([100, 200])) - assert len(outer) == 3 - - total: Mut = 0 - for i in unroll(0, len(outer)): - row_sum: Mut = 0 - for j in unroll(0, len(outer[i])): - row_sum = row_sum + outer[i][j] - total = total + row_sum - assert total == 363 - - # === PART 3: Multiple vectors cross product === - - v1 = DynArray([1, 2, 3]) - v2 = DynArray([10, 20]) - - cross_sum: Mut = 0 - for i in unroll(0, len(v1)): - for j in unroll(0, len(v2)): - cross_sum = cross_sum + v1[i] * v2[j] - assert cross_sum == 180 - - v2.push(30) - cross_sum: Mut = 0 - for i in unroll(0, len(v1)): - for j in unroll(0, len(v2)): - cross_sum = cross_sum + v1[i] * v2[j] - assert cross_sum == 360 - - # === PART 4: Accumulator reused without reset === - - data = DynArray([5, 10, 15, 20]) - - acc: Mut = 0 - for i in unroll(0, len(data)): - acc = acc + data[i] - assert acc == 50 - - for i in unroll(0, len(data)): - acc = acc + data[i] * data[i] - assert acc == 800 - - # === PART 5: if inside unrolled loop (compile-time condition) === - - data2 = DynArray([1, 2, 3, 4]) - acc2: Mut = 0 - for i in unroll(0, len(data2)): - acc2 = acc2 + data2[i] - if i == 2: - acc2 = acc2 * 2 - assert acc2 == 16 - - assert inlined() == 5 - - return - - -def inlined(): - v = DynArray([1, 2, 3]) - sum: Mut = 0 - for i in unroll(0, len(v)): - sum = sum + v[i] - debug_assert(sum == 6) - v.push(4) - assert len(v) == 4 - sum: Mut = 0 - for i in unroll(0, len(v)): - sum += v[i] - assert sum == 10 - w = DynArray([]) - for i in unroll(0, 5): - w.push(DynArray([])) - for j in unroll(0, i): - w[i].push(1) - sum: Mut = 0 - for j in unroll(0, len(w[i])): - sum += w[i][j] - assert sum == i - return len(w) diff --git a/crates/lean_compiler/tests/test_data/program_166.py b/crates/lean_compiler/tests/test_data/program_166.py deleted file mode 100644 index 5d653fc3..00000000 --- a/crates/lean_compiler/tests/test_data/program_166.py +++ /dev/null @@ -1,51 +0,0 @@ -from snark_lib import * - -DIM = 5 -ONE_EF_PTR = 1 # right after the (empty-public-input) zero-padded cell at memory[0] - - -def main(): - init_one_ef() - v = DynArray([1, 2, 3]) - sum1: Mut = 0 - for i in unroll(0, len(v)): - sum1 = sum1 + v[i] - assert sum1 == 6 - v.push(4) - assert len(v) == 4 - sum2: Mut = 0 - for i in unroll(0, len(v)): - sum2 = sum2 + v[i] - assert sum2 == 10 - # Test nested vectors with len(w[i]) - w = DynArray([]) - for i in unroll(0, 5): - w.push(DynArray([])) - for j in unroll(0, i): - w[i].push(1) - assert len(w[i]) == i - assert len(w) == 5 - a = Array(DIM) - for i in unroll(0, DIM): - a[i] = 1 - w.push(DynArray([a])) - b = Array(DIM) - copy_5(w[5][0], b) - return - - -@inline -def copy_5(a, b): - dot_product_ee(a, ONE_EF_PTR, b) - return - - -@inline -def init_one_ef(): - one_ef = ONE_EF_PTR - one_ef[0] = 1 - one_ef[1] = 0 - one_ef[2] = 0 - one_ef[3] = 0 - one_ef[4] = 0 - return diff --git a/crates/lean_compiler/tests/test_data/program_168.py b/crates/lean_compiler/tests/test_data/program_168.py deleted file mode 100644 index a529c6a0..00000000 --- a/crates/lean_compiler/tests/test_data/program_168.py +++ /dev/null @@ -1,86 +0,0 @@ -from snark_lib import * -# Comprehensive test for vector.pop() - - -def main(): - # Basic pop on simple vector - v1 = DynArray([1, 2, 3, 4, 5]) - assert len(v1) == 5 - v1.pop() - assert len(v1) == 4 - v1.pop() - v1.pop() - assert len(v1) == 2 - # v1 should now be [1, 2] - assert v1[0] == 1 - assert v1[1] == 2 - - """ - multi line - comment - """ - - # Pop in unrolled loop - v2 = DynArray([10, 20, 30, 40, 50]) - for i in unroll(0, 3): - v2.pop() - assert len(v2) == 2 - assert v2[0] == 10 - assert v2[1] == 20 - - # Pop from nested vector - matrix = DynArray([DynArray([1, 2, 3]), DynArray([4, 5, 6, 7]), DynArray([8, 9])]) - assert len(matrix[0]) == 3 - assert len(matrix[1]) == 4 - matrix[1].pop() - assert len(matrix[1]) == 3 - matrix[0].pop() - matrix[0].pop() - assert len(matrix[0]) == 1 - assert matrix[0][0] == 1 - assert matrix[1][0] == 4 - assert matrix[1][1] == 5 - assert matrix[1][2] == 6 - - # Pop outer vector element - matrix.pop() - assert len(matrix) == 2 - - # Mix push and pop - v3 = DynArray([100]) - v3.push(200) - v3.push(300) - assert len(v3) == 3 - v3.pop() - assert len(v3) == 2 - v3.push(400) - assert len(v3) == 3 - assert v3[0] == 100 - assert v3[1] == 200 - assert v3[2] == 400 - - # Pop until one element remains - v4 = DynArray([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) - for i in unroll(0, 9): - v4.pop() - assert len(v4) == 1 - assert v4[0] == 1 - - # Pop on nested vector with index expression - nested = DynArray([DynArray([DynArray([1, 2, 3])])]) - nested[0][0].pop() - assert len(nested[0][0]) == 2 - - # Build vector, pop some, then iterate - v5 = DynArray([]) - for i in unroll(0, 5): - v5.push(i * 10) - v5.pop() - v5.pop() - sum: Mut = 0 - for i in unroll(0, len(v5)): - sum = sum + v5[i] - # v5 = [0, 10, 20], sum = 30 - assert sum == 30 - - return diff --git a/crates/lean_compiler/tests/test_data/program_170.py b/crates/lean_compiler/tests/test_data/program_170.py index 8af7976f..96c00012 100644 --- a/crates/lean_compiler/tests/test_data/program_170.py +++ b/crates/lean_compiler/tests/test_data/program_170.py @@ -25,11 +25,6 @@ def main(): result = add_four(1, 2, 3, 4) assert result == 10 - arr = DynArray([1, 2, 3]) - assert arr[0] == 1 - assert arr[1] == 2 - assert arr[2] == 3 - nested = add_four(1, add_four(10, 20, 30, 40), 2, 3) assert nested == 106 diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 218be263..5ea8d46a 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -14,7 +14,7 @@ def helper(): # other functions (optional) ... ``` -The `from snark_lib import *` line imports Python definitions for zkDSL primitives (Array, DynArray, Mut, Const, etc.), allowing `.py` files to be executed as normal Python scripts for testing. The zkDSL compiler ignores this import line. +The `from snark_lib import *` line imports Python definitions for zkDSL primitives (Array, Mut, Const, etc.), allowing `.py` files to be executed as normal Python scripts for testing. The zkDSL compiler ignores this import line. To run zkDSL files as Python scripts, run from the file's directory with PYTHONPATH pointing to the lean_compiler crate (for snark_lib.py): ```bash @@ -197,65 +197,6 @@ arr[0] = 20 # ERROR: different value at same location Use `mut` variables when you need mutability, the compiler cannot handle mutability on hand-written allocated memory ("Array(...)"). -## DynArray (Compile-Time Dynamic Arrays) - -DynArrays are compile-time constructs for building dynamic arrays. Unlike `Array`, DynArrays track structure at compile time—each element gets its own memory slot. - -``` -v = DynArray([1, 2, 3]) # create dynamic array -v.push(4) # append element -v.pop() # remove last element (does not return it) -x = v[2] # access (index must be compile-time constant) -n = len(v) # get length -``` - -### Nested DynArrays - -``` -matrix = DynArray([DynArray([1, 2]), DynArray([3, 4, 5])]) -matrix[1].push(6) # push to inner array -matrix[0].pop() # pop from inner array -x = matrix[0][0] # x = 1 -n = len(matrix[1]) # n = 4 -``` - -### Building DynArrays in Loops - -Use `unroll` loops to build arrays dynamically: - -``` -v = DynArray([]) -for i in unroll(0, 5): - v.push(i * i) # v = [0, 1, 4, 9, 16] -``` - -### Restrictions - -DynArrays are compile-time only. The compiler must know the exact structure at every point: - -1. **Indices must be compile-time constants** (literals or unroll loop variables) -2. **Push/pop to outer-scope arrays forbidden** inside `if/else`, `match`, or non-unrolled loops -3. **DynArrays cannot be passed to non-inlined functions** -4. **Pop on empty array is a compile error** - -``` -# OK: local array in branch -if cond == 1: - v = DynArray([1, 2]) - v.push(3) - -# ERROR: push to outer-scope array in branch -v = DynArray([1, 2]) -if cond == 1: - v.push(3) # compile error - -# OK: same variable name in different branches -if cond == 1: - v = DynArray([1]) -else: - v = DynArray([2, 3]) # different structure, but only one executes -``` - ## Control Flow ### If/Else @@ -580,11 +521,11 @@ result = function_call( arg3 ) -arr = DynArray([ +ARR = [ 1, 2, - 3 -]) + 3, +] ``` ### Explicit continuation with backslash From e77b6fe62f2963a79a03c0264f42f5277b7830fe Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 25 May 2026 17:28:09 +0400 Subject: [PATCH 3/3] comments --- crates/lean_compiler/src/parser/parsers/expression.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/lean_compiler/src/parser/parsers/expression.rs b/crates/lean_compiler/src/parser/parsers/expression.rs index 004834ca..776ef5fa 100644 --- a/crates/lean_compiler/src/parser/parsers/expression.rs +++ b/crates/lean_compiler/src/parser/parsers/expression.rs @@ -148,7 +148,7 @@ impl Parse for ArrayAccessParser { } } -/// Parser for len() expressions on const arrays and vectors (supports indexed access like len(ARR[i])). +/// Parser for len() expressions on const arrays (supports indexed access like len(ARR[i])). pub struct LenParser; impl Parse for LenParser { @@ -214,7 +214,7 @@ impl Parse for LenParser { } } - // Defer evaluation for non-const arrays (could be vectors) or non-const indices + // Defer evaluation when indices aren't all const at parse time (e.g., `len(ARR[i])` inside an unroll loop). Ok(Expression::Len { array: ident, indices: index_exprs,