Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tools/hermes/src/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,16 @@ fn generate_body(
// 4. User Logic
out.push_str(&format!("({})", logic));

// 5. Automatic Validity Invariants
// For every binder we exposed (ret, x_final, etc.), we enforce logical validity.
// "Component 4 ... Return Value Injection ... Mutable Borrow Injection"
// "Action: Append the validity check ... Result: ... /\ Verifiable.is_valid ret"
for binder in binders {
if binder != "_" {
out.push_str(&format!(" /\\ Verifiable.is_valid {}", binder));
}
}
Comment on lines +218 to +222
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling format! inside a loop to build a string can be inefficient due to repeated memory allocations. A more idiomatic and performant approach is to use std::fmt::Write's write! macro, which appends to the string without creating intermediate String objects.

You'll need to add use std::fmt::Write; at the top of the file.

Suggested change
for binder in binders {
if binder != "_" {
out.push_str(&format!(" /\\ Verifiable.is_valid {}", binder));
}
}
for binder in binders {
if binder != "_" {
// The write! macro is more efficient here.
write!(out, " /\\ Verifiable.is_valid {}", binder).unwrap();
}
}


Ok(out)
}

Expand Down
102 changes: 98 additions & 4 deletions tools/hermes/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use syn::{
visit::{self, Visit},
};

// ... (imports remain)

