diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index f7f43c61f..5df9eedc5 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -1639,7 +1639,6 @@ fn handle_inlined_functions_helper( } fn handle_const_arguments(program: &mut Program) { - // TODO this doesnt suupport const functions calling other const functions let mut new_functions = BTreeMap::::new(); let constant_functions = program .functions @@ -1648,9 +1647,45 @@ fn handle_const_arguments(program: &mut Program) { .map(|(name, func)| (name.clone(), func.clone())) .collect::>(); + // First pass: process non-const functions that call const functions for func in program.functions.values_mut() { - handle_const_arguments_helper(&mut func.body, &constant_functions, &mut new_functions); + if !func.has_const_arguments() { + handle_const_arguments_helper(&mut func.body, &constant_functions, &mut new_functions); + } } + + // Process newly created const functions recursively until no more changes + let mut changed = true; + while changed { + changed = false; + let mut additional_functions = BTreeMap::new(); + + // Collect all function names to process + let function_names: Vec = new_functions.keys().cloned().collect(); + + for name in function_names { + if let Some(func) = new_functions.get_mut(&name) { + let initial_count = additional_functions.len(); + handle_const_arguments_helper( + &mut func.body, + &constant_functions, + &mut additional_functions, + ); + if additional_functions.len() > initial_count { + changed = true; + } + } + } + + // Add any newly discovered functions + for (name, func) in additional_functions { + if let std::collections::btree_map::Entry::Vacant(e) = new_functions.entry(name) { + e.insert(func); + changed = true; + } + } + } + for (name, func) in new_functions { assert!(!program.functions.contains_key(&name),); program.functions.insert(name, func); diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 6743c0954..a96c202a7 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -441,3 +441,26 @@ fn test_match() { // "#; // compile_and_run(program, &[], &[]); // } + +#[test] +fn test_const_functions_calling_const_functions() { + // Test that const functions can call other const functions + let program = r#" + fn main() { + y = compute_value(3); + print(y); + return; + } + + fn compute_value(const n) -> 1 { + result = complex_computation(n, 5); + return result; + } + + fn complex_computation(const a, const b) -> 1 { + return a * a + b * b; + } + "#; + + compile_and_run(program, &[], &[], false); +}