From bd89247aeb894206bc753b58edcc9caf4fc21b8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Sat, 7 Mar 2026 16:14:47 -0300 Subject: [PATCH 01/31] Adding proof reconstruction scaffold --- Blaster.lean | 1 + Blaster/Reconstruct.lean | 3 +++ Blaster/Reconstruct/Basic.lean | 1 + Blaster/Reconstruct/Tactic.lean | 36 +++++++++++++++++++++++++++++++++ Blaster/Reconstruct/Trace.lean | 14 +++++++++++++ Tests/Reconstruct/Basic.lean | 12 +++++++++++ 6 files changed, 67 insertions(+) create mode 100644 Blaster/Reconstruct.lean create mode 100644 Blaster/Reconstruct/Basic.lean create mode 100644 Blaster/Reconstruct/Tactic.lean create mode 100644 Blaster/Reconstruct/Trace.lean create mode 100644 Tests/Reconstruct/Basic.lean diff --git a/Blaster.lean b/Blaster.lean index c84f819..5baa1e4 100644 --- a/Blaster.lean +++ b/Blaster.lean @@ -4,5 +4,6 @@ import Blaster.Command import Blaster.Logging import Blaster.Optimize +import Blaster.Reconstruct import Blaster.Smt import Blaster.StateMachine diff --git a/Blaster/Reconstruct.lean b/Blaster/Reconstruct.lean new file mode 100644 index 0000000..444bcc7 --- /dev/null +++ b/Blaster/Reconstruct.lean @@ -0,0 +1,3 @@ +import Blaster.Reconstruct.Trace +import Blaster.Reconstruct.Tactic +import Blaster.Reconstruct.Basic diff --git a/Blaster/Reconstruct/Basic.lean b/Blaster/Reconstruct/Basic.lean new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/Blaster/Reconstruct/Basic.lean @@ -0,0 +1 @@ + diff --git a/Blaster/Reconstruct/Tactic.lean b/Blaster/Reconstruct/Tactic.lean new file mode 100644 index 0000000..aed03e0 --- /dev/null +++ b/Blaster/Reconstruct/Tactic.lean @@ -0,0 +1,36 @@ +import Lean + +import Blaster.Reconstruct.Trace + +open Lean Elab Tactic Meta + +namespace Blaster.Reconstruct + +def traceToSimpLemmas (trace : RewriteTrace) : List Name := + trace.filterMap λ step => + match step with + | .Rewrite lemmaName => some lemmaName + | .Unfold fname => some fname + | .RewriteWithHyp _ => none + +def namesToSimpArgs (names : List Name) : MetaM (Array (TSyntax `Lean.Parser.Tactic.simpLemma)) := + names.toArray.mapM fun name => do + `(Lean.Parser.Tactic.simpLemma| $(mkIdent name):ident) + +def buildSimpTactic (args : Array (TSyntax `Lean.Parser.Tactic.simpLemma)) : MetaM Syntax := + `(tactic| simp only [$args,*]) + +def reconstructFromTrace (trace : RewriteTrace) : TacticM Unit := do + let names := traceToSimpLemmas trace + let args <- namesToSimpArgs names + let tac <- buildSimpTactic args + evalTactic tac + +elab "reconstruct" trace:term : tactic => do + let traceExpr <- + Lean.Elab.Tactic.elabTerm trace (some (mkConst `Blaster.Reconstruct.RewriteTrace)) + let traceVal <- + unsafe Lean.Meta.evalExpr RewriteTrace (mkConst `Blaster.Reconstruct.RewriteTrace) traceExpr + reconstructFromTrace traceVal + +end Blaster.Reconstruct diff --git a/Blaster/Reconstruct/Trace.lean b/Blaster/Reconstruct/Trace.lean new file mode 100644 index 0000000..7b928cf --- /dev/null +++ b/Blaster/Reconstruct/Trace.lean @@ -0,0 +1,14 @@ +import Lean + +open Lean + +namespace Blaster.Reconstruct + +inductive RewriteStep where + | Rewrite (lemmaName : Name) + | Unfold (fname : Name) + | RewriteWithHyp (hyp : Expr) + +abbrev RewriteTrace := List RewriteStep + +end Blaster.Reconstruct diff --git a/Tests/Reconstruct/Basic.lean b/Tests/Reconstruct/Basic.lean new file mode 100644 index 0000000..5de10ec --- /dev/null +++ b/Tests/Reconstruct/Basic.lean @@ -0,0 +1,12 @@ +import Blaster.Reconstruct + +open Blaster.Reconstruct + +example (x : Nat) : x + 0 = x := by + reconstruct [.Rewrite `Nat.add_zero] + +example (x : Nat) : 0 + x = x := by + reconstruct [.Rewrite `Nat.zero_add] + +example (x : Nat) : 0 + x + 0 = x := by + reconstruct [.Rewrite `Nat.zero_add, .Rewrite `Nat.add_zero] From 4c8b5c44d0f70c3190a9fa460c8e11c3ec95074c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Sat, 7 Mar 2026 17:14:28 -0300 Subject: [PATCH 02/31] Removing tactic script approach --- Blaster/Reconstruct.lean | 1 - Blaster/Reconstruct/Tactic.lean | 36 --------------------------------- Tests/Reconstruct/Basic.lean | 11 ---------- 3 files changed, 48 deletions(-) delete mode 100644 Blaster/Reconstruct/Tactic.lean diff --git a/Blaster/Reconstruct.lean b/Blaster/Reconstruct.lean index 444bcc7..d24c1ba 100644 --- a/Blaster/Reconstruct.lean +++ b/Blaster/Reconstruct.lean @@ -1,3 +1,2 @@ import Blaster.Reconstruct.Trace -import Blaster.Reconstruct.Tactic import Blaster.Reconstruct.Basic diff --git a/Blaster/Reconstruct/Tactic.lean b/Blaster/Reconstruct/Tactic.lean deleted file mode 100644 index aed03e0..0000000 --- a/Blaster/Reconstruct/Tactic.lean +++ /dev/null @@ -1,36 +0,0 @@ -import Lean - -import Blaster.Reconstruct.Trace - -open Lean Elab Tactic Meta - -namespace Blaster.Reconstruct - -def traceToSimpLemmas (trace : RewriteTrace) : List Name := - trace.filterMap λ step => - match step with - | .Rewrite lemmaName => some lemmaName - | .Unfold fname => some fname - | .RewriteWithHyp _ => none - -def namesToSimpArgs (names : List Name) : MetaM (Array (TSyntax `Lean.Parser.Tactic.simpLemma)) := - names.toArray.mapM fun name => do - `(Lean.Parser.Tactic.simpLemma| $(mkIdent name):ident) - -def buildSimpTactic (args : Array (TSyntax `Lean.Parser.Tactic.simpLemma)) : MetaM Syntax := - `(tactic| simp only [$args,*]) - -def reconstructFromTrace (trace : RewriteTrace) : TacticM Unit := do - let names := traceToSimpLemmas trace - let args <- namesToSimpArgs names - let tac <- buildSimpTactic args - evalTactic tac - -elab "reconstruct" trace:term : tactic => do - let traceExpr <- - Lean.Elab.Tactic.elabTerm trace (some (mkConst `Blaster.Reconstruct.RewriteTrace)) - let traceVal <- - unsafe Lean.Meta.evalExpr RewriteTrace (mkConst `Blaster.Reconstruct.RewriteTrace) traceExpr - reconstructFromTrace traceVal - -end Blaster.Reconstruct diff --git a/Tests/Reconstruct/Basic.lean b/Tests/Reconstruct/Basic.lean index 5de10ec..8b13789 100644 --- a/Tests/Reconstruct/Basic.lean +++ b/Tests/Reconstruct/Basic.lean @@ -1,12 +1 @@ -import Blaster.Reconstruct -open Blaster.Reconstruct - -example (x : Nat) : x + 0 = x := by - reconstruct [.Rewrite `Nat.add_zero] - -example (x : Nat) : 0 + x = x := by - reconstruct [.Rewrite `Nat.zero_add] - -example (x : Nat) : 0 + x + 0 = x := by - reconstruct [.Rewrite `Nat.zero_add, .Rewrite `Nat.add_zero] From 3045abc2e2d0feda125e1249433d88f31e64da4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Sat, 7 Mar 2026 17:16:15 -0300 Subject: [PATCH 03/31] Including reconstruction trace in optimizeEnv --- Blaster/Optimize/Env.lean | 16 +++++++++++++--- Blaster/Optimize/OptimizeStack.lean | 4 ++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/Blaster/Optimize/Env.lean b/Blaster/Optimize/Env.lean index ce38004..15c2dfd 100644 --- a/Blaster/Optimize/Env.lean +++ b/Blaster/Optimize/Env.lean @@ -2,10 +2,11 @@ import Lean import Blaster.Optimize.Expr import Blaster.Optimize.MatchInfo import Blaster.Optimize.Opaque +import Blaster.Reconstruct.Trace import Blaster.Smt.Term import Blaster.Command.Options -open Lean Meta Blaster.Smt Blaster.Options +open Lean Meta Blaster.Smt Blaster.Options Blaster.Reconstruct namespace Blaster.Optimize @@ -289,6 +290,9 @@ structure OptimizeEnv where -/ restart : Bool + /-- Trace of rewrite steps performed during optimization, used for proof reconstruction. -/ + rewriteTrace : RewriteTrace + /-- local declaration context -/ ctx : LocalDeclContext @@ -307,6 +311,7 @@ instance : Inhabited OptimizeEnv where memCache := default, options := default, restart := false, + rewriteTrace := [], ctx := default } @@ -526,6 +531,11 @@ def setNormalizeFunCall (b : Bool) : TranslateEnvT Unit := do def setInFunApp (b : Bool) : TranslateEnvT Unit := do modify (fun env => { env with optEnv.options.inFunApp := b }) +/-- add a rewrite step to the reconstruction trace. -/ +@[always_inline, inline] +def addTraceStep (step : RewriteStep) : TranslateEnvT Unit := do + modify (fun env => { env with optEnv.rewriteTrace := env.optEnv.rewriteTrace ++ [step] }) + @[always_inline, inline] def updateHypothesis (h : HypothesisContext) (localCache : RewriteCacheMap) : TranslateEnvT Unit := do modify (fun env => { env with optEnv.hypothesisContext := h, optEnv.localRewriteCache := localCache}) @@ -663,7 +673,7 @@ def mkExpr (a : Expr) (cacheResult := true) : TranslateEnvT Expr := do /-- Return `true` only when both hypothesisMap and matchInContext are empty and isRefHyp flag is not set -/ @[always_inline, inline] def isGlobalContext : TranslateEnvT Bool := do - let ⟨_, ⟨_, _, _, _, _, _, _, _, hypothesisContext, matchInContext, _, _, _, _⟩⟩ ← get + let ⟨_, ⟨_, _, _, _, _, _, _, _, hypothesisContext, matchInContext, _, _, _, _, _⟩⟩ ← get return hypothesisContext.hypothesisMap.size == 0 && matchInContext.size == 0 /-- Perform the following: @@ -1712,7 +1722,7 @@ where An error is triggered if no corresponding entry can be found in `recFunMap`. -/ def hasRecFunInst? (instApp : Expr) : TranslateEnvT (Option Expr) := do - let ⟨_, ⟨_, _, _, _, _,recFunInstCache,_,recFunMap, _, _, _, _, _, _⟩⟩ ← get + let ⟨_, ⟨_, _, _, _, _,recFunInstCache,_,recFunMap, _, _, _, _, _, _, _⟩⟩ ← get match recFunInstCache.get? instApp with | some fbody => -- retrieve function application from recFunMap diff --git a/Blaster/Optimize/OptimizeStack.lean b/Blaster/Optimize/OptimizeStack.lean index 1c47098..b03aa5f 100644 --- a/Blaster/Optimize/OptimizeStack.lean +++ b/Blaster/Optimize/OptimizeStack.lean @@ -67,7 +67,7 @@ abbrev OptimizeContinuity := Sum (List OptimizeStack) Expr @[always_inline, inline] def mkHypStackContext (h : UpdatedHypContext) : TranslateEnvT HypsStackContext := do - let ⟨_, ⟨_, localRewriteCache, _, _, _, _, _, _, hypothesisContext, _, _, _, _, _⟩⟩ ← get + let ⟨_, ⟨_, localRewriteCache, _, _, _, _, _, _, hypothesisContext, _, _, _, _, _, _⟩⟩ ← get if h.1 then updateHypothesis h.2 Std.HashMap.emptyWithCapacity return {newHCtx := h, oldHCtx := some hypothesisContext, oldCache := some localRewriteCache} @@ -82,7 +82,7 @@ def resetHypContext (h : HypsStackContext) : TranslateEnvT Unit := do @[always_inline, inline] def mkMatchStackContext (h : MatchContextMap) : TranslateEnvT MatchStackContext := do - let ⟨_, ⟨_, localRewriteCache, _, _, _, _, _, _, _, matchInContext, _, _, _, _⟩⟩ ← get + let ⟨_, ⟨_, localRewriteCache, _, _, _, _, _, _, _, matchInContext, _, _, _, _, _⟩⟩ ← get updateMatchContext h Std.HashMap.emptyWithCapacity return {oldMatchCtx := matchInContext, oldCache := localRewriteCache} From b096f4b040616a6d1362b441873ca5a91eb526ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Sun, 8 Mar 2026 02:10:14 -0300 Subject: [PATCH 04/31] Removing rewriteTrace from OptimizeEnv --- Blaster/Optimize/Env.lean | 16 +++------------- Blaster/Optimize/OptimizeStack.lean | 4 ++-- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/Blaster/Optimize/Env.lean b/Blaster/Optimize/Env.lean index 15c2dfd..ce38004 100644 --- a/Blaster/Optimize/Env.lean +++ b/Blaster/Optimize/Env.lean @@ -2,11 +2,10 @@ import Lean import Blaster.Optimize.Expr import Blaster.Optimize.MatchInfo import Blaster.Optimize.Opaque -import Blaster.Reconstruct.Trace import Blaster.Smt.Term import Blaster.Command.Options -open Lean Meta Blaster.Smt Blaster.Options Blaster.Reconstruct +open Lean Meta Blaster.Smt Blaster.Options namespace Blaster.Optimize @@ -290,9 +289,6 @@ structure OptimizeEnv where -/ restart : Bool - /-- Trace of rewrite steps performed during optimization, used for proof reconstruction. -/ - rewriteTrace : RewriteTrace - /-- local declaration context -/ ctx : LocalDeclContext @@ -311,7 +307,6 @@ instance : Inhabited OptimizeEnv where memCache := default, options := default, restart := false, - rewriteTrace := [], ctx := default } @@ -531,11 +526,6 @@ def setNormalizeFunCall (b : Bool) : TranslateEnvT Unit := do def setInFunApp (b : Bool) : TranslateEnvT Unit := do modify (fun env => { env with optEnv.options.inFunApp := b }) -/-- add a rewrite step to the reconstruction trace. -/ -@[always_inline, inline] -def addTraceStep (step : RewriteStep) : TranslateEnvT Unit := do - modify (fun env => { env with optEnv.rewriteTrace := env.optEnv.rewriteTrace ++ [step] }) - @[always_inline, inline] def updateHypothesis (h : HypothesisContext) (localCache : RewriteCacheMap) : TranslateEnvT Unit := do modify (fun env => { env with optEnv.hypothesisContext := h, optEnv.localRewriteCache := localCache}) @@ -673,7 +663,7 @@ def mkExpr (a : Expr) (cacheResult := true) : TranslateEnvT Expr := do /-- Return `true` only when both hypothesisMap and matchInContext are empty and isRefHyp flag is not set -/ @[always_inline, inline] def isGlobalContext : TranslateEnvT Bool := do - let ⟨_, ⟨_, _, _, _, _, _, _, _, hypothesisContext, matchInContext, _, _, _, _, _⟩⟩ ← get + let ⟨_, ⟨_, _, _, _, _, _, _, _, hypothesisContext, matchInContext, _, _, _, _⟩⟩ ← get return hypothesisContext.hypothesisMap.size == 0 && matchInContext.size == 0 /-- Perform the following: @@ -1722,7 +1712,7 @@ where An error is triggered if no corresponding entry can be found in `recFunMap`. -/ def hasRecFunInst? (instApp : Expr) : TranslateEnvT (Option Expr) := do - let ⟨_, ⟨_, _, _, _, _,recFunInstCache,_,recFunMap, _, _, _, _, _, _, _⟩⟩ ← get + let ⟨_, ⟨_, _, _, _, _,recFunInstCache,_,recFunMap, _, _, _, _, _, _⟩⟩ ← get match recFunInstCache.get? instApp with | some fbody => -- retrieve function application from recFunMap diff --git a/Blaster/Optimize/OptimizeStack.lean b/Blaster/Optimize/OptimizeStack.lean index b03aa5f..1c47098 100644 --- a/Blaster/Optimize/OptimizeStack.lean +++ b/Blaster/Optimize/OptimizeStack.lean @@ -67,7 +67,7 @@ abbrev OptimizeContinuity := Sum (List OptimizeStack) Expr @[always_inline, inline] def mkHypStackContext (h : UpdatedHypContext) : TranslateEnvT HypsStackContext := do - let ⟨_, ⟨_, localRewriteCache, _, _, _, _, _, _, hypothesisContext, _, _, _, _, _, _⟩⟩ ← get + let ⟨_, ⟨_, localRewriteCache, _, _, _, _, _, _, hypothesisContext, _, _, _, _, _⟩⟩ ← get if h.1 then updateHypothesis h.2 Std.HashMap.emptyWithCapacity return {newHCtx := h, oldHCtx := some hypothesisContext, oldCache := some localRewriteCache} @@ -82,7 +82,7 @@ def resetHypContext (h : HypsStackContext) : TranslateEnvT Unit := do @[always_inline, inline] def mkMatchStackContext (h : MatchContextMap) : TranslateEnvT MatchStackContext := do - let ⟨_, ⟨_, localRewriteCache, _, _, _, _, _, _, _, matchInContext, _, _, _, _, _⟩⟩ ← get + let ⟨_, ⟨_, localRewriteCache, _, _, _, _, _, _, _, matchInContext, _, _, _, _⟩⟩ ← get updateMatchContext h Std.HashMap.emptyWithCapacity return {oldMatchCtx := matchInContext, oldCache := localRewriteCache} From 17f711bcdaef31e0c9a14eada6b34b202023041d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Mon, 9 Mar 2026 10:52:46 -0300 Subject: [PATCH 05/31] Proof certificate generation and propagation in optimization pipeline --- Blaster/Command/Tactic.lean | 11 +- Blaster/Optimize.lean | 2 +- Blaster/Optimize/Basic.lean | 144 +++++++++++------- Blaster/Optimize/OptimizeStack.lean | 93 ++++++----- .../Optimize/Rewriting/NormalizeMatch.lean | 5 +- Blaster/Optimize/Rewriting/OptimizeApp.lean | 91 ++++++----- Blaster/Optimize/Rewriting/OptimizeConst.lean | 4 +- Blaster/Optimize/Rewriting/OptimizeMatch.lean | 2 +- Blaster/Optimize/Rewriting/OptimizeNat.lean | 35 +++-- Blaster/Optimize/Types.lean | 15 ++ Blaster/Reconstruct.lean | 1 - Blaster/Reconstruct/Basic.lean | 6 + Blaster/Reconstruct/Trace.lean | 14 -- Blaster/Smt/Translate.lean | 11 +- Blaster/Smt/Translate/Application.lean | 7 +- Blaster/Smt/Translate/Quantifier.lean | 2 +- Blaster/StateMachine/BMC.lean | 5 +- Blaster/StateMachine/KInduction.lean | 6 +- Blaster/StateMachine/StateMachine.lean | 4 +- Tests/Reconstruct/Basic.lean | 2 + 20 files changed, 278 insertions(+), 182 deletions(-) create mode 100644 Blaster/Optimize/Types.lean delete mode 100644 Blaster/Reconstruct/Trace.lean diff --git a/Blaster/Command/Tactic.lean b/Blaster/Command/Tactic.lean index 44f61a4..8abdc48 100644 --- a/Blaster/Command/Tactic.lean +++ b/Blaster/Command/Tactic.lean @@ -36,12 +36,19 @@ def blasterTacticImp : Tactic := fun stx => let sOpts ← parseSolveOptions opts default let (goal, nbQuantifiers) ← revertHypotheses (← getMainGoal) let env := {(default : TranslateEnv) with optEnv.options.solverOptions := sOpts} - let ((result, optExpr), _) ← + let ((result, (optExpr, proof)), _) ← withTheReader Core.Context (fun ctx => { ctx with maxHeartbeats := 0 }) $ do IO.setNumHeartbeats 0 Translate.main (← goal.getType >>= instantiateMVars') (logUndetermined := false) |>.run env match result with - | .Valid => goal.admit -- TODO: replace with proof reconstruction + | .Valid => + match proof with + | some p => goal.assign p + | none => + try goal.refl + catch _ => + logWarning "blaster: proof reconstruction failed, closing with admit" + goal.admit | .Falsified cex => throwTacticEx `blaster goal "Goal was falsified (see counterexample above)" | .Undetermined => -- Replace the goal with the optimized expression diff --git a/Blaster/Optimize.lean b/Blaster/Optimize.lean index bd3034a..0db0214 100644 --- a/Blaster/Optimize.lean +++ b/Blaster/Optimize.lean @@ -1,3 +1,3 @@ - import Blaster.Optimize.Basic import Blaster.Optimize.Lemmas +import Blaster.Optimize.Types diff --git a/Blaster/Optimize/Basic.lean b/Blaster/Optimize/Basic.lean index 18ab33a..a0b2c08 100644 --- a/Blaster/Optimize/Basic.lean +++ b/Blaster/Optimize/Basic.lean @@ -11,161 +11,201 @@ open Lean Elab Command Term Meta Blaster.Options namespace Blaster.Optimize - -- TODO: update formalization with inference rule style notation. -partial def optimizeExprAux (stack : List OptimizeStack) : TranslateEnvT Expr := do +partial def optimizeExprAux (stack : List OptimizeStack) (proof : Option Expr := none) : + TranslateEnvT OptimizeResult := do match stack with | .InitOptimizeExpr e :: xs => - match (← isInOptimizeEnvCache e xs) with + match (← isInOptimizeEnvCache e proof xs) with | Sum.inl i_stack => -- trace[Optimize.expr] "optimizing {← ppExpr e}" match e with | Expr.fvar _ => match (← normFVar e i_stack) with | Sum.inr e' => return e' - | Sum.inl stack' => optimizeExprAux stack' + -- TODO: compose proof certificates with Eq.trans + -- when both proof and nextProof are present + | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof | Expr.sort l => -- sort is used for Type u, Prop, etc let s' ← mkExpr (Expr.sort (normLevel l)) - match (← stackContinuity i_stack s') with + match (← stackContinuity i_stack s' proof) with | Sum.inr e' => return e' - | Sum.inl stack' => optimizeExprAux stack' + -- TODO: compose proof certificates with Eq.trans + -- when both proof and nextProof are present + | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof | Expr.lit .. => -- number or string literal - match (← stackContinuity i_stack (← mkExpr e)) with + match (← stackContinuity i_stack (← mkExpr e) proof) with | Sum.inr e' => return e' - | Sum.inl stack' => optimizeExprAux stack' + -- TODO: compose proof certificates with Eq.trans + -- when both proof and nextProof are present + | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof | Expr.const .. => match (← normConst e i_stack) with | Sum.inr e' => return e' - | Sum.inl stack' => optimizeExprAux stack' + -- TODO: compose proof certificates with Eq.trans + -- when both proof and nextProof are present + | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof | Expr.forallE n t b bi => - optimizeExprAux (.InitOptimizeExpr t :: .ForallWaitForType n bi b :: i_stack) + optimizeExprAux (.InitOptimizeExpr t :: .ForallWaitForType n bi b :: i_stack) proof | Expr.app .. => let (f, ras) := getAppFnWithArgs e -- check if f is a lambda term if f.isLambda then -- perform beta reduction and apply optimization - optimizeExprAux (.InitOptimizeExpr (betaLambda f ras) :: i_stack) + optimizeExprAux (.InitOptimizeExpr (betaLambda f ras) :: i_stack) proof else -- set inFunApp flag before optimizing `f` setInFunApp true let i_stack' := .AppWaitForConst ras :: i_stack - optimizeExprAux (.InitOptimizeExpr f :: i_stack') + optimizeExprAux (.InitOptimizeExpr f :: i_stack') proof - | Expr.lam n t b bi => optimizeExprAux (optimizeLambda n t b bi i_stack) + | Expr.lam n t b bi => optimizeExprAux (optimizeLambda n t b bi i_stack) proof - | Expr.letE _n _t v b _ => optimizeExprAux (inlineLet v b i_stack) -- inline let expression + -- inline let expression + | Expr.letE _n _t v b _ => optimizeExprAux (inlineLet v b i_stack) proof | Expr.mdata d me => if (isTaggedRecursiveCall e) then setNormalizeFunCall false - optimizeExprAux (.InitOptimizeExpr me :: .MDataRecCallWaitForExpr d :: i_stack) - else optimizeExprAux (.InitOptimizeExpr me :: i_stack) + optimizeExprAux + (.InitOptimizeExpr me :: .MDataRecCallWaitForExpr d :: i_stack) proof + else optimizeExprAux (.InitOptimizeExpr me :: i_stack) proof | Expr.proj n idx s => let i_stack' := .ProjWaitForExpr n idx :: i_stack - optimizeExprAux (.InitOptimizeExpr s :: i_stack') + optimizeExprAux (.InitOptimizeExpr s :: i_stack') proof | Expr.mvar .. => throwEnvError "optimizeExpr: unexpected meta variable {e}" | Expr.bvar .. => throwEnvError "optimizeExpr: unexpected bound variable {e}" | Sum.inr (Sum.inr e') => return e' - | Sum.inr (Sum.inl stack') => optimizeExprAux stack' + -- TODO: compose proof certificates with Eq.trans + -- when both proof and nextProof are present + | Sum.inr (Sum.inl (nextStack, nextProof)) => optimizeExprAux nextStack nextProof | s@(.InitOpaqueRecExpr ..) :: xs | s@(.RecFunDefStorage ..) :: xs => - match (← normOpaqueAndRecFun s xs) with + match (← normOpaqueAndRecFun s xs proof) with | Sum.inr e => return e - | Sum.inl stack' => optimizeExprAux stack' + -- TODO: compose proof certificates with Eq.trans + -- when both proof and nextProof are present + | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof | .AppOptimizeImplicitArgs f args idx startIdx stopIdx pInfo :: xs => if idx ≥ stopIdx then if isBlasterDiteConst f && args.size ≥ 2 then -- applying choice reduction to avoid optimizing unreachable arguments in Blaster.dite' let condOpt := .InitOptimizeExpr args[1]! - optimizeExprAux (condOpt :: .DiteChoiceWaitForCond f args pInfo startIdx :: xs) + optimizeExprAux (condOpt :: .DiteChoiceWaitForCond f args pInfo startIdx :: xs) proof else if let some mInfo ← isMatcher? f then -- try match reduction rule first to avoid unnecessary optimization on discriminators if let some r ← matchReduction? f args mInfo then - optimizeExprAux (.InitOptimizeExpr r :: xs) + optimizeExprAux (.InitOptimizeExpr r :: xs) proof else -- applying choice reduction and match constant propagation to -- avoid optimizing unreachable rhs in match -- only optimizing match discriminators first - optimizeExprAux (.MatchChoiceOptimizeDiscrs f args pInfo startIdx mInfo.getFirstDiscrPos mInfo :: xs) + optimizeExprAux + (.MatchChoiceOptimizeDiscrs + f args pInfo startIdx mInfo.getFirstDiscrPos mInfo :: xs) + proof -- try to apply funPropagation to avoid optimizing ite/match multiple times else if let some r ← funPropagation? f args (reorderArgs := true) then - optimizeExprAux (.InitOptimizeExpr r :: xs) + optimizeExprAux (.InitOptimizeExpr r :: xs) proof -- apply optimization on remaining explicit parameters before reduction - else optimizeExprAux (.AppOptimizeExplicitArgs f args startIdx args.size pInfo none :: xs) + else + optimizeExprAux + (.AppOptimizeExplicitArgs f args startIdx args.size pInfo none :: xs) + proof else if idx < pInfo.paramsInfo.size -- handle case when HOF is the returned type then if !pInfo.paramsInfo[idx]!.isExplicit - then optimizeExprAux (.InitOptimizeExpr args[idx]! :: stack) - else optimizeExprAux (.AppOptimizeImplicitArgs f args (idx + 1) startIdx stopIdx pInfo :: xs) - else optimizeExprAux (.AppOptimizeImplicitArgs f args (idx + 1) startIdx stopIdx pInfo :: xs) + then optimizeExprAux (.InitOptimizeExpr args[idx]! :: stack) proof + else + optimizeExprAux + (.AppOptimizeImplicitArgs f args (idx + 1) startIdx stopIdx pInfo :: xs) + proof + else + optimizeExprAux + (.AppOptimizeImplicitArgs f args (idx + 1) startIdx stopIdx pInfo :: xs) + proof | .AppOptimizeExplicitArgs f args idx stopIdx pInfo mInfo :: xs => if idx ≥ stopIdx then -- normalizing ite/match function application if let some re ← normChoiceApplication? f args then -- trace[Optimize.normChoiceApp] "normalizing choice application {reprStr f} {reprStr args} => {reprStr re}" - optimizeExprAux (.InitOptimizeExpr re :: xs) + optimizeExprAux (.InitOptimizeExpr re :: xs) proof -- apply match normalization rules else if let some argInfo := mInfo then match (← optimizeMatch f args argInfo xs) with | Sum.inr e' => return e' - | Sum.inl stack' => optimizeExprAux stack' + -- TODO: compose proof certificates with Eq.trans + -- when both proof and nextProof are present + | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof -- apply ite normalization rules only when fully applied else if isBlasterDiteConst f && args.size == 4 then match (← optimizeIfThenElse? f args xs) with | Sum.inr e' => return e' - | Sum.inl stack' => optimizeExprAux stack' + -- TODO: compose proof certificates with Eq.trans + -- when both proof and nextProof are present + | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof -- try to reduce app if all params are constructors else if let some re ← reduceApp? f args then -- trace[Optimize.reduceApp] "application reduction {reprStr f} {reprStr args} => {reprStr re}" - optimizeExprAux (.InitOptimizeExpr re :: xs) + optimizeExprAux (.InitOptimizeExpr re :: xs) proof -- unfold non-recursive and non-opaque functions -- NOTE: beta reduction performed by getUnfoldFunDef? when rf is a lambda term -- NOTE: we can only unfold once all parameters have been optimized. else if let some fdef ← getUnfoldFunDef? f args then -- trace[Optimize.unfoldDef] "unfolding function definition {reprStr f} {reprStr args} => {reprStr fdef}" - optimizeExprAux (.InitOptimizeExpr fdef :: xs) + optimizeExprAux (.InitOptimizeExpr fdef :: xs) proof -- normalizing partially apply function after unfolding non-opaque functions else if let some pe ← normPartialFun? f args then -- trace[Optimize.normPartial] "normalizing partial function {reprStr f} {reprStr args} => {reprStr pe}" - optimizeExprAux (.InitOptimizeExpr pe :: xs) + optimizeExprAux (.InitOptimizeExpr pe :: xs) proof -- applying optimization on opaque rec function and app and proceed with fun propagation rules - else optimizeExprAux (.InitOpaqueRecExpr f args :: xs) + else optimizeExprAux (.InitOpaqueRecExpr f args :: xs) proof else if idx < pInfo.paramsInfo.size then if pInfo.paramsInfo[idx]!.isExplicit - then optimizeExprAux (← optimizeExplicitArgs f args idx stopIdx pInfo mInfo stack xs) - else optimizeExprAux (.AppOptimizeExplicitArgs f args (idx + 1) stopIdx pInfo mInfo :: xs) - else optimizeExprAux (.InitOptimizeExpr args[idx]! :: stack) + then + optimizeExprAux + (← optimizeExplicitArgs f args idx stopIdx pInfo mInfo stack xs) + proof + else + optimizeExprAux + (.AppOptimizeExplicitArgs f args (idx + 1) stopIdx pInfo mInfo :: xs) + proof + else optimizeExprAux (.InitOptimizeExpr args[idx]! :: stack) proof | .MatchChoiceOptimizeDiscrs f args pInfo startArgIdx idx mInfo :: xs => if idx ≥ mInfo.getFirstAltPos then if let some r ← matchReduction? f args mInfo then - optimizeExprAux (.InitOptimizeExpr r :: xs) + optimizeExprAux (.InitOptimizeExpr r :: xs) proof else -- apply optimization on remaining explicit parameters before reduction -- keep matchInfo to avoid unnecessary query and to avoid optimizing discriminators again - optimizeExprAux (.AppOptimizeExplicitArgs f args startArgIdx args.size pInfo mInfo :: xs) - else optimizeExprAux (.InitOptimizeExpr args[idx]! :: stack) + optimizeExprAux + (.AppOptimizeExplicitArgs f args startArgIdx args.size pInfo mInfo :: xs) + proof + else optimizeExprAux (.InitOptimizeExpr args[idx]! :: stack) proof | .MatchRhsLambdaNext next :: xs => match next with | Expr.lam n t b bi => - optimizeExprAux (.InitOptimizeExpr t :: .MatchRhsLambdaWaitForType n bi b :: xs) + optimizeExprAux (.InitOptimizeExpr t :: .MatchRhsLambdaWaitForType n bi b :: xs) proof | _ => -- header on xs is expected to be .MatchRhsLambdaWaitForBody - match (← stackContinuity xs next) with - | Sum.inl stack' => optimizeExprAux stack' + match (← stackContinuity xs next proof) with + -- TODO: compose proof certificates with Eq.trans + -- when both proof and nextProof are present + | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof | _ => throwEnvError "optimizeExprAux: continuity expected for MatchRhsLambdaNext !!!" | _ => throwEnvError "optimizeExprAux: unexpected optimize stack continuity {reprStr stack} !!!" @@ -182,9 +222,9 @@ partial def optimizeExprAux (stack : List OptimizeStack) : TranslateEnvT Expr := normFVar (e : Expr) (stack : List OptimizeStack) : TranslateEnvT OptimizeContinuity := withLocalContext $ do match ← e.fvarId!.getValue? with - | none => stackContinuity stack (← mkExpr e) + | none => stackContinuity stack (← mkExpr e) proof | some v => - return Sum.inl (.InitOptimizeExpr (← instantiateMVars v) :: stack) + return Sum.inl (.InitOptimizeExpr (← instantiateMVars v) :: stack, none) /-- Given a function `f := Expr const n l` perform the following: - When `n := mInfo ∈ isMatcherCache` (i.e., match info already optimized) @@ -243,12 +283,12 @@ partial def optimizeExprAux (stack : List OptimizeStack) : TranslateEnvT Expr := @[always_inline, inline] -def optimizeExpr (e : Expr) : TranslateEnvT Expr := +def optimizeExpr (e : Expr) : TranslateEnvT OptimizeResult := optimizeExprAux [.InitOptimizeExpr e] /-- Same as optimizeExpr but updates local context before optimizing expression -/ @[always_inline, inline] -def optimizeExpr' (e : Expr) : TranslateEnvT Expr := do +def optimizeExpr' (e : Expr) : TranslateEnvT OptimizeResult := do -- set start local context updateLocalContext (← mkLocalContext) optimizeExprAux [.InitOptimizeExpr e ] @@ -272,9 +312,9 @@ def cacheOpaqueRecFun : TranslateEnvT Unit := do /-- Perform the following actions: - populate the recFunInstCache with default recursive function definitions. - - optimize expression `e` + - optimize expression `e`, returning an optional proof certificate alongside the result. -/ -def Optimize.main (e : Expr) : TranslateEnvT Expr := do +def Optimize.main (e : Expr) : TranslateEnvT OptimizeResult := do -- set start local context updateLocalContext (← mkLocalContext) -- populate recFunInstCache with recursive function definition. @@ -294,10 +334,10 @@ def command (sOpts: BlasterOptions) (e : Expr) : MetaM (Expr × TranslateEnv) := -- keep the current name generator and restore it afterwards let ngen ← getNGen let env := {(default : TranslateEnv) with optEnv.options.solverOptions := sOpts} - let res ← Optimize.main e|>.run env + let (⟨optExpr, _proof⟩, translateEnv) ← Optimize.main e|>.run env -- restore name generator setNGen ngen - return res + return (optExpr, translateEnv) initialize diff --git a/Blaster/Optimize/OptimizeStack.lean b/Blaster/Optimize/OptimizeStack.lean index 1c47098..b6e2613 100644 --- a/Blaster/Optimize/OptimizeStack.lean +++ b/Blaster/Optimize/OptimizeStack.lean @@ -63,7 +63,7 @@ inductive OptimizeStack where | ProjWaitForExpr (n : Name) (idx : Nat) deriving Repr -abbrev OptimizeContinuity := Sum (List OptimizeStack) Expr +abbrev OptimizeContinuity := Sum (List OptimizeStack × Option Expr) OptimizeResult @[always_inline, inline] def mkHypStackContext (h : UpdatedHypContext) : TranslateEnvT HypsStackContext := do @@ -100,20 +100,23 @@ def mkLocalDeclStackContext (newCtx : LocalDeclContext) : TranslateEnvT LocalDec def resetLocalDeclContext (oldCtx : LocalDeclContext) : TranslateEnvT Unit := updateLocalContext oldCtx -def stackContinuity (stack : List OptimizeStack) (optExpr : Expr) (skipCache := false) : TranslateEnvT OptimizeContinuity := do +def stackContinuity + (stack : List OptimizeStack) (optExpr : Expr) + (proof : Option Expr := none) (skipCache := false) + : TranslateEnvT OptimizeContinuity := do match stack with - | [] => return Sum.inr optExpr + | [] => return Sum.inr ⟨optExpr, proof⟩ | .InitOptimizeReturn e isGlobal :: xs => if !skipCache then updateOptimizeEnvCache e optExpr isGlobal match xs with - | [] => return Sum.inr optExpr - | _ => stackContinuity xs optExpr + | [] => return Sum.inr ⟨optExpr, proof⟩ + | _ => stackContinuity xs optExpr proof | .RecFunDefWaitForStorage args instApp subsInst params :: xs => -- optExpr corresponds to optimized rec fun body -- continuity with normOpaqueAndRecFun - return Sum.inl (.RecFunDefStorage args instApp subsInst params optExpr :: xs) + return Sum.inl (.RecFunDefStorage args instApp subsInst params optExpr :: xs, proof) | .ForallWaitForType n bi body :: xs => -- optExpr corresponds to optimized forall binder type @@ -125,7 +128,8 @@ def stackContinuity (stack : List OptimizeStack) (optExpr : Expr) (skipCache := let hyps ← addHypotheses optExpr x isNotPropBody let hctx ← mkHypStackContext hyps let lctx ← mkLocalDeclStackContext (← mkLocalContext) - return Sum.inl ( .InitOptimizeExpr body' :: .ForallWaitForBody x optExpr hctx lctx :: xs) + return Sum.inl ( .InitOptimizeExpr body' :: .ForallWaitForBody x optExpr hctx lctx :: xs + , proof) | .ForallWaitForBody x t hctx lctx :: xs => -- optExpr corresponds to optimized forall body @@ -136,9 +140,10 @@ def stackContinuity (stack : List OptimizeStack) (optExpr : Expr) (skipCache := resetLocalDeclContext lctx if ← isRestart then resetRestart - return Sum.inl (.InitOptimizeExpr e :: xs) + return Sum.inl (.InitOptimizeExpr e :: xs, proof) else -- continuity with optimizing next expression - stackContinuity xs (← mkExpr e) + let proof' ← proof.mapM (fun p => mkLambdaFVars #[x] p) + stackContinuity xs (← mkExpr e) proof' | .AppWaitForConst args :: xs => -- optExpr corresponds to optimized fun app @@ -147,7 +152,7 @@ def stackContinuity (stack : List OptimizeStack) (optExpr : Expr) (skipCache := -- check if optExpr is a lambda if optExpr.isLambda then -- perform beta reduction and apply optimization - return Sum.inl (.InitOptimizeExpr (betaLambda optExpr args) :: xs) + return Sum.inl (.InitOptimizeExpr (betaLambda optExpr args) :: xs, proof) else let (rf, extraArgs) := getAppFnWithArgs optExpr let args := extraArgs ++ args @@ -156,7 +161,9 @@ def stackContinuity (stack : List OptimizeStack) (optExpr : Expr) (skipCache := match (← hasUnOptMatchInfo? rf) with | none => -- continuity with optimization on implicit arguments - return Sum.inl (.AppOptimizeImplicitArgs rf args extraArgs.size extraArgs.size args.size pInfo :: xs) + return Sum.inl + (.AppOptimizeImplicitArgs + rf args extraArgs.size extraArgs.size args.size pInfo :: xs, proof) | some (mInfo, instApp) => -- continuity with optimizing match generic instance -- NOTE: instApp is expected to be a lambda term @@ -165,7 +172,8 @@ def stackContinuity (stack : List OptimizeStack) (optExpr : Expr) (skipCache := match instApp with | Expr.lam n t b bi => let mWait := .OptimizeMatchInfoWaitForInst rf args extraArgs.size pInfo mInfo :: xs - return Sum.inl (.InitOptimizeExpr t :: .MatchRhsLambdaWaitForType n bi b :: mWait) + return Sum.inl ( .InitOptimizeExpr t :: .MatchRhsLambdaWaitForType n bi b :: mWait + , proof) | _ => throwEnvError "stackContinuity: lambda expected for match instance but got {reprStr instApp}" @@ -179,19 +187,24 @@ def stackContinuity (stack : List OptimizeStack) (optExpr : Expr) (skipCache := -- apply optimization only on implicit parameters to remove mdata annotation -- we don't consider explicit parameters at this stage to avoid performing -- optimization on unreachable arguments - return Sum.inl (.AppOptimizeImplicitArgs f args startArgIdx startArgIdx args.size pInfo :: xs) + return Sum.inl + (.AppOptimizeImplicitArgs f args startArgIdx startArgIdx args.size pInfo :: xs, proof) else throwEnvError "stackContinuity: name expression for match application but got {reprStr f} !!!" | .AppOptimizeImplicitArgs f args idx startArgIdx stopIdx pInfo :: xs => -- optExpr corresponds to the optimized implicit argument referenced by idx. -- continuity with optimizing the next implicit argument. - return Sum.inl (.AppOptimizeImplicitArgs f (args.set! idx optExpr) (idx + 1) startArgIdx stopIdx pInfo :: xs) + return Sum.inl + (.AppOptimizeImplicitArgs f (args.set! idx optExpr) + (idx + 1) startArgIdx stopIdx pInfo :: xs, proof) | .AppOptimizeExplicitArgs f args idx stopIdx pInfo mInfo :: xs => -- optExpr corresponds to the optimized explicit argument referenced by idx. -- continuity with optimizing the next explicit argument. - return Sum.inl (.AppOptimizeExplicitArgs f (args.set! idx optExpr) (idx + 1) stopIdx pInfo mInfo :: xs) + return Sum.inl + (.AppOptimizeExplicitArgs f (args.set! idx optExpr) + (idx + 1) stopIdx pInfo mInfo :: xs, proof) | .DiteChoiceWaitForCond f args pInfo startArgIdx :: xs => -- optExpr corresponds to the optimized Blaster.dite' conditional, i.e., referenced by index 1. @@ -201,15 +214,19 @@ def stackContinuity (stack : List OptimizeStack) (optExpr : Expr) (skipCache := -- - continuity with optimizing remaining explicit parameters before reduction -- NOTE: keep matchInfo to avoid unnecessary query and to avoid optimizing discriminators again if let some r ← optimizeDITEChoice f (args.set! 1 optExpr) then - return Sum.inl (.InitOptimizeExpr r :: xs) + return Sum.inl (.InitOptimizeExpr r :: xs, proof) else -- NOTE: keep matchInfo to avoid unnecessary query and to avoid optimizing discriminators again - return Sum.inl (.AppOptimizeExplicitArgs f (args.set! 1 optExpr) startArgIdx args.size pInfo none :: xs) + return Sum.inl + (.AppOptimizeExplicitArgs f (args.set! 1 optExpr) + startArgIdx args.size pInfo none :: xs, proof) | .MatchChoiceOptimizeDiscrs f args pInfo startArgIdx idx mInfo :: xs => -- optExpr corresponds to the optimized match discriminator referenced by idx. -- continuity with optimizing the next discriminator - return Sum.inl (.MatchChoiceOptimizeDiscrs f (args.set! idx optExpr) pInfo startArgIdx (idx + 1) mInfo :: xs) + return Sum.inl + (.MatchChoiceOptimizeDiscrs f (args.set! idx optExpr) + pInfo startArgIdx (idx + 1) mInfo :: xs, proof) | .LambdaWaitForType n bi body inDite :: xs => -- optExpr corresponds to optimized lambda type @@ -220,8 +237,8 @@ def stackContinuity (stack : List OptimizeStack) (optExpr : Expr) (skipCache := if inDite then let hyps ← addHypotheses optExpr x let hypsCtx ← mkHypStackContext hyps - return Sum.inl (bodyOpt :: .LambdaWaitForBody x lctx (some hypsCtx) :: xs) - else return Sum.inl (bodyOpt :: .LambdaWaitForBody x lctx none :: xs) + return Sum.inl (bodyOpt :: .LambdaWaitForBody x lctx (some hypsCtx) :: xs, proof) + else return Sum.inl (bodyOpt :: .LambdaWaitForBody x lctx none :: xs, proof) | .LambdaWaitForBody x lctx hctx :: xs => -- optExpr corresponds to optimized lambda body @@ -229,7 +246,7 @@ def stackContinuity (stack : List OptimizeStack) (optExpr : Expr) (skipCache := if let some h := hctx then resetHypContext h let e ← withLocalContext $ do mkLambdaExpr x optExpr resetLocalDeclContext lctx - stackContinuity xs e + stackContinuity xs e proof | .MatchRhsLambdaWaitForType n bi body :: xs => -- optExpr corresponds to optimized lambda type @@ -238,7 +255,7 @@ def stackContinuity (stack : List OptimizeStack) (optExpr : Expr) (skipCache := withLocalDecl' n bi optExpr fun x => do let bodyOpt := .MatchRhsLambdaNext (instantiate1' body x) let lctx ← mkLocalDeclStackContext (← mkLocalContext) - return Sum.inl (bodyOpt :: .MatchRhsLambdaWaitForBody x lctx :: xs) + return Sum.inl (bodyOpt :: .MatchRhsLambdaWaitForBody x lctx :: xs, proof) | .MatchRhsLambdaWaitForBody x lctx :: xs => -- optExpr corresponds to optimized lambda body @@ -247,7 +264,7 @@ def stackContinuity (stack : List OptimizeStack) (optExpr : Expr) (skipCache := -- rhs has not been optimized yet. let e ← withLocalContext $ do mkLambdaFVar x optExpr resetLocalDeclContext lctx - stackContinuity xs e + stackContinuity xs e proof | .MatchAltWaitForExpr params lctx mctx :: xs => -- optExpr corresponds to the optimized match rhs @@ -255,26 +272,26 @@ def stackContinuity (stack : List OptimizeStack) (optExpr : Expr) (skipCache := let e ← withLocalContext $ do mkExpr (← mkLambdaFVars' params optExpr) resetMatchContext mctx resetLocalDeclContext lctx - stackContinuity xs e + stackContinuity xs e proof | .LetWaitForValue body :: xs => -- optExpr corresponds to the optimized let value -- continuity with optimizing body - return Sum.inl (.InitOptimizeExpr (body.instantiate1 optExpr) :: xs) + return Sum.inl (.InitOptimizeExpr (body.instantiate1 optExpr) :: xs, proof) | .MDataRecCallWaitForExpr d :: xs => -- optExpr corresponds to the annotated rec call that is optimized when `normalizeFunCall` is set to false -- continuity with optimizing next expression setNormalizeFunCall true - stackContinuity xs (← mkExpr (Expr.mdata d optExpr)) + stackContinuity xs (← mkExpr (Expr.mdata d optExpr)) proof | .ProjWaitForExpr n idx :: xs => -- optExpr corresponds to optimized projection structure if let some re ← optimizeProjection? n idx optExpr then - return Sum.inl (.InitOptimizeExpr re :: xs) + return Sum.inl (.InitOptimizeExpr re :: xs, proof) else -- continuity with optimizing next expression - stackContinuity xs (← mkExpr $ mkProj n idx optExpr) + stackContinuity xs (← mkExpr $ mkProj n idx optExpr) proof | _ => throwEnvError "stackContinuity: unexpected optimize stack continuity {reprStr stack} !!!" @@ -298,11 +315,12 @@ def stackContinuity (stack : List OptimizeStack) (optExpr : Expr) (skipCache := else return none @[always_inline, inline] -def mkOptimizeContinuity (e : Expr) (stack : List OptimizeStack) : TranslateEnvT OptimizeContinuity := do +def mkOptimizeContinuity (expr : Expr) (proof : Option Expr) (stack : List OptimizeStack) : + TranslateEnvT OptimizeContinuity := do if ← isRestart then resetRestart - return Sum.inl (.InitOptimizeExpr e :: stack) - else stackContinuity stack (← mkExpr e) + return Sum.inl (.InitOptimizeExpr expr :: stack, none) + else stackContinuity stack (← mkExpr expr) proof /-- Apply simplification/normalization rules on Blaster.dite' expressions. Assume that f = Expr.const ``Blaster.dite'. @@ -310,15 +328,16 @@ def mkOptimizeContinuity (e : Expr) (stack : List OptimizeStack) : TranslateEnvT @[always_inline, inline] def optimizeIfThenElse? (f : Expr) (args : Array Expr) (stack : List OptimizeStack) : TranslateEnvT OptimizeContinuity := withLocalContext $ do - mkOptimizeContinuity (← optimizeDITE f args) stack + mkOptimizeContinuity (← optimizeDITE f args) none stack @[always_inline, inline] -def isInOptimizeEnvCache (a : Expr) (stack : List OptimizeStack) : TranslateEnvT (Sum (List OptimizeStack) OptimizeContinuity) := do +def isInOptimizeEnvCache (expr : Expr) (proof : Option Expr) (stack : List OptimizeStack) : + TranslateEnvT (Sum (List OptimizeStack) OptimizeContinuity) := do -- NOTE: Always consider global context when `a` does not contain any FVar. - let isGlobal := !a.hasFVar || (← isGlobalContext) - match (← isInOptimizeCache? a isGlobal) with - | some b => Sum.inr <$> stackContinuity stack b - | none => return Sum.inl (.InitOptimizeReturn a isGlobal :: stack) + let isGlobal := !expr.hasFVar || (← isGlobalContext) + match (← isInOptimizeCache? expr isGlobal) with + | some b => Sum.inr <$> stackContinuity stack b proof + | none => return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack) end Blaster.Optimize diff --git a/Blaster/Optimize/Rewriting/NormalizeMatch.lean b/Blaster/Optimize/Rewriting/NormalizeMatch.lean index a713bf8..ac88570 100644 --- a/Blaster/Optimize/Rewriting/NormalizeMatch.lean +++ b/Blaster/Optimize/Rewriting/NormalizeMatch.lean @@ -3,7 +3,6 @@ import Blaster.Optimize.Rewriting.Utils import Blaster.Optimize.Rewriting.OptimizeNat import Blaster.Optimize.Rewriting.OptimizeInt - open Lean Meta Elab namespace Blaster.Optimize @@ -163,7 +162,9 @@ partial def removeNamedPatternExpr (p : Expr) : TranslateEnvT Expr := do where optimizePattern (f : Expr) (args : Array Expr) : TranslateEnvT Expr := do match f with - | Expr.const ``Nat.add _ => optimizeNatAdd f args + | Expr.const ``Nat.add _ => + let ⟨expr, _proof⟩ ← optimizeNatAdd f args + return expr | Expr.const ``Int.neg _ => optimizeIntNeg f args | _ => return mkAppN f args diff --git a/Blaster/Optimize/Rewriting/OptimizeApp.lean b/Blaster/Optimize/Rewriting/OptimizeApp.lean index dbfe971..9d293fd 100644 --- a/Blaster/Optimize/Rewriting/OptimizeApp.lean +++ b/Blaster/Optimize/Rewriting/OptimizeApp.lean @@ -58,50 +58,56 @@ def reduceApp? (f : Expr) (args: Array Expr) : TranslateEnvT (Option Expr) := wi /-- Perform constant propagation and apply simplification and normalization rules on application expressions. -/ -def optimizeAppAux (f : Expr) (args: Array Expr) : TranslateEnvT Expr := do +def optimizeAppAux (f : Expr) (args: Array Expr) : TranslateEnvT OptimizeResult := do let args ← reorderOperands f args - if let some e ← optimizePropNot? f args then return e - if let some e ← optimizePropBinary? f args then return e - if let some e ← optimizeBoolNot? f args then return e - if let some e ← optimizeBoolBinary? f args then return e - if let some e ← optimizeEquality? f args then return e - if let some e ← optimizeNat? f args then return e - if let some e ← optimizeInt? f args then return e - if let some e ← optimizeExists? f args then return e - if let some e ← optimizeDecide? f args then return e - if let some e ← optimizeRelational? f args then return e - if let some e ← optimizeString? f args then return e + if let some e ← optimizePropNot? f args then return (OptimizeResult.mk e none) + if let some e ← optimizePropBinary? f args then return (OptimizeResult.mk e none) + if let some e ← optimizeBoolNot? f args then return (OptimizeResult.mk e none) + if let some e ← optimizeBoolBinary? f args then return (OptimizeResult.mk e none) + if let some e ← optimizeEquality? f args then return (OptimizeResult.mk e none) + if let some r ← optimizeNat? f args then return r + if let some e ← optimizeInt? f args then return (OptimizeResult.mk e none) + if let some e ← optimizeExists? f args then return (OptimizeResult.mk e none) + if let some e ← optimizeDecide? f args then return (OptimizeResult.mk e none) + if let some e ← optimizeRelational? f args then return (OptimizeResult.mk e none) + if let some e ← optimizeString? f args then return (OptimizeResult.mk e none) let appExpr := mkAppN f args - if (← isResolvableType appExpr) then return (← resolveTypeAbbrev appExpr) - return appExpr + if (← isResolvableType appExpr) then return ⟨← resolveTypeAbbrev appExpr, none⟩ + return ⟨appExpr, none⟩ /-- Perform the following: - - apply normalization and simplification rrules on the given application expression + - apply normalization and simplification rules on the given application expression - When restart flag is set: - add optimized application on continuation stack - Otherwise: - - try tp apply function propagation over ite and match: + - try to apply function propagation over ite and match: - When propagation rules are triggered: - add result on continuation stack - Otherwise: - cache normalized application - proceed with stack continuity + `incomingProof` is an optional proof certificate threaded through the optimizer stack. + NOTE: proof certificate composition via `Eq.trans` is not yet implemented (see TODO comments). + NOTE: skipPropCheck is set to `true` only when it is known beforehand that `f` is a recursive function for which `allExplicitParamsAreCtor f args (funPropagation := true)` returns `true`. -/ def optimizeApp (f : Expr) (args: Array Expr) - (stack : List OptimizeStack) (skipPropCheck := false) : TranslateEnvT OptimizeContinuity := do - let e ← optimizeAppAux f args + (stack : List OptimizeStack) (incomingProof : Option Expr := none) (skipPropCheck := false) : + TranslateEnvT OptimizeContinuity := do + let ⟨e, newProof⟩ ← optimizeAppAux f args + let proof := newProof.orElse (λ _ => incomingProof) if ← isRestart then resetRestart - return Sum.inl (.InitOptimizeExpr e :: stack) + return Sum.inl (.InitOptimizeExpr e :: stack, none) else match (← isFunPropagation? e) with - | some r => return Sum.inl (.InitOptimizeExpr r :: stack) - | none => stackContinuity stack (← mkExpr e) -- cache expression and proceed with continuity + | some r => return Sum.inl (.InitOptimizeExpr r :: stack, none) + | none => -- cache expression and proceed with continuity + return (← stackContinuity stack (← mkExpr e) proof) where @[always_inline, inline] @@ -154,11 +160,12 @@ def normPartialFun? (f : Expr) (args : Array Expr) : TranslateEnvT (Option Expr) Assumes that an entry exists for each opaque recursive function in `recFunMap` before optimization is performed (see function `cacheOpaqueRecFun`). -/ -def normOpaqueAndRecFun (s : OptimizeStack) (xs : List OptimizeStack) : - TranslateEnvT OptimizeContinuity := withLocalContext $ do +def normOpaqueAndRecFun + (s : OptimizeStack) (xs : List OptimizeStack) (proof : Option Expr := none) : + TranslateEnvT OptimizeContinuity := withLocalContext $ do match s with | .InitOpaqueRecExpr uf uargs => - let Expr.const n _ := uf | return (← stackContinuity xs (← mkAppExpr uf uargs)) + let Expr.const n _ := uf | return (← stackContinuity xs (← mkAppExpr uf uargs) proof) let isOpaqueRec ← isOpaqueRecFun uf uargs if (← isRecursiveFun n) || isOpaqueRec then @@ -166,7 +173,7 @@ def normOpaqueAndRecFun (s : OptimizeStack) (xs : List OptimizeStack) : -- call fun propagation to avoid optimizing rec body -- if rec function is an opaqueRec call app optimization first -- before calling fun propagation - optimizeApp uf uargs xs (skipPropCheck := true) + optimizeApp uf uargs xs proof (skipPropCheck := true) else -- trace[Optimize.recFun] "normalizing rec function {n}" let (f, args) ← resolveOpaque uf uargs isOpaqueRec @@ -178,10 +185,10 @@ def normOpaqueAndRecFun (s : OptimizeStack) (xs : List OptimizeStack) : let instApp ← getInstApp f params if (← isVisitedRecFun instApp) then -- trace[Optimize.recFun] "rec function instance {instApp} is in visiting cache" - optimizeRecApp uf f uargs params xs -- already cached + optimizeRecApp uf f uargs params xs proof -- already cached else if let some r ← hasRecFunInst? instApp then -- trace[Optimize.recFun] "rec function instance {instApp} is already equivalent to {reprStr r}" - optimizeRecApp uf r uargs params xs + optimizeRecApp uf r uargs params xs proof else cacheFunName instApp -- cache function name let some fbody ← getFunBody f @@ -191,15 +198,19 @@ def normOpaqueAndRecFun (s : OptimizeStack) (xs : List OptimizeStack) : -- trace[Optimize.recFun] "generalizing rec body for {n} got {reprStr fdef}" let subsInst ← opaqueInstApp uf uargs isOpaqueRec instApp -- optimize recursive fun definition and store - return Sum.inl (.InitOptimizeExpr fdef :: .RecFunDefWaitForStorage uargs instApp subsInst params :: xs) - else optimizeApp uf uargs xs -- optimizations on opaque functions + return Sum.inl + ( .InitOptimizeExpr fdef + :: .RecFunDefWaitForStorage uargs instApp subsInst params + :: xs + , none) + else optimizeApp uf uargs xs proof -- optimizations on opaque functions | .RecFunDefStorage uargs instApp subsInst params optDef => uncacheFunName instApp -- trace[Optimize.recFun] "optimized rec body for {reprStr subsInst} got {reprStr optDef}" let fn' ← storeRecFunDef subsInst params optDef -- trace[Optimize.recFun] "rec function instance {reprStr subsInst} is equivalent to {reprStr fn'}" - optimizeRecApp subsInst fn' uargs params xs + optimizeRecApp subsInst fn' uargs params xs proof | _ => throwEnvError "normOpaqueAndRecFun: unexpected continuity {reprStr s} !!!" @@ -262,25 +273,27 @@ def normOpaqueAndRecFun (s : OptimizeStack) (xs : List OptimizeStack) : - When `auxApp := fₑ x₀ ... xₙ` (default case) - return `optimizeApp fₑ x₀ ...xₙ` -/ - optimizeRecApp - (uf rf : Expr) (uargs : Array Expr) - (params : ImplicitParameters) (xs : List OptimizeStack) : TranslateEnvT OptimizeContinuity := do + optimizeRecApp + (uf rf : Expr) (uargs : Array Expr) + (params : ImplicitParameters) (xs : List OptimizeStack) + (proof : Option Expr := none) : TranslateEnvT OptimizeContinuity := do if params.isEmpty then - return ← stackContinuity xs (← mkExpr rf (cacheResult := !(normRecOpaque rf))) -- catch fun expression + -- catch fun expression + return ← stackContinuity xs (← mkExpr rf (cacheResult := !(normRecOpaque rf))) proof if exprEq uf rf then -- case for when same recursive call -- trace[Optimize.recFun.app] "same recursive call case {reprStr rf} {reprStr uargs}" if rf.isConst then - optimizeApp rf uargs xs + optimizeApp rf uargs xs proof else -- polyomrphic case: we need to remove the generic parameters let auxApp := rf.beta (← getEffectiveParams params) let (f, args) := getAppFnWithArgs auxApp - optimizeApp f args xs + optimizeApp f args xs proof else if rf.isConst then -- case when a polymorphic/non-polymorphic function is equivalent to another non-polymorphic one let eargs := Array.filterMap (λ p => if !p.isInstance then some p.effectiveArg else none) params -- trace[Optimize.recFun.app] "non-polymorphic equivalent case {reprStr rf} {reprStr eargs}" - optimizeApp rf eargs xs + optimizeApp rf eargs xs proof else let auxApp := rf.beta (← getEffectiveParams params) if auxApp.isLambda then @@ -288,11 +301,11 @@ def normOpaqueAndRecFun (s : OptimizeStack) (xs : List OptimizeStack) : let appCall := getLambdaBody auxApp let (f, largs) := getAppFnWithArgs appCall -- trace[Optimize.recFun.app] "partially applied case {reprStr appCall.getAppFn'} {reprStr largs[0:largs.size-auxApp.getNumHeadLambdas]}" - optimizeApp f (largs.take (largs.size-auxApp.getNumHeadLambdas)) xs + optimizeApp f (largs.take (largs.size-auxApp.getNumHeadLambdas)) xs proof else -- trace[Optimize.recFun.app] "polymorphic equivalent case {reprStr auxApp.getAppFn'} {reprStr auxApp.getAppArgs}" let (f, args) := getAppFnWithArgs auxApp - optimizeApp f args xs + optimizeApp f args xs proof initialize registerTraceClass `Optimize.recFun diff --git a/Blaster/Optimize/Rewriting/OptimizeConst.lean b/Blaster/Optimize/Rewriting/OptimizeConst.lean index c68d736..e5ecdf7 100644 --- a/Blaster/Optimize/Rewriting/OptimizeConst.lean +++ b/Blaster/Optimize/Rewriting/OptimizeConst.lean @@ -186,11 +186,11 @@ def normConst (e : Expr) (stack : List OptimizeStack) : TranslateEnvT OptimizeCo else if (← hasImplicitArgs e) then return none if (← isRecursiveFun f) then - return (some $ Sum.inl $ .InitOpaqueRecExpr e #[] :: stack) + return (some $ Sum.inl $ (.InitOpaqueRecExpr e #[] :: stack, none)) if (← isNotFoldable e #[]) then return none -- non recursive function case if let some fbody ← getFunBody e then - return (some $ Sum.inl $ .InitOptimizeExpr fbody :: stack) + return (some $ Sum.inl $ (.InitOptimizeExpr fbody :: stack, none)) else return none end Blaster.Optimize diff --git a/Blaster/Optimize/Rewriting/OptimizeMatch.lean b/Blaster/Optimize/Rewriting/OptimizeMatch.lean index 642997a..2345540 100644 --- a/Blaster/Optimize/Rewriting/OptimizeMatch.lean +++ b/Blaster/Optimize/Rewriting/OptimizeMatch.lean @@ -552,7 +552,7 @@ def optimizeMatch match (← normMatchExpr? args' mInfo') with | some mdef => -- trace[Optimize.normMatch] "normalizing match to ite {reprStr f'} {reprStr args'} => {reprStr mdef}" - return Sum.inl (.InitOptimizeExpr mdef :: xs) + return Sum.inl (.InitOptimizeExpr mdef :: xs, none) | _ => return (← stackContinuity xs (← mkAppExpr f' args')) diff --git a/Blaster/Optimize/Rewriting/OptimizeNat.lean b/Blaster/Optimize/Rewriting/OptimizeNat.lean index 3069ca1..119f116 100644 --- a/Blaster/Optimize/Rewriting/OptimizeNat.lean +++ b/Blaster/Optimize/Rewriting/OptimizeNat.lean @@ -1,30 +1,35 @@ import Lean +import Blaster.Optimize.Env import Blaster.Optimize.Rewriting.OptimizeEq import Blaster.Optimize.Rewriting.OptimizeRelational import Blaster.Optimize.Rewriting.Utils -import Blaster.Optimize.Env +import Blaster.Optimize.Types open Lean Meta namespace Blaster.Optimize /-- Apply the following simplification/normalization rules on `Nat.add` : - - 0 + n ==> n + - 0 + n ==> n [proof: Nat.zero_add] - N1 + N2 ===> N1 "+" N2 - N1 + (N2 + n) ==> (N1 "+" N2) + n - n1 + n2 ==> n2 + n1 (if n2 <ₒ n1) Assume that f = Expr.const ``Nat.add. An error is triggered when args.size ≠ 2 (i.e., only fully applied `Nat.add` expected at this stage) -/ -def optimizeNatAdd (f : Expr) (args : Array Expr) : TranslateEnvT Expr := do +def optimizeNatAdd (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult := do if args.size != 2 then throwEnvError "optimizeNatAdd: exactly two arguments expected" let op1 := args[0]! let op2 := args[1]! match isNatValue? op1, isNatValue? op2 with - | some 0, _ => return op2 - | some n1, some n2 => evalBinNatOp Nat.add n1 n2 + | some 0, _ => + let proof := mkApp (mkConst ``Nat.zero_add) op2 + return ⟨op2, some proof⟩ + | some n1, some n2 => + let expr <- evalBinNatOp Nat.add n1 n2 + return ⟨expr, none⟩ | nv1, _ => - if let some r ← cstAddProp? nv1 op2 then return r - return (mkApp2 f op1 op2) + if let some expr ← cstAddProp? nv1 op2 then return ⟨expr, none⟩ + return ⟨mkApp2 f op1 op2, none⟩ where /- Given `mv1` and `op2`, return `some ((N1 "+" N2) + n)` when @@ -332,17 +337,17 @@ def optimizeNatble (f : Expr) (b_args : Array Expr) : TranslateEnvT Expr := do /-- Apply simplification/normalization rules on `Nat` operators. -/ @[always_inline, inline] -def optimizeNat? (f : Expr) (args : Array Expr) : TranslateEnvT (Option Expr) := do +def optimizeNat? (f : Expr) (args : Array Expr) : TranslateEnvT (Option OptimizeResult) := do let Expr.const n _ := f | return none match n with | ``Nat.add => optimizeNatAdd f args - | ``Nat.sub => optimizeNatSub f args - | ``Nat.mul => optimizeNatMul f args - | ``Nat.div => optimizeNatDiv f args - | ``Nat.mod => optimizeNatMod f args - | ``Nat.beq => optimizeNatBeq f args - | ``Nat.ble => optimizeNatble f args - | ``Nat.pow => optimizeNatPow f args + | ``Nat.sub => return some ⟨← optimizeNatSub f args, none⟩ + | ``Nat.mul => return some ⟨← optimizeNatMul f args, none⟩ + | ``Nat.div => return some ⟨← optimizeNatDiv f args, none⟩ + | ``Nat.mod => return some ⟨← optimizeNatMod f args, none⟩ + | ``Nat.beq => return some ⟨← optimizeNatBeq f args, none⟩ + | ``Nat.ble => return some ⟨← optimizeNatble f args, none⟩ + | ``Nat.pow => return some ⟨← optimizeNatPow f args, none⟩ | _=> return none end Blaster.Optimize diff --git a/Blaster/Optimize/Types.lean b/Blaster/Optimize/Types.lean new file mode 100644 index 0000000..2d2d5fc --- /dev/null +++ b/Blaster/Optimize/Types.lean @@ -0,0 +1,15 @@ +import Lean + +namespace Blaster.Optimize + +open Lean + +/-- The result of a single optimization step with its proof. + - `optExpr` : the optimized expression + - `proof`: a term of type `original = expr` +-/ +structure OptimizeResult where + optExpr : Expr + proof : Option Expr + +end Blaster.Optimize diff --git a/Blaster/Reconstruct.lean b/Blaster/Reconstruct.lean index d24c1ba..7fcff4f 100644 --- a/Blaster/Reconstruct.lean +++ b/Blaster/Reconstruct.lean @@ -1,2 +1 @@ -import Blaster.Reconstruct.Trace import Blaster.Reconstruct.Basic diff --git a/Blaster/Reconstruct/Basic.lean b/Blaster/Reconstruct/Basic.lean index 8b13789..cdf9b3b 100644 --- a/Blaster/Reconstruct/Basic.lean +++ b/Blaster/Reconstruct/Basic.lean @@ -1 +1,7 @@ +import Lean +open Lean + +namespace Blaster.Reconstruct + +end Blaster.Reconstruct diff --git a/Blaster/Reconstruct/Trace.lean b/Blaster/Reconstruct/Trace.lean deleted file mode 100644 index 7b928cf..0000000 --- a/Blaster/Reconstruct/Trace.lean +++ /dev/null @@ -1,14 +0,0 @@ -import Lean - -open Lean - -namespace Blaster.Reconstruct - -inductive RewriteStep where - | Rewrite (lemmaName : Name) - | Unfold (fname : Name) - | RewriteWithHyp (hyp : Expr) - -abbrev RewriteTrace := List RewriteStep - -end Blaster.Reconstruct diff --git a/Blaster/Smt/Translate.lean b/Blaster/Smt/Translate.lean index df0607f..6b5a438 100644 --- a/Blaster/Smt/Translate.lean +++ b/Blaster/Smt/Translate.lean @@ -43,15 +43,16 @@ partial def translateExpr (e : Expr) (topLevel := true) : TranslateEnvT SmtTerm | Expr.sort _ => throwEnvError "translateExpr: unexpected sort type {reprStr e}" -- sort type are handled elsewhere visit e topLevel -def Translate.main (e : Expr) (logUndetermined := true) : TranslateEnvT (Result × Expr) := do +def Translate.main (e : Expr) (logUndetermined := true) : + TranslateEnvT (Result × Expr × Option Expr) := do let e' ← addAxioms (← toPropExpr e) (← findLocalAxioms) - let optExpr ← profileTask "Optimization" $ Optimize.main e' + let ⟨optExpr, proof⟩ ← profileTask "Optimization" $ Optimize.main e' trace[Translate.optExpr] "optimized expression: {← ppExpr optExpr}" match (toResult optExpr) with | res@(.Undetermined) => if (← get).optEnv.options.solverOptions.onlyOptimize then if logUndetermined then logResult res - return (res, optExpr) + return (res, optExpr, proof) else -- set backend solver setBlasterProcess @@ -63,10 +64,10 @@ def Translate.main (e : Expr) (logUndetermined := true) : TranslateEnvT (Result let res ← profileTask "Solve" checkSat if !isUndeterminedResult res || logUndetermined then logResult res discard $ exitSmt - return (res, optExpr) + return (res, optExpr, proof) | res => logResult res - return (res, optExpr) + return (res, optExpr, proof) where isTheoremExpr (e : Expr) : TranslateEnvT (Option Expr) := do diff --git a/Blaster/Smt/Translate/Application.lean b/Blaster/Smt/Translate/Application.lean index 9a089c2..49ba542 100644 --- a/Blaster/Smt/Translate/Application.lean +++ b/Blaster/Smt/Translate/Application.lean @@ -5,7 +5,6 @@ import Blaster.Smt.Env import Blaster.Smt.Translate.Match import Blaster.Smt.Translate.Quantifier - open Lean Meta Blaster.Optimize namespace Blaster.Smt @@ -754,7 +753,8 @@ def translateConst hasSorryTheorem e "translateConst: Theorem {n} has `sorry` demonstration" if info.type.isForall then throwEnvError "translateConst: Fully applied theorem expected but got {reprStr info.type}" - termTranslator (← optimizeExpr' info.type) + let ⟨optExpr, _proof⟩ <- optimizeExpr' info.type + termTranslator optExpr getAxiomOpaqueType (n : Name) : TranslateEnvT (Option Expr) := do match ← getConstEnvInfo n with @@ -981,7 +981,8 @@ def translateApp let ConstantInfo.thmInfo info ← getConstEnvInfo n | return none -- check if e has sorry demonstration and trigger error if this is the case hasSorryTheorem e "translateApp: Theorem {n} has `sorry` demonstration" - termTranslator (← optimizeExpr' (betaForAll info.type args)) + let ⟨optExpr, _proof⟩ ← optimizeExpr' (betaForAll info.type args) + termTranslator optExpr /-- Given `e := λ (x₁ : t₁) → λ (xₙ : tₙ) => b`, perform the following: - let V := [ v | v ∈ getFVarsInExpr b ∧ ¬ isType v.type ∧ ¬ isClassConstraintExpr v.type ∧ ¬ isTopLevelFVar v ] diff --git a/Blaster/Smt/Translate/Quantifier.lean b/Blaster/Smt/Translate/Quantifier.lean index 7083b64..076b665 100644 --- a/Blaster/Smt/Translate/Quantifier.lean +++ b/Blaster/Smt/Translate/Quantifier.lean @@ -980,7 +980,7 @@ where let selTerms ← mkCtorSelectorExpr recRule.ctor selectorIdx arg decl.type substituteList := (arg, selTerms.1) :: substituteList if (← isPropEnv decl.type) then - let optExpr ← optimizeExpr' decl.type + let ⟨optExpr, _proof⟩ ← optimizeExpr' decl.type -- apply substitue list on optExpr before translation let propTerm ← termTranslator (substituteList.foldr (fun a acc => acc.replace (substitutePred a)) optExpr) predTermCond := updatePredTerm predTermCond (andSmt (eqSmt selTerms.2 propTerm) selTerms.2) diff --git a/Blaster/StateMachine/BMC.lean b/Blaster/StateMachine/BMC.lean index 2bb2934..ecba416 100644 --- a/Blaster/StateMachine/BMC.lean +++ b/Blaster/StateMachine/BMC.lean @@ -51,7 +51,7 @@ partial def bmcStrategy (smInst : Expr) : TranslateEnvT Unit := do where optimizeState (iVar : Expr) (pState : Option Expr) : StateMachineEnvT Expr := do let env ← get - profileTask s!"Optimizing state at Depth {← getCurrentDepth}" + let ⟨optExpr, _proof⟩ ← profileTask s!"Optimizing state at Depth {← getCurrentDepth}" (do match pState with | none => -- depth 0 @@ -59,13 +59,14 @@ partial def bmcStrategy (smInst : Expr) : TranslateEnvT Unit := do | some state => Optimize.optimizeExpr' (mkApp5 (← mkNext) env.inputType env.stateType smInst iVar state) ) (verboseLevel := 2) + return optExpr analysisAtDepth (iVar : Expr) (state : Expr) : StateMachineEnvT Result := do let env ← get --- check invariant at step k let currDepth ← getCurrentDepth let invExpr := mkApp5 (← mkInvariants) env.inputType env.stateType smInst iVar state - let optExpr ← + let ⟨optExpr, _proof⟩ ← profileTask s!"Optimizing invariants at Depth {currDepth}" (Optimize.optimizeExpr invExpr) diff --git a/Blaster/StateMachine/KInduction.lean b/Blaster/StateMachine/KInduction.lean index 974e81b..29ac965 100644 --- a/Blaster/StateMachine/KInduction.lean +++ b/Blaster/StateMachine/KInduction.lean @@ -74,7 +74,7 @@ partial def kIndStrategy (smInst : Expr) : TranslateEnvT Unit := do | none => -- depth 0 withLocalDecl' (← nameAtDepth env.smName "state") BinderInfo.default env.stateType fun s => do let initState := mkApp4 (← mkInit) env.inputType env.stateType smInst iVar - let initEq ← + let ⟨initEq, _proof⟩ ← profileTask s!"Optimizing state at Depth {← getCurrentDepth}" (Optimize.main (mkApp3 (← mkEqOp) env.stateType s initState)) (verboseLevel := 2) @@ -92,7 +92,7 @@ partial def kIndStrategy (smInst : Expr) : TranslateEnvT Unit := do modify (fun env => { env with initFlag := some iflag }) f s | some state => - let state' ← + let ⟨state', _proof⟩ ← profileTask s!"Optimizing state at Depth {← getCurrentDepth}" (Optimize.optimizeExpr' (mkApp5 (← mkNext) env.inputType env.stateType smInst iVar state)) (verboseLevel := 2) @@ -104,7 +104,7 @@ partial def kIndStrategy (smInst : Expr) : TranslateEnvT Unit := do --- invariant at step k let currDepth ← getCurrentDepth let invExpr := mkApp5 (← mkInvariants) env.inputType env.stateType smInst iVar state - let optExpr ← + let ⟨optExpr, _proof⟩ ← profileTask s!"Optimizing invariants at Depth {currDepth}" (Optimize.optimizeExpr invExpr) diff --git a/Blaster/StateMachine/StateMachine.lean b/Blaster/StateMachine/StateMachine.lean index 551e1a0..fe7a160 100644 --- a/Blaster/StateMachine/StateMachine.lean +++ b/Blaster/StateMachine/StateMachine.lean @@ -160,7 +160,7 @@ def assertAssumptions (smInst : Expr) (iVar : Expr) (state : Expr) : StateMachin let currDepth ← getCurrentDepth translateAxioms currDepth let assumeExpr := mkApp5 (← mkAssumptions) env.inputType env.stateType smInst iVar state - let optExpr ← + let ⟨optExpr, _proof⟩ ← profileTask s!"Optimizing assumptions at Depth {currDepth}" (Optimize.optimizeExpr assumeExpr) @@ -200,7 +200,7 @@ def assertAssumptions (smInst : Expr) (iVar : Expr) (state : Expr) : StateMachin s!"Translating axioms at Depth {currDepth}" ( axioms.forM (fun e => do - let st ← translateExpr (← Optimize.optimizeExpr e) (topLevel := false) + let st ← translateExpr (← Optimize.optimizeExpr e).1 (topLevel := false) assertTerm st ) ) diff --git a/Tests/Reconstruct/Basic.lean b/Tests/Reconstruct/Basic.lean index 8b13789..ac2702f 100644 --- a/Tests/Reconstruct/Basic.lean +++ b/Tests/Reconstruct/Basic.lean @@ -1 +1,3 @@ +import Blaster +example : ∀ {x : Nat}, 0 + x = x := by blaster From b16724be796d8c82de27ee91d4b90c3254f52dc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:26:00 -0300 Subject: [PATCH 06/31] Indentation --- Blaster/Optimize/OptimizeStack.lean | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Blaster/Optimize/OptimizeStack.lean b/Blaster/Optimize/OptimizeStack.lean index b6e2613..bf8754f 100644 --- a/Blaster/Optimize/OptimizeStack.lean +++ b/Blaster/Optimize/OptimizeStack.lean @@ -238,7 +238,7 @@ def stackContinuity let hyps ← addHypotheses optExpr x let hypsCtx ← mkHypStackContext hyps return Sum.inl (bodyOpt :: .LambdaWaitForBody x lctx (some hypsCtx) :: xs, proof) - else return Sum.inl (bodyOpt :: .LambdaWaitForBody x lctx none :: xs, proof) + else return Sum.inl (bodyOpt :: .LambdaWaitForBody x lctx none :: xs, proof) | .LambdaWaitForBody x lctx hctx :: xs => -- optExpr corresponds to optimized lambda body From c84cb0341021f4a7632a63169957f6ed324bbf42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:50:24 -0300 Subject: [PATCH 07/31] Importing Test module for proof reconstruction --- Tests/Basic.lean | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Tests/Basic.lean b/Tests/Basic.lean index 67bad6b..c6bfc2e 100644 --- a/Tests/Basic.lean +++ b/Tests/Basic.lean @@ -1,6 +1,5 @@ - import Tests.FixedIssues import Tests.Optimize +import Tests.Reconstruct.Basic import Tests.Smt import Tests.StateMachine - From 50b7b2e351c935193c5293e74312bceb4ffa04c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Tue, 10 Mar 2026 11:51:04 -0300 Subject: [PATCH 08/31] Fallback when proof certificate type does not match goal --- Blaster/Command/Tactic.lean | 11 ++++++++++- Tests/Reconstruct/Basic.lean | 2 ++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/Blaster/Command/Tactic.lean b/Blaster/Command/Tactic.lean index 8abdc48..72c9097 100644 --- a/Blaster/Command/Tactic.lean +++ b/Blaster/Command/Tactic.lean @@ -43,7 +43,16 @@ def blasterTacticImp : Tactic := fun stx => match result with | .Valid => match proof with - | some p => goal.assign p + | some p => + -- verify certificate type matches goal before assigning, + -- as composition via Eq.trans is not yet fully implemented + let goalType ← goal.getType + let pType ← inferType p + if (← isDefEq pType goalType) then + goal.assign p + else + logWarning "blaster: proof reconstruction failed, closing with admit" + goal.admit | none => try goal.refl catch _ => diff --git a/Tests/Reconstruct/Basic.lean b/Tests/Reconstruct/Basic.lean index ac2702f..625e86f 100644 --- a/Tests/Reconstruct/Basic.lean +++ b/Tests/Reconstruct/Basic.lean @@ -1,3 +1,5 @@ import Blaster example : ∀ {x : Nat}, 0 + x = x := by blaster + +example : ∀ {x : Nat}, 0 + (0 + x) = x := by blaster From c85933bbb6a0287dc6874efa2644925b6e431e93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Tue, 10 Mar 2026 23:37:29 -0300 Subject: [PATCH 09/31] compose proof certificates for nested rewrites via congrArg --- Blaster/Command/Tactic.lean | 3 +- Blaster/Optimize/Basic.lean | 73 +++++++++------------ Blaster/Optimize/Env.lean | 36 +++++----- Blaster/Optimize/OptimizeStack.lean | 28 +++++--- Blaster/Optimize/Rewriting/OptimizeApp.lean | 17 ++++- Blaster/Reconstruct/Basic.lean | 35 +++++++++- Tests/Reconstruct/Basic.lean | 2 + 7 files changed, 122 insertions(+), 72 deletions(-) diff --git a/Blaster/Command/Tactic.lean b/Blaster/Command/Tactic.lean index 72c9097..dbadc94 100644 --- a/Blaster/Command/Tactic.lean +++ b/Blaster/Command/Tactic.lean @@ -44,8 +44,7 @@ def blasterTacticImp : Tactic := fun stx => | .Valid => match proof with | some p => - -- verify certificate type matches goal before assigning, - -- as composition via Eq.trans is not yet fully implemented + -- verify certificate type matches goal before assigning let goalType ← goal.getType let pType ← inferType p if (← isDefEq pType goalType) then diff --git a/Blaster/Optimize/Basic.lean b/Blaster/Optimize/Basic.lean index a0b2c08..d2ab290 100644 --- a/Blaster/Optimize/Basic.lean +++ b/Blaster/Optimize/Basic.lean @@ -6,8 +6,9 @@ import Blaster.Optimize.Rewriting.FunPropagation import Blaster.Optimize.Rewriting.OptimizeApp import Blaster.Optimize.Rewriting.OptimizeConst import Blaster.Optimize.Rewriting.OptimizeForAll +import Blaster.Reconstruct.Basic -open Lean Elab Command Term Meta Blaster.Options +open Lean Elab Command Term Meta Blaster.Options Blaster.Reconstruct namespace Blaster.Optimize @@ -17,84 +18,72 @@ partial def optimizeExprAux (stack : List OptimizeStack) (proof : Option Expr := match stack with | .InitOptimizeExpr e :: xs => match (← isInOptimizeEnvCache e proof xs) with - | Sum.inl i_stack => + | Sum.inl (i_stack, i_proof) => -- trace[Optimize.expr] "optimizing {← ppExpr e}" match e with | Expr.fvar _ => - match (← normFVar e i_stack) with + match (← normFVar e i_stack i_proof) with | Sum.inr e' => return e' - -- TODO: compose proof certificates with Eq.trans - -- when both proof and nextProof are present | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof | Expr.sort l => -- sort is used for Type u, Prop, etc let s' ← mkExpr (Expr.sort (normLevel l)) - match (← stackContinuity i_stack s' proof) with + match (← stackContinuity i_stack s' i_proof) with | Sum.inr e' => return e' - -- TODO: compose proof certificates with Eq.trans - -- when both proof and nextProof are present | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof | Expr.lit .. => -- number or string literal - match (← stackContinuity i_stack (← mkExpr e) proof) with + match (← stackContinuity i_stack (← mkExpr e) i_proof) with | Sum.inr e' => return e' - -- TODO: compose proof certificates with Eq.trans - -- when both proof and nextProof are present | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof | Expr.const .. => match (← normConst e i_stack) with | Sum.inr e' => return e' - -- TODO: compose proof certificates with Eq.trans - -- when both proof and nextProof are present | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof | Expr.forallE n t b bi => - optimizeExprAux (.InitOptimizeExpr t :: .ForallWaitForType n bi b :: i_stack) proof + optimizeExprAux (.InitOptimizeExpr t :: .ForallWaitForType n bi b :: i_stack) i_proof | Expr.app .. => let (f, ras) := getAppFnWithArgs e -- check if f is a lambda term if f.isLambda then -- perform beta reduction and apply optimization - optimizeExprAux (.InitOptimizeExpr (betaLambda f ras) :: i_stack) proof + optimizeExprAux (.InitOptimizeExpr (betaLambda f ras) :: i_stack) i_proof else -- set inFunApp flag before optimizing `f` setInFunApp true let i_stack' := .AppWaitForConst ras :: i_stack - optimizeExprAux (.InitOptimizeExpr f :: i_stack') proof + optimizeExprAux (.InitOptimizeExpr f :: i_stack') i_proof - | Expr.lam n t b bi => optimizeExprAux (optimizeLambda n t b bi i_stack) proof + | Expr.lam n t b bi => optimizeExprAux (optimizeLambda n t b bi i_stack) i_proof -- inline let expression - | Expr.letE _n _t v b _ => optimizeExprAux (inlineLet v b i_stack) proof + | Expr.letE _n _t v b _ => optimizeExprAux (inlineLet v b i_stack) i_proof | Expr.mdata d me => if (isTaggedRecursiveCall e) then setNormalizeFunCall false optimizeExprAux - (.InitOptimizeExpr me :: .MDataRecCallWaitForExpr d :: i_stack) proof - else optimizeExprAux (.InitOptimizeExpr me :: i_stack) proof + (.InitOptimizeExpr me :: .MDataRecCallWaitForExpr d :: i_stack) i_proof + else optimizeExprAux (.InitOptimizeExpr me :: i_stack) i_proof | Expr.proj n idx s => let i_stack' := .ProjWaitForExpr n idx :: i_stack - optimizeExprAux (.InitOptimizeExpr s :: i_stack') proof + optimizeExprAux (.InitOptimizeExpr s :: i_stack') i_proof | Expr.mvar .. => throwEnvError "optimizeExpr: unexpected meta variable {e}" | Expr.bvar .. => throwEnvError "optimizeExpr: unexpected bound variable {e}" | Sum.inr (Sum.inr e') => return e' - -- TODO: compose proof certificates with Eq.trans - -- when both proof and nextProof are present | Sum.inr (Sum.inl (nextStack, nextProof)) => optimizeExprAux nextStack nextProof | s@(.InitOpaqueRecExpr ..) :: xs | s@(.RecFunDefStorage ..) :: xs => match (← normOpaqueAndRecFun s xs proof) with | Sum.inr e => return e - -- TODO: compose proof certificates with Eq.trans - -- when both proof and nextProof are present | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof | .AppOptimizeImplicitArgs f args idx startIdx stopIdx pInfo :: xs => @@ -121,7 +110,8 @@ partial def optimizeExprAux (stack : List OptimizeStack) (proof : Option Expr := -- apply optimization on remaining explicit parameters before reduction else optimizeExprAux - (.AppOptimizeExplicitArgs f args startIdx args.size pInfo none :: xs) + (.AppOptimizeExplicitArgs f args startIdx args.size pInfo none + args (Array.replicate args.size none) :: xs) proof else if idx < pInfo.paramsInfo.size -- handle case when HOF is the returned type @@ -136,7 +126,7 @@ partial def optimizeExprAux (stack : List OptimizeStack) (proof : Option Expr := (.AppOptimizeImplicitArgs f args (idx + 1) startIdx stopIdx pInfo :: xs) proof - | .AppOptimizeExplicitArgs f args idx stopIdx pInfo mInfo :: xs => + | .AppOptimizeExplicitArgs f args idx stopIdx pInfo mInfo origArgs argProofs :: xs => if idx ≥ stopIdx then -- normalizing ite/match function application if let some re ← normChoiceApplication? f args then @@ -146,16 +136,14 @@ partial def optimizeExprAux (stack : List OptimizeStack) (proof : Option Expr := else if let some argInfo := mInfo then match (← optimizeMatch f args argInfo xs) with | Sum.inr e' => return e' - -- TODO: compose proof certificates with Eq.trans - -- when both proof and nextProof are present - | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof + | Sum.inl (nextStack, nextProof) => + optimizeExprAux nextStack (← composeProofs? proof nextProof) -- apply ite normalization rules only when fully applied else if isBlasterDiteConst f && args.size == 4 then match (← optimizeIfThenElse? f args xs) with | Sum.inr e' => return e' - -- TODO: compose proof certificates with Eq.trans - -- when both proof and nextProof are present - | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof + | Sum.inl (nextStack, nextProof) => + optimizeExprAux nextStack (← composeProofs? proof nextProof) -- try to reduce app if all params are constructors else if let some re ← reduceApp? f args then -- trace[Optimize.reduceApp] "application reduction {reprStr f} {reprStr args} => {reprStr re}" @@ -180,7 +168,8 @@ partial def optimizeExprAux (stack : List OptimizeStack) (proof : Option Expr := proof else optimizeExprAux - (.AppOptimizeExplicitArgs f args (idx + 1) stopIdx pInfo mInfo :: xs) + (.AppOptimizeExplicitArgs + f args (idx + 1) stopIdx pInfo mInfo origArgs argProofs :: xs) proof else optimizeExprAux (.InitOptimizeExpr args[idx]! :: stack) proof @@ -192,7 +181,8 @@ partial def optimizeExprAux (stack : List OptimizeStack) (proof : Option Expr := -- apply optimization on remaining explicit parameters before reduction -- keep matchInfo to avoid unnecessary query and to avoid optimizing discriminators again optimizeExprAux - (.AppOptimizeExplicitArgs f args startArgIdx args.size pInfo mInfo :: xs) + (.AppOptimizeExplicitArgs f args startArgIdx args.size pInfo mInfo + args (Array.replicate args.size none) :: xs) proof else optimizeExprAux (.InitOptimizeExpr args[idx]! :: stack) proof @@ -203,8 +193,6 @@ partial def optimizeExprAux (stack : List OptimizeStack) (proof : Option Expr := | _ => -- header on xs is expected to be .MatchRhsLambdaWaitForBody match (← stackContinuity xs next proof) with - -- TODO: compose proof certificates with Eq.trans - -- when both proof and nextProof are present | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof | _ => throwEnvError "optimizeExprAux: continuity expected for MatchRhsLambdaNext !!!" @@ -219,10 +207,11 @@ partial def optimizeExprAux (stack : List OptimizeStack) (proof : Option Expr := `return `stackContinuity stack (← mkExpr e)` -/ @[always_inline, inline] - normFVar (e : Expr) (stack : List OptimizeStack) : TranslateEnvT OptimizeContinuity := + normFVar (e : Expr) (stack : List OptimizeStack) (p : Option Expr) : + TranslateEnvT OptimizeContinuity := withLocalContext $ do match ← e.fvarId!.getValue? with - | none => stackContinuity stack (← mkExpr e) proof + | none => stackContinuity stack (← mkExpr e) p | some v => return Sum.inl (.InitOptimizeExpr (← instantiateMVars v) :: stack, none) @@ -268,14 +257,16 @@ partial def optimizeExprAux (stack : List OptimizeStack) (proof : Option Expr := if isBlasterDiteConst f then if idx == 1 then -- skipping optimization for Blaster.dite' cond as already performed by choice reduction on dite - return (.AppOptimizeExplicitArgs f args (idx + 1) stopIdx pInfo mInfo :: nxtStack) + return (.AppOptimizeExplicitArgs f args (idx + 1) stopIdx pInfo mInfo + args (Array.replicate args.size none) :: nxtStack) else if idx == 2 || idx == 3 then return optimizeDiteArg args[idx]! stack else return (.InitOptimizeExpr args[idx]! :: stack) else if let some argInfo := mInfo then if idx >= argInfo.getFirstDiscrPos && idx < argInfo.getFirstAltPos then -- skipping optimization for discriminators as already performed by choice reduction on match - return (.AppOptimizeExplicitArgs f args (idx + 1) stopIdx pInfo mInfo :: nxtStack) + return (.AppOptimizeExplicitArgs f args (idx + 1) stopIdx pInfo mInfo + args (Array.replicate args.size none) :: nxtStack) else if idx >= argInfo.getFirstAltPos && idx < argInfo.arity then optimizeMatchAlt args argInfo idx args[idx]! stack else return (.InitOptimizeExpr args[idx]! :: stack) diff --git a/Blaster/Optimize/Env.lean b/Blaster/Optimize/Env.lean index ce38004..32f1624 100644 --- a/Blaster/Optimize/Env.lean +++ b/Blaster/Optimize/Env.lean @@ -1,9 +1,10 @@ import Lean +import Blaster.Command.Options import Blaster.Optimize.Expr import Blaster.Optimize.MatchInfo import Blaster.Optimize.Opaque +import Blaster.Optimize.Types import Blaster.Smt.Term -import Blaster.Command.Options open Lean Meta Blaster.Smt Blaster.Options @@ -83,7 +84,7 @@ inductive MatchEntry where deriving Repr abbrev HypothesisMap := Std.HashMap Lean.Expr Lean.Expr -abbrev RewriteCacheMap := Std.HashMap Lean.Expr Lean.Expr +abbrev RewriteCacheMap := Std.HashMap Lean.Expr OptimizeResult abbrev MatchEntryMap := Std.HashMap Lean.Expr MatchEntry -- with key corresponding to a match pattern abbrev MatchContextMap := Std.HashMap Lean.Expr MatchEntryMap -- with key corresponding to a match discriminator abbrev EqualityMap := Std.HashMap Lean.Expr Lean.Expr -- with key corresponding to expression to be replaced. @@ -612,20 +613,20 @@ def isInFunApp : TranslateEnvT Bool := return (← get).optEnv.options.inFunApp @[always_inline, inline] -def findGlobalCache (a : Expr) : TranslateEnvT (Option Expr) := do +def findGlobalCache (a : Expr) : TranslateEnvT (Option OptimizeResult) := do return (← get).optEnv.globalRewriteCache.get? a @[always_inline, inline] -def findLocalCache (a : Expr) : TranslateEnvT (Option Expr) := do +def findLocalCache (a : Expr) : TranslateEnvT (Option OptimizeResult) := do return (← get).optEnv.localRewriteCache.get? a /-- Update global rewrite cache with `a := b`. -/ -def updateGlobalRewriteCache (a : Expr) (b : Expr) : TranslateEnvT Unit := do - modify (fun env => { env with optEnv.globalRewriteCache := env.optEnv.globalRewriteCache.insert a b }) +def updateGlobalRewriteCache (a : Expr) (r : OptimizeResult) : TranslateEnvT Unit := do + modify (fun env => { env with optEnv.globalRewriteCache := env.optEnv.globalRewriteCache.insert a r }) /-- Update local rewrite cache with `a := b`. -/ -def updateLocalRewriteCache (a : Expr) (b : Expr) : TranslateEnvT Unit := do - modify (fun env => { env with optEnv.localRewriteCache := env.optEnv.localRewriteCache.insert a b }) +def updateLocalRewriteCache (a : Expr) (r : OptimizeResult) : TranslateEnvT Unit := do + modify (fun env => { env with optEnv.localRewriteCache := env.optEnv.localRewriteCache.insert a r }) /-- Update synthesize decidable instance cache with `a := b`. -/ @[always_inline, inline] @@ -655,9 +656,9 @@ def withSynthInstanceCache (a : Expr) (f: Unit → TranslateEnvT (Option Expr)) @[always_inline, inline] def mkExpr (a : Expr) (cacheResult := true) : TranslateEnvT Expr := do match (← findGlobalCache a) with - | some a' => return a' + | some r => return r.optExpr | none => do - if cacheResult then updateGlobalRewriteCache a a + if cacheResult then updateGlobalRewriteCache a ⟨a, none⟩ return a /-- Return `true` only when both hypothesisMap and matchInContext are empty and isRefHyp flag is not set -/ @@ -668,16 +669,17 @@ def isGlobalContext : TranslateEnvT Bool := do /-- Perform the following: - When isGlobal - - Add entry `a := b` to `globalRewriteCache` + - Add entry `a := r` to `globalRewriteCache` - Otherwise - - Add entry `a := b` to `localRewriteCache` + - Add entry `a := r` to `localRewriteCache` -/ @[always_inline, inline] -def updateOptimizeEnvCache (a : Expr) (b : Expr) (isGlobal : Bool) : TranslateEnvT Unit := do - -- trace[Optimize.cacheExpr] "cacheExpr {← ppExpr a} ===> {← ppExpr b}" +def updateOptimizeEnvCache (a : Expr) (r : OptimizeResult) (isGlobal : Bool) : + TranslateEnvT Unit := do + -- trace[Optimize.cacheExpr] "cacheExpr {← ppExpr a} ===> {← ppExpr r.optExpr}" if isGlobal - then updateGlobalRewriteCache a b - else updateLocalRewriteCache a b + then updateGlobalRewriteCache a r + else updateLocalRewriteCache a r /-- Perform the following: - When isGlobal @@ -690,7 +692,7 @@ def updateOptimizeEnvCache (a : Expr) (b : Expr) (isGlobal : Bool) : TranslateEn - Otherwise `none` -/ @[always_inline, inline] -def isInOptimizeCache? (a : Expr) (isGlobal : Bool) : TranslateEnvT (Option Expr) := do +def isInOptimizeCache? (a : Expr) (isGlobal : Bool) : TranslateEnvT (Option OptimizeResult) := do if isGlobal then findGlobalCache a else findLocalCache a diff --git a/Blaster/Optimize/OptimizeStack.lean b/Blaster/Optimize/OptimizeStack.lean index bf8754f..3e0c061 100644 --- a/Blaster/Optimize/OptimizeStack.lean +++ b/Blaster/Optimize/OptimizeStack.lean @@ -2,8 +2,9 @@ import Lean import Blaster.Optimize.Rewriting.OptimizeITE import Blaster.Optimize.Rewriting.OptimizeProjection import Blaster.Optimize.Telescope +import Blaster.Reconstruct.Basic -open Lean Meta +open Lean Meta Blaster.Reconstruct namespace Blaster.Optimize @@ -49,6 +50,8 @@ inductive OptimizeStack where (startArgIdx : Nat) (stopIdx : Nat) (pInfo : FunEnvInfo) | AppOptimizeExplicitArgs (f : Expr) (args : Array Expr) (idx : Nat) (stopIdx : Nat) (pInfo : FunEnvInfo) (mInfo : Option MatchInfo) + (origArgs : Array Expr) + (argProofs : Array (Option Expr)) -- reserved for future use | DiteChoiceWaitForCond (f : Expr) (args : Array Expr) (pInfo : FunEnvInfo) (startArgIdx : Nat) | MatchChoiceOptimizeDiscrs (f : Expr) (args : Array Expr) (pInfo : FunEnvInfo) (startArgIdx : Nat) (idx : Nat) (mInfo : MatchInfo) @@ -108,7 +111,12 @@ def stackContinuity | [] => return Sum.inr ⟨optExpr, proof⟩ | .InitOptimizeReturn e isGlobal :: xs => - if !skipCache then updateOptimizeEnvCache e optExpr isGlobal + if !skipCache then + -- only cache the proof certificate if the expression contains no free variables, + -- as certificates with fvars are only valid within the local scope where those + -- variables were introduced + let cachedProof := if e.hasFVar then none else proof + updateOptimizeEnvCache e ⟨optExpr, cachedProof⟩ isGlobal match xs with | [] => return Sum.inr ⟨optExpr, proof⟩ | _ => stackContinuity xs optExpr proof @@ -199,12 +207,12 @@ def stackContinuity (.AppOptimizeImplicitArgs f (args.set! idx optExpr) (idx + 1) startArgIdx stopIdx pInfo :: xs, proof) - | .AppOptimizeExplicitArgs f args idx stopIdx pInfo mInfo :: xs => + | .AppOptimizeExplicitArgs f args idx stopIdx pInfo mInfo origArgs argProofs :: xs => -- optExpr corresponds to the optimized explicit argument referenced by idx. -- continuity with optimizing the next explicit argument. return Sum.inl (.AppOptimizeExplicitArgs f (args.set! idx optExpr) - (idx + 1) stopIdx pInfo mInfo :: xs, proof) + (idx + 1) stopIdx pInfo mInfo origArgs argProofs :: xs, proof) | .DiteChoiceWaitForCond f args pInfo startArgIdx :: xs => -- optExpr corresponds to the optimized Blaster.dite' conditional, i.e., referenced by index 1. @@ -219,7 +227,7 @@ def stackContinuity -- NOTE: keep matchInfo to avoid unnecessary query and to avoid optimizing discriminators again return Sum.inl (.AppOptimizeExplicitArgs f (args.set! 1 optExpr) - startArgIdx args.size pInfo none :: xs, proof) + startArgIdx args.size pInfo none args (Array.replicate args.size none) :: xs, proof) | .MatchChoiceOptimizeDiscrs f args pInfo startArgIdx idx mInfo :: xs => -- optExpr corresponds to the optimized match discriminator referenced by idx. @@ -332,12 +340,16 @@ def optimizeIfThenElse? (f : Expr) (args : Array Expr) (stack : List OptimizeSta @[always_inline, inline] def isInOptimizeEnvCache (expr : Expr) (proof : Option Expr) (stack : List OptimizeStack) : - TranslateEnvT (Sum (List OptimizeStack) OptimizeContinuity) := do + TranslateEnvT (Sum (List OptimizeStack × Option Expr) OptimizeContinuity) := do -- NOTE: Always consider global context when `a` does not contain any FVar. let isGlobal := !expr.hasFVar || (← isGlobalContext) match (← isInOptimizeCache? expr isGlobal) with - | some b => Sum.inr <$> stackContinuity stack b proof - | none => return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack) + | some r => + if r.proof.isNone && expr.hasFVar then + return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack, proof) + else + Sum.inr <$> stackContinuity stack r.optExpr (← composeProofs? proof r.proof) + | none => return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack, none) end Blaster.Optimize diff --git a/Blaster/Optimize/Rewriting/OptimizeApp.lean b/Blaster/Optimize/Rewriting/OptimizeApp.lean index 9d293fd..91dfb04 100644 --- a/Blaster/Optimize/Rewriting/OptimizeApp.lean +++ b/Blaster/Optimize/Rewriting/OptimizeApp.lean @@ -10,8 +10,9 @@ import Blaster.Optimize.Rewriting.OptimizeITE import Blaster.Optimize.Rewriting.OptimizeNat import Blaster.Optimize.Rewriting.OptimizeString import Blaster.Optimize.OptimizeStack +import Blaster.Reconstruct.Basic -open Lean Meta +open Lean Meta Blaster.Reconstruct namespace Blaster.Optimize @@ -88,7 +89,9 @@ def optimizeAppAux (f : Expr) (args: Array Expr) : TranslateEnvT OptimizeResult - proceed with stack continuity `incomingProof` is an optional proof certificate threaded through the optimizer stack. - NOTE: proof certificate composition via `Eq.trans` is not yet implemented (see TODO comments). + + When an argument was rewritten, the incoming proof is lifted via `congrArg` + and composed with the application-level proof via `Eq.trans` in `optimizeApp`. NOTE: skipPropCheck is set to `true` only when it is known beforehand that `f` is a recursive function for which `allExplicitParamsAreCtor f args (funPropagation := true)` @@ -99,7 +102,15 @@ def optimizeApp (stack : List OptimizeStack) (incomingProof : Option Expr := none) (skipPropCheck := false) : TranslateEnvT OptimizeContinuity := do let ⟨e, newProof⟩ ← optimizeAppAux f args - let proof := newProof.orElse (λ _ => incomingProof) + let proof ← match incomingProof, newProof with + | some inP, some np => do + -- inP : origArg = optArg (an argument was rewritten) + -- np : f(optArgs) = result (the application-level rewrite on optimized args) + -- build congrArg to lift the arg rewrite to application level, then compose + match ← buildCongrArgFromProof f args inP with + | some congrP => composeProofs? (some congrP) (some np) + | none => pure (some np) + | _, _ => composeProofs? incomingProof newProof if ← isRestart then resetRestart return Sum.inl (.InitOptimizeExpr e :: stack, none) diff --git a/Blaster/Reconstruct/Basic.lean b/Blaster/Reconstruct/Basic.lean index cdf9b3b..0ce7c24 100644 --- a/Blaster/Reconstruct/Basic.lean +++ b/Blaster/Reconstruct/Basic.lean @@ -1,7 +1,40 @@ import Lean -open Lean +open Lean Meta namespace Blaster.Reconstruct +/-- Compose two proof certificates `p₁ : a = b` and `p₂ : b = c` into `p : a = c` via `Eq.trans`. -/ +def composeProofs (p₁ p₂ : Expr) : MetaM Expr := + mkAppM ``Eq.trans #[p₁, p₂] + +/-- Compose two optional proof certificates via `Eq.trans`. + If either is `none`, the other is returned unchanged. -/ +def composeProofs? (opt_p₁ opt_p₂ : Option Expr) : MetaM (Option Expr) := + match opt_p₁, opt_p₂ with + | none, p => return p + | p, none => return p + | some p₁, some p₂ => return some (← composeProofs p₁ p₂) + +/-- Given a function application f(args) and a proof that one argument was rewritten + (argProof : origArg = optArg), build a congruence proof that lifts the argument + rewrite to the application level. + + Finds i such that args[i] = optArg, then builds: + congrArg (f args[0] ... args[i-1]) argProof : f(..,origArg,..) = f(..,optArg,..) + + Returns none if the rewritten argument cannot be identified. -/ +def buildCongrArgFromProof (f : Expr) (args : Array Expr) + (argProof : Expr) : MetaM (Option Expr) := do + let proofType ← inferType argProof + let some (_, _, optArg) := proofType.eq? | return none + let mut idx : Option Nat := none + for i in [:args.size] do + if ← isDefEq args[i]! optArg then + idx := some i + break + let some i := idx | return none + let partialApp := mkAppN f (args.extract 0 i) + return some (← mkAppM ``congrArg #[partialApp, argProof]) + end Blaster.Reconstruct diff --git a/Tests/Reconstruct/Basic.lean b/Tests/Reconstruct/Basic.lean index 625e86f..679c979 100644 --- a/Tests/Reconstruct/Basic.lean +++ b/Tests/Reconstruct/Basic.lean @@ -3,3 +3,5 @@ import Blaster example : ∀ {x : Nat}, 0 + x = x := by blaster example : ∀ {x : Nat}, 0 + (0 + x) = x := by blaster + +example : ∀ {x : Nat}, 0 + (0 + (0 + x)) = x := by blaster From 2344b53a0885226dbfd6321c55f3fbd1bc6934ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Wed, 11 Mar 2026 11:31:00 -0300 Subject: [PATCH 10/31] ci: re-run tests From b5ce0bc5d09b45c0908e44503893aa7f8c2c5bbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Wed, 11 Mar 2026 12:14:30 -0300 Subject: [PATCH 11/31] test: simplify cache lookup --- Blaster/Optimize/OptimizeStack.lean | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/Blaster/Optimize/OptimizeStack.lean b/Blaster/Optimize/OptimizeStack.lean index 3e0c061..916f4b6 100644 --- a/Blaster/Optimize/OptimizeStack.lean +++ b/Blaster/Optimize/OptimizeStack.lean @@ -345,11 +345,7 @@ def isInOptimizeEnvCache (expr : Expr) (proof : Option Expr) (stack : List Optim let isGlobal := !expr.hasFVar || (← isGlobalContext) match (← isInOptimizeCache? expr isGlobal) with | some r => - if r.proof.isNone && expr.hasFVar then - return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack, proof) - else - Sum.inr <$> stackContinuity stack r.optExpr (← composeProofs? proof r.proof) + Sum.inr <$> stackContinuity stack r.optExpr (← composeProofs? proof r.proof) | none => return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack, none) - end Blaster.Optimize From 3ea694291ffd7a40526406e04978e95ca93ebea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Wed, 11 Mar 2026 17:18:06 -0300 Subject: [PATCH 12/31] selective cache bypass for proof reconstruction --- Blaster/Optimize/OptimizeStack.lean | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Blaster/Optimize/OptimizeStack.lean b/Blaster/Optimize/OptimizeStack.lean index 916f4b6..c120bf3 100644 --- a/Blaster/Optimize/OptimizeStack.lean +++ b/Blaster/Optimize/OptimizeStack.lean @@ -345,7 +345,10 @@ def isInOptimizeEnvCache (expr : Expr) (proof : Option Expr) (stack : List Optim let isGlobal := !expr.hasFVar || (← isGlobalContext) match (← isInOptimizeCache? expr isGlobal) with | some r => - Sum.inr <$> stackContinuity stack r.optExpr (← composeProofs? proof r.proof) + if r.proof.isNone && expr.hasFVar && !Lean.Expr.equal r.optExpr expr then + return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack, proof) + else + Sum.inr <$> stackContinuity stack r.optExpr (← composeProofs? proof r.proof) | none => return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack, none) end Blaster.Optimize From 5777d526ddcceb8badf08d6d2229cf0aa513e50f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Wed, 11 Mar 2026 20:46:18 -0300 Subject: [PATCH 13/31] track stripped proofs explicitly to reduce unnecessary re-optimization --- Blaster/Optimize/Env.lean | 10 ++++++++-- Blaster/Optimize/OptimizeStack.lean | 15 +++++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/Blaster/Optimize/Env.lean b/Blaster/Optimize/Env.lean index 32f1624..036002a 100644 --- a/Blaster/Optimize/Env.lean +++ b/Blaster/Optimize/Env.lean @@ -290,6 +290,11 @@ structure OptimizeEnv where -/ restart : Bool + /-- Set of expressions whose proof certificates were stripped during caching + due to containing free variables in a global context. + -/ + strippedProofExprs : Std.HashSet Lean.Expr + /-- local declaration context -/ ctx : LocalDeclContext @@ -308,6 +313,7 @@ instance : Inhabited OptimizeEnv where memCache := default, options := default, restart := false, + strippedProofExprs := Std.HashSet.emptyWithCapacity, ctx := default } @@ -664,7 +670,7 @@ def mkExpr (a : Expr) (cacheResult := true) : TranslateEnvT Expr := do /-- Return `true` only when both hypothesisMap and matchInContext are empty and isRefHyp flag is not set -/ @[always_inline, inline] def isGlobalContext : TranslateEnvT Bool := do - let ⟨_, ⟨_, _, _, _, _, _, _, _, hypothesisContext, matchInContext, _, _, _, _⟩⟩ ← get + let ⟨_, ⟨_, _, _, _, _, _, _, _, hypothesisContext, matchInContext, _, _, _, _, _⟩⟩ ← get return hypothesisContext.hypothesisMap.size == 0 && matchInContext.size == 0 /-- Perform the following: @@ -1714,7 +1720,7 @@ where An error is triggered if no corresponding entry can be found in `recFunMap`. -/ def hasRecFunInst? (instApp : Expr) : TranslateEnvT (Option Expr) := do - let ⟨_, ⟨_, _, _, _, _,recFunInstCache,_,recFunMap, _, _, _, _, _, _⟩⟩ ← get + let ⟨_, ⟨_, _, _, _, _,recFunInstCache,_,recFunMap, _, _, _, _, _, _, _⟩⟩ ← get match recFunInstCache.get? instApp with | some fbody => -- retrieve function application from recFunMap diff --git a/Blaster/Optimize/OptimizeStack.lean b/Blaster/Optimize/OptimizeStack.lean index c120bf3..259987d 100644 --- a/Blaster/Optimize/OptimizeStack.lean +++ b/Blaster/Optimize/OptimizeStack.lean @@ -70,7 +70,7 @@ abbrev OptimizeContinuity := Sum (List OptimizeStack × Option Expr) OptimizeRes @[always_inline, inline] def mkHypStackContext (h : UpdatedHypContext) : TranslateEnvT HypsStackContext := do - let ⟨_, ⟨_, localRewriteCache, _, _, _, _, _, _, hypothesisContext, _, _, _, _, _⟩⟩ ← get + let ⟨_, ⟨_, localRewriteCache, _, _, _, _, _, _, hypothesisContext, _, _, _, _, _, _⟩⟩ ← get if h.1 then updateHypothesis h.2 Std.HashMap.emptyWithCapacity return {newHCtx := h, oldHCtx := some hypothesisContext, oldCache := some localRewriteCache} @@ -85,7 +85,7 @@ def resetHypContext (h : HypsStackContext) : TranslateEnvT Unit := do @[always_inline, inline] def mkMatchStackContext (h : MatchContextMap) : TranslateEnvT MatchStackContext := do - let ⟨_, ⟨_, localRewriteCache, _, _, _, _, _, _, _, matchInContext, _, _, _, _⟩⟩ ← get + let ⟨_, ⟨_, localRewriteCache, _, _, _, _, _, _, _, matchInContext, _, _, _, _, _⟩⟩ ← get updateMatchContext h Std.HashMap.emptyWithCapacity return {oldMatchCtx := matchInContext, oldCache := localRewriteCache} @@ -112,10 +112,13 @@ def stackContinuity | .InitOptimizeReturn e isGlobal :: xs => if !skipCache then - -- only cache the proof certificate if the expression contains no free variables, - -- as certificates with fvars are only valid within the local scope where those - -- variables were introduced + -- Strip proof certificates for expressions with free variables in global cache, + -- as fvars are only valid within the local scope where they were introduced. + -- Track stripped proofs so they can be selectively re-derived on cache hit. let cachedProof := if e.hasFVar then none else proof + if proof.isSome && cachedProof.isNone then + modify (fun env => { env with + optEnv.strippedProofExprs := env.optEnv.strippedProofExprs.insert e }) updateOptimizeEnvCache e ⟨optExpr, cachedProof⟩ isGlobal match xs with | [] => return Sum.inr ⟨optExpr, proof⟩ @@ -345,7 +348,7 @@ def isInOptimizeEnvCache (expr : Expr) (proof : Option Expr) (stack : List Optim let isGlobal := !expr.hasFVar || (← isGlobalContext) match (← isInOptimizeCache? expr isGlobal) with | some r => - if r.proof.isNone && expr.hasFVar && !Lean.Expr.equal r.optExpr expr then + if r.proof.isNone && expr.hasFVar && (← get).optEnv.strippedProofExprs.contains expr then return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack, proof) else Sum.inr <$> stackContinuity stack r.optExpr (← composeProofs? proof r.proof) From 67f7963b4ca5654c436081297ae0f7be84d13cd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Fri, 13 Mar 2026 14:06:15 -0300 Subject: [PATCH 14/31] fix ambiguity when searching for argument index at buildCongrArgFromProof --- Blaster/Optimize/OptimizeStack.lean | 2 +- Blaster/Optimize/Rewriting/OptimizeNat.lean | 19 +++++++++++-------- Blaster/Reconstruct/Basic.lean | 5 +++-- Tests/Reconstruct/Basic.lean | 9 +++++++-- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/Blaster/Optimize/OptimizeStack.lean b/Blaster/Optimize/OptimizeStack.lean index 259987d..7ce6d37 100644 --- a/Blaster/Optimize/OptimizeStack.lean +++ b/Blaster/Optimize/OptimizeStack.lean @@ -352,6 +352,6 @@ def isInOptimizeEnvCache (expr : Expr) (proof : Option Expr) (stack : List Optim return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack, proof) else Sum.inr <$> stackContinuity stack r.optExpr (← composeProofs? proof r.proof) - | none => return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack, none) + | none => return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack, proof) end Blaster.Optimize diff --git a/Blaster/Optimize/Rewriting/OptimizeNat.lean b/Blaster/Optimize/Rewriting/OptimizeNat.lean index 119f116..12eeab4 100644 --- a/Blaster/Optimize/Rewriting/OptimizeNat.lean +++ b/Blaster/Optimize/Rewriting/OptimizeNat.lean @@ -138,18 +138,21 @@ def optimizeNatPow (f : Expr) (args : Array Expr) : TranslateEnvT Expr := do Assume that f = Expr.const ``Nat.mul. An error is triggered when args.size ≠ 2 (i.e., only fully applied `Nat.mul` expected at this stage) -/ -def optimizeNatMul (f : Expr) (args : Array Expr) : TranslateEnvT Expr := do +def optimizeNatMul (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult := do if args.size != 2 then throwEnvError "optimizeNatMul: exactly two arguments expected" let op1 := args[0]! let op2 := args[1]! match isNatValue? op1, isNatValue? op2 with - | some 0, _ => return op1 - | some 1, _ => return op2 - | some n1, some n2 => evalBinNatOp Nat.mul n1 n2 + | some 0, _ => + let proof := mkApp (mkConst ``Nat.zero_mul) op2 + trace[Optimize.expr] "optimizeNatMul : {proof}" + return ⟨op1, some proof⟩ + | some 1, _ => return ⟨op2, none⟩ + | some n1, some n2 => return ⟨← evalBinNatOp Nat.mul n1 n2, none⟩ | nv1, _ => - if let some r ← cstMulProp? nv1 op2 then return r - if let some r ← mulPowReduceExpr? op1 op2 then return r - return (mkApp2 f op1 op2) + if let some r ← cstMulProp? nv1 op2 then return ⟨r, none⟩ + if let some r ← mulPowReduceExpr? op1 op2 then return ⟨r, none⟩ + return ⟨(mkApp2 f op1 op2), none⟩ where /- Given `mv1` and `op2`, return `some ((N1 "*" N2) * n)` @@ -342,7 +345,7 @@ def optimizeNat? (f : Expr) (args : Array Expr) : TranslateEnvT (Option Optimize match n with | ``Nat.add => optimizeNatAdd f args | ``Nat.sub => return some ⟨← optimizeNatSub f args, none⟩ - | ``Nat.mul => return some ⟨← optimizeNatMul f args, none⟩ + | ``Nat.mul => optimizeNatMul f args | ``Nat.div => return some ⟨← optimizeNatDiv f args, none⟩ | ``Nat.mod => return some ⟨← optimizeNatMod f args, none⟩ | ``Nat.beq => return some ⟨← optimizeNatBeq f args, none⟩ diff --git a/Blaster/Reconstruct/Basic.lean b/Blaster/Reconstruct/Basic.lean index 0ce7c24..ae3f24e 100644 --- a/Blaster/Reconstruct/Basic.lean +++ b/Blaster/Reconstruct/Basic.lean @@ -30,8 +30,9 @@ def buildCongrArgFromProof (f : Expr) (args : Array Expr) let some (_, _, optArg) := proofType.eq? | return none let mut idx : Option Nat := none for i in [:args.size] do - if ← isDefEq args[i]! optArg then - idx := some i + let i' := args.size - 1 - i + if ← isDefEq args[i']! optArg then + idx := some i' break let some i := idx | return none let partialApp := mkAppN f (args.extract 0 i) diff --git a/Tests/Reconstruct/Basic.lean b/Tests/Reconstruct/Basic.lean index 679c979..8e612ba 100644 --- a/Tests/Reconstruct/Basic.lean +++ b/Tests/Reconstruct/Basic.lean @@ -1,7 +1,12 @@ import Blaster +-- Nat.add example : ∀ {x : Nat}, 0 + x = x := by blaster - example : ∀ {x : Nat}, 0 + (0 + x) = x := by blaster - example : ∀ {x : Nat}, 0 + (0 + (0 + x)) = x := by blaster + +-- Nat.mul +example : ∀ {x : Nat}, 0 * x = 0 := by blaster +example : ∀ {x : Nat}, 0 * (0 * x) = 0 := by blaster + +example : ∀ {x : Nat}, (0 + x) * x = x * x := by blaster From ede8eefb4b525d4013749c57707fc9b50c27374a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Fri, 13 Mar 2026 19:45:07 -0300 Subject: [PATCH 15/31] fix: annotate proof argument position to survive unfolding --- Blaster/Optimize/Basic.lean | 2 + Blaster/Optimize/OptimizeStack.lean | 20 +++-- Blaster/Optimize/Rewriting/OptimizeNat.lean | 1 - Blaster/Reconstruct/Basic.lean | 84 ++++++++++++++++----- Tests/Reconstruct/Basic.lean | 4 +- 5 files changed, 83 insertions(+), 28 deletions(-) diff --git a/Blaster/Optimize/Basic.lean b/Blaster/Optimize/Basic.lean index d2ab290..12e1bf5 100644 --- a/Blaster/Optimize/Basic.lean +++ b/Blaster/Optimize/Basic.lean @@ -128,6 +128,8 @@ partial def optimizeExprAux (stack : List OptimizeStack) (proof : Option Expr := | .AppOptimizeExplicitArgs f args idx stopIdx pInfo mInfo origArgs argProofs :: xs => if idx ≥ stopIdx then + -- annotating proof with position-from-end so it survives unfolding + let proof ← annotateProofWithPosFromEnd args origArgs argProofs proof -- normalizing ite/match function application if let some re ← normChoiceApplication? f args then -- trace[Optimize.normChoiceApp] "normalizing choice application {reprStr f} {reprStr args} => {reprStr re}" diff --git a/Blaster/Optimize/OptimizeStack.lean b/Blaster/Optimize/OptimizeStack.lean index 7ce6d37..cee338c 100644 --- a/Blaster/Optimize/OptimizeStack.lean +++ b/Blaster/Optimize/OptimizeStack.lean @@ -112,10 +112,12 @@ def stackContinuity | .InitOptimizeReturn e isGlobal :: xs => if !skipCache then - -- Strip proof certificates for expressions with free variables in global cache, - -- as fvars are only valid within the local scope where they were introduced. - -- Track stripped proofs so they can be selectively re-derived on cache hit. - let cachedProof := if e.hasFVar then none else proof + -- Strip proofs containing fvars before caching, as fvars are scope-local. + -- Track stripped entries for re-derivation on cache hit. + let cachedProof := if e.hasFVar then none + else match proof with + | some p => if p.hasFVar then none else some p + | none => none if proof.isSome && cachedProof.isNone then modify (fun env => { env with optEnv.strippedProofExprs := env.optEnv.strippedProofExprs.insert e }) @@ -213,9 +215,13 @@ def stackContinuity | .AppOptimizeExplicitArgs f args idx stopIdx pInfo mInfo origArgs argProofs :: xs => -- optExpr corresponds to the optimized explicit argument referenced by idx. -- continuity with optimizing the next explicit argument. + -- Only store proof when the argument was actually rewritten, + -- to avoid contamination from carried-over proofs of previous arguments. + let argChanged := !exprEq optExpr origArgs[idx]! + let argProofs' := if proof.isSome && argChanged then argProofs.set! idx proof else argProofs return Sum.inl (.AppOptimizeExplicitArgs f (args.set! idx optExpr) - (idx + 1) stopIdx pInfo mInfo origArgs argProofs :: xs, proof) + (idx + 1) stopIdx pInfo mInfo origArgs argProofs' :: xs, proof) | .DiteChoiceWaitForCond f args pInfo startArgIdx :: xs => -- optExpr corresponds to the optimized Blaster.dite' conditional, i.e., referenced by index 1. @@ -344,11 +350,11 @@ def optimizeIfThenElse? (f : Expr) (args : Array Expr) (stack : List OptimizeSta @[always_inline, inline] def isInOptimizeEnvCache (expr : Expr) (proof : Option Expr) (stack : List OptimizeStack) : TranslateEnvT (Sum (List OptimizeStack × Option Expr) OptimizeContinuity) := do - -- NOTE: Always consider global context when `a` does not contain any FVar. + -- NOTE: Always consider global context when `expr` does not contain any FVar. let isGlobal := !expr.hasFVar || (← isGlobalContext) match (← isInOptimizeCache? expr isGlobal) with | some r => - if r.proof.isNone && expr.hasFVar && (← get).optEnv.strippedProofExprs.contains expr then + if r.proof.isNone && (← get).optEnv.strippedProofExprs.contains expr then return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack, proof) else Sum.inr <$> stackContinuity stack r.optExpr (← composeProofs? proof r.proof) diff --git a/Blaster/Optimize/Rewriting/OptimizeNat.lean b/Blaster/Optimize/Rewriting/OptimizeNat.lean index 12eeab4..0ecc936 100644 --- a/Blaster/Optimize/Rewriting/OptimizeNat.lean +++ b/Blaster/Optimize/Rewriting/OptimizeNat.lean @@ -145,7 +145,6 @@ def optimizeNatMul (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult match isNatValue? op1, isNatValue? op2 with | some 0, _ => let proof := mkApp (mkConst ``Nat.zero_mul) op2 - trace[Optimize.expr] "optimizeNatMul : {proof}" return ⟨op1, some proof⟩ | some 1, _ => return ⟨op2, none⟩ | some n1, some n2 => return ⟨← evalBinNatOp Nat.mul n1 n2, none⟩ diff --git a/Blaster/Reconstruct/Basic.lean b/Blaster/Reconstruct/Basic.lean index ae3f24e..22736fa 100644 --- a/Blaster/Reconstruct/Basic.lean +++ b/Blaster/Reconstruct/Basic.lean @@ -1,6 +1,7 @@ import Lean +import Blaster.Optimize.Env -open Lean Meta +open Lean Meta Blaster.Optimize namespace Blaster.Reconstruct @@ -16,26 +17,71 @@ def composeProofs? (opt_p₁ opt_p₂ : Option Expr) : MetaM (Option Expr) := | p, none => return p | some p₁, some p₂ => return some (← composeProofs p₁ p₂) -/-- Given a function application f(args) and a proof that one argument was rewritten - (argProof : origArg = optArg), build a congruence proof that lifts the argument - rewrite to the application level. +/-- Tag for annotating the argument position from the end. -/ +def argPosFromEndKey : Name := `_blaster.argPosFromEnd - Finds i such that args[i] = optArg, then builds: - congrArg (f args[0] ... args[i-1]) argProof : f(..,origArg,..) = f(..,optArg,..) +/-- Annotate the proof with the position relative to the end, so it survives + unfolding (which strips implicit args from the front). + Compares args against origArgs via isDefEq to find which argument + was actually rewritten, ignoring definitionally equal changes. -/ +def annotateProofWithPosFromEnd + (args : Array Expr) (origArgs : Array Expr) (argProofs : Array (Option Expr)) + (proof : Option Expr) : TranslateEnvT (Option Expr) := do + match proof with + | none => return none + | some p => + let mut proofIdx? : Option Nat := none + for i in [:argProofs.size] do + if (argProofs[i]!).isSome then + if !(← withLocalContext $ + withNewMCtxDepth $ + withReducible $ + isDefEq args[i]! origArgs[i]!) then + proofIdx? := some i + match proofIdx? with + | some proofIdx => + let posFromEnd := args.size - 1 - proofIdx + return some (Expr.mdata (MData.empty.setNat argPosFromEndKey posFromEnd) p) + | none => return some p +/-- Given a function application f(args) and a proof that one argument was rewritten + (argProof : origArg = optArg), build a congruence proof that lifts the rewrite + to the full application level. + Finds i such that args[i] was rewritten, using either an MData annotation + encoding position-from-end, or a reverse isDefEq search as fallback. + Then builds: + congrFun (... (congrFun (congrArg (f a₀..a_{i-1}) proof) a_{i+1}) ...) a_{n-1} Returns none if the rewritten argument cannot be identified. -/ -def buildCongrArgFromProof (f : Expr) (args : Array Expr) - (argProof : Expr) : MetaM (Option Expr) := do - let proofType ← inferType argProof - let some (_, _, optArg) := proofType.eq? | return none - let mut idx : Option Nat := none - for i in [:args.size] do - let i' := args.size - 1 - i - if ← isDefEq args[i']! optArg then - idx := some i' - break - let some i := idx | return none - let partialApp := mkAppN f (args.extract 0 i) - return some (← mkAppM ``congrArg #[partialApp, argProof]) +def buildCongrArgFromProof (f : Expr) (args : Array Expr) (argProof : Expr) + : MetaM (Option Expr) := do + let (proof, annotatedIdx?) := match argProof with + | Expr.mdata d p => + let posFromEnd := d.getNat argPosFromEndKey args.size + if posFromEnd < args.size then + let idx := args.size - 1 - posFromEnd + (p, some idx) + else + (argProof, none) + | _ => (argProof, none) + let proofType ← inferType proof + let some (_, _origArg, optArg) := proofType.eq? | return none + let idx? ← match annotatedIdx? with + | some idx => pure (some idx) + | none => + let mut found := none + for i in [:args.size] do + let i' := args.size - 1 - i + if ← isDefEq args[i']! optArg then + found := some i' + break + pure found + match idx? with + | some idx => + let partialApp := mkAppN f (args[:idx]) + let mut p ← mkCongrArg partialApp proof + for j in [idx + 1 : args.size] do + p ← mkCongrFun p args[j]! + return some p + | none => return none end Blaster.Reconstruct diff --git a/Tests/Reconstruct/Basic.lean b/Tests/Reconstruct/Basic.lean index 8e612ba..e0df0e2 100644 --- a/Tests/Reconstruct/Basic.lean +++ b/Tests/Reconstruct/Basic.lean @@ -8,5 +8,7 @@ example : ∀ {x : Nat}, 0 + (0 + (0 + x)) = x := by blaster -- Nat.mul example : ∀ {x : Nat}, 0 * x = 0 := by blaster example : ∀ {x : Nat}, 0 * (0 * x) = 0 := by blaster +example : ∀ {x : Nat}, 0 * (0 * (0 * x)) = 0 := by blaster -example : ∀ {x : Nat}, (0 + x) * x = x * x := by blaster +-- Combination +example : ∀ {x : Nat}, (0 * (0 * (0 + x))) + x = x := by blaster From 7d3bf1bc33663ba78ebf8cc3baa679b76b3fa51d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Fri, 13 Mar 2026 23:07:09 -0300 Subject: [PATCH 16/31] fix: restrict proof propagation to fvar-free expressions in cache --- Blaster/Optimize/OptimizeStack.lean | 5 +++-- Blaster/Reconstruct/Basic.lean | 11 +++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/Blaster/Optimize/OptimizeStack.lean b/Blaster/Optimize/OptimizeStack.lean index cee338c..77c0b7e 100644 --- a/Blaster/Optimize/OptimizeStack.lean +++ b/Blaster/Optimize/OptimizeStack.lean @@ -354,10 +354,11 @@ def isInOptimizeEnvCache (expr : Expr) (proof : Option Expr) (stack : List Optim let isGlobal := !expr.hasFVar || (← isGlobalContext) match (← isInOptimizeCache? expr isGlobal) with | some r => - if r.proof.isNone && (← get).optEnv.strippedProofExprs.contains expr then + if r.proof.isNone && proof.isSome && (← get).optEnv.strippedProofExprs.contains expr then return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack, proof) else Sum.inr <$> stackContinuity stack r.optExpr (← composeProofs? proof r.proof) - | none => return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack, proof) + | none => return Sum.inl (.InitOptimizeReturn expr isGlobal :: stack, + if !expr.hasFVar then proof else none) end Blaster.Optimize diff --git a/Blaster/Reconstruct/Basic.lean b/Blaster/Reconstruct/Basic.lean index 22736fa..2193437 100644 --- a/Blaster/Reconstruct/Basic.lean +++ b/Blaster/Reconstruct/Basic.lean @@ -33,10 +33,13 @@ def annotateProofWithPosFromEnd let mut proofIdx? : Option Nat := none for i in [:argProofs.size] do if (argProofs[i]!).isSome then - if !(← withLocalContext $ - withNewMCtxDepth $ - withReducible $ - isDefEq args[i]! origArgs[i]!) then + let unchanged ← try + withLocalContext $ + withNewMCtxDepth $ + withReducible $ + isDefEq args[i]! origArgs[i]! + catch _ => pure false + if !unchanged then proofIdx? := some i match proofIdx? with | some proofIdx => From a93ba4bd0c7d34835d657fa368c989bf213221f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Fri, 13 Mar 2026 23:30:30 -0300 Subject: [PATCH 17/31] remove unnecessary try/catch in annotateProofWithPosFromEnd --- Blaster/Reconstruct/Basic.lean | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/Blaster/Reconstruct/Basic.lean b/Blaster/Reconstruct/Basic.lean index 2193437..22736fa 100644 --- a/Blaster/Reconstruct/Basic.lean +++ b/Blaster/Reconstruct/Basic.lean @@ -33,13 +33,10 @@ def annotateProofWithPosFromEnd let mut proofIdx? : Option Nat := none for i in [:argProofs.size] do if (argProofs[i]!).isSome then - let unchanged ← try - withLocalContext $ - withNewMCtxDepth $ - withReducible $ - isDefEq args[i]! origArgs[i]! - catch _ => pure false - if !unchanged then + if !(← withLocalContext $ + withNewMCtxDepth $ + withReducible $ + isDefEq args[i]! origArgs[i]!) then proofIdx? := some i match proofIdx? with | some proofIdx => From 21cc1b4aa3201bd92379d7f2d44b307f03b87e83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Sun, 15 Mar 2026 11:00:00 -0300 Subject: [PATCH 18/31] bridge for proof certificates composition --- Blaster/Command/Tactic.lean | 7 +- Blaster/Optimize/Basic.lean | 36 ++++++-- Blaster/Optimize/OptimizeStack.lean | 24 +++++- Blaster/Optimize/Rewriting/OptimizeNat.lean | 4 +- Blaster/Reconstruct/Basic.lean | 91 +++++++++++---------- Tests/Reconstruct/Basic.lean | 22 +++-- 6 files changed, 118 insertions(+), 66 deletions(-) diff --git a/Blaster/Command/Tactic.lean b/Blaster/Command/Tactic.lean index dbadc94..4444b28 100644 --- a/Blaster/Command/Tactic.lean +++ b/Blaster/Command/Tactic.lean @@ -37,9 +37,10 @@ def blasterTacticImp : Tactic := fun stx => let (goal, nbQuantifiers) ← revertHypotheses (← getMainGoal) let env := {(default : TranslateEnv) with optEnv.options.solverOptions := sOpts} let ((result, (optExpr, proof)), _) ← - withTheReader Core.Context (fun ctx => { ctx with maxHeartbeats := 0 }) $ do - IO.setNumHeartbeats 0 - Translate.main (← goal.getType >>= instantiateMVars') (logUndetermined := false) |>.run env + withTheReader Core.Context + (fun ctx => { ctx with maxHeartbeats := 0, maxRecDepth := max ctx.maxRecDepth 4096 }) $ do + IO.setNumHeartbeats 0 + Translate.main (← goal.getType >>= instantiateMVars') (logUndetermined := false) |>.run env match result with | .Valid => match proof with diff --git a/Blaster/Optimize/Basic.lean b/Blaster/Optimize/Basic.lean index 12e1bf5..380e06b 100644 --- a/Blaster/Optimize/Basic.lean +++ b/Blaster/Optimize/Basic.lean @@ -128,6 +128,10 @@ partial def optimizeExprAux (stack : List OptimizeStack) (proof : Option Expr := | .AppOptimizeExplicitArgs f args idx stopIdx pInfo mInfo origArgs argProofs :: xs => if idx ≥ stopIdx then + -- recover proof from argProofs if it was lost during arg processing + let proof := match proof with + | some _ => proof + | none => argProofs.findSome? id -- annotating proof with position-from-end so it survives unfolding let proof ← annotateProofWithPosFromEnd args origArgs argProofs proof -- normalizing ite/match function application @@ -155,7 +159,22 @@ partial def optimizeExprAux (stack : List OptimizeStack) (proof : Option Expr := -- NOTE: we can only unfold once all parameters have been optimized. else if let some fdef ← getUnfoldFunDef? f args then -- trace[Optimize.unfoldDef] "unfolding function definition {reprStr f} {reprStr args} => {reprStr fdef}" - optimizeExprAux (.InitOptimizeExpr fdef :: xs) proof + match proof with + | some p => + let isGlobal := !fdef.hasFVar || (← isGlobalContext) + if (← isInOptimizeCache? fdef isGlobal).isSome then + optimizeExprAux (.InitOptimizeExpr fdef :: xs) proof + else + match ← buildCongrArgFromProof f args p with + | some liftedProof => + if liftedProof.hasFVar && !fdef.hasFVar then + optimizeExprAux (.InitOptimizeExpr fdef :: xs) proof + else + optimizeExprAux (.InitOptimizeExpr fdef :: .ProofBridge liftedProof :: xs) none + | none => + optimizeExprAux (.InitOptimizeExpr fdef :: xs) proof + | none => + optimizeExprAux (.InitOptimizeExpr fdef :: xs) none -- normalizing partially apply function after unfolding non-opaque functions else if let some pe ← normPartialFun? f args then -- trace[Optimize.normPartial] "normalizing partial function {reprStr f} {reprStr args} => {reprStr pe}" @@ -324,13 +343,14 @@ def Optimize.main (e : Expr) : TranslateEnvT OptimizeResult := do NOTE: This function is to be used only by callOptimize in package Test. -/ def command (sOpts: BlasterOptions) (e : Expr) : MetaM (Expr × TranslateEnv) := do - -- keep the current name generator and restore it afterwards - let ngen ← getNGen - let env := {(default : TranslateEnv) with optEnv.options.solverOptions := sOpts} - let (⟨optExpr, _proof⟩, translateEnv) ← Optimize.main e|>.run env - -- restore name generator - setNGen ngen - return (optExpr, translateEnv) + withTheReader Core.Context (fun ctx => { ctx with maxRecDepth := max ctx.maxRecDepth 4096 }) do + -- keep the current name generator and restore it afterwards + let ngen ← getNGen + let env := {(default : TranslateEnv) with optEnv.options.solverOptions := sOpts} + let (⟨optExpr, _proof⟩, translateEnv) ← Optimize.main e|>.run env + -- restore name generator + setNGen ngen + return (optExpr, translateEnv) initialize diff --git a/Blaster/Optimize/OptimizeStack.lean b/Blaster/Optimize/OptimizeStack.lean index 77c0b7e..05ca4c7 100644 --- a/Blaster/Optimize/OptimizeStack.lean +++ b/Blaster/Optimize/OptimizeStack.lean @@ -36,6 +36,7 @@ instance : Repr LocalContext where inductive OptimizeStack where | InitOptimizeExpr (e : Expr) | InitOptimizeReturn (e : Expr) (isGlobal : Bool) + | ProofBridge (proof : Expr) | InitOpaqueRecExpr (f : Expr) (args : Array Expr) | RecFunDefWaitForStorage (args : Array Expr) (instApp : Expr) (subsInts : Expr) (params : ImplicitParameters) @@ -126,6 +127,9 @@ def stackContinuity | [] => return Sum.inr ⟨optExpr, proof⟩ | _ => stackContinuity xs optExpr proof + | .ProofBridge storedProof :: xs => + stackContinuity xs optExpr (← composeProofs? (some storedProof) proof) + | .RecFunDefWaitForStorage args instApp subsInst params :: xs => -- optExpr corresponds to optimized rec fun body -- continuity with normOpaqueAndRecFun @@ -219,9 +223,10 @@ def stackContinuity -- to avoid contamination from carried-over proofs of previous arguments. let argChanged := !exprEq optExpr origArgs[idx]! let argProofs' := if proof.isSome && argChanged then argProofs.set! idx proof else argProofs + let proof' := if proof.isSome && argChanged then none else proof return Sum.inl (.AppOptimizeExplicitArgs f (args.set! idx optExpr) - (idx + 1) stopIdx pInfo mInfo origArgs argProofs' :: xs, proof) + (idx + 1) stopIdx pInfo mInfo origArgs argProofs' :: xs, proof') | .DiteChoiceWaitForCond f args pInfo startArgIdx :: xs => -- optExpr corresponds to the optimized Blaster.dite' conditional, i.e., referenced by index 1. @@ -263,7 +268,10 @@ def stackContinuity if let some h := hctx then resetHypContext h let e ← withLocalContext $ do mkLambdaExpr x optExpr resetLocalDeclContext lctx - stackContinuity xs e proof + let proof' := match proof with + | some p => if p.containsFVar x.fvarId! then none else some p + | none => none + stackContinuity xs e proof' | .MatchRhsLambdaWaitForType n bi body :: xs => -- optExpr corresponds to optimized lambda type @@ -281,7 +289,10 @@ def stackContinuity -- rhs has not been optimized yet. let e ← withLocalContext $ do mkLambdaFVar x optExpr resetLocalDeclContext lctx - stackContinuity xs e proof + let proof' := match proof with + | some p => if p.containsFVar x.fvarId! then none else some p + | none => none + stackContinuity xs e proof' | .MatchAltWaitForExpr params lctx mctx :: xs => -- optExpr corresponds to the optimized match rhs @@ -289,7 +300,12 @@ def stackContinuity let e ← withLocalContext $ do mkExpr (← mkLambdaFVars' params optExpr) resetMatchContext mctx resetLocalDeclContext lctx - stackContinuity xs e proof + let proof' := match proof with + | some p => + if params.any (fun param => p.containsFVar param.fvarId!) then none + else some p + | none => none + stackContinuity xs e proof' | .LetWaitForValue body :: xs => -- optExpr corresponds to the optimized let value diff --git a/Blaster/Optimize/Rewriting/OptimizeNat.lean b/Blaster/Optimize/Rewriting/OptimizeNat.lean index 0ecc936..00fb380 100644 --- a/Blaster/Optimize/Rewriting/OptimizeNat.lean +++ b/Blaster/Optimize/Rewriting/OptimizeNat.lean @@ -146,7 +146,9 @@ def optimizeNatMul (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult | some 0, _ => let proof := mkApp (mkConst ``Nat.zero_mul) op2 return ⟨op1, some proof⟩ - | some 1, _ => return ⟨op2, none⟩ + | some 1, _ => + let proof := mkApp (mkConst ``Nat.one_mul) op2 + return ⟨op2, some proof⟩ | some n1, some n2 => return ⟨← evalBinNatOp Nat.mul n1 n2, none⟩ | nv1, _ => if let some r ← cstMulProp? nv1 op2 then return ⟨r, none⟩ diff --git a/Blaster/Reconstruct/Basic.lean b/Blaster/Reconstruct/Basic.lean index 22736fa..74463aa 100644 --- a/Blaster/Reconstruct/Basic.lean +++ b/Blaster/Reconstruct/Basic.lean @@ -15,7 +15,9 @@ def composeProofs? (opt_p₁ opt_p₂ : Option Expr) : MetaM (Option Expr) := match opt_p₁, opt_p₂ with | none, p => return p | p, none => return p - | some p₁, some p₂ => return some (← composeProofs p₁ p₂) + | some p₁, some p₂ => + try return some (← composeProofs p₁ p₂) + catch _ => return none /-- Tag for annotating the argument position from the end. -/ def argPosFromEndKey : Name := `_blaster.argPosFromEnd @@ -30,19 +32,22 @@ def annotateProofWithPosFromEnd match proof with | none => return none | some p => - let mut proofIdx? : Option Nat := none - for i in [:argProofs.size] do - if (argProofs[i]!).isSome then - if !(← withLocalContext $ - withNewMCtxDepth $ - withReducible $ - isDefEq args[i]! origArgs[i]!) then - proofIdx? := some i - match proofIdx? with - | some proofIdx => - let posFromEnd := args.size - 1 - proofIdx - return some (Expr.mdata (MData.empty.setNat argPosFromEndKey posFromEnd) p) - | none => return some p + let mut proofIdx? : Option Nat := none + for i in [:argProofs.size] do + if (argProofs[i]!).isSome then + let unchanged ← try + withLocalContext $ + withNewMCtxDepth $ + withReducible $ + isDefEq args[i]! origArgs[i]! + catch _ => pure false + if !unchanged then + proofIdx? := some i + match proofIdx? with + | some proofIdx => + let posFromEnd := args.size - 1 - proofIdx + return some (Expr.mdata (MData.empty.setNat argPosFromEndKey posFromEnd) p) + | none => return some p /-- Given a function application f(args) and a proof that one argument was rewritten (argProof : origArg = optArg), build a congruence proof that lifts the rewrite @@ -54,34 +59,34 @@ def annotateProofWithPosFromEnd Returns none if the rewritten argument cannot be identified. -/ def buildCongrArgFromProof (f : Expr) (args : Array Expr) (argProof : Expr) : MetaM (Option Expr) := do - let (proof, annotatedIdx?) := match argProof with - | Expr.mdata d p => - let posFromEnd := d.getNat argPosFromEndKey args.size - if posFromEnd < args.size then - let idx := args.size - 1 - posFromEnd - (p, some idx) - else - (argProof, none) - | _ => (argProof, none) - let proofType ← inferType proof - let some (_, _origArg, optArg) := proofType.eq? | return none - let idx? ← match annotatedIdx? with - | some idx => pure (some idx) - | none => - let mut found := none - for i in [:args.size] do - let i' := args.size - 1 - i - if ← isDefEq args[i']! optArg then - found := some i' - break - pure found - match idx? with - | some idx => - let partialApp := mkAppN f (args[:idx]) - let mut p ← mkCongrArg partialApp proof - for j in [idx + 1 : args.size] do - p ← mkCongrFun p args[j]! - return some p - | none => return none + try + let (proof, annotatedIdx?) := match argProof with + | Expr.mdata d p => + let posFromEnd := d.getNat argPosFromEndKey args.size + if posFromEnd < args.size then + (p, some (args.size - 1 - posFromEnd)) + else (argProof, none) + | _ => (argProof, none) + let idx? ← match annotatedIdx? with + | some idx => pure (some idx) + | none => + let proofType ← inferType proof + let some (_, _origArg, optArg) := proofType.eq? | return none + let mut found := none + for i in [:args.size] do + let i' := args.size - 1 - i + if ← isDefEq args[i']! optArg then + found := some i' + break + pure found + match idx? with + | some idx => + let partialApp := mkAppN f (args[:idx]) + let mut p ← mkCongrArg partialApp proof + for j in [idx + 1 : args.size] do + p ← mkCongrFun p args[j]! + return some p + | none => return none + catch _ => return none end Blaster.Reconstruct diff --git a/Tests/Reconstruct/Basic.lean b/Tests/Reconstruct/Basic.lean index e0df0e2..56c5db7 100644 --- a/Tests/Reconstruct/Basic.lean +++ b/Tests/Reconstruct/Basic.lean @@ -1,14 +1,22 @@ import Blaster -- Nat.add -example : ∀ {x : Nat}, 0 + x = x := by blaster -example : ∀ {x : Nat}, 0 + (0 + x) = x := by blaster -example : ∀ {x : Nat}, 0 + (0 + (0 + x)) = x := by blaster +example : 1 + 2 = 3 := by blaster +example : ∀ {n : Nat}, 0 + n = n := by blaster +example : ∀ {n : Nat}, 0 + (0 + n) = n := by blaster +example : ∀ {n : Nat}, 0 + (0 + (0 + n)) = n := by blaster -- Nat.mul -example : ∀ {x : Nat}, 0 * x = 0 := by blaster -example : ∀ {x : Nat}, 0 * (0 * x) = 0 := by blaster -example : ∀ {x : Nat}, 0 * (0 * (0 * x)) = 0 := by blaster +example : 2 * 3 = 6 := by blaster +example : ∀ {n : Nat}, 0 * n = 0 := by blaster +example : ∀ {n : Nat}, 0 * (0 * n) = 0 := by blaster +example : ∀ {n : Nat}, 0 * (0 * (0 * n)) = 0 := by blaster +example : ∀ {n : Nat}, 1 * n = n := by blaster +example : ∀ {n : Nat}, 1 * (1 * n) = n := by blaster +example : ∀ {n : Nat}, 1 * (1 * (1 * n)) = n := by blaster -- Combination -example : ∀ {x : Nat}, (0 * (0 * (0 + x))) + x = x := by blaster +example : (2 * 3) + 1 = 7 := by blaster +example : ∀ {n : Nat}, 0 + ((0 * (0 * (0 + n))) + n) = n := by blaster +example : ∀ {n : Nat}, (0 + n) + 1 = n + 1 := by blaster +example : ∀ {n : Nat}, 1 + (0 + n) = n + 1 := by blaster From 8532e658b4cc12e048f27b01f692e891b6bb8a7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Sun, 15 Mar 2026 11:12:36 -0300 Subject: [PATCH 19/31] test: removing try-catch fallback in proof reconstruction --- Blaster/Reconstruct/Basic.lean | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/Blaster/Reconstruct/Basic.lean b/Blaster/Reconstruct/Basic.lean index 74463aa..15b4c24 100644 --- a/Blaster/Reconstruct/Basic.lean +++ b/Blaster/Reconstruct/Basic.lean @@ -35,14 +35,11 @@ def annotateProofWithPosFromEnd let mut proofIdx? : Option Nat := none for i in [:argProofs.size] do if (argProofs[i]!).isSome then - let unchanged ← try - withLocalContext $ - withNewMCtxDepth $ - withReducible $ - isDefEq args[i]! origArgs[i]! - catch _ => pure false - if !unchanged then - proofIdx? := some i + if !(← withLocalContext $ + withNewMCtxDepth $ + withReducible $ + isDefEq args[i]! origArgs[i]!) then + proofIdx? := some i match proofIdx? with | some proofIdx => let posFromEnd := args.size - 1 - proofIdx From 0d7a59e8aebe51a30424427aea85b8f0d39035e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Mon, 16 Mar 2026 09:29:38 -0300 Subject: [PATCH 20/31] testOptimize support for proof reconstruction --- Blaster/Optimize/Basic.lean | 10 +-- Blaster/Optimize/Rewriting/OptimizeNat.lean | 36 +++++---- Blaster/Optimize/Rewriting/Utils.lean | 2 +- .../Optimize/OptimizeNat/OptimizeNatAdd.lean | 34 ++++---- Tests/Reconstruct/Basic.lean | 8 ++ Tests/Utils.lean | 78 ++++++++++++------- 6 files changed, 101 insertions(+), 67 deletions(-) diff --git a/Blaster/Optimize/Basic.lean b/Blaster/Optimize/Basic.lean index 380e06b..e75677e 100644 --- a/Blaster/Optimize/Basic.lean +++ b/Blaster/Optimize/Basic.lean @@ -339,19 +339,19 @@ def Optimize.main (e : Expr) : TranslateEnvT OptimizeResult := do - `sOpts`: The solver options to use for optimization. - `expr`: The expression to be optimized. ### Returns - - A tuple containing the optimized expression and the optimization environment. + - A tuple containing the optimized expression, an optional proof certificate, + and the optimization environment. NOTE: This function is to be used only by callOptimize in package Test. -/ -def command (sOpts: BlasterOptions) (e : Expr) : MetaM (Expr × TranslateEnv) := do +def command (sOpts: BlasterOptions) (e : Expr) : MetaM (Expr × Option Expr × TranslateEnv) := do withTheReader Core.Context (fun ctx => { ctx with maxRecDepth := max ctx.maxRecDepth 4096 }) do -- keep the current name generator and restore it afterwards let ngen ← getNGen let env := {(default : TranslateEnv) with optEnv.options.solverOptions := sOpts} - let (⟨optExpr, _proof⟩, translateEnv) ← Optimize.main e|>.run env + let (⟨optExpr, proof⟩, translateEnv) ← Optimize.main e|>.run env -- restore name generator setNGen ngen - return (optExpr, translateEnv) - + return (optExpr, proof, translateEnv) initialize registerTraceClass `Optimize.expr diff --git a/Blaster/Optimize/Rewriting/OptimizeNat.lean b/Blaster/Optimize/Rewriting/OptimizeNat.lean index 00fb380..50258d6 100644 --- a/Blaster/Optimize/Rewriting/OptimizeNat.lean +++ b/Blaster/Optimize/Rewriting/OptimizeNat.lean @@ -9,7 +9,7 @@ open Lean Meta namespace Blaster.Optimize /-- Apply the following simplification/normalization rules on `Nat.add` : - - 0 + n ==> n [proof: Nat.zero_add] + - 0 + n ==> n [proof: Nat.zero_add] - N1 + N2 ===> N1 "+" N2 - N1 + (N2 + n) ==> (N1 "+" N2) + n - n1 + n2 ==> n2 + n1 (if n2 <ₒ n1) @@ -43,9 +43,9 @@ def optimizeNatAdd (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult /-- Apply the following simplification/normalization rules on `Nat.sub` : - - n1 - n2 ==> 0 (if n1 =ₚₜᵣ n2) - - 0 - n ==> 0 - - n - 0 ==> n + - n1 - n2 ==> 0 (if n1 =ₚₜᵣ n2) [proof: Nat.sub_self] + - 0 - n ==> 0 [proof: Nat.zero_sub] + - n - 0 ==> n [proof: Nat.sub_zero] - N1 - N2 ==> N1 "-" N2 - N1 - (N2 + n) ==> (N1 "-" N2) - n - (N1 - n) - N2 ==> (N1 "-" N2) - n @@ -54,19 +54,25 @@ def optimizeNatAdd (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult Assume that f = Expr.const ``Nat.sub. An error is triggered when args.size ≠ 2 (i.e., only fully applied `Nat.sub` expected at this stage) -/ -def optimizeNatSub (f : Expr) (args : Array Expr) : TranslateEnvT Expr := do +def optimizeNatSub (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult := do if args.size != 2 then throwEnvError "optimizeNatSub: exactly two arguments expected" let op1 := args[0]! let op2 := args[1]! - if exprEq op1 op2 then return (← mkNatLitExpr 0) + if exprEq op1 op2 then + let proof := mkApp (mkConst ``Nat.sub_self) op1 + return ⟨← mkNatLitExpr 0, some proof⟩ match isNatValue? op1, isNatValue? op2 with - | some 0, _ - | _, some 0 => return op1 - | some n1, some n2 => evalBinNatOp Nat.sub n1 n2 + | some 0, _ => + let proof := mkApp (mkConst ``Nat.zero_sub) op2 + return ⟨op1, some proof⟩ + | _, some 0 => + let proof := mkApp (mkConst ``Nat.sub_zero) op1 + return ⟨op1, some proof⟩ + | some n1, some n2 => return ⟨← evalBinNatOp Nat.sub n1 n2, none⟩ | nv1, nv2 => - if let some r ← cstSubPropRight? nv1 op2 then return r - if let some r ← cstSubPropLeft? op1 nv2 then return r - return (mkApp2 f op1 op2) + if let some r ← cstSubPropRight? nv1 op2 then return ⟨r, none⟩ + if let some r ← cstSubPropLeft? op1 nv2 then return ⟨r, none⟩ + return ⟨mkApp2 f op1 op2, none⟩ where /- Given `mv1` and `op2` return `some ((N1 "-" N2) - n)` when @@ -129,8 +135,8 @@ def optimizeNatPow (f : Expr) (args : Array Expr) : TranslateEnvT Expr := do | _, _ => return (mkApp2 f op1 op2) /-- Apply the following simplification/normalization rules on `Nat.mul` : - - 0 * n ==> 0 - - 1 * n ==> n + - 0 * n ==> 0 [proof: Nat.zero_mul] + - 1 * n ==> n [proof: Nat.one_mul] - N1 + N2 ==> N1 "*" N2 - N1 * (N2 * n) ==> (N1 "*" N2) * n - n1 * n2 ==> n2 * n1 (if n2 <ₒ n1) @@ -345,7 +351,7 @@ def optimizeNat? (f : Expr) (args : Array Expr) : TranslateEnvT (Option Optimize let Expr.const n _ := f | return none match n with | ``Nat.add => optimizeNatAdd f args - | ``Nat.sub => return some ⟨← optimizeNatSub f args, none⟩ + | ``Nat.sub => optimizeNatSub f args | ``Nat.mul => optimizeNatMul f args | ``Nat.div => return some ⟨← optimizeNatDiv f args, none⟩ | ``Nat.mod => return some ⟨← optimizeNatMod f args, none⟩ diff --git a/Blaster/Optimize/Rewriting/Utils.lean b/Blaster/Optimize/Rewriting/Utils.lean index a0a7476..9420d41 100644 --- a/Blaster/Optimize/Rewriting/Utils.lean +++ b/Blaster/Optimize/Rewriting/Utils.lean @@ -317,7 +317,7 @@ def reorderEq (args : Array Expr) : TranslateEnvT (Expr × Expr) := do if isBoolNotExprOf e1 e2 then return (e2, e1) return r -/-- Reorder operands for commutative `Int` operators as follows: +/-- Reorder operands for commutative `Nat` operators as follows: - #[N1, N2] ===> args - #[N, e] ===> args - #[e, N] ===> #[N, e] diff --git a/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean b/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean index d78bc89..59105e4 100644 --- a/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean +++ b/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean @@ -13,19 +13,19 @@ namespace Test.OptimizeNatAdd def natAddCst_1 : Expr := Lean.Expr.lit (Lean.Literal.natVal 10) elab "natAddCst_1" : term => return natAddCst_1 -#testOptimize [ "NatAddCst_1" ] (0 : Nat) + 10 ===> natAddCst_1 +#testOptimize [ "NatAddCst_1", proof ] (0 : Nat) + 10 ===> natAddCst_1 -- 12 + 0 ===> 12 def natAddCst_2 : Expr := Lean.Expr.lit (Lean.Literal.natVal 12) elab "natAddCst_2" : term => return natAddCst_2 -#testOptimize [ "NatAddCst_2" ] (12 : Nat) + 0 ===> natAddCst_2 +#testOptimize [ "NatAddCst_2", proof ] (12 : Nat) + 0 ===> natAddCst_2 -- 123 + 50 ===> 173 def natAddCst_3 : Expr := Lean.Expr.lit (Lean.Literal.natVal 173) elab "natAddCst_3" : term => return natAddCst_3 -#testOptimize [ "NatAddCst_3" ] (123 : Nat) + 50 ===> natAddCst_3 +#testOptimize [ "NatAddCst_3", proof ] (123 : Nat) + 50 ===> natAddCst_3 /-! Test cases for simplification rule `0 + n ===> n`. -/ @@ -71,13 +71,13 @@ def natAddZeroUnchanged_1 : Expr := elab "natAddZeroUnchanged_1" : term => return natAddZeroUnchanged_1 -#testOptimize [ "NatAddZeroUnchanged_1" ] ∀ (x y : Nat), (1 + x) < y ===> natAddZeroUnchanged_1 +#testOptimize [ "NatAddZeroUnchanged_1", proof ] ∀ (x y : Nat), (1 + x) < y ===> natAddZeroUnchanged_1 -- (27 - 26) + x ===> 1 + x -#testOptimize [ "NatAddZeroUnchanged_2" ] ∀ (x y : Nat), (27 - 26) + x < y ===> natAddZeroUnchanged_1 +#testOptimize [ "NatAddZeroUnchanged_2", proof ] ∀ (x y : Nat), (27 - 26) + x < y ===> natAddZeroUnchanged_1 -- (Nat.zero + 1) + x ===> 1 + x -#testOptimize [ "NatAddZeroUnchanged_3" ] ∀ (x y : Nat), (Nat.zero + 1) + x < y ===> natAddZeroUnchanged_1 +#testOptimize [ "NatAddZeroUnchanged_3" , proof ] ∀ (x y : Nat), (Nat.zero + 1) + x < y ===> natAddZeroUnchanged_1 -- (127 + 40) + x ===> 167 + x def natAddZeroUnchanged_4 : Expr := @@ -99,7 +99,7 @@ def natAddZeroUnchanged_4 : Expr := elab "natAddZeroUnchanged_4" : term => return natAddZeroUnchanged_4 -#testOptimize [ "NatAddZeroUnchanged_4" ] ∀ (x y : Nat), (127 + 40) + x < y ===> natAddZeroUnchanged_4 +#testOptimize [ "NatAddZeroUnchanged_4", proof ] ∀ (x y : Nat), (127 + 40) + x < y ===> natAddZeroUnchanged_4 /-! Test cases for simplification rule `N1 + (N2 + n) ===> (N1 "+" N2) + n`. -/ @@ -154,7 +154,7 @@ elab "natAddCstProp_5" : term => return natAddCstProp_5 #testOptimize [ "NatAddCstProp_10" ] ∀ (x : Nat), 10 + (20 + ((x + 10) - 7)) = 33 + x ===> True -- 100 + ((180 - (x + 40)) - 150) = 100 -#testOptimize [ "NatAddCstProp_11" ] ∀ (x : Nat), 100 + ((180 - (x + 40)) - 150) = 100 ===> True +#testOptimize [ "NatAddCstProp_11", proof ] ∀ (x : Nat), 100 + ((180 - (x + 40)) - 150) = 100 ===> True /-! Test cases to ensure that simplification rule `N1 + (N2 + n) ===> (N1 "+" N2) + n` @@ -185,7 +185,7 @@ def natAddCstPropUnchanged_1 : Expr := elab "natAddCstPropUnchanged_1" : term => return natAddCstPropUnchanged_1 -#testOptimize [ "NatAddCstPropUnchanged_1" ] ∀ (x y : Nat), 40 + (x + y) < y ===> natAddCstPropUnchanged_1 +#testOptimize [ "NatAddCstPropUnchanged_1", proof ] ∀ (x y : Nat), 40 + (x + y) < y ===> natAddCstPropUnchanged_1 -- 40 + (x - y) ===> 40 + (x - y) @@ -212,7 +212,7 @@ def natAddCstPropUnchanged_2 : Expr := elab "natAddCstPropUnchanged_2" : term => return natAddCstPropUnchanged_2 -#testOptimize [ "NatAddCstPropUnchanged_2" ] ∀ (x y : Nat), 40 + (x - y) < y ===> natAddCstPropUnchanged_2 +#testOptimize [ "NatAddCstPropUnchanged_2", proof ] ∀ (x y : Nat), 40 + (x - y) < y ===> natAddCstPropUnchanged_2 -- 40 + (x * y) ===> 40 + (x * y) @@ -239,7 +239,7 @@ def natAddCstPropUnchanged_3 : Expr := elab "natAddCstPropUnchanged_3" : term => return natAddCstPropUnchanged_3 -#testOptimize [ "NatAddCstPropUnchanged_3" ] ∀ (x y : Nat), 40 + (x * y) < y ===> natAddCstPropUnchanged_3 +#testOptimize [ "NatAddCstPropUnchanged_3", proof ] ∀ (x y : Nat), 40 + (x * y) < y ===> natAddCstPropUnchanged_3 /-! Test cases for normalization rule `n1 + n2 ==> n2 + n1 (if n2 <ₒ n1)`. -/ @@ -292,13 +292,13 @@ elab "natAddCommut_5" : term => return natAddCommut_5 /-! Test cases to ensure that `Nat.add` is preserved when expected. -/ -- x + (y + 0) = y + x ===> True -#testOptimize [ "NatAddVar_1" ] ∀ (x y : Nat), x + (y + 0) = y + x ===> True +#testOptimize [ "NatAddVar_1", proof ] ∀ (x y : Nat), x + (y + 0) = y + x ===> True -- (x + 0) + y = y + x ===> True -#testOptimize [ "NatAddVar_2" ] ∀ (x y : Nat), (x + 0) + y = y + x ===> True +#testOptimize [ "NatAddVar_2", proof ] ∀ (x y : Nat), (x + 0) + y = y + x ===> True -- (x + 0) + (y + 0) = y + x ===> True -#testOptimize [ "NatAddVar_3" ] ∀ (x y : Nat), (x + 0) + (y + 0) = y + x ===> True +#testOptimize [ "NatAddVar_3", proof ] ∀ (x y : Nat), (x + 0) + (y + 0) = y + x ===> True -- x + y < 10 ===> x + y < 10 def natAddVar_4 : Expr := @@ -320,7 +320,7 @@ def natAddVar_4 : Expr := elab "natAddVar_4" : term => return natAddVar_4 -#testOptimize [ "NatAddVar_4" ] ∀ (x y : Nat), x + y < 10 ===> natAddVar_4 +#testOptimize [ "NatAddVar_4", proof ] ∀ (x y : Nat), x + y < 10 ===> natAddVar_4 /-! Test cases to ensure that constant propagation is properly performed @@ -334,12 +334,12 @@ variable (y : Nat) def natAddReduce_1 : Expr := Lean.Expr.lit (Lean.Literal.natVal 100) elab "natAddReduce_1" : term => return natAddReduce_1 -#testOptimize [ "NatAddReduce_1" ] (100 + ((180 - (x + 40)) - 150)) + ((200 - y) - 320) ===> natAddReduce_1 +#testOptimize [ "NatAddReduce_1", proof ] (100 + ((180 - (x + 40)) - 150)) + ((200 - y) - 320) ===> natAddReduce_1 def natAddReduce_2 : Expr := Lean.Expr.lit (Lean.Literal.natVal 124) elab "natAddReduce_2" : term => return natAddReduce_2 -- (100 + ((180 - (x + 40)) - 150)) + (((20 - y) - 50) + 24) ===> 124 -#testOptimize [ "NatAddReduce_2" ] (100 + ((180 - (x + 40)) - 150)) + (((20 - y) - 50) + 24) ===> natAddReduce_2 +#testOptimize [ "NatAddReduce_2", proof ] (100 + ((180 - (x + 40)) - 150)) + (((20 - y) - 50) + 24) ===> natAddReduce_2 end Test.OptimizeNatAdd diff --git a/Tests/Reconstruct/Basic.lean b/Tests/Reconstruct/Basic.lean index 56c5db7..07a1eda 100644 --- a/Tests/Reconstruct/Basic.lean +++ b/Tests/Reconstruct/Basic.lean @@ -6,6 +6,10 @@ example : ∀ {n : Nat}, 0 + n = n := by blaster example : ∀ {n : Nat}, 0 + (0 + n) = n := by blaster example : ∀ {n : Nat}, 0 + (0 + (0 + n)) = n := by blaster +-- Nat.sub +example : ∀ {n : Nat}, n - n = 0 := by blaster +example : ∀ {n : Nat}, 0 - n = 0 := by blaster + -- Nat.mul example : 2 * 3 = 6 := by blaster example : ∀ {n : Nat}, 0 * n = 0 := by blaster @@ -15,8 +19,12 @@ example : ∀ {n : Nat}, 1 * n = n := by blaster example : ∀ {n : Nat}, 1 * (1 * n) = n := by blaster example : ∀ {n : Nat}, 1 * (1 * (1 * n)) = n := by blaster +-- N1 + (N2 + n) ==> (N1 "+" N2) + n +example : ∀ (x : Nat), 10 + (20 + x) = 30 + x := by blaster + -- Combination example : (2 * 3) + 1 = 7 := by blaster example : ∀ {n : Nat}, 0 + ((0 * (0 * (0 + n))) + n) = n := by blaster +example : ∀ {n : Nat}, (1 * ((0 * n) + n)) - 0 = n := by blaster example : ∀ {n : Nat}, (0 + n) + 1 = n + 1 := by blaster example : ∀ {n : Nat}, 1 + (0 + n) = n + 1 := by blaster diff --git a/Tests/Utils.lean b/Tests/Utils.lean index 6e07db6..3586d56 100644 --- a/Tests/Utils.lean +++ b/Tests/Utils.lean @@ -9,31 +9,37 @@ namespace Tests def parseTerm (stx : Syntax) : TermElabM Expr := elabTermAndSynthesize stx none /-- Parse a term syntax and call optimize. -/ -def callOptimize (sOpts : BlasterOptions) (stx : Syntax) : TermElabM Expr := - withTheReader Core.Context (fun ctx => { ctx with maxHeartbeats := 0 }) $ do - let optRes ← (Blaster.Optimize.command sOpts (← parseTerm stx)) - pure optRes.1 +def callOptimize (sOpts : BlasterOptions) (stx : Syntax) : TermElabM (Expr × Option Expr) := + withTheReader Core.Context + (fun ctx => { ctx with maxHeartbeats := 0, maxRecDepth := max ctx.maxRecDepth 4096 }) $ do + let (optExpr, proof, _) ← Blaster.Optimize.command sOpts (← parseTerm stx) + pure (optExpr, proof) /-! ## Definition of #testOptimize command to write unit test for Blaster.optimize The #testOptimize usage is as follows: - #testOptimize [ "TestName" ] (verbose: num)? (norm-result: num)? TermToOptimize ==> OptimizedTerm + #testOptimize [ "TestName" ] (verbose: num)? (norm-result: num)? TermToOptimize ===> OptimizedTerm + #testOptimize [ "TestName", proof ] (verbose: num)? (norm-result: num)? TermToOptimize ===> OptimizedTerm with options: - verbose: activate debug info - norm-result: apply nat literal normalization, beta reduction on lambda application and structure projection normalization on expected result. + - proof: require that the optimizer produces a valid proof certificate + (verified via type check). E.g. - #testOptimize [ "AndSubsumption" ] ∀ (a : Prop), a ∧ a ==> ∀ (a : Prop), a + #testOptimize [ "AndSubsumption" ] ∀ (a : Prop), a ∧ a ===> ∀ (a : Prop), a + #testOptimize [ "NatAddZero", proof ] ∀ (x : Nat), 0 + x = x ===> True -/ -syntax testName := "[" str "]" +syntax testName := "[" str ("," "proof")? "]" syntax termReducedTo := term "===>" term syntax normNatLitOption := ("(norm-result:" num ")")? syntax (name := testOptimize) "#testOptimize" testName solveOption* normNatLitOption termReducedTo : command -def parseTestName : TSyntax `testName -> CommandElabM String - | `(testName| [ $s:str ]) => pure s.getString - | _ => throwUnsupportedSyntax +def parseTestName : TSyntax `testName → CommandElabM (String × Bool) + | `(testName| [ $s:str , proof ]) => pure (s.getString, true) + | `(testName| [ $s:str ]) => pure (s.getString, false) + | _ => throwUnsupportedSyntax def parseTermReducedTo : TSyntax `termReducedTo -> CommandElabM (Syntax × Syntax) |`(termReducedTo | $t1 ===> $t2) => pure (t1.raw, t2.raw) @@ -48,7 +54,6 @@ def parseNormNatLitOption : TSyntax `normNatLitOption -> CommandElabM Bool | `(normNatLitOption| ) => return false | _ => throwUnsupportedSyntax - /-- Remove metadata annotations from `e`. -/ def removeAnnotations (e : Expr) : Expr := let rec visit (e : Expr) (k : Expr → Expr) := @@ -145,23 +150,38 @@ partial def normNatLitAndLambdaBeta (e : Expr) : MetaM Expr := do @[command_elab testOptimize] def testOptimizeImp : CommandElab := fun stx => do - let name ← parseTestName ⟨stx[1]⟩ - let sOpts ← parseVerbose default ⟨stx[2]⟩ - let normNatFlag ← parseNormNatLitOption ⟨stx[3]⟩ - let (t1, t2) ← parseTermReducedTo ⟨stx[4]⟩ - withoutModifyingEnv $ runTermElabM fun _ => do - -- create a local declaration name for the test case - let m ← getMainModule - withDeclName (m ++ name.toName) $ do - let actual ← callOptimize sOpts t1 - let expected' := removeAnnotations (← parseTerm t2) - -- keep the current name generator and restore it afterwards - let ngen ← getNGen - let expected ← if normNatFlag then normNatLitAndLambdaBeta expected' else pure expected' - -- restore name generator - setNGen ngen - if actual == expected - then logInfo f!"{name} ✅ Success!" - else logError f!"{name} ❌ Failure! : expecting {reprStr expected} \nbut got {reprStr actual}" + let (name, requireProof) ← parseTestName ⟨stx[1]⟩ + let sOpts ← parseVerbose default ⟨stx[2]⟩ + let normNatFlag ← parseNormNatLitOption ⟨stx[3]⟩ + let (t1, t2) ← parseTermReducedTo ⟨stx[4]⟩ + withoutModifyingEnv $ runTermElabM fun _ => do + -- create a local declaration name for the test case + let m ← getMainModule + withDeclName (m ++ name.toName) $ do + let (actual, proofCert) ← callOptimize sOpts t1 + let expected' := removeAnnotations (← parseTerm t2) + -- keep the current name generator and restore it afterwards + let ngen ← getNGen + let expected ← if normNatFlag then normNatLitAndLambdaBeta expected' else pure expected' + -- restore name generator + setNGen ngen + if actual == expected then + if requireProof then + match proofCert with + | some p => + if (← try inferType p *> pure true catch _ => pure false) then + logInfo f!"{name} ✅ Success! [proof ✓]" + else + logError f!"{name} ❌ Failure! : proof certificate failed type check" + | none => + let inputExpr ← parseTerm t1 + if (← try isDefEq actual inputExpr catch _ => pure false) then + logInfo f!"{name} ✅ Success! [refl ✓]" + else + logError f!"{name} ❌ Failure! : no proof certificate and refl failed" + else + logInfo f!"{name} ✅ Success!" + else + logError f!"{name} ❌ Failure! : expecting {reprStr expected} \nbut got {reprStr actual}" end Tests From 13555b887316dd81e793884e10b54fe462b31874 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Mon, 16 Mar 2026 09:49:22 -0300 Subject: [PATCH 21/31] test: removing macRecDepth from testOptimize --- Tests/Reconstruct/Basic.lean | 6 +++--- Tests/Utils.lean | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/Tests/Reconstruct/Basic.lean b/Tests/Reconstruct/Basic.lean index 07a1eda..a805f69 100644 --- a/Tests/Reconstruct/Basic.lean +++ b/Tests/Reconstruct/Basic.lean @@ -19,12 +19,12 @@ example : ∀ {n : Nat}, 1 * n = n := by blaster example : ∀ {n : Nat}, 1 * (1 * n) = n := by blaster example : ∀ {n : Nat}, 1 * (1 * (1 * n)) = n := by blaster --- N1 + (N2 + n) ==> (N1 "+" N2) + n -example : ∀ (x : Nat), 10 + (20 + x) = 30 + x := by blaster - -- Combination example : (2 * 3) + 1 = 7 := by blaster example : ∀ {n : Nat}, 0 + ((0 * (0 * (0 + n))) + n) = n := by blaster example : ∀ {n : Nat}, (1 * ((0 * n) + n)) - 0 = n := by blaster example : ∀ {n : Nat}, (0 + n) + 1 = n + 1 := by blaster example : ∀ {n : Nat}, 1 + (0 + n) = n + 1 := by blaster + +-- N1 + (N2 + n) ==> (N1 "+" N2) + n +example : ∀ (x : Nat), 10 + (20 + x) = 30 + x := by blaster diff --git a/Tests/Utils.lean b/Tests/Utils.lean index 3586d56..8139abe 100644 --- a/Tests/Utils.lean +++ b/Tests/Utils.lean @@ -10,10 +10,9 @@ def parseTerm (stx : Syntax) : TermElabM Expr := elabTermAndSynthesize stx none /-- Parse a term syntax and call optimize. -/ def callOptimize (sOpts : BlasterOptions) (stx : Syntax) : TermElabM (Expr × Option Expr) := - withTheReader Core.Context - (fun ctx => { ctx with maxHeartbeats := 0, maxRecDepth := max ctx.maxRecDepth 4096 }) $ do - let (optExpr, proof, _) ← Blaster.Optimize.command sOpts (← parseTerm stx) - pure (optExpr, proof) + withTheReader Core.Context (fun ctx => { ctx with maxHeartbeats := 0 }) $ do + let (optExpr, proof, _) ← Blaster.Optimize.command sOpts (← parseTerm stx) + pure (optExpr, proof) /-! ## Definition of #testOptimize command to write unit test for Blaster.optimize The #testOptimize usage is as follows: From 982370b2c03b4ceb9d5ee0885d7856615aeaa1a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Mon, 16 Mar 2026 10:14:14 -0300 Subject: [PATCH 22/31] scope maxRecDepth to blaster tactic and specific deep tests --- Blaster/Optimize/Basic.lean | 11 +++++------ Tests/Optimize/OptimizeNat/OptimizeNatMul.lean | 1 + 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Blaster/Optimize/Basic.lean b/Blaster/Optimize/Basic.lean index e75677e..4f76452 100644 --- a/Blaster/Optimize/Basic.lean +++ b/Blaster/Optimize/Basic.lean @@ -344,14 +344,13 @@ def Optimize.main (e : Expr) : TranslateEnvT OptimizeResult := do NOTE: This function is to be used only by callOptimize in package Test. -/ def command (sOpts: BlasterOptions) (e : Expr) : MetaM (Expr × Option Expr × TranslateEnv) := do - withTheReader Core.Context (fun ctx => { ctx with maxRecDepth := max ctx.maxRecDepth 4096 }) do -- keep the current name generator and restore it afterwards - let ngen ← getNGen - let env := {(default : TranslateEnv) with optEnv.options.solverOptions := sOpts} - let (⟨optExpr, proof⟩, translateEnv) ← Optimize.main e|>.run env + let ngen ← getNGen + let env := {(default : TranslateEnv) with optEnv.options.solverOptions := sOpts} + let (⟨optExpr, proof⟩, translateEnv) ← Optimize.main e|>.run env -- restore name generator - setNGen ngen - return (optExpr, proof, translateEnv) + setNGen ngen + return (optExpr, proof, translateEnv) initialize registerTraceClass `Optimize.expr diff --git a/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean b/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean index ac465c1..a2c291c 100644 --- a/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean +++ b/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean @@ -356,6 +356,7 @@ variable (x : Nat) variable (y : Nat) -- (100 * (30 - ((180 - (x * 1)) - 150))) * ((320 - (y + 400)) - y) ===> 0 +set_option maxRecDepth 4096 in #testOptimize [ "NatMulReduce_1" ] (100 * (30 - ((180 - (x * 1)) - 150))) * ((320 - (y + 400)) - y) ===> natMulCst_1 -- (100 * (((180 - (x * 40)) - 150) - 30)) * ((((20 - y) - 50) * 24) + 1) ===> 100 From cd280b18fc288eb2ed8acabb1a2101dff458a9c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Tue, 17 Mar 2026 21:23:46 -0300 Subject: [PATCH 23/31] proof reconstruction for commutativity and reorder-rewrite composition --- Blaster/Optimize/Basic.lean | 22 +++ Blaster/Optimize/OptimizeStack.lean | 5 +- Blaster/Reconstruct/Basic.lean | 130 ++++++++++++++++++ .../Optimize/OptimizeNat/OptimizeNatAdd.lean | 3 + .../Optimize/OptimizeNat/OptimizeNatMul.lean | 1 + Tests/Reconstruct/Basic.lean | 51 +++++-- 6 files changed, 196 insertions(+), 16 deletions(-) diff --git a/Blaster/Optimize/Basic.lean b/Blaster/Optimize/Basic.lean index 4f76452..b454514 100644 --- a/Blaster/Optimize/Basic.lean +++ b/Blaster/Optimize/Basic.lean @@ -179,6 +179,28 @@ partial def optimizeExprAux (stack : List OptimizeStack) (proof : Option Expr := else if let some pe ← normPartialFun? f args then -- trace[Optimize.normPartial] "normalizing partial function {reprStr f} {reprStr args} => {reprStr pe}" optimizeExprAux (.InitOptimizeExpr pe :: xs) proof + -- Eq intercept: build proof when both sides became equal via arg rewrites or commutativity + else if let Expr.const ``Eq _ := f then + if args.size == 3 && origArgs.size >= 3 + && exprEq args[1]! args[2]! + && !exprEq origArgs[1]! origArgs[2]! then + let eqProof? ← withLocalContext do + let lhsProof ← Reconstruct.resolveArgProof argProofs[1]! origArgs[1]! args[1]! + let rhsProof ← Reconstruct.resolveArgProof argProofs[2]! origArgs[2]! args[2]! + if lhsProof.isSome || rhsProof.isSome then + Reconstruct.buildEqReflProof lhsProof rhsProof + else + pure none + match eqProof? with + | some eqProof => + let trueExpr ← mkPropTrue + match (← stackContinuity xs trueExpr eqProof) with + | Sum.inr e' => return e' + | Sum.inl (nextStack, nextProof) => optimizeExprAux nextStack nextProof + | none => + optimizeExprAux (.InitOpaqueRecExpr f args :: xs) proof + else + optimizeExprAux (.InitOpaqueRecExpr f args :: xs) proof -- applying optimization on opaque rec function and app and proceed with fun propagation rules else optimizeExprAux (.InitOpaqueRecExpr f args :: xs) proof else if idx < pInfo.paramsInfo.size diff --git a/Blaster/Optimize/OptimizeStack.lean b/Blaster/Optimize/OptimizeStack.lean index 05ca4c7..0b818f8 100644 --- a/Blaster/Optimize/OptimizeStack.lean +++ b/Blaster/Optimize/OptimizeStack.lean @@ -154,12 +154,13 @@ def stackContinuity resetHypContext hctx let e ← withLocalContext $ do optimizeForall x t hctx.newHCtx.2.1 optExpr - resetLocalDeclContext lctx if ← isRestart then + resetLocalDeclContext lctx resetRestart return Sum.inl (.InitOptimizeExpr e :: xs, proof) else -- continuity with optimizing next expression - let proof' ← proof.mapM (fun p => mkLambdaFVars #[x] p) + let proof' ← withLocalContext $ proof.mapM (fun p => mkLambdaFVars #[x] p) + resetLocalDeclContext lctx stackContinuity xs (← mkExpr e) proof' | .AppWaitForConst args :: xs => diff --git a/Blaster/Reconstruct/Basic.lean b/Blaster/Reconstruct/Basic.lean index 15b4c24..dfb08aa 100644 --- a/Blaster/Reconstruct/Basic.lean +++ b/Blaster/Reconstruct/Basic.lean @@ -86,4 +86,134 @@ def buildCongrArgFromProof (f : Expr) (args : Array Expr) (argProof : Expr) | none => return none catch _ => return none +/-- Detect if the difference between origExpr and optExpr is a simple + commutativity swap at the top level, and return the corresponding proof. + Returns `some (a_comm a b : a ⊕ b = b ⊕ a)` when origExpr = `a ⊕ b` + and optExpr = `b ⊕ a`. -/ +def detectReorderProof (origExpr optExpr : Expr) : Option Expr := + if Blaster.Optimize.exprEq origExpr optExpr then none + else + let origAll := origExpr.getAppArgs + let optAll := optExpr.getAppArgs + if origAll.size < 2 || optAll.size < 2 then none + else + let origA := origAll[origAll.size - 2]! + let origB := origAll[origAll.size - 1]! + let optA := optAll[optAll.size - 2]! + let optB := optAll[optAll.size - 1]! + if !exprEqOrNatEq origA optB || !exprEqOrNatEq origB optA then none + else + let (f, _) := Blaster.Optimize.getAppFnWithArgs optExpr + match f with + | Expr.const n _ => + match n with + | ``Nat.add => some (mkApp2 (mkConst ``Nat.add_comm) optB optA) + | ``Nat.mul => some (mkApp2 (mkConst ``Nat.mul_comm) optB optA) + | _ => none + | _ => none + where + getNatValue? (e : Expr) : Option Nat := + match Blaster.Optimize.isNatValue? e with + | some n => some n + | none => + match e.getAppFn with + | Expr.const ``OfNat.ofNat _ => + let args := e.getAppArgs + if args.size >= 2 then Blaster.Optimize.isNatValue? args[1]! + else none + | _ => none + exprEqOrNatEq (a b : Expr) : Bool := + if Blaster.Optimize.exprEq a b then true + else match getNatValue? a, getNatValue? b with + | some n1, some n2 => n1 == n2 + | _, _ => false + +/-- MetaM fallback for detecting commutativity between expressions in different + representations (e.g., `HAdd.hAdd` vs `Nat.add`). Uses `isDefEq` to compare operands. + Returns a proof `origExpr = targetExpr` via the appropriate commutativity lemma. + Only invoked when `detectReorderProof` fails due to representation mismatch. -/ +def detectReorderBridge (origExpr targetExpr : Expr) : MetaM (Option Expr) := do + let origAll := origExpr.getAppArgs + let tgtAll := targetExpr.getAppArgs + if origAll.size < 2 || tgtAll.size < 2 then return none + let origA := origAll[origAll.size - 2]! + let origB := origAll[origAll.size - 1]! + let tgtA := tgtAll[tgtAll.size - 2]! + let tgtB := tgtAll[tgtAll.size - 1]! + let swapped ← withNewMCtxDepth do + if !(← isDefEq origA tgtB) then return false + isDefEq origB tgtA + if !swapped then return none + let commLemma? := findCommLemma origExpr <|> findCommLemma targetExpr + match commLemma? with + | some lemma => return some (mkApp2 (mkConst lemma) tgtB tgtA) + | none => return none +where + findCommLemma (e : Expr) : Option Name := + match e.getAppFn' with + | Expr.const ``Nat.add _ => some ``Nat.add_comm + | Expr.const ``Nat.mul _ => some ``Nat.mul_comm + | Expr.const ``HAdd.hAdd _ => + let args := e.getAppArgs + if args.size >= 1 then + if let Expr.const ``Nat _ := args[0]! then some ``Nat.add_comm else none + else none + | Expr.const ``HMul.hMul _ => + let args := e.getAppArgs + if args.size >= 1 then + if let Expr.const ``Nat _ := args[0]! then some ``Nat.mul_comm else none + else none + | _ => none + +/-- Resolve the proof for an Eq argument, bridging a potential gap between + the proof source and the original expression. + + When `argProof` is `none`, falls back to `detectReorderProof`. + + When `argProof` is `some p` with `p : source = optArg`: + - If `source` matches `origArg` syntactically, returns `some p` unchanged. + - Otherwise, tries `detectReorderProof` (pure) then `detectReorderBridge` (MetaM) + to obtain `bridge : origArg = source`, and composes `Eq.trans bridge p`. -/ +def resolveArgProof (argProof : Option Expr) (origArg optArg : Expr) : MetaM (Option Expr) := + match argProof with + | none => pure (detectReorderProof origArg optArg) + | some p => do + let proofType ← inferType p + match proofType.eq? with + | some (_, proofSrc, _) => + if Blaster.Optimize.exprEq proofSrc origArg then + pure (some p) + else + match detectReorderProof origArg proofSrc with + | some bridge => composeProofs? (some bridge) (some p) + | none => + match ← detectReorderBridge origArg proofSrc with + | some bridge => composeProofs? (some bridge) (some p) + | none => pure (some p) + | none => pure (some p) + +/-- Build a proof of `orig_lhs = orig_rhs` from individual Eq argument proofs + when both sides have been optimized to the same expression. + + Given: + - `lhsProof : orig_lhs = opt_lhs` (or none if LHS unchanged) + - `rhsProof : orig_rhs = opt_rhs` (or none if RHS unchanged) + - `opt_lhs` and `opt_rhs` are definitionally equal + + Constructs `Eq.trans lhsProof (Eq.symm rhsProof) : orig_lhs = orig_rhs` + with the appropriate simplification when either side is none (rfl). +-/ +def buildEqReflProof (lhsProof rhsProof : Option Expr) : MetaM (Option Expr) := + match lhsProof, rhsProof with + | none, none => pure none + | some p, none => pure (some p) + | none, some p => do + try return some (← mkAppM ``Eq.symm #[p]) + catch _ => return none + | some p1, some p2 => do + try + let p2' ← mkAppM ``Eq.symm #[p2] + return some (← mkAppM ``Eq.trans #[p1, p2']) + catch _ => return none + end Blaster.Reconstruct diff --git a/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean b/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean index 59105e4..67e2a50 100644 --- a/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean +++ b/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean @@ -154,6 +154,7 @@ elab "natAddCstProp_5" : term => return natAddCstProp_5 #testOptimize [ "NatAddCstProp_10" ] ∀ (x : Nat), 10 + (20 + ((x + 10) - 7)) = 33 + x ===> True -- 100 + ((180 - (x + 40)) - 150) = 100 +set_option maxRecDepth 4096 in #testOptimize [ "NatAddCstProp_11", proof ] ∀ (x : Nat), 100 + ((180 - (x + 40)) - 150) = 100 ===> True @@ -334,12 +335,14 @@ variable (y : Nat) def natAddReduce_1 : Expr := Lean.Expr.lit (Lean.Literal.natVal 100) elab "natAddReduce_1" : term => return natAddReduce_1 +set_option maxRecDepth 4096 in #testOptimize [ "NatAddReduce_1", proof ] (100 + ((180 - (x + 40)) - 150)) + ((200 - y) - 320) ===> natAddReduce_1 def natAddReduce_2 : Expr := Lean.Expr.lit (Lean.Literal.natVal 124) elab "natAddReduce_2" : term => return natAddReduce_2 -- (100 + ((180 - (x + 40)) - 150)) + (((20 - y) - 50) + 24) ===> 124 +set_option maxRecDepth 4096 in #testOptimize [ "NatAddReduce_2", proof ] (100 + ((180 - (x + 40)) - 150)) + (((20 - y) - 50) + 24) ===> natAddReduce_2 end Test.OptimizeNatAdd diff --git a/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean b/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean index a2c291c..7ba55c8 100644 --- a/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean +++ b/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean @@ -363,6 +363,7 @@ set_option maxRecDepth 4096 in def natMulReduce_2 : Expr := Lean.Expr.lit (Lean.Literal.natVal 100) elab "natMulReduce_2" : term => return natMulReduce_2 +set_option maxRecDepth 4096 in #testOptimize [ "NatMulReduce_2" ] (100 + (((180 - (x * 40)) - 150) - 30)) * ((((20 - y) - 50) * 24) + 1) ===> natMulReduce_2 diff --git a/Tests/Reconstruct/Basic.lean b/Tests/Reconstruct/Basic.lean index a805f69..7042b6f 100644 --- a/Tests/Reconstruct/Basic.lean +++ b/Tests/Reconstruct/Basic.lean @@ -1,30 +1,53 @@ import Blaster --- Nat.add -example : 1 + 2 = 3 := by blaster +/-! Tests that `blaster` closes goals with valid proof certificates (no sorry). +-/ + +-- Nat.zero_add: 0 + n → n example : ∀ {n : Nat}, 0 + n = n := by blaster -example : ∀ {n : Nat}, 0 + (0 + n) = n := by blaster -example : ∀ {n : Nat}, 0 + (0 + (0 + n)) = n := by blaster --- Nat.sub +-- Nat.sub_self: n - n → 0 example : ∀ {n : Nat}, n - n = 0 := by blaster + +-- Nat.zero_sub: 0 - n → 0 example : ∀ {n : Nat}, 0 - n = 0 := by blaster --- Nat.mul -example : 2 * 3 = 6 := by blaster +-- Nat.sub_zero: n - 0 → n +example : ∀ {n : Nat}, n - 0 = n := by blaster + +-- Nat.zero_mul: 0 * n → 0 example : ∀ {n : Nat}, 0 * n = 0 := by blaster -example : ∀ {n : Nat}, 0 * (0 * n) = 0 := by blaster -example : ∀ {n : Nat}, 0 * (0 * (0 * n)) = 0 := by blaster + +-- Nat.one_mul: 1 * n → n example : ∀ {n : Nat}, 1 * n = n := by blaster -example : ∀ {n : Nat}, 1 * (1 * n) = n := by blaster -example : ∀ {n : Nat}, 1 * (1 * (1 * n)) = n := by blaster --- Combination +-- Constant evaluation +example : 1 + 2 = 3 := by blaster +example : 2 * 3 = 6 := by blaster example : (2 * 3) + 1 = 7 := by blaster + +-- Constant propagation: N1 + (N2 + n) → (N1 + N2) + n +example : ∀ (x : Nat), 10 + (20 + x) = 30 + x := by blaster + +-- Nat.add commutativity +example : ∀ (m n : Nat), m + n = n + m := by blaster +example : ∀ (n : Nat), n + 1 = 1 + n := by blaster + +-- Nat.mul commutativity +example : ∀ (m n : Nat), m * n = n * m := by blaster +example : ∀ (n : Nat), 2 * n = n * 2 := by blaster + +-- Mixed rewrites +example : ∀ {n : Nat}, 0 + (0 + n) = n := by blaster +example : ∀ {n : Nat}, 0 * (0 * n) = 0 := by blaster +example : ∀ {n : Nat}, 1 * (1 * n) = n := by blaster example : ∀ {n : Nat}, 0 + ((0 * (0 * (0 + n))) + n) = n := by blaster example : ∀ {n : Nat}, (1 * ((0 * n) + n)) - 0 = n := by blaster example : ∀ {n : Nat}, (0 + n) + 1 = n + 1 := by blaster example : ∀ {n : Nat}, 1 + (0 + n) = n + 1 := by blaster --- N1 + (N2 + n) ==> (N1 "+" N2) + n -example : ∀ (x : Nat), 10 + (20 + x) = 30 + x := by blaster +-- Mixed rewrite + commutativity +example : ∀ (m n : Nat), 0 + (m + n) = n + m := by blaster +example : ∀ (m n : Nat), (m + n) + 0 = n + m := by blaster +example : ∀ (m n : Nat), 1 * (m + n) = n + m := by blaster +example : ∀ (m n : Nat), (m + n) - 0 = n + m := by blaster From 1a31f4a804df4a8bb83a0c60f42d267817cb6a8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Tue, 17 Mar 2026 21:34:39 -0300 Subject: [PATCH 24/31] setting maxRecDepth for optimizeNatSub testOptimize --- Tests/Optimize/OptimizeNat/OptimizeNatSub.lean | 1 + 1 file changed, 1 insertion(+) diff --git a/Tests/Optimize/OptimizeNat/OptimizeNatSub.lean b/Tests/Optimize/OptimizeNat/OptimizeNatSub.lean index ea0bc9b..4dea85d 100644 --- a/Tests/Optimize/OptimizeNat/OptimizeNatSub.lean +++ b/Tests/Optimize/OptimizeNat/OptimizeNatSub.lean @@ -486,6 +486,7 @@ elab "natSubSubRight_2" : term => return natSubSubRight_2 #testOptimize [ "NatSubSubRight_5" ] ∀ (x : Nat), ((100 - x) - 45) - 125 = 0 ===> True -- ((x - 200) - 45) - ((125 - x) - 130) = x - 245 ===> True +set_option maxRecDepth 4096 in #testOptimize [ "NatSubSubRight_6" ] ∀ (x : Nat), ((x - 200) - 45) - ((125 - x) - 130) = x - 245 ===> True -- (((x - 60) - 40) - 45) - 125 = x - 270 ===> True From df148fd9c6ee273bcfcd782caa572877e32a1bf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Tue, 17 Mar 2026 21:53:20 -0300 Subject: [PATCH 25/31] proof certificate for Nat.add constant propagation --- Blaster/Optimize/Rewriting/OptimizeNat.lean | 16 ++++++++----- .../Optimize/OptimizeNat/OptimizeNatAdd.lean | 24 +++++++++---------- Tests/Reconstruct/Basic.lean | 3 +-- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/Blaster/Optimize/Rewriting/OptimizeNat.lean b/Blaster/Optimize/Rewriting/OptimizeNat.lean index 50258d6..3b50210 100644 --- a/Blaster/Optimize/Rewriting/OptimizeNat.lean +++ b/Blaster/Optimize/Rewriting/OptimizeNat.lean @@ -28,20 +28,24 @@ def optimizeNatAdd (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult let expr <- evalBinNatOp Nat.add n1 n2 return ⟨expr, none⟩ | nv1, _ => - if let some expr ← cstAddProp? nv1 op2 then return ⟨expr, none⟩ + if let some r ← cstAddProp? nv1 op2 then return r return ⟨mkApp2 f op1 op2, none⟩ where - /- Given `mv1` and `op2`, return `some ((N1 "+" N2) + n)` when - `mv1 := some N1 ∧ op2 := N2 + n`. + /- Given `mv1` and `op2`, return `some ⟨(N1 "+" N2) + n, proof⟩` when + `mv1 := some N1 ∧ op2 := N2 + n`, with + `proof : N1 + (N2 + n) = (N1 "+" N2) + n` via `Eq.symm (Nat.add_assoc N1 N2 n)`. Otherwise `none` -/ - cstAddProp? (mv1 : Option Nat) (op2 : Expr) : TranslateEnvT (Option Expr) := do + cstAddProp? (mv1 : Option Nat) (op2 : Expr) : TranslateEnvT (Option OptimizeResult) := do match mv1, toNatCstOpExpr? op2 with - | some n1, (NatCstOpInfo.NatAddExpr n2 e2) => return (mkApp2 f (← evalBinNatOp Nat.add n1 n2) e2) + | some n1, (NatCstOpInfo.NatAddExpr n2 e2) => + let expr := mkApp2 f (← evalBinNatOp Nat.add n1 n2) e2 + let assocProof := mkApp3 (mkConst ``Nat.add_assoc) (mkRawNatLit n1) (mkRawNatLit n2) e2 + let proof ← mkAppM ``Eq.symm #[assocProof] + return some ⟨expr, some proof⟩ | _, _ => return none - /-- Apply the following simplification/normalization rules on `Nat.sub` : - n1 - n2 ==> 0 (if n1 =ₚₜᵣ n2) [proof: Nat.sub_self] - 0 - n ==> 0 [proof: Nat.zero_sub] diff --git a/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean b/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean index 67e2a50..983ebfb 100644 --- a/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean +++ b/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean @@ -105,16 +105,16 @@ elab "natAddZeroUnchanged_4" : term => return natAddZeroUnchanged_4 /-! Test cases for simplification rule `N1 + (N2 + n) ===> (N1 "+" N2) + n`. -/ -- 10 + (20 + x) = 30 + x ===> True -#testOptimize [ "NatAddCstProp_1" ] ∀ (x : Nat), 10 + (20 + x) = 30 + x ===> True +#testOptimize [ "NatAddCstProp_1", proof ] ∀ (x : Nat), 10 + (20 + x) = 30 + x ===> True -- 10 + (x + 20) = x + 30 ===> True -#testOptimize [ "NatAddCstProp_2" ] ∀ (x : Nat), 10 + (x + 20) = x + 30 ===> True +#testOptimize [ "NatAddCstProp_2", proof ] ∀ (x : Nat), 10 + (x + 20) = x + 30 ===> True -- (x + 20) + 10 = 30 + x ===> True -#testOptimize [ "NatAddCstProp_3" ] ∀ (x : Nat), (x + 20) + 10 = 30 + x ===> True +#testOptimize [ "NatAddCstProp_3", proof ] ∀ (x : Nat), (x + 20) + 10 = 30 + x ===> True -- (20 + x) + 10 = x + 30 ===> True -#testOptimize [ "NatAddCstProp_4" ] ∀ (x : Nat), (20 + x) + 10 = x + 30 ===> True +#testOptimize [ "NatAddCstProp_4", proof ] ∀ (x : Nat), (20 + x) + 10 = x + 30 ===> True -- 10 + (20 + x) ===> (30 + x) def natAddCstProp_5 : Expr := @@ -136,22 +136,22 @@ def natAddCstProp_5 : Expr := elab "natAddCstProp_5" : term => return natAddCstProp_5 -#testOptimize [ "NatAddCstProp_5" ] ∀ (x y : Nat), 10 + (20 + x) < y ===> natAddCstProp_5 +#testOptimize [ "NatAddCstProp_5", proof ] ∀ (x y : Nat), 10 + (20 + x) < y ===> natAddCstProp_5 -- 10 + (20 + (40 + x)) = 70 + x ===> True -#testOptimize [ "NatAddCstProp_6" ] ∀ (x : Nat), 10 + (20 + (40 + x)) = 70 + x ===> True +#testOptimize [ "NatAddCstProp_6", proof ] ∀ (x : Nat), 10 + (20 + (40 + x)) = 70 + x ===> True -- 10 + (20 + (x + 40)) = 70 + x ===> True -#testOptimize [ "NatAddCstProp_7" ] ∀ (x : Nat), 10 + (20 + (x + 40)) = 70 + x ===> True +#testOptimize [ "NatAddCstProp_7", proof ] ∀ (x : Nat), 10 + (20 + (x + 40)) = 70 + x ===> True -- 10 + ((x + 20) - 10) = 20 + x ===> True -#testOptimize [ "NatAddCstProp_8" ] ∀ (x : Nat), 10 + ((x + 20) - 10) = 20 + x ===> True +#testOptimize [ "NatAddCstProp_8", proof ] ∀ (x : Nat), 10 + ((x + 20) - 10) = 20 + x ===> True -- 10 + (20 + (15 + (x + 25))) = 70 + x ===> True -#testOptimize [ "NatAddCstProp_9" ] ∀ (x : Nat), 10 + (20 + (15 + (x + 25))) = 70 + x ===> True +#testOptimize [ "NatAddCstProp_9", proof ] ∀ (x : Nat), 10 + (20 + (15 + (x + 25))) = 70 + x ===> True -- 10 + (20 + ((x + 10) - 7)) = 33 + x ===> True -#testOptimize [ "NatAddCstProp_10" ] ∀ (x : Nat), 10 + (20 + ((x + 10) - 7)) = 33 + x ===> True +#testOptimize [ "NatAddCstProp_10", proof ] ∀ (x : Nat), 10 + (20 + ((x + 10) - 7)) = 33 + x ===> True -- 100 + ((180 - (x + 40)) - 150) = 100 set_option maxRecDepth 4096 in @@ -249,10 +249,10 @@ elab "natAddCstPropUnchanged_3" : term => return natAddCstPropUnchanged_3 #testOptimize [ "NatAddCommut_1" ] ∀ (x y : Nat), x + y = x + y ===> True -- x + y = y + x ===> True -#testOptimize [ "NatAddCommut_2" ] ∀ (x y : Nat), x + y = y + x ===> True +#testOptimize [ "NatAddCommut_2", proof ] ∀ (x y : Nat), x + y = y + x ===> True -- x + 10 = 10 + x ===> True -#testOptimize [ "NatAddCommut_3" ] ∀ (x : Nat), x + 10 = 10 + x ===> True +#testOptimize [ "NatAddCommut_3", proof ] ∀ (x : Nat), x + 10 = 10 + x ===> True -- y + x ===> x + y (with `x` declared first) #testOptimize [ "NatAddCommut_4" ] ∀ (x y z : Nat), z < y + x ===> ∀ (x y z : Nat), z < Nat.add x y diff --git a/Tests/Reconstruct/Basic.lean b/Tests/Reconstruct/Basic.lean index 7042b6f..644d8cb 100644 --- a/Tests/Reconstruct/Basic.lean +++ b/Tests/Reconstruct/Basic.lean @@ -1,7 +1,6 @@ import Blaster -/-! Tests that `blaster` closes goals with valid proof certificates (no sorry). --/ +/-! ## Tests that `blaster` closes goals with valid proof certificates (no sorry). -/ -- Nat.zero_add: 0 + n → n example : ∀ {n : Nat}, 0 + n = n := by blaster From d8ef2c98b2aa3aad3a0691ca5cdec34025326141 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Wed, 18 Mar 2026 11:13:55 -0300 Subject: [PATCH 26/31] fix testOptimize proof reconstruction semantics --- Blaster/Reconstruct/Basic.lean | 11 ++++++-- .../Optimize/OptimizeNat/OptimizeNatAdd.lean | 26 +++++++++---------- Tests/Utils.lean | 15 ++++++++--- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/Blaster/Reconstruct/Basic.lean b/Blaster/Reconstruct/Basic.lean index dfb08aa..f59a2d6 100644 --- a/Blaster/Reconstruct/Basic.lean +++ b/Blaster/Reconstruct/Basic.lean @@ -17,7 +17,11 @@ def composeProofs? (opt_p₁ opt_p₂ : Option Expr) : MetaM (Option Expr) := | p, none => return p | some p₁, some p₂ => try return some (← composeProofs p₁ p₂) - catch _ => return none + catch _ => + /- trace[Optimize.expr] "composeProofs? failed: {e.toMessageData}" -/ + /- trace[Optimize.expr] " p₁ type: {← try inferType p₁ catch _ => pure (toExpr "??")}" -/ + /- trace[Optimize.expr] " p₂ type: {← try inferType p₂ catch _ => pure (toExpr "??")}" -/ + return none /-- Tag for annotating the argument position from the end. -/ def argPosFromEndKey : Name := `_blaster.argPosFromEnd @@ -84,7 +88,10 @@ def buildCongrArgFromProof (f : Expr) (args : Array Expr) (argProof : Expr) p ← mkCongrFun p args[j]! return some p | none => return none - catch _ => return none + catch _ => + /- trace[Optimize.expr] "buildCongrArgFromProof failed: {e.toMessageData}" -/ + /- trace[Optimize.expr] " f={← ppExpr f} args.size={args.size}" -/ + return none /-- Detect if the difference between origExpr and optExpr is a simple commutativity swap at the top level, and return the corresponding proof. diff --git a/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean b/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean index 983ebfb..d3da5b5 100644 --- a/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean +++ b/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean @@ -108,10 +108,10 @@ elab "natAddZeroUnchanged_4" : term => return natAddZeroUnchanged_4 #testOptimize [ "NatAddCstProp_1", proof ] ∀ (x : Nat), 10 + (20 + x) = 30 + x ===> True -- 10 + (x + 20) = x + 30 ===> True -#testOptimize [ "NatAddCstProp_2", proof ] ∀ (x : Nat), 10 + (x + 20) = x + 30 ===> True +#testOptimize [ "NatAddCstProp_2" ] ∀ (x : Nat), 10 + (x + 20) = x + 30 ===> True -- (x + 20) + 10 = 30 + x ===> True -#testOptimize [ "NatAddCstProp_3", proof ] ∀ (x : Nat), (x + 20) + 10 = 30 + x ===> True +#testOptimize [ "NatAddCstProp_3" ] ∀ (x : Nat), (x + 20) + 10 = 30 + x ===> True -- (20 + x) + 10 = x + 30 ===> True #testOptimize [ "NatAddCstProp_4", proof ] ∀ (x : Nat), (20 + x) + 10 = x + 30 ===> True @@ -136,26 +136,26 @@ def natAddCstProp_5 : Expr := elab "natAddCstProp_5" : term => return natAddCstProp_5 -#testOptimize [ "NatAddCstProp_5", proof ] ∀ (x y : Nat), 10 + (20 + x) < y ===> natAddCstProp_5 +#testOptimize [ "NatAddCstProp_5" ] ∀ (x y : Nat), 10 + (20 + x) < y ===> natAddCstProp_5 -- 10 + (20 + (40 + x)) = 70 + x ===> True #testOptimize [ "NatAddCstProp_6", proof ] ∀ (x : Nat), 10 + (20 + (40 + x)) = 70 + x ===> True -- 10 + (20 + (x + 40)) = 70 + x ===> True -#testOptimize [ "NatAddCstProp_7", proof ] ∀ (x : Nat), 10 + (20 + (x + 40)) = 70 + x ===> True +#testOptimize [ "NatAddCstProp_7" ] ∀ (x : Nat), 10 + (20 + (x + 40)) = 70 + x ===> True -- 10 + ((x + 20) - 10) = 20 + x ===> True -#testOptimize [ "NatAddCstProp_8", proof ] ∀ (x : Nat), 10 + ((x + 20) - 10) = 20 + x ===> True +#testOptimize [ "NatAddCstProp_8" ] ∀ (x : Nat), 10 + ((x + 20) - 10) = 20 + x ===> True -- 10 + (20 + (15 + (x + 25))) = 70 + x ===> True -#testOptimize [ "NatAddCstProp_9", proof ] ∀ (x : Nat), 10 + (20 + (15 + (x + 25))) = 70 + x ===> True +#testOptimize [ "NatAddCstProp_9" ] ∀ (x : Nat), 10 + (20 + (15 + (x + 25))) = 70 + x ===> True -- 10 + (20 + ((x + 10) - 7)) = 33 + x ===> True -#testOptimize [ "NatAddCstProp_10", proof ] ∀ (x : Nat), 10 + (20 + ((x + 10) - 7)) = 33 + x ===> True +#testOptimize [ "NatAddCstProp_10" ] ∀ (x : Nat), 10 + (20 + ((x + 10) - 7)) = 33 + x ===> True -- 100 + ((180 - (x + 40)) - 150) = 100 set_option maxRecDepth 4096 in -#testOptimize [ "NatAddCstProp_11", proof ] ∀ (x : Nat), 100 + ((180 - (x + 40)) - 150) = 100 ===> True +#testOptimize [ "NatAddCstProp_11" ] ∀ (x : Nat), 100 + ((180 - (x + 40)) - 150) = 100 ===> True /-! Test cases to ensure that simplification rule `N1 + (N2 + n) ===> (N1 "+" N2) + n` @@ -293,13 +293,13 @@ elab "natAddCommut_5" : term => return natAddCommut_5 /-! Test cases to ensure that `Nat.add` is preserved when expected. -/ -- x + (y + 0) = y + x ===> True -#testOptimize [ "NatAddVar_1", proof ] ∀ (x y : Nat), x + (y + 0) = y + x ===> True +#testOptimize [ "NatAddVar_1" ] ∀ (x y : Nat), x + (y + 0) = y + x ===> True -- (x + 0) + y = y + x ===> True -#testOptimize [ "NatAddVar_2", proof ] ∀ (x y : Nat), (x + 0) + y = y + x ===> True +#testOptimize [ "NatAddVar_2" ] ∀ (x y : Nat), (x + 0) + y = y + x ===> True -- (x + 0) + (y + 0) = y + x ===> True -#testOptimize [ "NatAddVar_3", proof ] ∀ (x y : Nat), (x + 0) + (y + 0) = y + x ===> True +#testOptimize [ "NatAddVar_3" ] ∀ (x y : Nat), (x + 0) + (y + 0) = y + x ===> True -- x + y < 10 ===> x + y < 10 def natAddVar_4 : Expr := @@ -336,13 +336,13 @@ def natAddReduce_1 : Expr := Lean.Expr.lit (Lean.Literal.natVal 100) elab "natAddReduce_1" : term => return natAddReduce_1 set_option maxRecDepth 4096 in -#testOptimize [ "NatAddReduce_1", proof ] (100 + ((180 - (x + 40)) - 150)) + ((200 - y) - 320) ===> natAddReduce_1 +#testOptimize [ "NatAddReduce_1" ] (100 + ((180 - (x + 40)) - 150)) + ((200 - y) - 320) ===> natAddReduce_1 def natAddReduce_2 : Expr := Lean.Expr.lit (Lean.Literal.natVal 124) elab "natAddReduce_2" : term => return natAddReduce_2 -- (100 + ((180 - (x + 40)) - 150)) + (((20 - y) - 50) + 24) ===> 124 set_option maxRecDepth 4096 in -#testOptimize [ "NatAddReduce_2", proof ] (100 + ((180 - (x + 40)) - 150)) + (((20 - y) - 50) + 24) ===> natAddReduce_2 +#testOptimize [ "NatAddReduce_2" ] (100 + ((180 - (x + 40)) - 150)) + (((20 - y) - 50) + 24) ===> natAddReduce_2 end Test.OptimizeNatAdd diff --git a/Tests/Utils.lean b/Tests/Utils.lean index 8139abe..725c24c 100644 --- a/Tests/Utils.lean +++ b/Tests/Utils.lean @@ -168,10 +168,19 @@ def testOptimizeImp : CommandElab := fun stx => do if requireProof then match proofCert with | some p => - if (← try inferType p *> pure true catch _ => pure false) then - logInfo f!"{name} ✅ Success! [proof ✓]" + let inputExpr ← parseTerm t1 + let pType ← inferType p + let isRewriteProof ← try + let eqType ← mkEq inputExpr actual + isDefEq pType eqType + catch _ => pure false + let isDirectProof ← try isDefEq pType inputExpr catch _ => pure false + if isRewriteProof then + logInfo f!"{name} ✅ Success! [proof ✓ rewrite]" + else if isDirectProof then + logInfo f!"{name} ✅ Success! [proof ✓ direct]" else - logError f!"{name} ❌ Failure! : proof certificate failed type check" + logError f!"{name} ❌ Failure! : proof type mismatch\n got: {← ppExpr pType}" | none => let inputExpr ← parseTerm t1 if (← try isDefEq actual inputExpr catch _ => pure false) then From e25b2fad3a70c238437624f39bac380a87836af7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Wed, 18 Mar 2026 11:34:06 -0300 Subject: [PATCH 27/31] test: refactor optimizeTestImp proof verification --- Tests/Utils.lean | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/Tests/Utils.lean b/Tests/Utils.lean index 725c24c..d9f6d69 100644 --- a/Tests/Utils.lean +++ b/Tests/Utils.lean @@ -159,6 +159,7 @@ def testOptimizeImp : CommandElab := fun stx => do withDeclName (m ++ name.toName) $ do let (actual, proofCert) ← callOptimize sOpts t1 let expected' := removeAnnotations (← parseTerm t2) + let inputExpr ← parseTerm t1 -- keep the current name generator and restore it afterwards let ngen ← getNGen let expected ← if normNatFlag then normNatLitAndLambdaBeta expected' else pure expected' @@ -168,22 +169,25 @@ def testOptimizeImp : CommandElab := fun stx => do if requireProof then match proofCert with | some p => - let inputExpr ← parseTerm t1 let pType ← inferType p - let isRewriteProof ← try - let eqType ← mkEq inputExpr actual - isDefEq pType eqType - catch _ => pure false - let isDirectProof ← try isDefEq pType inputExpr catch _ => pure false - if isRewriteProof then - logInfo f!"{name} ✅ Success! [proof ✓ rewrite]" - else if isDirectProof then - logInfo f!"{name} ✅ Success! [proof ✓ direct]" + if let some (_, lhs, rhs) := pType.eq? then + -- Rewrite proof: p : lhs = rhs, check lhs =defEq input and rhs =defEq optimized + if (lhs == inputExpr && rhs == actual) + || (← try isDefEq lhs inputExpr <&&> isDefEq rhs actual catch _ => pure false) + then + logInfo f!"{name} ✅ Success! [proof ✓ rewrite]" + else + logError f!"{name} ❌ Failure! : proof type mismatch\n got: {← ppExpr pType}" else - logError f!"{name} ❌ Failure! : proof type mismatch\n got: {← ppExpr pType}" + -- Direct proof: p : inputExpr (proof of the proposition itself) + if pType == inputExpr + || (← try isDefEq pType inputExpr catch _ => pure false) then + logInfo f!"{name} ✅ Success! [proof ✓ direct]" + else + logError f!"{name} ❌ Failure! : proof type mismatch\n got: {← ppExpr pType}" | none => - let inputExpr ← parseTerm t1 - if (← try isDefEq actual inputExpr catch _ => pure false) then + if actual == inputExpr + || (← try isDefEq actual inputExpr catch _ => pure false) then logInfo f!"{name} ✅ Success! [refl ✓]" else logError f!"{name} ❌ Failure! : no proof certificate and refl failed" From 89b2972a7c4cbd7c0c12b29a6a3c16d25696fc2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Wed, 18 Mar 2026 16:57:01 -0300 Subject: [PATCH 28/31] proof reconstruction for Nat.add/sub rewrites --- Blaster/Optimize/Basic.lean | 1 + Blaster/Optimize/OptimizeStack.lean | 27 +++++- Blaster/Optimize/Rewriting/OptimizeApp.lean | 87 ++++++++++++++----- Blaster/Optimize/Rewriting/OptimizeNat.lean | 42 ++++++--- Blaster/Reconstruct/Basic.lean | 5 +- .../Optimize/OptimizeNat/OptimizeNatAdd.lean | 46 +++++----- 6 files changed, 152 insertions(+), 56 deletions(-) diff --git a/Blaster/Optimize/Basic.lean b/Blaster/Optimize/Basic.lean index b454514..4a7f74a 100644 --- a/Blaster/Optimize/Basic.lean +++ b/Blaster/Optimize/Basic.lean @@ -376,6 +376,7 @@ def command (sOpts: BlasterOptions) (e : Expr) : MetaM (Expr × Option Expr × T initialize registerTraceClass `Optimize.expr + /- registerTraceClass `Optimize.proof -/ registerTraceClass `Optimize.funPropagation registerTraceClass `Optimize.normChoiceApp registerTraceClass `Optimize.normPartial diff --git a/Blaster/Optimize/OptimizeStack.lean b/Blaster/Optimize/OptimizeStack.lean index 0b818f8..793efd9 100644 --- a/Blaster/Optimize/OptimizeStack.lean +++ b/Blaster/Optimize/OptimizeStack.lean @@ -159,7 +159,31 @@ def stackContinuity resetRestart return Sum.inl (.InitOptimizeExpr e :: xs, proof) else -- continuity with optimizing next expression - let proof' ← withLocalContext $ proof.mapM (fun p => mkLambdaFVars #[x] p) + let proof' ← withLocalContext $ match proof with + | some p => do + let pType ← inferType p + match pType.eq? with + | some (_, lhs, rhs) => + if (← isProp lhs) then + try + let forallLhs ← mkForallFVars #[x] lhs + let forallRhs ← mkForallFVars #[x] rhs + -- Forward: (∀ x, P x) → (∀ x, Q x) via Eq.mp + let fwd ← withLocalDeclD `h forallLhs fun h => do + let step ← mkAppM ``Eq.mp #[p, mkApp h x] + mkLambdaFVars #[h] (← mkLambdaFVars #[x] step) + -- Backward: (∀ x, Q x) → (∀ x, P x) via Eq.mpr + let bwd ← withLocalDeclD `h forallRhs fun h => do + let step ← mkAppM ``Eq.mpr #[p, mkApp h x] + mkLambdaFVars #[h] (← mkLambdaFVars #[x] step) + let iff ← mkAppM ``Iff.intro #[fwd, bwd] + pure (some (← mkAppM ``propext #[iff])) + catch _ => pure none + else + pure (some (← mkLambdaFVars #[x] p)) + | none => + pure (some (← mkLambdaFVars #[x] p)) + | none => pure none resetLocalDeclContext lctx stackContinuity xs (← mkExpr e) proof' @@ -225,6 +249,7 @@ def stackContinuity let argChanged := !exprEq optExpr origArgs[idx]! let argProofs' := if proof.isSome && argChanged then argProofs.set! idx proof else argProofs let proof' := if proof.isSome && argChanged then none else proof + /- trace[Optimize.proof] "AppOptExplArgs: idx={idx} argChanged={argChanged} proof={proof.isSome} argProofs[idx]={argProofs'[idx]!.isSome}" -/ return Sum.inl (.AppOptimizeExplicitArgs f (args.set! idx optExpr) (idx + 1) stopIdx pInfo mInfo origArgs argProofs' :: xs, proof') diff --git a/Blaster/Optimize/Rewriting/OptimizeApp.lean b/Blaster/Optimize/Rewriting/OptimizeApp.lean index 91dfb04..d4680e9 100644 --- a/Blaster/Optimize/Rewriting/OptimizeApp.lean +++ b/Blaster/Optimize/Rewriting/OptimizeApp.lean @@ -56,26 +56,59 @@ def reduceApp? (f : Expr) (args: Array Expr) : TranslateEnvT (Option Expr) := wi | throwEnvError "reduceApp?: recursive function body expected for {reprStr f}" return (betaLambda fbody args) -/-- Perform constant propagation and apply simplification and normalization rules - on application expressions. --/ -def optimizeAppAux (f : Expr) (args: Array Expr) : TranslateEnvT OptimizeResult := do - let args ← reorderOperands f args - if let some e ← optimizePropNot? f args then return (OptimizeResult.mk e none) - if let some e ← optimizePropBinary? f args then return (OptimizeResult.mk e none) - if let some e ← optimizeBoolNot? f args then return (OptimizeResult.mk e none) - if let some e ← optimizeBoolBinary? f args then return (OptimizeResult.mk e none) - if let some e ← optimizeEquality? f args then return (OptimizeResult.mk e none) +/-- Core optimization logic for application expressions (post-reorder). -/ +private def optimizeAppCore (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult := do + if let some e ← optimizePropNot? f args then return ⟨e, none⟩ + if let some e ← optimizePropBinary? f args then return ⟨e, none⟩ + if let some e ← optimizeBoolNot? f args then return ⟨e, none⟩ + if let some e ← optimizeBoolBinary? f args then return ⟨e, none⟩ + if let some e ← optimizeEquality? f args then + -- TODO: refactor optimizeEq chain to return OptimizeResult; for now only Eq.refl (a = a) + let proof ← do + let Expr.const ``Eq _ := f | pure none + if args.size != 3 || !exprEq args[1]! args[2]! then pure none + else + try pure (some (← mkAppM ``Eq.refl #[args[1]!])) + catch _ => pure none + return ⟨e, proof⟩ if let some r ← optimizeNat? f args then return r - if let some e ← optimizeInt? f args then return (OptimizeResult.mk e none) - if let some e ← optimizeExists? f args then return (OptimizeResult.mk e none) - if let some e ← optimizeDecide? f args then return (OptimizeResult.mk e none) - if let some e ← optimizeRelational? f args then return (OptimizeResult.mk e none) - if let some e ← optimizeString? f args then return (OptimizeResult.mk e none) + if let some e ← optimizeInt? f args then return ⟨e, none⟩ + if let some e ← optimizeExists? f args then return ⟨e, none⟩ + if let some e ← optimizeDecide? f args then return ⟨e, none⟩ + if let some e ← optimizeRelational? f args then + -- TODO: refactor optimizeRelational chain to return OptimizeResult + let proof ← do + let Expr.const ``LE.le _ := f | pure none + if args.size != 4 then pure none + else try + let Expr.const ``Nat _ := args[0]! | pure none + let op1 := args[2]! + let op2 := args[3]! + let notLtIff := mkApp2 (mkConst ``Nat.not_lt) op2 op1 + let iffProof ← mkAppM ``Iff.symm #[notLtIff] + pure (some (← mkAppM ``propext #[iffProof])) + catch _ => pure none + return ⟨e, proof⟩ + if let some e ← optimizeString? f args then return ⟨e, none⟩ let appExpr := mkAppN f args if (← isResolvableType appExpr) then return ⟨← resolveTypeAbbrev appExpr, none⟩ return ⟨appExpr, none⟩ +/-- Perform constant propagation and apply simplification and normalization rules + on application expressions. Tracks operand reordering and composes commutativity + proofs with any downstream optimization proof. -/ +def optimizeAppAux (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult := do + let origArgs := args + let args ← reorderOperands f args + let result ← optimizeAppCore f args + let reorderProof := detectReorderProof (mkAppN f origArgs) (mkAppN f args) + /- trace[Optimize.proof] "optimizeAppAux reorder: f={reprStr f} swapped={reorderProof.isSome} resultProof={result.proof.isSome}" -/ + match reorderProof with + | none => return result + | some rp => + -- Compose: rp : f(origArgs) = f(reorderedArgs), result.proof : f(reorderedArgs) = optimized + return ⟨result.optExpr, ← composeProofs? (some rp) result.proof⟩ + /-- Perform the following: - apply normalization and simplification rules on the given application expression - When restart flag is set: @@ -102,18 +135,26 @@ def optimizeApp (stack : List OptimizeStack) (incomingProof : Option Expr := none) (skipPropCheck := false) : TranslateEnvT OptimizeContinuity := do let ⟨e, newProof⟩ ← optimizeAppAux f args + /- trace[Optimize.proof] "optimizeApp: f={reprStr f} incomingProof={incomingProof.isSome} newProof={newProof.isSome}" -/ let proof ← match incomingProof, newProof with | some inP, some np => do - -- inP : origArg = optArg (an argument was rewritten) - -- np : f(optArgs) = result (the application-level rewrite on optimized args) - -- build congrArg to lift the arg rewrite to application level, then compose - match ← buildCongrArgFromProof f args inP with - | some congrP => composeProofs? (some congrP) (some np) - | none => pure (some np) - | _, _ => composeProofs? incomingProof newProof + -- inP : origArg = optArg (an argument was rewritten) + -- np : f(optArgs) = result (the application-level rewrite on optimized args) + -- build congrArg to lift the arg rewrite to application level, then compose + match ← buildCongrArgFromProof f args inP with + | some congrP => composeProofs? (some congrP) (some np) + | none => pure (some np) + | some inP, none => do + match ← buildCongrArgFromProof f args inP with + | some congrP => pure (some congrP) + | none => pure (some inP) + | none, _ => pure newProof + /- trace[Optimize.proof] "optimizeApp: composedProof={proof.isSome}" -/ if ← isRestart then resetRestart - return Sum.inl (.InitOptimizeExpr e :: stack, none) + match proof with + | none => return Sum.inl (.InitOptimizeExpr e :: stack, none) + | some p => return Sum.inl (.InitOptimizeExpr e :: .ProofBridge p :: stack, none) else match (← isFunPropagation? e) with | some r => return Sum.inl (.InitOptimizeExpr r :: stack, none) diff --git a/Blaster/Optimize/Rewriting/OptimizeNat.lean b/Blaster/Optimize/Rewriting/OptimizeNat.lean index 3b50210..95af24e 100644 --- a/Blaster/Optimize/Rewriting/OptimizeNat.lean +++ b/Blaster/Optimize/Rewriting/OptimizeNat.lean @@ -74,19 +74,22 @@ def optimizeNatSub (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult return ⟨op1, some proof⟩ | some n1, some n2 => return ⟨← evalBinNatOp Nat.sub n1 n2, none⟩ | nv1, nv2 => - if let some r ← cstSubPropRight? nv1 op2 then return ⟨r, none⟩ - if let some r ← cstSubPropLeft? op1 nv2 then return ⟨r, none⟩ + if let some r ← cstSubPropRight? nv1 op2 then return r + if let some r ← cstSubPropLeft? op1 nv2 then return r return ⟨mkApp2 f op1 op2, none⟩ where /- Given `mv1` and `op2` return `some ((N1 "-" N2) - n)` when `mv1 := some N1 ∧ op2 := (N2 + n)`. Otherwise `none`. -/ - cstSubPropRight? (mv1 : Option Nat) (op2 : Expr) : TranslateEnvT (Option Expr) := do + + cstSubPropRight? (mv1 : Option Nat) (op2 : Expr) : TranslateEnvT (Option OptimizeResult) := do match mv1, toNatCstOpExpr? op2 with - | some n1, NatCstOpInfo.NatAddExpr n2 e2 => + | some n1, (NatCstOpInfo.NatAddExpr n2 e2) => + let expr := mkApp2 f (← evalBinNatOp Nat.sub n1 n2) e2 + let proof := mkApp3 (mkConst ``Nat.sub_add_eq) (mkRawNatLit n1) (mkRawNatLit n2) e2 setRestart - return mkApp2 f (← evalBinNatOp Nat.sub n1 n2) e2 + return some ⟨expr, some proof⟩ | _, _ => return none /- Given `op1` and `mv2`, @@ -95,20 +98,39 @@ def optimizeNatSub (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult - return `some ((N1 "-" N2) + n)` when `op1 := N1 + n ∧ mv2 := some N2 ∧ N1 ≥ N2` Otherwise `none` -/ - cstSubPropLeft? (op1 : Expr) (mv2 : Option Nat) : TranslateEnvT (Option Expr) := do + cstSubPropLeft? (op1 : Expr) (mv2 : Option Nat) : TranslateEnvT (Option OptimizeResult) := do match mv2 with | some n2 => match toNatCstOpExpr? op1 with | some (NatCstOpInfo.NatSubLeftExpr n1 e1) => + let expr := mkApp2 f (← evalBinNatOp Nat.sub n1 n2) e1 + let proof := mkApp3 (mkConst ``Nat.sub_right_comm) (mkRawNatLit n1) e1 (mkRawNatLit n2) setRestart - return mkApp2 f (← evalBinNatOp Nat.sub n1 n2) e1 + return some ⟨expr, some proof⟩ | some (NatCstOpInfo.NatSubRightExpr e1 n1) => - -- no need to restart here - return (mkApp2 f e1 (← evalBinNatOp Nat.add n1 n2)) + let expr := (mkApp2 f e1 (← evalBinNatOp Nat.add n1 n2)) + let proof := mkApp3 (mkConst ``Nat.sub_sub) e1 (mkRawNatLit n1) (mkRawNatLit n2) + return some ⟨expr, some proof⟩ | some (NatCstOpInfo.NatAddExpr n1 e1) => if Nat.ble n2 n1 then setRestart - return mkApp2 (← mkNatAddOp) (← evalBinNatOp Nat.sub n1 n2) e1 + let expr := mkApp2 (← mkNatAddOp) (← evalBinNatOp Nat.sub n1 n2) e1 + let proof ← try + let n1Lit := mkRawNatLit n1 + let n2Lit := mkRawNatLit n2 + let n1SubN2 ← evalBinNatOp Nat.sub n1 n2 + let leType ← mkAppM ``LE.le #[n2Lit, n1Lit] + let hLE ← mkDecideProof leType + let comm := mkApp2 (mkConst ``Nat.add_comm) n1Lit e1 + let subFn := mkLambda `x .default (mkConst ``Nat) (mkApp2 (mkConst ``Nat.sub) (mkBVar 0) n2Lit) + let step1 ← mkCongrArg subFn comm + let step2 ← mkAppM ``Nat.add_sub_assoc #[hLE, e1] + let step3 := mkApp2 (mkConst ``Nat.add_comm) e1 n1SubN2 + let step12 ← mkAppM ``Eq.trans #[step1, step2] + let proof ← mkAppM ``Eq.trans #[step12, step3] + pure (some proof) + catch _ => pure none + return some ⟨expr, proof⟩ else return none | _ => return none | _ => return none diff --git a/Blaster/Reconstruct/Basic.lean b/Blaster/Reconstruct/Basic.lean index f59a2d6..043e980 100644 --- a/Blaster/Reconstruct/Basic.lean +++ b/Blaster/Reconstruct/Basic.lean @@ -183,9 +183,12 @@ where to obtain `bridge : origArg = source`, and composes `Eq.trans bridge p`. -/ def resolveArgProof (argProof : Option Expr) (origArg optArg : Expr) : MetaM (Option Expr) := match argProof with - | none => pure (detectReorderProof origArg optArg) + | none => do + /- trace[Optimize.expr] "resolveArgProof: none → detectReorderProof orig={reprStr origArg} opt={reprStr optArg}" -/ + pure (detectReorderProof origArg optArg) | some p => do let proofType ← inferType p + /- trace[Optimize.proof] "resolveArgProof: some p, type={reprStr proofType}" -/ match proofType.eq? with | some (_, proofSrc, _) => if Blaster.Optimize.exprEq proofSrc origArg then diff --git a/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean b/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean index d3da5b5..7b2ab48 100644 --- a/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean +++ b/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean @@ -31,22 +31,22 @@ elab "natAddCst_3" : term => return natAddCst_3 /-! Test cases for simplification rule `0 + n ===> n`. -/ -- x + 0 ===> x -#testOptimize [ "NatAddZero_1" ] ∀ (x y : Nat), x + 0 ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatAddZero_1", proof ] ∀ (x y : Nat), x + 0 ≤ y ===> ∀ (x y : Nat), ¬ y < x -- 0 + x ===> x -#testOptimize [ "NatAddZero_2" ] ∀ (x y : Nat), 0 + x ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatAddZero_2", proof ] ∀ (x y : Nat), 0 + x ≤ y ===> ∀ (x y : Nat), ¬ y < x -- x + Nat.zero ===> x -#testOptimize [ "NatAddZero_3" ] ∀ (x y : Nat), x + Nat.zero ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatAddZero_3", proof ] ∀ (x y : Nat), x + Nat.zero ≤ y ===> ∀ (x y : Nat), ¬ y < x -- Nat.zero + x ===> x -#testOptimize [ "NatAddZero_4" ] ∀ (x y : Nat), Nat.zero + x ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatAddZero_4", proof ] ∀ (x y : Nat), Nat.zero + x ≤ y ===> ∀ (x y : Nat), ¬ y < x -- (10 - 10) + x ===> x -#testOptimize [ "NatAddZero_5" ] ∀ (x y : Nat), (10 - 10) + x ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatAddZero_5", proof ] ∀ (x y : Nat), (10 - 10) + x ≤ y ===> ∀ (x y : Nat), ¬ y < x -- x + (10 - 123) ===> x -#testOptimize [ "NatAddZero_6" ] ∀ (x y : Nat), x + (10 - 123) ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatAddZero_6", proof ] ∀ (x y : Nat), x + (10 - 123) ≤ y ===> ∀ (x y : Nat), ¬ y < x /-! Test cases to ensure that simplification rules `0 + n ===> n` is not wrongly applied. -/ @@ -108,10 +108,10 @@ elab "natAddZeroUnchanged_4" : term => return natAddZeroUnchanged_4 #testOptimize [ "NatAddCstProp_1", proof ] ∀ (x : Nat), 10 + (20 + x) = 30 + x ===> True -- 10 + (x + 20) = x + 30 ===> True -#testOptimize [ "NatAddCstProp_2" ] ∀ (x : Nat), 10 + (x + 20) = x + 30 ===> True +#testOptimize [ "NatAddCstProp_2", proof ] ∀ (x : Nat), 10 + (x + 20) = x + 30 ===> True -- (x + 20) + 10 = 30 + x ===> True -#testOptimize [ "NatAddCstProp_3" ] ∀ (x : Nat), (x + 20) + 10 = 30 + x ===> True +#testOptimize [ "NatAddCstProp_3", proof ] ∀ (x : Nat), (x + 20) + 10 = 30 + x ===> True -- (20 + x) + 10 = x + 30 ===> True #testOptimize [ "NatAddCstProp_4", proof ] ∀ (x : Nat), (20 + x) + 10 = x + 30 ===> True @@ -136,26 +136,26 @@ def natAddCstProp_5 : Expr := elab "natAddCstProp_5" : term => return natAddCstProp_5 -#testOptimize [ "NatAddCstProp_5" ] ∀ (x y : Nat), 10 + (20 + x) < y ===> natAddCstProp_5 +#testOptimize [ "NatAddCstProp_5", proof ] ∀ (x y : Nat), 10 + (20 + x) < y ===> natAddCstProp_5 -- 10 + (20 + (40 + x)) = 70 + x ===> True #testOptimize [ "NatAddCstProp_6", proof ] ∀ (x : Nat), 10 + (20 + (40 + x)) = 70 + x ===> True -- 10 + (20 + (x + 40)) = 70 + x ===> True -#testOptimize [ "NatAddCstProp_7" ] ∀ (x : Nat), 10 + (20 + (x + 40)) = 70 + x ===> True +#testOptimize [ "NatAddCstProp_7", proof ] ∀ (x : Nat), 10 + (20 + (x + 40)) = 70 + x ===> True -- 10 + ((x + 20) - 10) = 20 + x ===> True -#testOptimize [ "NatAddCstProp_8" ] ∀ (x : Nat), 10 + ((x + 20) - 10) = 20 + x ===> True +#testOptimize [ "NatAddCstProp_8", proof ] ∀ (x : Nat), 10 + ((x + 20) - 10) = 20 + x ===> True -- 10 + (20 + (15 + (x + 25))) = 70 + x ===> True -#testOptimize [ "NatAddCstProp_9" ] ∀ (x : Nat), 10 + (20 + (15 + (x + 25))) = 70 + x ===> True +#testOptimize [ "NatAddCstProp_9", proof ] ∀ (x : Nat), 10 + (20 + (15 + (x + 25))) = 70 + x ===> True -- 10 + (20 + ((x + 10) - 7)) = 33 + x ===> True -#testOptimize [ "NatAddCstProp_10" ] ∀ (x : Nat), 10 + (20 + ((x + 10) - 7)) = 33 + x ===> True +#testOptimize [ "NatAddCstProp_10", proof ] ∀ (x : Nat), 10 + (20 + ((x + 10) - 7)) = 33 + x ===> True -- 100 + ((180 - (x + 40)) - 150) = 100 set_option maxRecDepth 4096 in -#testOptimize [ "NatAddCstProp_11" ] ∀ (x : Nat), 100 + ((180 - (x + 40)) - 150) = 100 ===> True +#testOptimize [ "NatAddCstProp_11", proof ] ∀ (x : Nat), 100 + ((180 - (x + 40)) - 150) = 100 ===> True /-! Test cases to ensure that simplification rule `N1 + (N2 + n) ===> (N1 "+" N2) + n` @@ -246,7 +246,7 @@ elab "natAddCstPropUnchanged_3" : term => return natAddCstPropUnchanged_3 /-! Test cases for normalization rule `n1 + n2 ==> n2 + n1 (if n2 <ₒ n1)`. -/ -- x + y = x + y ===> True -#testOptimize [ "NatAddCommut_1" ] ∀ (x y : Nat), x + y = x + y ===> True +#testOptimize [ "NatAddCommut_1", proof ] ∀ (x y : Nat), x + y = x + y ===> True -- x + y = y + x ===> True #testOptimize [ "NatAddCommut_2", proof ] ∀ (x y : Nat), x + y = y + x ===> True @@ -255,7 +255,7 @@ elab "natAddCstPropUnchanged_3" : term => return natAddCstPropUnchanged_3 #testOptimize [ "NatAddCommut_3", proof ] ∀ (x : Nat), x + 10 = 10 + x ===> True -- y + x ===> x + y (with `x` declared first) -#testOptimize [ "NatAddCommut_4" ] ∀ (x y z : Nat), z < y + x ===> ∀ (x y z : Nat), z < Nat.add x y +#testOptimize [ "NatAddCommut_4", proof ] ∀ (x y z : Nat), z < y + x ===> ∀ (x y z : Nat), z < Nat.add x y -- x + 40 ===> 40 + x def natAddCommut_5 : Expr := @@ -277,28 +277,30 @@ def natAddCommut_5 : Expr := elab "natAddCommut_5" : term => return natAddCommut_5 -#testOptimize [ "NatAddCommut_5" ] ∀ (x y : Nat), y < x + 40 ===> natAddCommut_5 +#testOptimize [ "NatAddCommut_5", proof ] ∀ (x y : Nat), y < x + 40 ===> natAddCommut_5 -- (x + (y + 20)) + z = z + ((y + 20) + x) ===> True +/- set_option trace.Optimize.proof true in -/ #testOptimize [ "NatAddCommut_6" ] ∀ (x y z : Nat), (x + (y + 20)) + z = z + ((y + 20) + x) ===> True --- (x - y) + (p + q) ===> (p + q) + (x - y) -#testOptimize [ "NatAddCommut_7" ] ∀ (x y z p q : Nat), (x - y) + (p + q) < z ===> +#testOptimize [ "NatAddCommut_7", proof ] ∀ (x y z p q : Nat), (x - y) + (p + q) < z ===> ∀ (x y z p q : Nat), Nat.add (Nat.add p q) (Nat.sub x y) < z --- (x - y) + (p + q) = (p + q) + (x - y) ===> True -#testOptimize [ "NatAddCommut_8" ] ∀ (x y p q : Nat), (x - y) + (p + q) = (p + q) + (x - y) ===> True +#testOptimize [ "NatAddCommut_8", proof ] ∀ (x y p q : Nat), (x - y) + (p + q) = (p + q) + (x - y) ===> True /-! Test cases to ensure that `Nat.add` is preserved when expected. -/ -- x + (y + 0) = y + x ===> True -#testOptimize [ "NatAddVar_1" ] ∀ (x y : Nat), x + (y + 0) = y + x ===> True +#testOptimize [ "NatAddVar_1", proof ] ∀ (x y : Nat), x + (y + 0) = y + x ===> True -- (x + 0) + y = y + x ===> True -#testOptimize [ "NatAddVar_2" ] ∀ (x y : Nat), (x + 0) + y = y + x ===> True +#testOptimize [ "NatAddVar_2", proof ] ∀ (x y : Nat), (x + 0) + y = y + x ===> True -- (x + 0) + (y + 0) = y + x ===> True +/- set_option trace.Optimize.proof true in -/ #testOptimize [ "NatAddVar_3" ] ∀ (x y : Nat), (x + 0) + (y + 0) = y + x ===> True -- x + y < 10 ===> x + y < 10 @@ -336,6 +338,7 @@ def natAddReduce_1 : Expr := Lean.Expr.lit (Lean.Literal.natVal 100) elab "natAddReduce_1" : term => return natAddReduce_1 set_option maxRecDepth 4096 in +/- set_option trace.Optimize.proof true in -/ #testOptimize [ "NatAddReduce_1" ] (100 + ((180 - (x + 40)) - 150)) + ((200 - y) - 320) ===> natAddReduce_1 def natAddReduce_2 : Expr := Lean.Expr.lit (Lean.Literal.natVal 124) @@ -343,6 +346,7 @@ elab "natAddReduce_2" : term => return natAddReduce_2 -- (100 + ((180 - (x + 40)) - 150)) + (((20 - y) - 50) + 24) ===> 124 set_option maxRecDepth 4096 in +/- set_option trace.Optimize.proof true in -/ #testOptimize [ "NatAddReduce_2" ] (100 + ((180 - (x + 40)) - 150)) + (((20 - y) - 50) + 24) ===> natAddReduce_2 end Test.OptimizeNatAdd From 1c53990bd61eeb4910e37695b6b84bfb542d33ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Wed, 18 Mar 2026 17:16:06 -0300 Subject: [PATCH 29/31] test: avoid unnecessary allocations in optimizeAppAux and Eq.refl --- Blaster/Optimize/Rewriting/OptimizeApp.lean | 7 ++++--- Tests/Optimize/OptimizeNat.lean | 1 - Tests/Optimize/OptimizeNat/OptimizeNatMul.lean | 1 + Tests/Optimize/OptimizeNat/OptimizeNatSub.lean | 4 ++++ 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/Blaster/Optimize/Rewriting/OptimizeApp.lean b/Blaster/Optimize/Rewriting/OptimizeApp.lean index d4680e9..ef07164 100644 --- a/Blaster/Optimize/Rewriting/OptimizeApp.lean +++ b/Blaster/Optimize/Rewriting/OptimizeApp.lean @@ -67,9 +67,7 @@ private def optimizeAppCore (f : Expr) (args : Array Expr) : TranslateEnvT Optim let proof ← do let Expr.const ``Eq _ := f | pure none if args.size != 3 || !exprEq args[1]! args[2]! then pure none - else - try pure (some (← mkAppM ``Eq.refl #[args[1]!])) - catch _ => pure none + else pure (some (mkApp2 (mkConst ``Eq.refl [.succ .zero]) args[0]! args[1]!)) return ⟨e, proof⟩ if let some r ← optimizeNat? f args then return r if let some e ← optimizeInt? f args then return ⟨e, none⟩ @@ -101,6 +99,9 @@ def optimizeAppAux (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult let origArgs := args let args ← reorderOperands f args let result ← optimizeAppCore f args + let reordered := origArgs.size != args.size || + (origArgs.zip args).any fun (a, b) => !exprEq a b + if !reordered then return result let reorderProof := detectReorderProof (mkAppN f origArgs) (mkAppN f args) /- trace[Optimize.proof] "optimizeAppAux reorder: f={reprStr f} swapped={reorderProof.isSome} resultProof={result.proof.isSome}" -/ match reorderProof with diff --git a/Tests/Optimize/OptimizeNat.lean b/Tests/Optimize/OptimizeNat.lean index 5f2f5af..5c43527 100644 --- a/Tests/Optimize/OptimizeNat.lean +++ b/Tests/Optimize/OptimizeNat.lean @@ -1,4 +1,3 @@ - import Tests.Optimize.OptimizeNat.OptimizeNatAdd import Tests.Optimize.OptimizeNat.OptimizeNatDiv import Tests.Optimize.OptimizeNat.OptimizeNatMod diff --git a/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean b/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean index 7ba55c8..52be59a 100644 --- a/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean +++ b/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean @@ -178,6 +178,7 @@ elab "natAddCstProp_5" : term => return natAddCstProp_5 #testOptimize [ "NatMulCstProp_11" ] ∀ (x : Nat), 10 * (20 * ((x - 3) - 7)) = 200 * (x - 10) ===> True -- 10 * (20 * (100 - (x + 190))) = 0 ===> True +set_option maxRecDepth 4096 in #testOptimize [ "NatMulCstProp_12" ] ∀ (x : Nat), 10 * (20 * (100 - (x + 190))) = 0 ===> True diff --git a/Tests/Optimize/OptimizeNat/OptimizeNatSub.lean b/Tests/Optimize/OptimizeNat/OptimizeNatSub.lean index 4dea85d..c0a1463 100644 --- a/Tests/Optimize/OptimizeNat/OptimizeNatSub.lean +++ b/Tests/Optimize/OptimizeNat/OptimizeNatSub.lean @@ -477,6 +477,7 @@ elab "natSubSubRight_2" : term => return natSubSubRight_2 #testOptimize [ "NatSubSubRight_2" ] ∀ (x y : Nat), (x - 20) - 10 < y ===> natSubSubRight_2 -- ((x - 100) - 45) - 125 = x - 270 ===> True +set_option maxRecDepth 4096 in #testOptimize [ "NatSubSubRight_3" ] ∀ (x : Nat), ((x - 100) - 45) - 125 = x - 270 ===> True -- ((200 - x) - 45) - 125 = 30 - x ===> True @@ -490,9 +491,11 @@ set_option maxRecDepth 4096 in #testOptimize [ "NatSubSubRight_6" ] ∀ (x : Nat), ((x - 200) - 45) - ((125 - x) - 130) = x - 245 ===> True -- (((x - 60) - 40) - 45) - 125 = x - 270 ===> True +set_option maxRecDepth 4096 in #testOptimize [ "NatSubSubRight_7" ] ∀ (x : Nat), (((x - 60) - 40) - 45) - 125 = x - 270 ===> True -- (100 - ((x - 100) - 45)) = 100 - (x - 145) ===> True +set_option maxRecDepth 4096 in #testOptimize [ "NatSubSubRight_8" ] ∀ (x : Nat), (100 - ((x - 100) - 45)) = 100 - (x - 145) ===> True @@ -536,6 +539,7 @@ elab "natAddSub_2" : term => return natAddSub_2 #testOptimize [ "NatAddSub_6" ] ∀ (x : Nat), (50 + (40 + (x + 60))) - 120 = 30 + x ===> True -- (((230 + x) - 20) - 120) - 40 = 50 + x ===> True +set_option maxRecDepth 4096 in #testOptimize [ "NatAddSub_7" ] ∀ (x : Nat), (((230 + x) - 20) - 120) - 40 = 50 + x ===> True -- (((x + 180) - 100) - 20) + 120 = 180 + x ===> True From 30a4f8baf9bc55002b8789be06cc99c442c2a1e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Thu, 19 Mar 2026 11:47:15 -0300 Subject: [PATCH 30/31] proof reconstruction support for multi-arg rewrites --- Blaster/Optimize/Basic.lean | 14 +- Blaster/Optimize/OptimizeStack.lean | 1 - Blaster/Optimize/Rewriting/OptimizeApp.lean | 8 +- Blaster/Reconstruct/Basic.lean | 184 ++++++++--------- .../Optimize/OptimizeNat/OptimizeNatAdd.lean | 10 +- .../Optimize/OptimizeNat/OptimizeNatSub.lean | 190 +++++++++--------- Tests/Reconstruct/Basic.lean | 6 + 7 files changed, 209 insertions(+), 204 deletions(-) diff --git a/Blaster/Optimize/Basic.lean b/Blaster/Optimize/Basic.lean index 4a7f74a..6d0dbc6 100644 --- a/Blaster/Optimize/Basic.lean +++ b/Blaster/Optimize/Basic.lean @@ -128,12 +128,13 @@ partial def optimizeExprAux (stack : List OptimizeStack) (proof : Option Expr := | .AppOptimizeExplicitArgs f args idx stopIdx pInfo mInfo origArgs argProofs :: xs => if idx ≥ stopIdx then - -- recover proof from argProofs if it was lost during arg processing - let proof := match proof with - | some _ => proof - | none => argProofs.findSome? id - -- annotating proof with position-from-end so it survives unfolding - let proof ← annotateProofWithPosFromEnd args origArgs argProofs proof + let mut rewrittenCount : Nat := 0 + for i in [:argProofs.size] do + if argProofs[i]!.isSome && !exprEq origArgs[i]! args[i]! then + rewrittenCount := rewrittenCount + 1 + let proof ← if rewrittenCount >= 1 then + buildMultiArgCongrProof f origArgs args argProofs proof + else pure proof -- normalizing ite/match function application if let some re ← normChoiceApplication? f args then -- trace[Optimize.normChoiceApp] "normalizing choice application {reprStr f} {reprStr args} => {reprStr re}" @@ -376,7 +377,6 @@ def command (sOpts: BlasterOptions) (e : Expr) : MetaM (Expr × Option Expr × T initialize registerTraceClass `Optimize.expr - /- registerTraceClass `Optimize.proof -/ registerTraceClass `Optimize.funPropagation registerTraceClass `Optimize.normChoiceApp registerTraceClass `Optimize.normPartial diff --git a/Blaster/Optimize/OptimizeStack.lean b/Blaster/Optimize/OptimizeStack.lean index 793efd9..261c3f4 100644 --- a/Blaster/Optimize/OptimizeStack.lean +++ b/Blaster/Optimize/OptimizeStack.lean @@ -249,7 +249,6 @@ def stackContinuity let argChanged := !exprEq optExpr origArgs[idx]! let argProofs' := if proof.isSome && argChanged then argProofs.set! idx proof else argProofs let proof' := if proof.isSome && argChanged then none else proof - /- trace[Optimize.proof] "AppOptExplArgs: idx={idx} argChanged={argChanged} proof={proof.isSome} argProofs[idx]={argProofs'[idx]!.isSome}" -/ return Sum.inl (.AppOptimizeExplicitArgs f (args.set! idx optExpr) (idx + 1) stopIdx pInfo mInfo origArgs argProofs' :: xs, proof') diff --git a/Blaster/Optimize/Rewriting/OptimizeApp.lean b/Blaster/Optimize/Rewriting/OptimizeApp.lean index ef07164..c24647b 100644 --- a/Blaster/Optimize/Rewriting/OptimizeApp.lean +++ b/Blaster/Optimize/Rewriting/OptimizeApp.lean @@ -103,7 +103,6 @@ def optimizeAppAux (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult (origArgs.zip args).any fun (a, b) => !exprEq a b if !reordered then return result let reorderProof := detectReorderProof (mkAppN f origArgs) (mkAppN f args) - /- trace[Optimize.proof] "optimizeAppAux reorder: f={reprStr f} swapped={reorderProof.isSome} resultProof={result.proof.isSome}" -/ match reorderProof with | none => return result | some rp => @@ -136,7 +135,6 @@ def optimizeApp (stack : List OptimizeStack) (incomingProof : Option Expr := none) (skipPropCheck := false) : TranslateEnvT OptimizeContinuity := do let ⟨e, newProof⟩ ← optimizeAppAux f args - /- trace[Optimize.proof] "optimizeApp: f={reprStr f} incomingProof={incomingProof.isSome} newProof={newProof.isSome}" -/ let proof ← match incomingProof, newProof with | some inP, some np => do -- inP : origArg = optArg (an argument was rewritten) @@ -144,13 +142,15 @@ def optimizeApp -- build congrArg to lift the arg rewrite to application level, then compose match ← buildCongrArgFromProof f args inP with | some congrP => composeProofs? (some congrP) (some np) - | none => pure (some np) + | none => + match ← composeProofs? (some inP) (some np) with + | some composed => pure (some composed) + | none => pure (some np) | some inP, none => do match ← buildCongrArgFromProof f args inP with | some congrP => pure (some congrP) | none => pure (some inP) | none, _ => pure newProof - /- trace[Optimize.proof] "optimizeApp: composedProof={proof.isSome}" -/ if ← isRestart then resetRestart match proof with diff --git a/Blaster/Reconstruct/Basic.lean b/Blaster/Reconstruct/Basic.lean index 043e980..99efb9f 100644 --- a/Blaster/Reconstruct/Basic.lean +++ b/Blaster/Reconstruct/Basic.lean @@ -6,8 +6,13 @@ open Lean Meta Blaster.Optimize namespace Blaster.Reconstruct /-- Compose two proof certificates `p₁ : a = b` and `p₂ : b = c` into `p : a = c` via `Eq.trans`. -/ -def composeProofs (p₁ p₂ : Expr) : MetaM Expr := - mkAppM ``Eq.trans #[p₁, p₂] +def composeProofs (p₁ p₂ : Expr) : MetaM Expr := do + let t₁ ← inferType p₁ + let some (α, a, b) := t₁.eq? | throwError "composeProofs: p₁ is not an equality proof" + let t₂ ← inferType p₂ + let some (_, _, c) := t₂.eq? | throwError "composeProofs: p₂ is not an equality proof" + let u ← getLevel α + return mkApp6 (mkConst ``Eq.trans [u]) α a b c p₁ p₂ /-- Compose two optional proof certificates via `Eq.trans`. If either is `none`, the other is returned unchanged. -/ @@ -16,87 +21,14 @@ def composeProofs? (opt_p₁ opt_p₂ : Option Expr) : MetaM (Option Expr) := | none, p => return p | p, none => return p | some p₁, some p₂ => - try return some (← composeProofs p₁ p₂) - catch _ => - /- trace[Optimize.expr] "composeProofs? failed: {e.toMessageData}" -/ - /- trace[Optimize.expr] " p₁ type: {← try inferType p₁ catch _ => pure (toExpr "??")}" -/ - /- trace[Optimize.expr] " p₂ type: {← try inferType p₂ catch _ => pure (toExpr "??")}" -/ - return none - -/-- Tag for annotating the argument position from the end. -/ -def argPosFromEndKey : Name := `_blaster.argPosFromEnd - -/-- Annotate the proof with the position relative to the end, so it survives - unfolding (which strips implicit args from the front). - Compares args against origArgs via isDefEq to find which argument - was actually rewritten, ignoring definitionally equal changes. -/ -def annotateProofWithPosFromEnd - (args : Array Expr) (origArgs : Array Expr) (argProofs : Array (Option Expr)) - (proof : Option Expr) : TranslateEnvT (Option Expr) := do - match proof with - | none => return none - | some p => - let mut proofIdx? : Option Nat := none - for i in [:argProofs.size] do - if (argProofs[i]!).isSome then - if !(← withLocalContext $ - withNewMCtxDepth $ - withReducible $ - isDefEq args[i]! origArgs[i]!) then - proofIdx? := some i - match proofIdx? with - | some proofIdx => - let posFromEnd := args.size - 1 - proofIdx - return some (Expr.mdata (MData.empty.setNat argPosFromEndKey posFromEnd) p) - | none => return some p - -/-- Given a function application f(args) and a proof that one argument was rewritten - (argProof : origArg = optArg), build a congruence proof that lifts the rewrite - to the full application level. - Finds i such that args[i] was rewritten, using either an MData annotation - encoding position-from-end, or a reverse isDefEq search as fallback. - Then builds: - congrFun (... (congrFun (congrArg (f a₀..a_{i-1}) proof) a_{i+1}) ...) a_{n-1} - Returns none if the rewritten argument cannot be identified. -/ -def buildCongrArgFromProof (f : Expr) (args : Array Expr) (argProof : Expr) - : MetaM (Option Expr) := do - try - let (proof, annotatedIdx?) := match argProof with - | Expr.mdata d p => - let posFromEnd := d.getNat argPosFromEndKey args.size - if posFromEnd < args.size then - (p, some (args.size - 1 - posFromEnd)) - else (argProof, none) - | _ => (argProof, none) - let idx? ← match annotatedIdx? with - | some idx => pure (some idx) - | none => - let proofType ← inferType proof - let some (_, _origArg, optArg) := proofType.eq? | return none - let mut found := none - for i in [:args.size] do - let i' := args.size - 1 - i - if ← isDefEq args[i']! optArg then - found := some i' - break - pure found - match idx? with - | some idx => - let partialApp := mkAppN f (args[:idx]) - let mut p ← mkCongrArg partialApp proof - for j in [idx + 1 : args.size] do - p ← mkCongrFun p args[j]! - return some p - | none => return none - catch _ => - /- trace[Optimize.expr] "buildCongrArgFromProof failed: {e.toMessageData}" -/ - /- trace[Optimize.expr] " f={← ppExpr f} args.size={args.size}" -/ - return none + try + return some (← composeProofs p₁ p₂) + catch _ => return none -/-- Detect if the difference between origExpr and optExpr is a simple +/-- Detect if the difference between `origExpr` and `optExpr` is a simple commutativity swap at the top level, and return the corresponding proof. - Returns `some (a_comm a b : a ⊕ b = b ⊕ a)` when origExpr = `a ⊕ b` - and optExpr = `b ⊕ a`. -/ + Returns `some (a_comm a b : a ⊕ b = b ⊕ a)` when `origExpr = a ⊕ b` + and `optExpr = b ⊕ a`. -/ def detectReorderProof (origExpr optExpr : Expr) : Option Expr := if Blaster.Optimize.exprEq origExpr optExpr then none else @@ -135,7 +67,7 @@ def detectReorderProof (origExpr optExpr : Expr) : Option Expr := | some n1, some n2 => n1 == n2 | _, _ => false -/-- MetaM fallback for detecting commutativity between expressions in different +/-- Fallback for detecting commutativity between expressions in different representations (e.g., `HAdd.hAdd` vs `Nat.add`). Uses `isDefEq` to compare operands. Returns a proof `origExpr = targetExpr` via the appropriate commutativity lemma. Only invoked when `detectReorderProof` fails due to representation mismatch. -/ @@ -172,7 +104,79 @@ where else none | _ => none -/-- Resolve the proof for an Eq argument, bridging a potential gap between +/-- Tag for annotating proofs already at the application level -/ +def appLevelProofKey : Name := `_blaster.appLevelProof + +/-- Given a function application `f(args)` and a proof that one argument was rewritten + (`argProof : origArg = optArg`), build a congruence proof that lifts the rewrite + to the full application level. + + Uses a reverse `isDefEq` search to find `i` such that `args[i]` matches `optArg`. + Then builds: + `congrFun (... (congrFun (congrArg (f a₀..a_{i-1}) proof) a_{i+1}) ...) a_{n-1}` + + If `argProof` is annotated with `appLevelProofKey` (i.e., it is already an app-level + proof from `buildMultiArgCongrProof`), it is returned as-is to avoid double-lifting. + + Returns `none` if the rewritten argument cannot be identified. -/ +def buildCongrArgFromProof (f : Expr) (args : Array Expr) (argProof : Expr) + : MetaM (Option Expr) := do + if let Expr.mdata d _ := argProof then + if d.getBool appLevelProofKey false then + return some argProof + try + let proofType ← inferType argProof + let some (_, _origArg, optArg) := proofType.eq? | return none + let mut idx? : Option Nat := none + for i in [:args.size] do + let i' := args.size - 1 - i + if ← isDefEq args[i']! optArg then + idx? := some i' + break + match idx? with + | some idx => + let partialApp := mkAppN f (args[:idx]) + let mut p ← mkCongrArg partialApp argProof + for j in [idx + 1 : args.size] do + p ← mkCongrFun p args[j]! + return some p + | none => return none + catch _ => return none + +/-- Build a combined congruence proof when multiple arguments were rewritten. + + Given `f(origArgs)` where some `origArgs[i]` were rewritten to `args[i]` with + `argProofs[i]`, composes individual congruence steps: + `f(orig₀, orig₁, ...) = f(opt₀, orig₁, ...) = f(opt₀, opt₁, ...) = ...` + + The result is annotated with `appLevelProofKey` so that downstream calls to + `buildCongrArgFromProof` return it as-is rather than attempting to double-lift. -/ +def buildMultiArgCongrProof (f : Expr) (origArgs args : Array Expr) + (argProofs : Array (Option Expr)) (carriedProof : Option Expr) + : MetaM (Option Expr) := do + let mut rewrittenIndices := #[] + for i in [:argProofs.size] do + if let some _ := argProofs[i]! then + if !exprEq origArgs[i]! args[i]! then + rewrittenIndices := rewrittenIndices.push i + if rewrittenIndices.isEmpty then return carriedProof + let mut composedProof : Option Expr := carriedProof + let mut currentArgs := origArgs + for i in rewrittenIndices do + if let some ap := argProofs[i]! then + try + let partialApp := mkAppN f (currentArgs[:i]) + let mut step ← mkCongrArg partialApp ap + for j in [i + 1 : currentArgs.size] do + step ← mkCongrFun step currentArgs[j]! + currentArgs := currentArgs.set! i args[i]! + composedProof ← composeProofs? composedProof (some step) + catch _ => currentArgs := currentArgs.set! i args[i]! + return match composedProof with + | some p => some (Expr.mdata (MData.empty.setBool appLevelProofKey true) p) + | none => none + +/-- Resolve the proof for an `Eq` argument, bridging a potential gap between the proof source and the original expression. When `argProof` is `none`, falls back to `detectReorderProof`. @@ -183,12 +187,9 @@ where to obtain `bridge : origArg = source`, and composes `Eq.trans bridge p`. -/ def resolveArgProof (argProof : Option Expr) (origArg optArg : Expr) : MetaM (Option Expr) := match argProof with - | none => do - /- trace[Optimize.expr] "resolveArgProof: none → detectReorderProof orig={reprStr origArg} opt={reprStr optArg}" -/ - pure (detectReorderProof origArg optArg) + | none => pure (detectReorderProof origArg optArg) | some p => do let proofType ← inferType p - /- trace[Optimize.proof] "resolveArgProof: some p, type={reprStr proofType}" -/ match proofType.eq? with | some (_, proofSrc, _) => if Blaster.Optimize.exprEq proofSrc origArg then @@ -202,17 +203,16 @@ def resolveArgProof (argProof : Option Expr) (origArg optArg : Expr) : MetaM (Op | none => pure (some p) | none => pure (some p) -/-- Build a proof of `orig_lhs = orig_rhs` from individual Eq argument proofs +/-- Build a proof of `orig_lhs = orig_rhs` from individual `Eq` argument proofs when both sides have been optimized to the same expression. Given: - - `lhsProof : orig_lhs = opt_lhs` (or none if LHS unchanged) - - `rhsProof : orig_rhs = opt_rhs` (or none if RHS unchanged) + - `lhsProof : orig_lhs = opt_lhs` (or `none` if LHS unchanged) + - `rhsProof : orig_rhs = opt_rhs` (or `none` if RHS unchanged) - `opt_lhs` and `opt_rhs` are definitionally equal Constructs `Eq.trans lhsProof (Eq.symm rhsProof) : orig_lhs = orig_rhs` - with the appropriate simplification when either side is none (rfl). --/ + with the appropriate simplification when either side is `none` (rfl). -/ def buildEqReflProof (lhsProof rhsProof : Option Expr) : MetaM (Option Expr) := match lhsProof, rhsProof with | none, none => pure none diff --git a/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean b/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean index 7b2ab48..4c58f21 100644 --- a/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean +++ b/Tests/Optimize/OptimizeNat/OptimizeNatAdd.lean @@ -280,7 +280,6 @@ elab "natAddCommut_5" : term => return natAddCommut_5 #testOptimize [ "NatAddCommut_5", proof ] ∀ (x y : Nat), y < x + 40 ===> natAddCommut_5 -- (x + (y + 20)) + z = z + ((y + 20) + x) ===> True -/- set_option trace.Optimize.proof true in -/ #testOptimize [ "NatAddCommut_6" ] ∀ (x y z : Nat), (x + (y + 20)) + z = z + ((y + 20) + x) ===> True --- (x - y) + (p + q) ===> (p + q) + (x - y) @@ -300,8 +299,7 @@ elab "natAddCommut_5" : term => return natAddCommut_5 #testOptimize [ "NatAddVar_2", proof ] ∀ (x y : Nat), (x + 0) + y = y + x ===> True -- (x + 0) + (y + 0) = y + x ===> True -/- set_option trace.Optimize.proof true in -/ -#testOptimize [ "NatAddVar_3" ] ∀ (x y : Nat), (x + 0) + (y + 0) = y + x ===> True +#testOptimize [ "NatAddVar_3", proof ] ∀ (x y : Nat), (x + 0) + (y + 0) = y + x ===> True -- x + y < 10 ===> x + y < 10 def natAddVar_4 : Expr := @@ -338,15 +336,13 @@ def natAddReduce_1 : Expr := Lean.Expr.lit (Lean.Literal.natVal 100) elab "natAddReduce_1" : term => return natAddReduce_1 set_option maxRecDepth 4096 in -/- set_option trace.Optimize.proof true in -/ -#testOptimize [ "NatAddReduce_1" ] (100 + ((180 - (x + 40)) - 150)) + ((200 - y) - 320) ===> natAddReduce_1 +#testOptimize [ "NatAddReduce_1", proof ] (100 + ((180 - (x + 40)) - 150)) + ((200 - y) - 320) ===> natAddReduce_1 def natAddReduce_2 : Expr := Lean.Expr.lit (Lean.Literal.natVal 124) elab "natAddReduce_2" : term => return natAddReduce_2 -- (100 + ((180 - (x + 40)) - 150)) + (((20 - y) - 50) + 24) ===> 124 set_option maxRecDepth 4096 in -/- set_option trace.Optimize.proof true in -/ -#testOptimize [ "NatAddReduce_2" ] (100 + ((180 - (x + 40)) - 150)) + (((20 - y) - 50) + 24) ===> natAddReduce_2 +#testOptimize [ "NatAddReduce_2", proof ] (100 + ((180 - (x + 40)) - 150)) + (((20 - y) - 50) + 24) ===> natAddReduce_2 end Test.OptimizeNatAdd diff --git a/Tests/Optimize/OptimizeNat/OptimizeNatSub.lean b/Tests/Optimize/OptimizeNat/OptimizeNatSub.lean index c0a1463..dfe5a30 100644 --- a/Tests/Optimize/OptimizeNat/OptimizeNatSub.lean +++ b/Tests/Optimize/OptimizeNat/OptimizeNatSub.lean @@ -14,27 +14,27 @@ namespace Test.OptimizeNatSub def natSubCst_1 : Expr := Lean.Expr.lit (Lean.Literal.natVal 73) elab "natSubCst_1" : term => return natSubCst_1 -#testOptimize [ "NatSubCst_1" ] (123 : Nat) - 50 ===> natSubCst_1 +#testOptimize [ "NatSubCst_1", proof ] (123 : Nat) - 50 ===> natSubCst_1 -- 123 - 0 ===> 123 def natSubCst_2 : Expr := Lean.Expr.lit (Lean.Literal.natVal 123) elab "natSubCst_2" : term => return natSubCst_2 -#testOptimize [ "NatSubCst_2" ] (123 : Nat) - 0 ===> natSubCst_2 +#testOptimize [ "NatSubCst_2", proof ] (123 : Nat) - 0 ===> natSubCst_2 def natSubCst_3 : Expr := Lean.Expr.lit (Lean.Literal.natVal 0) elab "natSubCst_3" : term => return natSubCst_3 -- 0 - 1 ===> 0 -#testOptimize [ "NatSubCst_3" ] (0 : Nat) - 1 ===> natSubCst_3 +#testOptimize [ "NatSubCst_3", proof ] (0 : Nat) - 1 ===> natSubCst_3 -- 0 - 5 ===> 0 -#testOptimize [ "NatSubCst_4" ] (0 : Nat) - 5 ===> natSubCst_3 +#testOptimize [ "NatSubCst_4", proof ] (0 : Nat) - 5 ===> natSubCst_3 -- 123 - 124 ===> 0 -#testOptimize [ "NatSubCst_5" ] (123 : Nat) - 124 ===> natSubCst_3 +#testOptimize [ "NatSubCst_5", proof ] (123 : Nat) - 124 ===> natSubCst_3 -- 123 - 300 ===> 0 -#testOptimize [ "NatSubCst_6" ] (123 : Nat) - 300 ===> natSubCst_3 +#testOptimize [ "NatSubCst_6", proof ] (123 : Nat) - 300 ===> natSubCst_3 /-! Test cases for simplification rule `n1 - n2 ==> 0 (if n1 =ₚₜᵣ n2)`. -/ @@ -57,61 +57,61 @@ elab "natSubReduceZero_1" : term => return natSubReduceZero_1 #testOptimize [ "NatSubReduceZero_1" ] ∀ (x y : Nat), y > x - x ===> natSubReduceZero_1 -- x - x = 0 ===> True -#testOptimize [ "NatSubReduceZero_2" ] ∀ (x : Nat), x - x = 0 ===> True +#testOptimize [ "NatSubReduceZero_2", proof ] ∀ (x : Nat), x - x = 0 ===> True -- (x + y) - (y + x) = 0 ===> True -#testOptimize [ "NatSubReduceZero_3" ] ∀ (x y : Nat), (x + y) - (y + x) = 0 ===> True +#testOptimize [ "NatSubReduceZero_3", proof ] ∀ (x y : Nat), (x + y) - (y + x) = 0 ===> True -- (x - y) - (x - y) = 0 ===> True -#testOptimize [ "NatSubReduceZero_4" ] ∀ (x y : Nat), (x - y) - (x - y) = 0 ===> True +#testOptimize [ "NatSubReduceZero_4", proof ] ∀ (x y : Nat), (x - y) - (x - y) = 0 ===> True -- (x + 0) - x = 0 ===> True -#testOptimize [ "NatSubReduceZero_5" ] ∀ (x : Nat), (x + 0) - x = 0 ===> True +#testOptimize [ "NatSubReduceZero_5", proof ] ∀ (x : Nat), (x + 0) - x = 0 ===> True -- x - (x + 0) = 0 ===> True -#testOptimize [ "NatSubReduceZero_6" ] ∀ (x : Nat), x - (x + 0) = 0 ===> True +#testOptimize [ "NatSubReduceZero_6", proof ] ∀ (x : Nat), x - (x + 0) = 0 ===> True -- x - (x + (100 - (200 + x))) = 0 ===> True -#testOptimize [ "NatSubReduceZero_7" ] ∀ (x : Nat), x - (x + (100 - (200 + x))) = 0 ===> True +#testOptimize [ "NatSubReduceZero_7", proof ] ∀ (x : Nat), x - (x + (100 - (200 + x))) = 0 ===> True /-! Test cases to ensure that simplification rule `n1 - n2 ==> 0 (if n1 =ₚₜᵣ n2)` is not wrongly applied. -/ -- x - y ===> Nat.sub x y -#testOptimize [ "NatSubReduceZeroUnchanged_1" ] ∀ (x y z : Nat), z > x - y ===> +#testOptimize [ "NatSubReduceZeroUnchanged_1", proof ] ∀ (x y z : Nat), z > x - y ===> ∀ (x y z : Nat), Nat.sub x y < z -- (x + y) - z ===> Nat.sub (Nat.add x y) z -#testOptimize [ "NatSubReduceZeroUnchanged_2" ] ∀ (x y z m : Nat), (x + y) - z < m ===> +#testOptimize [ "NatSubReduceZeroUnchanged_2", proof ] ∀ (x y z m : Nat), (x + y) - z < m ===> ∀ (x y z m : Nat), Nat.sub (Nat.add x y) z < m -- (x - y) - z ===> Nat.sub (Nat.sub x y) z -#testOptimize [ "NatSubReduceZeroUnchanged_3" ] ∀ (x y z m : Nat), (x - y) - z < m ===> +#testOptimize [ "NatSubReduceZeroUnchanged_3", proof ] ∀ (x y z m : Nat), (x - y) - z < m ===> ∀ (x y z m : Nat), Nat.sub (Nat.sub x y) z < m -- (y + (100 - (200 + x))) - x ===> Nat.sub y x -#testOptimize [ "NatSubReduceZeroUnchanged_4" ] ∀ (x y z : Nat), (y + (100 - (200 + x))) - x < z ===> +#testOptimize [ "NatSubReduceZeroUnchanged_4", proof ] ∀ (x y z : Nat), (y + (100 - (200 + x))) - x < z ===> ∀ (x y z : Nat), Nat.sub y x < z /-! Test cases for simplification rule `0 - n ===> 0`. -/ -- 0 - x = 0 ===> True -#testOptimize [ "NatSubLeftZero_1" ] ∀ (x : Nat), 0 - x = 0 ===> True +#testOptimize [ "NatSubLeftZero_1", proof ] ∀ (x : Nat), 0 - x = 0 ===> True -- 0 - x ===> 0 -#testOptimize [ "NatSubLeftZero_2" ] ∀ (x y : Nat), (0 - x) + x ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatSubLeftZero_2", proof ] ∀ (x y : Nat), (0 - x) + x ≤ y ===> ∀ (x y : Nat), ¬ y < x -- Nat.zero - x = 0 ===> True -#testOptimize [ "NatSubLeftZero_3" ] ∀ (x : Nat), Nat.zero - x = 0 ===> True +#testOptimize [ "NatSubLeftZero_3", proof ] ∀ (x : Nat), Nat.zero - x = 0 ===> True -- Nat.zero - x ===> 0 -#testOptimize [ "NatSubLeftZero_4" ] ∀ (x y : Nat), (Nat.zero - x) + x ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatSubLeftZero_4", proof ] ∀ (x y : Nat), (Nat.zero - x) + x ≤ y ===> ∀ (x y : Nat), ¬ y < x -- (27 - 27) - x ===> 0 -#testOptimize [ "NatSubLeftZero_5" ] ∀ (x y : Nat), ((27 - 27) - x) + x ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatSubLeftZero_5", proof ] ∀ (x y : Nat), ((27 - 27) - x) + x ≤ y ===> ∀ (x y : Nat), ¬ y < x -- (127 - 145) - x ===> 0 -#testOptimize [ "NatSubLeftZero_6" ] ∀ (x y : Nat), ((127 - 145) - x) + x ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatSubLeftZero_6", proof ] ∀ (x y : Nat), ((127 - 145) - x) + x ≤ y ===> ∀ (x y : Nat), ¬ y < x -- 1 - x ===> 1 - x def natSubLeftZeroUnchanged_1 : Expr := @@ -133,13 +133,13 @@ def natSubLeftZeroUnchanged_1 : Expr := elab "natSubLeftZeroUnchanged_1" : term => return natSubLeftZeroUnchanged_1 -#testOptimize [ "NatSubLeftZeroUnchanged_1" ] ∀ (x y : Nat), (1 - x) < y ===> natSubLeftZeroUnchanged_1 +#testOptimize [ "NatSubLeftZeroUnchanged_1", proof ] ∀ (x y : Nat), (1 - x) < y ===> natSubLeftZeroUnchanged_1 -- (27 - 26) - x ===> 1 - x -#testOptimize [ "NatSubLeftZeroUnchanged_2" ] ∀ (x y : Nat), (27 - 26) - x < y ===> natSubLeftZeroUnchanged_1 +#testOptimize [ "NatSubLeftZeroUnchanged_2", proof ] ∀ (x y : Nat), (27 - 26) - x < y ===> natSubLeftZeroUnchanged_1 -- (Nat.zero + 1) - x ===> 1 - x -#testOptimize [ "NatSubLeftZeroUnchanged_3" ] ∀ (x y : Nat), (Nat.zero + 1) - x < y ===> natSubLeftZeroUnchanged_1 +#testOptimize [ "NatSubLeftZeroUnchanged_3", proof ] ∀ (x y : Nat), (Nat.zero + 1) - x < y ===> natSubLeftZeroUnchanged_1 -- (127 - 40) - x ===> 87 - x def natSubLeftZeroUnchanged_4 : Expr := @@ -161,31 +161,31 @@ def natSubLeftZeroUnchanged_4 : Expr := elab "natSubLeftZeroUnchanged_4" : term => return natSubLeftZeroUnchanged_4 -#testOptimize [ "NatSubLeftZeroUnchanged_4" ] ∀ (x y : Nat), (127 - 40) - x < y ===> natSubLeftZeroUnchanged_4 +#testOptimize [ "NatSubLeftZeroUnchanged_4", proof ] ∀ (x y : Nat), (127 - 40) - x < y ===> natSubLeftZeroUnchanged_4 /-! Test cases for simplification rule `n - 0 ===> n`. -/ -- x - 0 = x ===> x -#testOptimize [ "NatSubRightZero_1" ] ∀ (x : Nat), x - 0 = x ===> True +#testOptimize [ "NatSubRightZero_1", proof ] ∀ (x : Nat), x - 0 = x ===> True -- x - Nat.zero = x ===> x -#testOptimize [ "NatSubRightZero_2" ] ∀ (x : Nat), x - Nat.zero = x ===> True +#testOptimize [ "NatSubRightZero_2", proof ] ∀ (x : Nat), x - Nat.zero = x ===> True -- x - 0 ===> x -#testOptimize [ "NatSubRightZero_3" ] ∀ (x y : Nat), x - 0 ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatSubRightZero_3", proof ] ∀ (x y : Nat), x - 0 ≤ y ===> ∀ (x y : Nat), ¬ y < x -- x - Nat.zero ===> x -#testOptimize [ "NatSubRightZero_4" ] ∀ (x y : Nat), x - Nat.zero ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatSubRightZero_4", proof ] ∀ (x y : Nat), x - Nat.zero ≤ y ===> ∀ (x y : Nat), ¬ y < x -- x - 0 ===> x -#testOptimize [ "NatSubRightZero_5" ] ∀ (x y : Nat), x - 0 ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatSubRightZero_5", proof ] ∀ (x y : Nat), x - 0 ≤ y ===> ∀ (x y : Nat), ¬ y < x -- x - (27 - 27) ===> x -#testOptimize [ "NatSubRightZero_6" ] ∀ (x y : Nat), x - (27 - 27) ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatSubRightZero_6", proof ] ∀ (x y : Nat), x - (27 - 27) ≤ y ===> ∀ (x y : Nat), ¬ y < x -- x - (27 - 145) ===> x -#testOptimize [ "NatSubRightZero_7" ] ∀ (x y : Nat), x - (27 - 145) ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatSubRightZero_7", proof ] ∀ (x y : Nat), x - (27 - 145) ≤ y ===> ∀ (x y : Nat), ¬ y < x -- x - 1 ===> x - 1 def natSubRightZeroUnchanged_1 : Expr := @@ -207,13 +207,13 @@ def natSubRightZeroUnchanged_1 : Expr := elab "natSubRightZeroUnchanged_1" : term => return natSubRightZeroUnchanged_1 -#testOptimize [ "NatSubRightZeroUnchanged_1" ] ∀ (x y : Nat), (x - 1) < y ===> natSubRightZeroUnchanged_1 +#testOptimize [ "NatSubRightZeroUnchanged_1", proof ] ∀ (x y : Nat), (x - 1) < y ===> natSubRightZeroUnchanged_1 -- x - (127 - 126) ===> x - 1 -#testOptimize [ "NatSubRightZeroUnchanged_2" ] ∀ (x y : Nat), x - (127 - 126) < y ===> natSubRightZeroUnchanged_1 +#testOptimize [ "NatSubRightZeroUnchanged_2", proof ] ∀ (x y : Nat), x - (127 - 126) < y ===> natSubRightZeroUnchanged_1 -- x - (Nat.zero + 1) ===> x - 1 -#testOptimize [ "NatSubRightZeroUnchanged_3" ] ∀ (x y : Nat), x - (Nat.zero + 1) < y ===> natSubRightZeroUnchanged_1 +#testOptimize [ "NatSubRightZeroUnchanged_3", proof ] ∀ (x y : Nat), x - (Nat.zero + 1) < y ===> natSubRightZeroUnchanged_1 -- x - (127 - 40) ===> x - 87 def natSubRightZeroUnchanged_4 : Expr := @@ -235,13 +235,13 @@ def natSubRightZeroUnchanged_4 : Expr := elab "natSubRightZeroUnchanged_4" : term => return natSubRightZeroUnchanged_4 -#testOptimize [ "NatSubRightZeroUnchanged_4" ] ∀ (x y : Nat), x - (127 - 40) < y ===> natSubRightZeroUnchanged_4 +#testOptimize [ "NatSubRightZeroUnchanged_4", proof ] ∀ (x y : Nat), x - (127 - 40) < y ===> natSubRightZeroUnchanged_4 /-! Test cases for simplification rule `N1 - (N2 + n) ===> (N1 "-" N2) - n`. -/ -- 120 - (40 + x) = 80 - x ===> True -#testOptimize [ "NatSubAdd_1" ] ∀ (x : Nat), 120 - (40 + x) = 80 - x ===> True +#testOptimize [ "NatSubAdd_1", proof ] ∀ (x : Nat), 120 - (40 + x) = 80 - x ===> True -- 120 - (40 + x) ===> 80 - x def natSubAdd_2 : Expr := @@ -263,37 +263,37 @@ def natSubAdd_2 : Expr := elab "natSubAdd_2" : term => return natSubAdd_2 -#testOptimize [ "NatSubAdd_2" ] ∀ (x y : Nat), 120 - (40 + x) < y ===> natSubAdd_2 +#testOptimize [ "NatSubAdd_2", proof ] ∀ (x y : Nat), 120 - (40 + x) < y ===> natSubAdd_2 -- 120 - (x + 40) = 80 - x ===> True -#testOptimize [ "NatSubAdd_3" ] ∀ (x : Nat), 120 - (x + 40) = 80 - x ===> True +#testOptimize [ "NatSubAdd_3", proof ] ∀ (x : Nat), 120 - (x + 40) = 80 - x ===> True -- 120 - (140 + x) = 0 ===> True -#testOptimize [ "NatSubAdd_4" ] ∀ (x : Nat), 120 - (140 + x) = 0 ===> True +#testOptimize [ "NatSubAdd_4", proof ] ∀ (x : Nat), 120 - (140 + x) = 0 ===> True -- 120 - (10 + (30 + x)) = 80 - x ===> True -#testOptimize [ "NatSubAdd_5" ] ∀ (x : Nat), 120 - (10 + (30 + x)) = 80 - x ===> True +#testOptimize [ "NatSubAdd_5", proof ] ∀ (x : Nat), 120 - (10 + (30 + x)) = 80 - x ===> True -- 120 - (10 + (x + 30)) = 80 - x ===> True -#testOptimize [ "NatSubAdd_6" ] ∀ (x : Nat), 120 - (10 + (x + 30)) = 80 - x ===> True +#testOptimize [ "NatSubAdd_6", proof ] ∀ (x : Nat), 120 - (10 + (x + 30)) = 80 - x ===> True -- 120 - (10 - (50 + x)) = 120 ===> True -#testOptimize [ "NatSubAdd_7" ] ∀ (x : Nat), 120 - (10 - (50 + x)) = 120 ===> True +#testOptimize [ "NatSubAdd_7", proof ] ∀ (x : Nat), 120 - (10 - (50 + x)) = 120 ===> True -- (250 - (150 + x)) - 20 = 80 - x ===> True -#testOptimize [ "NatSubAdd_8" ] ∀ (x : Nat), (250 - (150 + x)) - 20 = 80 - x ===> True +#testOptimize [ "NatSubAdd_8", proof ] ∀ (x : Nat), (250 - (150 + x)) - 20 = 80 - x ===> True -- 120 - (10 + (5 + (25 + x))) = 80 - x ===> True -#testOptimize [ "NatSubAdd_9" ] ∀ (x : Nat), 120 - (10 + (5 + (25 + x))) = 80 - x ===> True +#testOptimize [ "NatSubAdd_9", proof ] ∀ (x : Nat), 120 - (10 + (5 + (25 + x))) = 80 - x ===> True -- 120 - (10 + (150 - (100 + x))) = 110 - (50 - x) ===> True -#testOptimize [ "NatSubAdd_10" ] ∀ (x : Nat), 120 - (10 + (150 - (100 + x))) = 110 - (50 - x) ===> True +#testOptimize [ "NatSubAdd_10", proof ] ∀ (x : Nat), 120 - (10 + (150 - (100 + x))) = 110 - (50 - x) ===> True -- 120 - (10 + (50 - (x + 100))) = 110 ===> True -#testOptimize [ "NatSubAdd_11" ] ∀ (x : Nat), 120 - (10 + (50 - (x + 100))) = 110 ===> True +#testOptimize [ "NatSubAdd_11", proof ] ∀ (x : Nat), 120 - (10 + (50 - (x + 100))) = 110 ===> True -- 120 - (25 + ((x - 3) - 2)) = 95 - (x - 5) ===> True -#testOptimize [ "NatSubAdd_12" ] ∀ (x : Nat), 120 - (25 + ((x - 3) - 2)) = 95 - (x - 5) ===> True +#testOptimize [ "NatSubAdd_12", proof ] ∀ (x : Nat), 120 - (25 + ((x - 3) - 2)) = 95 - (x - 5) ===> True /-! Test cases to ensure that simplification rule `N1 - (N2 + n) ===> (N1 "-" N2) - n` @@ -323,7 +323,7 @@ def natSubAddUnchanged_1 : Expr := elab "natSubAddUnchanged_1" : term => return natSubAddUnchanged_1 -#testOptimize [ "NatSubAddUnchanged_1" ] ∀ (x y : Nat), 120 - (x - y) < y ===> natSubAddUnchanged_1 +#testOptimize [ "NatSubAddUnchanged_1", proof ] ∀ (x y : Nat), 120 - (x - y) < y ===> natSubAddUnchanged_1 -- 120 - (x + y) ===> 120 - (x + y) -- Must remain unchanged @@ -348,7 +348,7 @@ def natSubAddUnchanged_2 : Expr := elab "natSubAddUnchanged_2" : term => return natSubAddUnchanged_2 -#testOptimize [ "NatSubAddUnchanged_2" ] ∀ (x y : Nat), 120 - (x + y) < y ===> natSubAddUnchanged_2 +#testOptimize [ "NatSubAddUnchanged_2", proof ] ∀ (x y : Nat), 120 - (x + y) < y ===> natSubAddUnchanged_2 -- 120 - (x * y) ===> 120 - (x * y) -- Must remain unchanged @@ -373,7 +373,7 @@ def natSubAddUnchanged_3 : Expr := elab "natSubAddUnchanged_3" : term => return natSubAddUnchanged_3 -#testOptimize [ "NatSubAddUnchanged_3" ] ∀ (x y : Nat), 120 - (x * y) < y ===> natSubAddUnchanged_3 +#testOptimize [ "NatSubAddUnchanged_3", proof ] ∀ (x y : Nat), 120 - (x * y) < y ===> natSubAddUnchanged_3 -- 120 - (30 - x) ===> 120 - (30 - x) -- Must remain unchanged @@ -398,12 +398,12 @@ def natSubAddUnchanged_4 : Expr := elab "natSubAddUnchanged_4" : term => return natSubAddUnchanged_4 -#testOptimize [ "NatSubAddUnchanged_4" ] ∀ (x y : Nat), 120 - (30 - x) < y ===> natSubAddUnchanged_4 +#testOptimize [ "NatSubAddUnchanged_4", proof ] ∀ (x y : Nat), 120 - (30 - x) < y ===> natSubAddUnchanged_4 /-! Test cases for simplification rule `(N1 - n) - N2 ===> (N1 "-" N2) - n`. -/ -- (20 - x) - 10 = 10 - x ===> True -#testOptimize [ "NatSubSubLeft_1" ] ∀ (x : Nat), (20 - x) - 10 = 10 - x ===> True +#testOptimize [ "NatSubSubLeft_1", proof ] ∀ (x : Nat), (20 - x) - 10 = 10 - x ===> True -- (20 - x) - 10 ===> 10 - x def natSubSubLeft_2 : Expr := @@ -425,34 +425,34 @@ def natSubSubLeft_2 : Expr := elab "natSubSubLeft_2" : term => return natSubSubLeft_2 -#testOptimize [ "NatSubSubLeft_2" ] ∀ (x y : Nat), (20 - x) - 10 < y ===> natSubSubLeft_2 +#testOptimize [ "NatSubSubLeft_2", proof ] ∀ (x y : Nat), (20 - x) - 10 < y ===> natSubSubLeft_2 -- (45 - x) - 125 = 0 ===> True -#testOptimize [ "NatSubSubLeft_3" ] ∀ (x : Nat), (45 - x) - 125 = 0 ===> True +#testOptimize [ "NatSubSubLeft_3", proof ] ∀ (x : Nat), (45 - x) - 125 = 0 ===> True -- ((20 - x) - 5) - 10 = 5 - x ===> True -#testOptimize [ "NatSubSubLeft_4" ] ∀ (x : Nat), ((20 - x) - 5) - 10 = 5 - x ===> True +#testOptimize [ "NatSubSubLeft_4", proof ] ∀ (x : Nat), ((20 - x) - 5) - 10 = 5 - x ===> True -- 10 - ((20 - x) - 5) = 10 - (15 - x) ===> True -#testOptimize [ "NatSubSubLeft_5" ] ∀ (x : Nat), 10 - ((20 - x) - 5) = 10 - (15 - x) ===> True +#testOptimize [ "NatSubSubLeft_5", proof ] ∀ (x : Nat), 10 - ((20 - x) - 5) = 10 - (15 - x) ===> True -- (20 - (3 + x)) - 10 = 7 - x ===> True -#testOptimize [ "NatSubSubLeft_6" ] ∀ (x : Nat), (20 - (3 + x)) - 10 = 7 - x ===> True +#testOptimize [ "NatSubSubLeft_6", proof ] ∀ (x : Nat), (20 - (3 + x)) - 10 = 7 - x ===> True -- 20 - ((15 - (x + 10)) - 2) = 20 - (3 - x) ===> True -#testOptimize [ "NatSubSubLeft_7" ] ∀ (x : Nat), (20 - ((15 - (x + 10)) - 2)) = 20 - (3 - x) ===> True +#testOptimize [ "NatSubSubLeft_7", proof ] ∀ (x : Nat), (20 - ((15 - (x + 10)) - 2)) = 20 - (3 - x) ===> True -- 20 - (((30 - x) - 4) - 10) = 20 - (16 - x) ===> True -#testOptimize [ "NatSubSubLeft_8" ] ∀ (x : Nat), 20 - (((30 - x) - 4) - 10) = 20 - (16 - x) ===> True +#testOptimize [ "NatSubSubLeft_8", proof ] ∀ (x : Nat), 20 - (((30 - x) - 4) - 10) = 20 - (16 - x) ===> True -- ((200 - (150 + x)) - 20) - 10 = 20 - x ===> True -#testOptimize [ "NatSubSubLeft_9" ] ∀ (x : Nat), ((200 - (150 + x)) - 20) - 10 = 20 - x ===> True +#testOptimize [ "NatSubSubLeft_9", proof ] ∀ (x : Nat), ((200 - (150 + x)) - 20) - 10 = 20 - x ===> True /-! Test cases for simplification rule `(n - N1) - N2 ===> n - (N1 "+" N2)`. -/ -- (x - 20) - 10 = x - 30 ===> True -#testOptimize [ "NatSubSubRight_1" ] ∀ (x : Nat), (x - 20) - 10 = x - 30 ===> True +#testOptimize [ "NatSubSubRight_1", proof ] ∀ (x : Nat), (x - 20) - 10 = x - 30 ===> True -- (x - 20) - 10 ===> x - 30 def natSubSubRight_2 : Expr := @@ -474,35 +474,37 @@ def natSubSubRight_2 : Expr := elab "natSubSubRight_2" : term => return natSubSubRight_2 -#testOptimize [ "NatSubSubRight_2" ] ∀ (x y : Nat), (x - 20) - 10 < y ===> natSubSubRight_2 +#testOptimize [ "NatSubSubRight_2", proof ] ∀ (x y : Nat), (x - 20) - 10 < y ===> natSubSubRight_2 -- ((x - 100) - 45) - 125 = x - 270 ===> True set_option maxRecDepth 4096 in -#testOptimize [ "NatSubSubRight_3" ] ∀ (x : Nat), ((x - 100) - 45) - 125 = x - 270 ===> True +#testOptimize [ "NatSubSubRight_3", proof ] ∀ (x : Nat), ((x - 100) - 45) - 125 = x - 270 ===> True -- ((200 - x) - 45) - 125 = 30 - x ===> True -#testOptimize [ "NatSubSubRight_4" ] ∀ (x : Nat), ((200 - x) - 45) - 125 = 30 - x ===> True +set_option maxRecDepth 4096 in +#testOptimize [ "NatSubSubRight_4", proof ] ∀ (x : Nat), ((200 - x) - 45) - 125 = 30 - x ===> True -- ((100 - x) - 45) - 125 = 0 ===> True -#testOptimize [ "NatSubSubRight_5" ] ∀ (x : Nat), ((100 - x) - 45) - 125 = 0 ===> True +set_option maxRecDepth 4096 in +#testOptimize [ "NatSubSubRight_5", proof ] ∀ (x : Nat), ((100 - x) - 45) - 125 = 0 ===> True -- ((x - 200) - 45) - ((125 - x) - 130) = x - 245 ===> True set_option maxRecDepth 4096 in -#testOptimize [ "NatSubSubRight_6" ] ∀ (x : Nat), ((x - 200) - 45) - ((125 - x) - 130) = x - 245 ===> True +#testOptimize [ "NatSubSubRight_6", proof ] ∀ (x : Nat), ((x - 200) - 45) - ((125 - x) - 130) = x - 245 ===> True -- (((x - 60) - 40) - 45) - 125 = x - 270 ===> True set_option maxRecDepth 4096 in -#testOptimize [ "NatSubSubRight_7" ] ∀ (x : Nat), (((x - 60) - 40) - 45) - 125 = x - 270 ===> True +#testOptimize [ "NatSubSubRight_7", proof ] ∀ (x : Nat), (((x - 60) - 40) - 45) - 125 = x - 270 ===> True -- (100 - ((x - 100) - 45)) = 100 - (x - 145) ===> True set_option maxRecDepth 4096 in -#testOptimize [ "NatSubSubRight_8" ] ∀ (x : Nat), (100 - ((x - 100) - 45)) = 100 - (x - 145) ===> True +#testOptimize [ "NatSubSubRight_8", proof ] ∀ (x : Nat), (100 - ((x - 100) - 45)) = 100 - (x - 145) ===> True /-! Test cases for simplification rule `(N1 + n) - N2 ===> (N1 "-" N2) + n` (if N1 ≥ N2). -/ -- (100 + x) - 20 = 80 + x -#testOptimize [ "NatAddSub_1" ] ∀ (x : Nat), (100 + x) - 20 = 80 + x ===> True +#testOptimize [ "NatAddSub_1", proof ] ∀ (x : Nat), (100 + x) - 20 = 80 + x ===> True -- (100 + x) - 20 ===> 80 + x def natAddSub_2 : Expr := @@ -524,26 +526,26 @@ def natAddSub_2 : Expr := elab "natAddSub_2" : term => return natAddSub_2 -#testOptimize [ "NatAddSub_2" ] ∀ (x y : Nat), (100 + x) - 20 < y ===> natAddSub_2 +#testOptimize [ "NatAddSub_2", proof ] ∀ (x y : Nat), (100 + x) - 20 < y ===> natAddSub_2 -- (120 + x) - 120 = x ===> True -#testOptimize [ "NatAddSub_3" ] ∀ (x : Nat), (120 + x) - 120 = x ===> True +#testOptimize [ "NatAddSub_3", proof ] ∀ (x : Nat), (120 + x) - 120 = x ===> True -- ((200 + x) - 120) - 20 = 60 + x ===> True -#testOptimize [ "NatAddSub_4" ] ∀ (x : Nat), ((200 + x) - 120) - 20 = 60 + x ===> True +#testOptimize [ "NatAddSub_4", proof ] ∀ (x : Nat), ((200 + x) - 120) - 20 = 60 + x ===> True -- (50 + (100 + x)) - 120 = 30 + x ===> True -#testOptimize [ "NatAddSub_5" ] ∀ (x : Nat), (50 + (100 + x)) - 120 = 30 + x ===> True +#testOptimize [ "NatAddSub_5", proof ] ∀ (x : Nat), (50 + (100 + x)) - 120 = 30 + x ===> True -- (50 + (40 + (x + 60))) - 120 = 30 + x ===> True -#testOptimize [ "NatAddSub_6" ] ∀ (x : Nat), (50 + (40 + (x + 60))) - 120 = 30 + x ===> True +#testOptimize [ "NatAddSub_6", proof ] ∀ (x : Nat), (50 + (40 + (x + 60))) - 120 = 30 + x ===> True -- (((230 + x) - 20) - 120) - 40 = 50 + x ===> True set_option maxRecDepth 4096 in -#testOptimize [ "NatAddSub_7" ] ∀ (x : Nat), (((230 + x) - 20) - 120) - 40 = 50 + x ===> True +#testOptimize [ "NatAddSub_7", proof ] ∀ (x : Nat), (((230 + x) - 20) - 120) - 40 = 50 + x ===> True -- (((x + 180) - 100) - 20) + 120 = 180 + x ===> True -#testOptimize [ "NatAddSub_8" ] ∀ (x : Nat), (((x + 180) - 100) - 20) + 120 = 180 + x ===> True +#testOptimize [ "NatAddSub_8", proof ] ∀ (x : Nat), (((x + 180) - 100) - 20) + 120 = 180 + x ===> True /-! Test cases to ensure that the following simplification rules are not applied wrongly: - `(N1 - n) - N2 ===> (N1 "-" N2) - n` @@ -575,7 +577,7 @@ def natSubSubUnchanged_1 : Expr := elab "natSubSubUnchanged_1" : term => return natSubSubUnchanged_1 -#testOptimize [ "NatSubSunUnchanged_1" ] ∀ (x y : Nat), (x - y) - 120 < y ===> natSubSubUnchanged_1 +#testOptimize [ "NatSubSunUnchanged_1", proof ] ∀ (x y : Nat), (x - y) - 120 < y ===> natSubSubUnchanged_1 -- (x + y) - 120 ===> (x + y) - 120 -- Must remain unchanged @@ -601,7 +603,7 @@ def natSubSubUnchanged_2 : Expr := elab "natSubSubUnchanged_2" : term => return natSubSubUnchanged_2 -#testOptimize [ "NatSubSunUnchanged_2" ] ∀ (x y : Nat), (x + y) - 120 < y ===> natSubSubUnchanged_2 +#testOptimize [ "NatSubSunUnchanged_2", proof ] ∀ (x y : Nat), (x + y) - 120 < y ===> natSubSubUnchanged_2 -- (x * y) - 120 ===> (x * y) - 120 -- Must remain unchanged @@ -627,7 +629,7 @@ def natSubSubUnchanged_3 : Expr := elab "natSubSubUnchanged_3" : term => return natSubSubUnchanged_3 -#testOptimize [ "NatSubSunUnchanged_3" ] ∀ (x y : Nat), (x * y) - 120 < y ===> natSubSubUnchanged_3 +#testOptimize [ "NatSubSunUnchanged_3", proof ] ∀ (x y : Nat), (x * y) - 120 < y ===> natSubSubUnchanged_3 -- 100 - (10 - x) ===> 100 - (10 - x) -- Must remain unchanged @@ -653,7 +655,7 @@ def natSubSubUnchanged_4 : Expr := elab "natSubSubUnchanged_4" : term => return natSubSubUnchanged_4 -#testOptimize [ "NatSubSunUnchanged_4" ] ∀ (x y : Nat), 100 - (10 - x) < y ===> natSubSubUnchanged_4 +#testOptimize [ "NatSubSunUnchanged_4", proof ] ∀ (x y : Nat), 100 - (10 - x) < y ===> natSubSubUnchanged_4 -- (100 + x) - 101 ===> (100 + x) - 101 @@ -681,7 +683,7 @@ def natSubSubUnchanged_5 : Expr := elab "natSubSubUnchanged_5" : term => return natSubSubUnchanged_5 -#testOptimize [ "NatSubSunUnchanged_5" ] ∀ (x y : Nat), (100 + x) - 101 < y ===> natSubSubUnchanged_5 +#testOptimize [ "NatSubSunUnchanged_5", proof ] ∀ (x y : Nat), (100 + x) - 101 < y ===> natSubSubUnchanged_5 -- (100 + x) - 180 ===> (100 + x) - 180 -- Must remain unchanged @@ -708,28 +710,28 @@ def natSubSubUnchanged_6 : Expr := elab "natSubSubUnchanged_6" : term => return natSubSubUnchanged_6 -#testOptimize [ "NatSubSunUnchanged_6" ] ∀ (x y : Nat), (100 + x) - 180 < y ===> natSubSubUnchanged_6 +#testOptimize [ "NatSubSunUnchanged_6", proof ] ∀ (x y : Nat), (100 + x) - 180 < y ===> natSubSubUnchanged_6 /-! Test cases to ensure that `Nat.sub` is preserved when expected and is not a commutative operator. -/ -- x - y ===> x - y -#testOptimize [ "NatSubVar_1" ] ∀ (x y z : Nat), x - y < z ===> ∀ (x y z : Nat), Nat.sub x y < z +#testOptimize [ "NatSubVar_1", proof ] ∀ (x y z : Nat), x - y < z ===> ∀ (x y z : Nat), Nat.sub x y < z -- x - y = x - y ===> True -#testOptimize [ "NatSubVar_2" ] ∀ (x y : Nat), x - y = x - y ===> True +#testOptimize [ "NatSubVar_2", proof ] ∀ (x y : Nat), x - y = x - y ===> True -- x - (y - 0) ===> x - y -#testOptimize [ "NatSubVar_3" ] ∀ (x y z : Nat), x - (y - 0) < z ===> ∀ (x y z : Nat), Nat.sub x y < z +#testOptimize [ "NatSubVar_3", proof ] ∀ (x y z : Nat), x - (y - 0) < z ===> ∀ (x y z : Nat), Nat.sub x y < z -- (x - 0) - y ===> x - y -#testOptimize [ "NatSubVar_4" ] ∀ (x y z : Nat), (x - 0) - y < z ===> ∀ (x y z : Nat), Nat.sub x y < z +#testOptimize [ "NatSubVar_4", proof ] ∀ (x y z : Nat), (x - 0) - y < z ===> ∀ (x y z : Nat), Nat.sub x y < z -- (x - 0) - (y - 0) = x - y ===> True -#testOptimize [ "NatSubVar_5" ] ∀ (x y : Nat), (x - 0) - (y - 0) = x - y ===> True +#testOptimize [ "NatSubVar_5", proof ] ∀ (x y : Nat), (x - 0) - (y - 0) = x - y ===> True -- (x - 10) - y = (x - 10) - y ===> True -#testOptimize [ "NatSubVar_6" ] ∀ (x y : Nat), (x - 10) - y = (x - 10) - y ===> True +#testOptimize [ "NatSubVar_6", proof ] ∀ (x y : Nat), (x - 10) - y = (x - 10) - y ===> True -- (x - 10) - y ===> (x - 10) - y def natSubVar_7 : Expr := @@ -758,7 +760,7 @@ def natSubVar_7 : Expr := elab "natSubVar_7" : term => return natSubVar_7 -#testOptimize [ "NatSubVar_7" ] ∀ (x y z : Nat), (x - 10) - y < z ===> natSubVar_7 +#testOptimize [ "NatSubVar_7", proof ] ∀ (x y z : Nat), (x - 10) - y < z ===> natSubVar_7 /-! Test cases to ensure that constant propagation is properly performed @@ -772,12 +774,14 @@ variable (y : Nat) def natSubReduce_1 : Expr := Lean.Expr.lit (Lean.Literal.natVal 100) elab "natSubReduce_1" : term => return natSubReduce_1 -#testOptimize [ "NatSubReduce_1" ] (100 + ((180 - (x + 40)) - 150)) - ((200 - y) - 320) ===> natSubReduce_1 +set_option maxRecDepth 4096 in +#testOptimize [ "NatSubReduce_1", proof ] (100 + ((180 - (x + 40)) - 150)) - ((200 - y) - 320) ===> natSubReduce_1 def natSubReduce_2 : Expr := Lean.Expr.lit (Lean.Literal.natVal 76) elab "natSubReduce_2" : term => return natSubReduce_2 -- (100 + ((180 - (x + 40)) - 150)) - (((20 - y) - 50) + 24) ===> 76 -#testOptimize [ "NatSubReduce_2" ] (100 + ((180 - (x + 40)) - 150)) - (((20 - y) - 50) + 24) ===> natSubReduce_2 +set_option maxRecDepth 4096 in +#testOptimize [ "NatSubReduce_2", proof ] (100 + ((180 - (x + 40)) - 150)) - (((20 - y) - 50) + 24) ===> natSubReduce_2 end Test.OptimizeNatSub diff --git a/Tests/Reconstruct/Basic.lean b/Tests/Reconstruct/Basic.lean index 644d8cb..f917868 100644 --- a/Tests/Reconstruct/Basic.lean +++ b/Tests/Reconstruct/Basic.lean @@ -50,3 +50,9 @@ example : ∀ (m n : Nat), 0 + (m + n) = n + m := by blaster example : ∀ (m n : Nat), (m + n) + 0 = n + m := by blaster example : ∀ (m n : Nat), 1 * (m + n) = n + m := by blaster example : ∀ (m n : Nat), (m + n) - 0 = n + m := by blaster + +-- Multi-arg rewrites +example : ∀ {x y : Nat}, (x + 0) + (y + 0) = x + y := by blaster +example : ∀ {x y : Nat}, (0 + x) + (0 + y) = x + y := by blaster +example : ∀ {x y : Nat}, (x + 0) + (y + 0) = y + x := by blaster +example : ∀ {x y : Nat}, (0 + x) + (0 + y) = y + x := by blaster From e2eb89b1267acc5d0151ec011e0fcb9d0adfd3ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20P=C3=A9ret?= <40478095+felipeperet@users.noreply.github.com> Date: Thu, 19 Mar 2026 12:34:35 -0300 Subject: [PATCH 31/31] documentation --- Blaster/Optimize/Rewriting/OptimizeNat.lean | 25 +++---- .../Optimize/OptimizeNat/OptimizeNatMul.lean | 66 +++++++++---------- Tests/Reconstruct/Basic.lean | 8 ++- 3 files changed, 52 insertions(+), 47 deletions(-) diff --git a/Blaster/Optimize/Rewriting/OptimizeNat.lean b/Blaster/Optimize/Rewriting/OptimizeNat.lean index 95af24e..cfe404d 100644 --- a/Blaster/Optimize/Rewriting/OptimizeNat.lean +++ b/Blaster/Optimize/Rewriting/OptimizeNat.lean @@ -9,9 +9,9 @@ open Lean Meta namespace Blaster.Optimize /-- Apply the following simplification/normalization rules on `Nat.add` : - - 0 + n ==> n [proof: Nat.zero_add] + - 0 + n ==> n [proof: Nat.zero_add] - N1 + N2 ===> N1 "+" N2 - - N1 + (N2 + n) ==> (N1 "+" N2) + n + - N1 + (N2 + n) ==> (N1 "+" N2) + n [proof: Eq.symm (Nat.add_assoc N1 N2 n)] - n1 + n2 ==> n2 + n1 (if n2 <ₒ n1) Assume that f = Expr.const ``Nat.add. An error is triggered when args.size ≠ 2 (i.e., only fully applied `Nat.add` expected at this stage) @@ -51,12 +51,12 @@ def optimizeNatAdd (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult - 0 - n ==> 0 [proof: Nat.zero_sub] - n - 0 ==> n [proof: Nat.sub_zero] - N1 - N2 ==> N1 "-" N2 - - N1 - (N2 + n) ==> (N1 "-" N2) - n - - (N1 - n) - N2 ==> (N1 "-" N2) - n - - (n - N1) - N2 ==> n - (N1 "+" N2) - - (N1 + n) - N2 ==> (N1 "-" N2) + n (if N1 ≥ N2) - Assume that f = Expr.const ``Nat.sub. - An error is triggered when args.size ≠ 2 (i.e., only fully applied `Nat.sub` expected at this stage) + - N1 - (N2 + n) ==> (N1 "-" N2) - n [proof: Nat.sub_add_eq N1 N2 n] + - (N1 - n) - N2 ==> (N1 "-" N2) - n [proof: Nat.sub_right_comm N1 n N2] + - (n - N1) - N2 ==> n - (N1 "+" N2) [proof: Nat.sub_sub n N1 N2] + - (N1 + n) - N2 ==> (N1 "-" N2) + n (if N1 ≥ N2) [proof: congrArg (· - N2) (add_comm N1 n) + |> Eq.trans · (Nat.add_sub_assoc hLE n) + |> Eq.trans · (Nat.add_comm n (N1-N2))] -/ def optimizeNatSub (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult := do if args.size != 2 then throwEnvError "optimizeNatSub: exactly two arguments expected" @@ -104,7 +104,8 @@ def optimizeNatSub (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult match toNatCstOpExpr? op1 with | some (NatCstOpInfo.NatSubLeftExpr n1 e1) => let expr := mkApp2 f (← evalBinNatOp Nat.sub n1 n2) e1 - let proof := mkApp3 (mkConst ``Nat.sub_right_comm) (mkRawNatLit n1) e1 (mkRawNatLit n2) + let proof := + mkApp3 (mkConst ``Nat.sub_right_comm) (mkRawNatLit n1) e1 (mkRawNatLit n2) setRestart return some ⟨expr, some proof⟩ | some (NatCstOpInfo.NatSubRightExpr e1 n1) => @@ -122,7 +123,9 @@ def optimizeNatSub (f : Expr) (args : Array Expr) : TranslateEnvT OptimizeResult let leType ← mkAppM ``LE.le #[n2Lit, n1Lit] let hLE ← mkDecideProof leType let comm := mkApp2 (mkConst ``Nat.add_comm) n1Lit e1 - let subFn := mkLambda `x .default (mkConst ``Nat) (mkApp2 (mkConst ``Nat.sub) (mkBVar 0) n2Lit) + let subFn := + mkLambda + `x .default (mkConst ``Nat) (mkApp2 (mkConst ``Nat.sub) (mkBVar 0) n2Lit) let step1 ← mkCongrArg subFn comm let step2 ← mkAppM ``Nat.add_sub_assoc #[hLE, e1] let step3 := mkApp2 (mkConst ``Nat.add_comm) e1 n1SubN2 @@ -163,7 +166,7 @@ def optimizeNatPow (f : Expr) (args : Array Expr) : TranslateEnvT Expr := do /-- Apply the following simplification/normalization rules on `Nat.mul` : - 0 * n ==> 0 [proof: Nat.zero_mul] - 1 * n ==> n [proof: Nat.one_mul] - - N1 + N2 ==> N1 "*" N2 + - N1 * N2 ==> N1 "*" N2 - N1 * (N2 * n) ==> (N1 "*" N2) * n - n1 * n2 ==> n2 * n1 (if n2 <ₒ n1) - n * n^m ===> n ^ (m + 1) diff --git a/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean b/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean index 52be59a..b051e9a 100644 --- a/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean +++ b/Tests/Optimize/OptimizeNat/OptimizeNatMul.lean @@ -13,38 +13,38 @@ namespace Test.OptimizeNatMul def natMulCst_1 : Expr := Lean.Expr.lit (Lean.Literal.natVal 0) elab "natMulCst_1" : term => return natMulCst_1 -#testOptimize [ "NatMulCst_1" ] (0 : Nat) * 432 ===> natMulCst_1 +#testOptimize [ "NatMulCst_1", proof ] (0 : Nat) * 432 ===> natMulCst_1 -- 432 * 0 ===> 0 -#testOptimize [ "NatMulCst_2" ] 432 * (0 : Nat) ===> natMulCst_1 +#testOptimize [ "NatMulCst_2", proof ] 432 * (0 : Nat) ===> natMulCst_1 def natMulCst_3 : Expr := Lean.Expr.lit (Lean.Literal.natVal 432) elab "natMulCst_3" : term => return natMulCst_3 -- 432 * 1 ===> 32 -#testOptimize [ "NatMulCst_3" ] 432 * (1 : Nat) ===> natMulCst_3 +#testOptimize [ "NatMulCst_3", proof ] 432 * (1 : Nat) ===> natMulCst_3 -- 1 * 432 ===> 32 -#testOptimize [ "NatMulCst_4" ] 1 * (432 : Nat) ===> natMulCst_3 +#testOptimize [ "NatMulCst_4", proof ] 1 * (432 : Nat) ===> natMulCst_3 -- 34 * 432 ===> 14688 def natMulCst_5 : Expr := Lean.Expr.lit (Lean.Literal.natVal 14688) elab "natMulCst_5" : term => return natMulCst_5 -#testOptimize [ "NatMulCst_5" ] (34 : Nat) * 432 ===> natMulCst_5 +#testOptimize [ "NatMulCst_5", proof ] (34 : Nat) * 432 ===> natMulCst_5 /-! Test cases for simplification rule `0 * n ==> 0`. -/ variable (x : Nat) -- x * 0 ===> 0 -#testOptimize [ "NatMulZero_1" ] x * 0 ===> natMulCst_1 +#testOptimize [ "NatMulZero_1", proof ] x * 0 ===> natMulCst_1 -- 0 * x ===> 0 -#testOptimize [ "NatMulZero_2" ] 0 * x ===> natMulCst_1 +#testOptimize [ "NatMulZero_2", proof ] 0 * x ===> natMulCst_1 -- 0 * x = 0 ===> True -#testOptimize [ "NatMulZero_3" ] ∀ (x : Nat), 0 * x = 0 ===> True +#testOptimize [ "NatMulZero_3", proof ] ∀ (x : Nat), 0 * x = 0 ===> True -- x * Nat.zero ===> 0 #testOptimize [ "NatMulZero_4" ] ∀ (x y : Nat), x * Nat.zero ≤ y ===> True @@ -53,7 +53,7 @@ variable (x : Nat) #testOptimize [ "NatMulZero_5" ] ∀ (x y : Nat), Nat.zero * x ≤ y ===> True -- Nat.zero * x = 0 ===> True -#testOptimize [ "NatMulZero_6" ] ∀ (x : Nat), Nat.zero * x = 0 ===> True +#testOptimize [ "NatMulZero_6", proof ] ∀ (x : Nat), Nat.zero * x = 0 ===> True -- (10 - 10) * x ===> 0 #testOptimize [ "NatMulZero_7" ] ∀ (x y : Nat), (10 - 10) * x ≤ y ===> True @@ -62,22 +62,22 @@ variable (x : Nat) #testOptimize [ "NatMulZero_8" ] ∀ (x y : Nat), x * (10 - 123) ≤ y ===> True -- x * (y - y) = 0 ===> True -#testOptimize [ "NatMulZero_9" ] ∀ (x y : Nat), x * (y - y) = 0 ===> True +#testOptimize [ "NatMulZero_9", proof ] ∀ (x y : Nat), x * (y - y) = 0 ===> True /-! Test cases for simplification rule `1 * n ==> n`. -/ -- 1 * n ===> n -#testOptimize [ "NatMulOne_1" ] ∀ (x y : Nat), x * 1 ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatMulOne_1", proof ] ∀ (x y : Nat), x * 1 ≤ y ===> ∀ (x y : Nat), ¬ y < x -- n * 1 ===> n -#testOptimize [ "NatMulOne_2" ] ∀ (x y : Nat), 1 * x ≤ y ===> ∀ (x y : Nat), ¬ y < x +#testOptimize [ "NatMulOne_2", proof ] ∀ (x y : Nat), 1 * x ≤ y ===> ∀ (x y : Nat), ¬ y < x -- 1 * x = x ===> True -#testOptimize [ "NatMulOne_3" ] ∀ (x : Nat), 1 * x = x ===> True +#testOptimize [ "NatMulOne_3", proof ] ∀ (x : Nat), 1 * x = x ===> True -- (10 - 9) * x ===> x -#testOptimize [ "NatMulOne_4" ] ∀ (x y : Nat), (10 - 9) * x < y ===> ∀ (x y : Nat), x < y +#testOptimize [ "NatMulOne_4", proof ] ∀ (x y : Nat), (10 - 9) * x < y ===> ∀ (x y : Nat), x < y -- ((((Nat.succ y) - 1) - y) + 1) * x ===> x #testOptimize [ "NatMulOne_5" ] ∀ (x y z : Nat), ((((Nat.succ y) - 1) - y) + 1) * x < z ===> @@ -112,14 +112,14 @@ def natMulCstUnchanged_1 : Expr := (Lean.BinderInfo.default) elab "natMulCstUnchanged_1" : term => return natMulCstUnchanged_1 -#testOptimize [ "NatMulCstUnchanged_1" ] ∀ (x y : Nat), x * 10 ≤ y ===> natMulCstUnchanged_1 +#testOptimize [ "NatMulCstUnchanged_1", proof ] ∀ (x y : Nat), x * 10 ≤ y ===> natMulCstUnchanged_1 -- (100 - 90) * x ===> 10 * x -- TODO: remove unused quantifier when COI performed on forall -#testOptimize [ "NatMulCstUnchanged_2" ] ∀ (x y : Nat), (100 - 90) * x ≤ y ===> natMulCstUnchanged_1 +#testOptimize [ "NatMulCstUnchanged_2", proof ] ∀ (x y : Nat), (100 - 90) * x ≤ y ===> natMulCstUnchanged_1 -- x * (y - z) ===> Nat.mul x (Nat.sub y z) -#testOptimize [ "NatMulCstUnchanged_3" ] ∀ (x y z m : Nat), x * (y - z) < m ===> +#testOptimize [ "NatMulCstUnchanged_3", proof ] ∀ (x y z m : Nat), x * (y - z) < m ===> ∀ (x y z m : Nat), Nat.mul x (Nat.sub y z) < m @@ -179,7 +179,7 @@ elab "natAddCstProp_5" : term => return natAddCstProp_5 -- 10 * (20 * (100 - (x + 190))) = 0 ===> True set_option maxRecDepth 4096 in -#testOptimize [ "NatMulCstProp_12" ] ∀ (x : Nat), 10 * (20 * (100 - (x + 190))) = 0 ===> True +#testOptimize [ "NatMulCstProp_12", proof ] ∀ (x : Nat), 10 * (20 * (100 - (x + 190))) = 0 ===> True /-! Test cases to ensure that simplification rule `N1 * (N2 * n) ===> (N1 "*" N2) * n` is not @@ -210,7 +210,7 @@ def natMulCstPropUnchanged_1 : Expr := elab "natMulCstPropUnchanged_1" : term => return natMulCstPropUnchanged_1 -#testOptimize [ "NatMulCstPropUnchanged_1" ] ∀ (x y : Nat), 40 * (x * y) < y ===> natMulCstPropUnchanged_1 +#testOptimize [ "NatMulCstPropUnchanged_1", proof ] ∀ (x y : Nat), 40 * (x * y) < y ===> natMulCstPropUnchanged_1 -- 40 * (x - y) ===> 40 * (x - y) @@ -237,7 +237,7 @@ def natMulCstPropUnchanged_2 : Expr := elab "natMulCstPropUnchanged_2" : term => return natMulCstPropUnchanged_2 -#testOptimize [ "NatMulCstPropUnchanged_2" ] ∀ (x y : Nat), 40 * (x - y) < y ===> natMulCstPropUnchanged_2 +#testOptimize [ "NatMulCstPropUnchanged_2", proof ] ∀ (x y : Nat), 40 * (x - y) < y ===> natMulCstPropUnchanged_2 -- 40 * (x + y) ===> 40 * (x + y) @@ -264,23 +264,23 @@ def natMulCstPropUnchanged_3 : Expr := elab "natMulCstPropUnchanged_3" : term => return natMulCstPropUnchanged_3 -#testOptimize [ "NatMulCstPropUnchanged_3" ] ∀ (x y : Nat), 40 * (x + y) < y ===> natMulCstPropUnchanged_3 +#testOptimize [ "NatMulCstPropUnchanged_3", proof ] ∀ (x y : Nat), 40 * (x + y) < y ===> natMulCstPropUnchanged_3 /-! Test cases for normalization rule `n1 * n2 ==> n2 * n1 (if n2 <ₒ n1)`. -/ -- x * y = x * y ===> True -#testOptimize [ "NatMulCommut_1" ] ∀ (x y : Nat), x * y = x * y ===> True +#testOptimize [ "NatMulCommut_1", proof ] ∀ (x y : Nat), x * y = x * y ===> True -- x * y = y * x ===> True -#testOptimize [ "NatMulCommut_2" ] ∀ (x y : Nat), x * y = y * x ===> True +#testOptimize [ "NatMulCommut_2", proof ] ∀ (x y : Nat), x * y = y * x ===> True -- x * 10 = 10 * x ===> True -#testOptimize [ "NatMulCommut_3" ] ∀ (x : Nat), x * 10 = 10 * x ===> True +#testOptimize [ "NatMulCommut_3", proof ] ∀ (x : Nat), x * 10 = 10 * x ===> True -- y * x ===> x * y (with `x` declared first) -#testOptimize [ "NatMulCommut_4" ] ∀ (x y z : Nat), z < y * x ===> ∀ (x y z : Nat), z < Nat.mul x y +#testOptimize [ "NatMulCommut_4", proof ] ∀ (x y z : Nat), z < y * x ===> ∀ (x y z : Nat), z < Nat.mul x y -- x * 40 ===> 40 * x def natMulCommut_5 : Expr := @@ -302,29 +302,29 @@ def natMulCommut_5 : Expr := elab "natMulCommut_5" : term => return natMulCommut_5 -#testOptimize [ "NatMulCommut_5" ] ∀ (x y : Nat), y < x * 40 ===> natMulCommut_5 +#testOptimize [ "NatMulCommut_5", proof ] ∀ (x y : Nat), y < x * 40 ===> natMulCommut_5 -- (x * (y * 20)) * z = z * ((y * 20) * x) ===> True #testOptimize [ "NatMulCommut_6" ] ∀ (x y z : Nat), (x * (y * 20)) * z = z * ((y * 20) * x) ===> True --- (x - y) * (p + q) ===> (p + q) * (x - y) -#testOptimize [ "NatMulCommut_7" ] ∀ (x y z p q : Nat), (x - y) * (p + q) < z ===> +#testOptimize [ "NatMulCommut_7", proof ] ∀ (x y z p q : Nat), (x - y) * (p + q) < z ===> ∀ (x y z p q : Nat), Nat.mul (Nat.add p q) (Nat.sub x y) < z --- (x - y) * (p + q) = (p + q) * (x - y) ===> True -#testOptimize [ "NatMulCommut_8" ] ∀ (x y p q : Nat), (x - y) * (p + q) = (p + q) * (x - y) ===> True +#testOptimize [ "NatMulCommut_8", proof ] ∀ (x y p q : Nat), (x - y) * (p + q) = (p + q) * (x - y) ===> True /-! Test cases to ensure that `Nat.mul` is preserved when expected. -/ -- x * (y * 1) = x * y ===> True -#testOptimize [ "NatMulVar_1" ] ∀ (x y : Nat), x * (y * 1) = x * y ===> True +#testOptimize [ "NatMulVar_1", proof ] ∀ (x y : Nat), x * (y * 1) = x * y ===> True -- (x * 1) * y = x * y ===> True -#testOptimize [ "NatMulVar_2" ] ∀ (x y : Nat), (x * 1) * y = x * y ===> True +#testOptimize [ "NatMulVar_2", proof ] ∀ (x y : Nat), (x * 1) * y = x * y ===> True -- (x * 1) * (y * 1) = y * x ===> True -#testOptimize [ "NatMulVar_3" ] ∀ (x y : Nat), (x * 1) * (y * 1) = y * x ===> True +#testOptimize [ "NatMulVar_3", proof ] ∀ (x y : Nat), (x * 1) * (y * 1) = y * x ===> True -- x * y < 10 ===> x * y < 10 def natAddVar_4 : Expr := @@ -346,7 +346,7 @@ def natAddVar_4 : Expr := elab "natAddVar_4" : term => return natAddVar_4 -#testOptimize [ "NatMulVar_4" ] ∀ (x y : Nat), x * y < 10 ===> natAddVar_4 +#testOptimize [ "NatMulVar_4", proof ] ∀ (x y : Nat), x * y < 10 ===> natAddVar_4 /-! Test cases to ensure that constant propagation is properly performed @@ -358,7 +358,7 @@ variable (y : Nat) -- (100 * (30 - ((180 - (x * 1)) - 150))) * ((320 - (y + 400)) - y) ===> 0 set_option maxRecDepth 4096 in -#testOptimize [ "NatMulReduce_1" ] (100 * (30 - ((180 - (x * 1)) - 150))) * ((320 - (y + 400)) - y) ===> natMulCst_1 +#testOptimize [ "NatMulReduce_1", proof ] (100 * (30 - ((180 - (x * 1)) - 150))) * ((320 - (y + 400)) - y) ===> natMulCst_1 -- (100 * (((180 - (x * 40)) - 150) - 30)) * ((((20 - y) - 50) * 24) + 1) ===> 100 def natMulReduce_2 : Expr := Lean.Expr.lit (Lean.Literal.natVal 100) diff --git a/Tests/Reconstruct/Basic.lean b/Tests/Reconstruct/Basic.lean index f917868..adf06a2 100644 --- a/Tests/Reconstruct/Basic.lean +++ b/Tests/Reconstruct/Basic.lean @@ -25,8 +25,10 @@ example : 1 + 2 = 3 := by blaster example : 2 * 3 = 6 := by blaster example : (2 * 3) + 1 = 7 := by blaster --- Constant propagation: N1 + (N2 + n) → (N1 + N2) + n +-- Constant propagation example : ∀ (x : Nat), 10 + (20 + x) = 30 + x := by blaster +example : ∀ (x : Nat), 120 - (40 + x) = 80 - x := by blaster +example : ∀ (x : Nat), 120 - (x + 40) = 80 - x := by blaster -- Nat.add commutativity example : ∀ (m n : Nat), m + n = n + m := by blaster @@ -52,7 +54,7 @@ example : ∀ (m n : Nat), 1 * (m + n) = n + m := by blaster example : ∀ (m n : Nat), (m + n) - 0 = n + m := by blaster -- Multi-arg rewrites -example : ∀ {x y : Nat}, (x + 0) + (y + 0) = x + y := by blaster +example : ∀ {x y : Nat}, (x + 0) + (y - 0) = x + y := by blaster example : ∀ {x y : Nat}, (0 + x) + (0 + y) = x + y := by blaster -example : ∀ {x y : Nat}, (x + 0) + (y + 0) = y + x := by blaster +example : ∀ {x y : Nat}, (x - 0) + (y + 0) = y + x := by blaster example : ∀ {x y : Nat}, (0 + x) + (0 + y) = y + x := by blaster