/// Represents a function parsed from the source code, including its signature and attached specs.
#[derive(Debug, Clone)]
pub struct ParsedFunction {
Expand All @@ -21,26 +23,36 @@ pub struct ParsedFunction {
pub is_model: bool,
}

/// Represents a struct parsed from the source code, including its invariant.
#[derive(Debug, Clone)]
pub struct ParsedStruct {
pub ident: syn::Ident,
pub generics: syn::Generics,
pub invariant: Option<String>,
}

pub struct ExtractedBlocks {
pub functions: Vec<ParsedFunction>,
pub structs: Vec<ParsedStruct>,
}

struct SpecVisitor {
functions: Vec<ParsedFunction>,
structs: Vec<ParsedStruct>,
errors: Vec<anyhow::Error>,
}

impl SpecVisitor {
fn new() -> Self {
Self { functions: Vec::new(), errors: Vec::new() }
Self { functions: Vec::new(), structs: Vec::new(), errors: Vec::new() }
}

fn check_attrs_for_misplaced_spec(&mut self, attrs: &[Attribute], item_kind: &str) {
for attr in attrs {
if let Some(doc_str) = parse_doc_attr(attr) {
if doc_str.trim_start().starts_with("@") {
self.errors.push(anyhow::anyhow!(
"Found `///@` spec usage on a {}, but it is only allowed on functions.",
"Found `///@` spec usage on a {}, but it is only allowed on functions or structs.",
item_kind
));
}
Expand Down Expand Up @@ -110,7 +122,89 @@ impl<'ast> Visit<'ast> for SpecVisitor {
}

fn visit_item_struct(&mut self, node: &'ast syn::ItemStruct) {
self.check_attrs_for_misplaced_spec(&node.attrs, "struct");
let mut invariant_lines = Vec::new();
let mut current_mode = None; // None, Some("invariant")

for attr in &node.attrs {
if let Some(doc_str) = parse_doc_attr(attr) {
let trimmed = doc_str.trim();
if trimmed.starts_with('@') {
if let Some(content) = trimmed.strip_prefix("@ lean invariant") {
current_mode = Some("invariant");
let mut content = content.trim();
// Ignore if it's just the struct name or empty
// referencing node.ident
if content == node.ident.to_string() {
content = "";
}

// Strip "is_valid self :=" or "is_valid :="
if let Some(rest) = content.strip_prefix("is_valid") {
let rest = rest.trim();
if let Some(rest) = rest.strip_prefix("self") {
let rest = rest.trim();
if let Some(rest) = rest.strip_prefix(":=") {
content = rest.trim();
}
} else if let Some(rest) = rest.strip_prefix(":=") {
content = rest.trim();
}
}

if !content.is_empty() {
invariant_lines.push(content.to_string());
}
} else {
match current_mode {
Some("invariant") => {
let content = &trimmed[1..];
invariant_lines.push(content.to_string());
}
None => {
// Only error if it looks like a spec attempt?
// For now, we update check_attrs_for_misplaced_spec to strictly call out non-struct/fn
// But here we just ignore or could error.
// Let's rely on the fact that if we didn't handle it here, it might be misplaced if we didn't check.
// Actually, we should probably support it.
self.errors.push(anyhow::anyhow!("Found `///@` line without preceding `lean invariant` on struct '{}'", node.ident));
}
_ => {}
}
}
}
}
}

let invariant = if !invariant_lines.is_empty() {
let mut full_inv = invariant_lines.join("\n").trim().to_string();
// Strip "is_valid self :=" or "is_valid :="
if let Some(rest) = full_inv.strip_prefix("is_valid") {
let rest = rest.trim();
if let Some(rest) = rest.strip_prefix("self") {
let rest = rest.trim();
if let Some(rest) = rest.strip_prefix(":=") {
full_inv = rest.trim().to_string();
}
} else if let Some(rest) = rest.strip_prefix(":=") {
full_inv = rest.trim().to_string();
}
}
Some(full_inv)
} else {
None
};
Comment on lines +178 to +195
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for stripping prefixes from the joined invariant string is very similar to the logic on lines 141-152, which operates on the first line of an invariant. This duplication can be avoided by extracting the logic into a helper function. This would improve maintainability and reduce the chance of bugs if the logic needs to be updated in the future.

A single helper function could be used in both places to keep the parsing logic consistent and DRY (Don't Repeat Yourself).


// We always collect structs now because we need to generate Verifiable instances for ALL structs
// Ensure we don't add duplicate structs if for some reason we visit twice (unlikely but safe)
// Checking by ident is enough for this context
if !self.structs.iter().any(|s| s.ident == node.ident) {
self.structs.push(ParsedStruct {
ident: node.ident.clone(),
generics: node.generics.clone(),
invariant,
});
}

visit::visit_item_struct(self, node);
}

Expand Down Expand Up @@ -164,7 +258,7 @@ pub fn extract_blocks(content: &str) -> Result<ExtractedBlocks> {
bail!("Spec extraction failed:\n{}", msg);
}

Ok(ExtractedBlocks { functions: visitor.functions })
Ok(ExtractedBlocks { functions: visitor.functions, structs: visitor.structs })
}

#[cfg(test)]
Expand Down
98 changes: 93 additions & 5 deletions tools/hermes/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,14 @@ fn stitch_user_proofs(
sorry_mode: Sorry,
) -> Result<()> {
let mut all_functions = Vec::new();
let mut all_structs = Vec::new();

if let Some(path) = source_file {
if path.exists() {
let content = fs::read_to_string(path)?;
let extracted = extract_blocks(&content)?;
all_functions.extend(extracted.functions);
all_structs.extend(extracted.structs);
}
} else {
let src_dir = crate_root.join("src");
Expand All @@ -187,39 +189,109 @@ fn stitch_user_proofs(
let content = fs::read_to_string(entry.path())?;
let extracted = extract_blocks(&content)?;
all_functions.extend(extracted.functions);
all_structs.extend(extracted.structs);
}
}
}
}

generate_lean_file(dest, crate_name_snake, crate_name_camel, &all_functions, sorry_mode)
generate_lean_file(dest, crate_name_snake, crate_name_camel, &all_functions, &all_structs, sorry_mode)
}

fn generate_lean_file(
dest: &Path,
namespace_name: &str,
import_name: &str,
functions: &[ParsedFunction],
structs: &[crate::parser::ParsedStruct],
sorry_mode: Sorry,
) -> Result<()> {
let mut content = String::new();
content.push_str(&format!("import {}\n", import_name));
content.push_str("import Aeneas\n");
content.push_str("open Aeneas Aeneas.Std Result Error\n\n");
content.push_str("open Aeneas Aeneas.Std Result Error\n");
content.push_str("set_option linter.unusedVariables false\n\n");
content.push_str(&format!("namespace {}\n\n", namespace_name));

// Inject OfNat instances to support numeric literals in specs (e.g. `x > 0`)
// We use wrapping construction (BitVec.ofNat) to avoid needing in-bounds proofs.
// Inject Prelude: Verifiable Class and Primitives
content.push_str(
"
class Verifiable (α : Type) where
is_valid : α -> Prop

attribute [simp] Verifiable.is_valid

instance : Verifiable U8 where is_valid _ := True
instance : Verifiable U16 where is_valid _ := True
instance : Verifiable U32 where is_valid _ := True
instance : Verifiable U64 where is_valid _ := True
instance : Verifiable U128 where is_valid _ := True
instance : Verifiable I8 where is_valid _ := True
instance : Verifiable I16 where is_valid _ := True
instance : Verifiable I32 where is_valid _ := True
instance : Verifiable I64 where is_valid _ := True
instance : Verifiable I128 where is_valid _ := True
instance : Verifiable Usize where is_valid _ := True
instance : Verifiable Isize where is_valid _ := True
instance : Verifiable Bool where is_valid _ := True
instance : Verifiable Unit where is_valid _ := True

"
);

// Inject OfNat instances
content.push_str(
"
instance : OfNat U32 n where ofNat := UScalar.mk (BitVec.ofNat 32 n)
instance : OfNat I32 n where ofNat := IScalar.mk (BitVec.ofNat 32 n)
instance : OfNat Usize n where ofNat := UScalar.mk (BitVec.ofNat System.Platform.numBits n)
instance : OfNat Isize n where ofNat := IScalar.mk (BitVec.ofNat System.Platform.numBits n)

",
"
);

// Struct Instances
// Dedup structs just in case
let mut unique_structs = Vec::new();
let mut seen_structs = std::collections::HashSet::new();
for st in structs {
if seen_structs.insert(st.ident.to_string()) {
unique_structs.push(st);
}
}

for st in unique_structs {
let name = &st.ident;
let mut invariant = st.invariant.as_deref().unwrap_or("True");
if invariant.is_empty() {
invariant = "True";
}

// Handle Generics: [Verifiable T] for each T
let mut generic_params = String::new();
let mut generic_constraints = String::new();
let mut type_args = String::new();

for param in &st.generics.params {
if let syn::GenericParam::Type(t) = param {
generic_params.push_str(&format!("{{{}}} ", t.ident));
generic_constraints.push_str(&format!("[Verifiable {}] ", t.ident));
type_args.push_str(&format!("{} ", t.ident));
}
}

// Format: instance {T} [Verifiable T] : Verifiable (Wrapper T) where
let type_str = if type_args.is_empty() {
name.to_string()
} else {
format!("({} {})", name, type_args.trim())
};

let header = format!("instance {}{} : Verifiable {} where", generic_params, generic_constraints, type_str);
content.push_str(&header);
content.push_str(&format!("\n is_valid self := {}\n\n", invariant));
}

for func in functions {
let spec_content = match &func.spec {
Some(s) => s,
Expand Down Expand Up @@ -259,6 +331,22 @@ instance : OfNat Isize n where ofNat := IScalar.mk (BitVec.ofNat System.Platform
if let Some(args) = desugared.signature_args {
signature_parts.push(args);
}

// INJECT ARGUMENT VALIDITY CHECKS
// For each arg `x : T`, inject `(h_x : Verifiable.is_valid x)`
// We need to parse inputs to get names.
for arg in &inputs {
if let syn::FnArg::Typed(pat_type) = arg {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
let name = &pat_ident.ident;
// We assume the type is verifiable.
// The signature args in `desugared.signature_args` already listed them as `(x : T)`.
// We just append validity hypotheses.
// Note: This relies on `x` being available in scope, which it is in the signature.
signature_parts.push(format!("(h_{}_valid : Verifiable.is_valid {})", name, name));
}
}
}

for req in desugared.extra_args {
signature_parts.push(req);
Expand Down
2 changes: 2 additions & 0 deletions tools/hermes/src/translator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ impl SignatureTranslator {
if let GenericParam::Type(type_param) = param {
let name = &type_param.ident;
context.push_str(&format!("{{{name} : Type}} "));
// Inject Verifiable constraint
context.push_str(&format!("[inst{name}Verifiable : Verifiable {name}] "));
}
}

Expand Down
2 changes: 1 addition & 1 deletion tools/hermes/tests/cases/success/generic_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
///@ lean spec id (x : T)
///@ ensures |ret| ret = x
///@ proof
///@ simp [id]
///@ simp_all [id]
pub fn id<T>(x: T) -> T {
x
}
Expand Down
48 changes: 48 additions & 0 deletions tools/hermes/tests/cases/success/invariant.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
///@ lean invariant MyStruct
///@ is_valid self := self.val < 100
pub struct MyStruct {
pub val: u32,
}

///@ lean invariant Wrapper
///@ is_valid self := self.inner.val > 0
pub struct Wrapper<T> {
pub inner: MyStruct,
pub data: T,
}

///@ lean spec use_invariant (s : MyStruct)
///@ ensures |ret| ret = s.val /\ ret < 100
///@ proof
///@ simp_all [use_invariant]
pub fn use_invariant(s: MyStruct) -> u32 {
s.val
}

///@ lean spec generic_invariant (w : Wrapper U32)
///@ ensures |ret| ret = w.inner.val /\ ret > 0
///@ proof
///@ simp_all [generic_invariant]
pub fn generic_invariant<T>(w: Wrapper<T>) -> u32 {
w.inner.val
}

///@ lean spec make_mystruct (val : U32)
///@ requires h : val < 100
///@ ensures |ret| ret.val = val
///@ proof
///@ simp_all [make_mystruct]
pub fn make_mystruct(val: u32) -> MyStruct {
MyStruct { val }
}

///@ lean spec make_wrapper (inner : MyStruct) (data : U32)
///@ requires h : inner.val > 0
///@ ensures |ret| ret.inner = inner /\ ret.data = data
///@ proof
///@ simp_all [make_wrapper]
pub fn make_wrapper(inner: MyStruct, data: u32) -> Wrapper<u32> {
Wrapper { inner, data }
}

fn main() {}
2 changes: 1 addition & 1 deletion tools/hermes/tests/cases/success/shadow_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub unsafe fn safe_div(a: u32, b: u32) -> u32 {
///@ ensures |ret| ret.val = a.val
///@ proof
///@ rw [wrapper]
///@ have ⟨ ret, h ⟩ := safe_div_spec a 1#u32 (by native_decide)
///@ have ⟨ ret, h ⟩ := safe_div_spec a 1#u32 (by simp) (by simp) (by native_decide)
///@ simp [h.1]
///@ simp_all [Nat.div_one, h.2]
pub fn wrapper(a: u32) -> u32 {
Expand Down
Loading