diff --git a/tools/hermes/src/desugar.rs b/tools/hermes/src/desugar.rs index 441a7bcd9c..fb5bc6eb5a 100644 --- a/tools/hermes/src/desugar.rs +++ b/tools/hermes/src/desugar.rs @@ -211,6 +211,16 @@ fn generate_body( // 4. User Logic out.push_str(&format!("({})", logic)); + // 5. Automatic Validity Invariants + // For every binder we exposed (ret, x_final, etc.), we enforce logical validity. + // "Component 4 ... Return Value Injection ... Mutable Borrow Injection" + // "Action: Append the validity check ... Result: ... /\ Verifiable.is_valid ret" + for binder in binders { + if binder != "_" { + out.push_str(&format!(" /\\ Verifiable.is_valid {}", binder)); + } + } + Ok(out) } diff --git a/tools/hermes/src/parser.rs b/tools/hermes/src/parser.rs index cc9b006da9..0ba97826c8 100644 --- a/tools/hermes/src/parser.rs +++ b/tools/hermes/src/parser.rs @@ -12,6 +12,8 @@ use syn::{ visit::{self, Visit}, }; +// ... (imports remain) + /// Represents a function parsed from the source code, including its signature and attached specs. #[derive(Debug, Clone)] pub struct ParsedFunction { @@ -21,18 +23,28 @@ pub struct ParsedFunction { pub is_model: bool, } +/// Represents a struct parsed from the source code, including its invariant. +#[derive(Debug, Clone)] +pub struct ParsedStruct { + pub ident: syn::Ident, + pub generics: syn::Generics, + pub invariant: Option, +} + pub struct ExtractedBlocks { pub functions: Vec, + pub structs: Vec, } struct SpecVisitor { functions: Vec, + structs: Vec, errors: Vec, } impl SpecVisitor { fn new() -> Self { - Self { functions: Vec::new(), errors: Vec::new() } + Self { functions: Vec::new(), structs: Vec::new(), errors: Vec::new() } } fn check_attrs_for_misplaced_spec(&mut self, attrs: &[Attribute], item_kind: &str) { @@ -40,7 +52,7 @@ impl SpecVisitor { if let Some(doc_str) = parse_doc_attr(attr) { if doc_str.trim_start().starts_with("@") { self.errors.push(anyhow::anyhow!( - "Found `///@` spec usage on a {}, but it is only allowed on functions.", + "Found `///@` spec usage on a {}, but it is only allowed on functions or structs.", item_kind )); } @@ -110,7 +122,89 @@ impl<'ast> Visit<'ast> for SpecVisitor { } fn visit_item_struct(&mut self, node: &'ast syn::ItemStruct) { - self.check_attrs_for_misplaced_spec(&node.attrs, "struct"); + let mut invariant_lines = Vec::new(); + let mut current_mode = None; // None, Some("invariant") + + for attr in &node.attrs { + if let Some(doc_str) = parse_doc_attr(attr) { + let trimmed = doc_str.trim(); + if trimmed.starts_with('@') { + if let Some(content) = trimmed.strip_prefix("@ lean invariant") { + current_mode = Some("invariant"); + let mut content = content.trim(); + // Ignore if it's just the struct name or empty + // referencing node.ident + if content == node.ident.to_string() { + content = ""; + } + + // Strip "is_valid self :=" or "is_valid :=" + if let Some(rest) = content.strip_prefix("is_valid") { + let rest = rest.trim(); + if let Some(rest) = rest.strip_prefix("self") { + let rest = rest.trim(); + if let Some(rest) = rest.strip_prefix(":=") { + content = rest.trim(); + } + } else if let Some(rest) = rest.strip_prefix(":=") { + content = rest.trim(); + } + } + + if !content.is_empty() { + invariant_lines.push(content.to_string()); + } + } else { + match current_mode { + Some("invariant") => { + let content = &trimmed[1..]; + invariant_lines.push(content.to_string()); + } + None => { + // Only error if it looks like a spec attempt? + // For now, we update check_attrs_for_misplaced_spec to strictly call out non-struct/fn + // But here we just ignore or could error. + // Let's rely on the fact that if we didn't handle it here, it might be misplaced if we didn't check. + // Actually, we should probably support it. + self.errors.push(anyhow::anyhow!("Found `///@` line without preceding `lean invariant` on struct '{}'", node.ident)); + } + _ => {} + } + } + } + } + } + + let invariant = if !invariant_lines.is_empty() { + let mut full_inv = invariant_lines.join("\n").trim().to_string(); + // Strip "is_valid self :=" or "is_valid :=" + if let Some(rest) = full_inv.strip_prefix("is_valid") { + let rest = rest.trim(); + if let Some(rest) = rest.strip_prefix("self") { + let rest = rest.trim(); + if let Some(rest) = rest.strip_prefix(":=") { + full_inv = rest.trim().to_string(); + } + } else if let Some(rest) = rest.strip_prefix(":=") { + full_inv = rest.trim().to_string(); + } + } + Some(full_inv) + } else { + None + }; + + // We always collect structs now because we need to generate Verifiable instances for ALL structs + // Ensure we don't add duplicate structs if for some reason we visit twice (unlikely but safe) + // Checking by ident is enough for this context + if !self.structs.iter().any(|s| s.ident == node.ident) { + self.structs.push(ParsedStruct { + ident: node.ident.clone(), + generics: node.generics.clone(), + invariant, + }); + } + visit::visit_item_struct(self, node); } @@ -164,7 +258,7 @@ pub fn extract_blocks(content: &str) -> Result { bail!("Spec extraction failed:\n{}", msg); } - Ok(ExtractedBlocks { functions: visitor.functions }) + Ok(ExtractedBlocks { functions: visitor.functions, structs: visitor.structs }) } #[cfg(test)] diff --git a/tools/hermes/src/pipeline.rs b/tools/hermes/src/pipeline.rs index 82fdd199e1..644087bfa5 100644 --- a/tools/hermes/src/pipeline.rs +++ b/tools/hermes/src/pipeline.rs @@ -171,12 +171,14 @@ fn stitch_user_proofs( sorry_mode: Sorry, ) -> Result<()> { let mut all_functions = Vec::new(); + let mut all_structs = Vec::new(); if let Some(path) = source_file { if path.exists() { let content = fs::read_to_string(path)?; let extracted = extract_blocks(&content)?; all_functions.extend(extracted.functions); + all_structs.extend(extracted.structs); } } else { let src_dir = crate_root.join("src"); @@ -187,12 +189,13 @@ fn stitch_user_proofs( let content = fs::read_to_string(entry.path())?; let extracted = extract_blocks(&content)?; all_functions.extend(extracted.functions); + all_structs.extend(extracted.structs); } } } } - generate_lean_file(dest, crate_name_snake, crate_name_camel, &all_functions, sorry_mode) + generate_lean_file(dest, crate_name_snake, crate_name_camel, &all_functions, &all_structs, sorry_mode) } fn generate_lean_file( @@ -200,16 +203,43 @@ fn generate_lean_file( namespace_name: &str, import_name: &str, functions: &[ParsedFunction], + structs: &[crate::parser::ParsedStruct], sorry_mode: Sorry, ) -> Result<()> { let mut content = String::new(); content.push_str(&format!("import {}\n", import_name)); content.push_str("import Aeneas\n"); - content.push_str("open Aeneas Aeneas.Std Result Error\n\n"); + content.push_str("open Aeneas Aeneas.Std Result Error\n"); + content.push_str("set_option linter.unusedVariables false\n\n"); content.push_str(&format!("namespace {}\n\n", namespace_name)); - // Inject OfNat instances to support numeric literals in specs (e.g. `x > 0`) - // We use wrapping construction (BitVec.ofNat) to avoid needing in-bounds proofs. + // Inject Prelude: Verifiable Class and Primitives + content.push_str( + " +class Verifiable (α : Type) where + is_valid : α -> Prop + +attribute [simp] Verifiable.is_valid + +instance : Verifiable U8 where is_valid _ := True +instance : Verifiable U16 where is_valid _ := True +instance : Verifiable U32 where is_valid _ := True +instance : Verifiable U64 where is_valid _ := True +instance : Verifiable U128 where is_valid _ := True +instance : Verifiable I8 where is_valid _ := True +instance : Verifiable I16 where is_valid _ := True +instance : Verifiable I32 where is_valid _ := True +instance : Verifiable I64 where is_valid _ := True +instance : Verifiable I128 where is_valid _ := True +instance : Verifiable Usize where is_valid _ := True +instance : Verifiable Isize where is_valid _ := True +instance : Verifiable Bool where is_valid _ := True +instance : Verifiable Unit where is_valid _ := True + +" + ); + + // Inject OfNat instances content.push_str( " instance : OfNat U32 n where ofNat := UScalar.mk (BitVec.ofNat 32 n) @@ -217,9 +247,51 @@ instance : OfNat I32 n where ofNat := IScalar.mk (BitVec.ofNat 32 n) instance : OfNat Usize n where ofNat := UScalar.mk (BitVec.ofNat System.Platform.numBits n) instance : OfNat Isize n where ofNat := IScalar.mk (BitVec.ofNat System.Platform.numBits n) -", +" ); + // Struct Instances + // Dedup structs just in case + let mut unique_structs = Vec::new(); + let mut seen_structs = std::collections::HashSet::new(); + for st in structs { + if seen_structs.insert(st.ident.to_string()) { + unique_structs.push(st); + } + } + + for st in unique_structs { + let name = &st.ident; + let mut invariant = st.invariant.as_deref().unwrap_or("True"); + if invariant.is_empty() { + invariant = "True"; + } + + // Handle Generics: [Verifiable T] for each T + let mut generic_params = String::new(); + let mut generic_constraints = String::new(); + let mut type_args = String::new(); + + for param in &st.generics.params { + if let syn::GenericParam::Type(t) = param { + generic_params.push_str(&format!("{{{}}} ", t.ident)); + generic_constraints.push_str(&format!("[Verifiable {}] ", t.ident)); + type_args.push_str(&format!("{} ", t.ident)); + } + } + + // Format: instance {T} [Verifiable T] : Verifiable (Wrapper T) where + let type_str = if type_args.is_empty() { + name.to_string() + } else { + format!("({} {})", name, type_args.trim()) + }; + + let header = format!("instance {}{} : Verifiable {} where", generic_params, generic_constraints, type_str); + content.push_str(&header); + content.push_str(&format!("\n is_valid self := {}\n\n", invariant)); + } + for func in functions { let spec_content = match &func.spec { Some(s) => s, @@ -259,6 +331,22 @@ instance : OfNat Isize n where ofNat := IScalar.mk (BitVec.ofNat System.Platform if let Some(args) = desugared.signature_args { signature_parts.push(args); } + + // INJECT ARGUMENT VALIDITY CHECKS + // For each arg `x : T`, inject `(h_x : Verifiable.is_valid x)` + // We need to parse inputs to get names. + for arg in &inputs { + if let syn::FnArg::Typed(pat_type) = arg { + if let syn::Pat::Ident(pat_ident) = &*pat_type.pat { + let name = &pat_ident.ident; + // We assume the type is verifiable. + // The signature args in `desugared.signature_args` already listed them as `(x : T)`. + // We just append validity hypotheses. + // Note: This relies on `x` being available in scope, which it is in the signature. + signature_parts.push(format!("(h_{}_valid : Verifiable.is_valid {})", name, name)); + } + } + } for req in desugared.extra_args { signature_parts.push(req); diff --git a/tools/hermes/src/translator.rs b/tools/hermes/src/translator.rs index 7019ee3490..210e57840b 100644 --- a/tools/hermes/src/translator.rs +++ b/tools/hermes/src/translator.rs @@ -10,6 +10,8 @@ impl SignatureTranslator { if let GenericParam::Type(type_param) = param { let name = &type_param.ident; context.push_str(&format!("{{{name} : Type}} ")); + // Inject Verifiable constraint + context.push_str(&format!("[inst{name}Verifiable : Verifiable {name}] ")); } } diff --git a/tools/hermes/tests/cases/success/generic_id.rs b/tools/hermes/tests/cases/success/generic_id.rs index 1cdf006435..a77f3ec7a6 100644 --- a/tools/hermes/tests/cases/success/generic_id.rs +++ b/tools/hermes/tests/cases/success/generic_id.rs @@ -11,7 +11,7 @@ ///@ lean spec id (x : T) ///@ ensures |ret| ret = x ///@ proof -///@ simp [id] +///@ simp_all [id] pub fn id(x: T) -> T { x } diff --git a/tools/hermes/tests/cases/success/invariant.rs b/tools/hermes/tests/cases/success/invariant.rs new file mode 100644 index 0000000000..0051184db2 --- /dev/null +++ b/tools/hermes/tests/cases/success/invariant.rs @@ -0,0 +1,48 @@ +///@ lean invariant MyStruct +///@ is_valid self := self.val < 100 +pub struct MyStruct { + pub val: u32, +} + +///@ lean invariant Wrapper +///@ is_valid self := self.inner.val > 0 +pub struct Wrapper { + pub inner: MyStruct, + pub data: T, +} + +///@ lean spec use_invariant (s : MyStruct) +///@ ensures |ret| ret = s.val /\ ret < 100 +///@ proof +///@ simp_all [use_invariant] +pub fn use_invariant(s: MyStruct) -> u32 { + s.val +} + +///@ lean spec generic_invariant (w : Wrapper U32) +///@ ensures |ret| ret = w.inner.val /\ ret > 0 +///@ proof +///@ simp_all [generic_invariant] +pub fn generic_invariant(w: Wrapper) -> u32 { + w.inner.val +} + +///@ lean spec make_mystruct (val : U32) +///@ requires h : val < 100 +///@ ensures |ret| ret.val = val +///@ proof +///@ simp_all [make_mystruct] +pub fn make_mystruct(val: u32) -> MyStruct { + MyStruct { val } +} + +///@ lean spec make_wrapper (inner : MyStruct) (data : U32) +///@ requires h : inner.val > 0 +///@ ensures |ret| ret.inner = inner /\ ret.data = data +///@ proof +///@ simp_all [make_wrapper] +pub fn make_wrapper(inner: MyStruct, data: u32) -> Wrapper { + Wrapper { inner, data } +} + +fn main() {} diff --git a/tools/hermes/tests/cases/success/shadow_model.rs b/tools/hermes/tests/cases/success/shadow_model.rs index 7f8942dc16..ed76d25ae0 100644 --- a/tools/hermes/tests/cases/success/shadow_model.rs +++ b/tools/hermes/tests/cases/success/shadow_model.rs @@ -9,7 +9,7 @@ pub unsafe fn safe_div(a: u32, b: u32) -> u32 { ///@ ensures |ret| ret.val = a.val ///@ proof ///@ rw [wrapper] -///@ have ⟨ ret, h ⟩ := safe_div_spec a 1#u32 (by native_decide) +///@ have ⟨ ret, h ⟩ := safe_div_spec a 1#u32 (by simp) (by simp) (by native_decide) ///@ simp [h.1] ///@ simp_all [Nat.div_one, h.2] pub fn wrapper(a: u32) -> u32 {