Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions Duper/RuleM.lean
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ register_option includeDatatypeRules : Bool := {
descr := "Whether to include datatype rules (distinctness, injectivity, and acyclicity)"
}

register_option includeUnsafeAcyclicity : Bool := {
defValue := false
descr := "Whether to include the unsafe datatype acyclicity rule"
}

def getInhabitationReasoning (opts : Options) : Bool :=
inhabitationReasoning.get opts

Expand All @@ -58,9 +53,6 @@ def getIncludeExpensiveRules (opts : Options) : Bool :=
def getIncludeDatatypeRules (opts : Options) : Bool :=
includeDatatypeRules.get opts

def getIncludeUnsafeAcyclicity (opts : Options) : Bool :=
includeUnsafeAcyclicity.get opts

def getInhabitationReasoningM : CoreM Bool := do
let opts ← getOptions
return getInhabitationReasoning opts
Expand All @@ -81,10 +73,6 @@ def getIncludeDatatypeRulesM : CoreM Bool := do
let opts ← getOptions
return getIncludeDatatypeRules opts

def getIncludeUnsafeAcyclicityM : CoreM Bool := do
let opts ← getOptions
return getIncludeUnsafeAcyclicity opts

structure Context where
order : Expr → Expr → Bool → MetaM Comparison
symbolPrecMap : SymbolPrecMap
Expand Down
108 changes: 105 additions & 3 deletions Duper/Rules/DatatypeAcyclicity.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,44 @@ open LitSide

initialize Lean.registerTraceClass `duper.rule.datatypeAcyclicity

theorem one_add_ge (n : Nat) : 1 + n > n := by grind

def addAllRight (head : Expr) (xs : Array Expr) : MetaM Expr := do
xs.foldlM (init := head) fun acc lit => do mkAppM ``Nat.lt_add_right #[← mkAppM ``sizeOf #[lit], acc]

def buildLeftSum (head : Expr) (xs : Array Expr) : MetaM Expr :=
xs.foldlM (init := head) fun acc lit => mkAppM ``HAdd.hAdd #[acc, lit]

def liftAddRight (eq s : Expr) : MetaM Expr := do
let f ← withLocalDeclD `t (mkConst ``Nat) fun t => do
mkLambdaFVars #[t] (← mkAppM ``HAdd.hAdd #[t, s])
mkAppM ``congrArg #[f, eq]

def flattenAddRight (head : Expr) (xs : Array Expr) : MetaM Expr := do
let n := xs.size
if n ≤ 1 then
return ← mkAppM ``Eq.refl #[← buildLeftSum head xs]
let innerSum ← buildLeftSum xs[0]! (xs.extract 1 n)
let lhs ← mkAppM ``HAdd.hAdd #[head, innerSum]
let mut eq ← mkAppM ``Eq.refl #[lhs]
for k in [:n - 1] do
let stillIn := xs.extract 0 (n - 1 - k)
let remaining ← buildLeftSum xs[0]! (stillIn.extract 1 stillIn.size)
let assocSymm ← mkAppM ``Eq.symm #[← mkAppM ``Nat.add_assoc #[head, remaining, xs[n - 1 - k]!]]
let alreadyOut := xs.extract (n - k) n
let lifted ← alreadyOut.foldlM (init := assocSymm) liftAddRight
eq ← mkAppM ``Eq.trans #[eq, lifted]
return eq

partial def bubbleToLeft (spec : Expr) (sizes : Array Expr) (idx : Nat) : MetaM Expr := do
if idx = 0 then return spec
let prefixSum ← buildLeftSum (mkNatLit 1) (sizes.extract 0 (idx - 1))
let comm ← mkAppM ``Nat.add_right_comm #[prefixSum, sizes[idx - 1]!, sizes[idx]!]
let suffix := sizes.extract (idx + 1) sizes.size
let lifted ← suffix.foldlM (init := comm) liftAddRight
let newSpec ← mkAppM ``Eq.trans #[spec, lifted]
bubbleToLeft newSpec (sizes.swapIfInBounds (idx - 1) idx) (idx - 1)

/-- Produces a list of (possibly duplicate) constructor subterms for `e` -/
partial def collectConstructorSubterms (e : Expr) : MetaM (Array Expr) := do
let isConstructor ← matchConstCtor e.getAppFn' (fun _ => pure false) (fun _ _ => pure true)
Expand All @@ -22,6 +60,61 @@ partial def collectConstructorSubterms (e : Expr) : MetaM (Array Expr) := do
else
return #[e]

/-- Builds a proof of `sizeOf lhs > sizeOf rhs` with `rhs` guaranteed to be a subterm of `lhs` -/
partial def buildGtProof (lhs : Expr) (rhs : Expr) : MetaM Expr := do
let ctor := lhs.getAppFn'
let some ctorName := ctor.constName?
| throwError "datatypeAcyclicity: lhs head is not a constant"
let ctorType ← inferType ctor
let explicitLhsArgs ← forallTelescopeReducing ctorType fun binders _ =>
(binders.zip lhs.getAppArgs).filterMapM fun (b, a) => do
if (← b.fvarId!.getDecl).binderInfo.isExplicit then pure (some a)
else pure none
let specName := ctorName ++ `sizeOf_spec
unless (← getEnv).contains specName do throwError "datatypeAcyclicity: no sizeOf_spec for {ctorName}"
let specExpr ← mkConstWithFreshMVarLevels specName
let specImplicitCount ← forallTelescopeReducing (← inferType specExpr) fun binders _ => do
let mut count := 0
for b in binders do
if !(← b.fvarId!.getDecl).binderInfo.isExplicit then
count := count + 1
else
break
pure count
let nones := Array.replicate specImplicitCount none
let specApplied ← mkAppOptM specName (nones ++ explicitLhsArgs.map some)
let sizes ← explicitLhsArgs.mapM fun a => mkAppM ``sizeOf #[a]
match ← explicitLhsArgs.findIdxM? (fun a => isDefEq a rhs) with
| some rhsIdx => -- base case: `rhsIdx` is where `rhs` was found in the direct subterms of `lhs`
let rearranged ← bubbleToLeft specApplied sizes rhsIdx
let oneAddGe ← mkAppM ``one_add_ge #[← mkAppM ``sizeOf #[rhs]]
let lhsArgsExtra := explicitLhsArgs.eraseIdx! rhsIdx
let gtProof ← addAllRight oneAddGe lhsArgsExtra
let rearrangedSymm ← mkAppM ``Eq.symm #[rearranged]
let motive ← withLocalDeclD `y (mkConst ``Nat) fun y => do
mkLambdaFVars #[y] (← mkAppM ``LT.lt #[← mkAppM ``sizeOf #[rhs], y])
mkAppOptM ``Eq.subst #[none, some motive, none, none, some rearrangedSymm, some gtProof]
| none => -- recursive case: `subtermIdx` is a direct subterm of `lhs` with `rhs` as a subterm
let some subtermIdx ← explicitLhsArgs.findIdxM? (fun term => do
let subterms ← collectConstructorSubterms term
subterms.anyM (fun s => isDefEq s rhs))
| throwError "datatypeAcyclicity: rhs {rhs} not found among subterms"
let subterm := explicitLhsArgs[subtermIdx]!
let subProof ← buildGtProof subterm rhs
let lhsArgsExtra := explicitLhsArgs.eraseIdx! subtermIdx
let subProof ← addAllRight subProof lhsArgsExtra
let subProof ← mkAppM ``Nat.lt_add_left #[mkNatLit 1, subProof]
let rearranged ← bubbleToLeft specApplied sizes subtermIdx
let rearrangedSymm ← mkAppM ``Eq.symm #[rearranged]
let restInOrder := (sizes.extract 0 subtermIdx) ++ (sizes.extract (subtermIdx + 1) sizes.size)
let parenInner := #[sizes[subtermIdx]!] ++ restInOrder
let flattenEq ← flattenAddRight (mkNatLit 1) parenInner
let combined ← mkAppM ``Eq.trans #[flattenEq, rearrangedSymm]
let sizeOfRhs ← mkAppM ``sizeOf #[rhs]
let motive ← withLocalDeclD `y (mkConst ``Nat) fun y => do
mkLambdaFVars #[y] (← mkAppM ``LT.lt #[sizeOfRhs, y])
mkAppOptM ``Eq.subst #[none, some motive, none, none, some combined, some subProof]

/-- Returns `none` if `lit` does not compare constructor subterms, and returns `some litside` if `lit.litside`
is a subterm of the constructor it is being compared to. Note that `lit.litside` may not itself be a constructor
(e.g. `xs` is a constructor subterm of `x :: xs`) -/
Expand Down Expand Up @@ -61,10 +154,19 @@ def mkDatatypeAcyclicityProof (removedLitNum : Nat) (litSide : LitSide) (premise
let litTyMVar ← mkFreshExprMVar lit.ty
let abstrLam ← mkLambdaFVars #[litTyMVar] $ ← mkAppOptM ``sizeOf #[some lit.ty, some sizeOfInst, some litTyMVar]
let sizeOfEq ← mkAppM ``congrArg #[abstrLam, h] -- Has the type `sizeOf lit.lhs = sizeOf lit.rhs`
-- Need to generate a term of type `¬(sizeOf lit.lhs = sizeOf lit.rhs)`
let sizeOfEq ←
match litSide with
| lhs => mkAppM ``Eq.symm #[sizeOfEq]
| rhs => pure sizeOfEq
let lit : Lit :=
match litSide with
| lhs => lit.symm
| rhs => lit
let sizeOfEqFalseMVar ← mkFreshExprMVar $ ← mkAppM ``Not #[← inferType sizeOfEq] -- Has the type `¬(sizeOf lit.lhs = sizeOf lit.rhs)`
let sizeOfEqFalseMVarId := sizeOfEqFalseMVar.mvarId!
-- **TODO**: Figure out how to assign `sizeOfEqFalseMVar` an actual term
let gtProof ← buildGtProof lit.lhs lit.rhs
let neProof ← mkAppM ``Nat.ne_of_gt #[gtProof]
sizeOfEqFalseMVarId.assign neProof
let proofCase := mkApp2 (mkConst ``False.elim [levelZero]) body $ mkApp sizeOfEqFalseMVar sizeOfEq -- Has the type `body`
trace[duper.rule.datatypeAcyclicity] "lit: {lit}, lit.ty: {lit.ty}, sizeOfInst: {sizeOfInst}, abstrLam: {abstrLam}, sizeOfEq: {sizeOfEq}"
trace[duper.rule.datatypeAcyclicity] "sizeOfEqFalseMVar: {sizeOfEqFalseMVar}, proofCase: {proofCase}"
Expand All @@ -86,7 +188,7 @@ def datatypeAcyclicity : MSimpRule := fun c => do
| some side =>
if lit.sign then -- `lit` is never true so `lit` can be removed from `c`
let res := c.eraseIdx i
let yC ← yieldClause res "datatypeAcyclicity" none -- $ mkDatatypeAcyclicityProof i side
let yC ← yieldClause res "datatypeAcyclicity" $ mkDatatypeAcyclicityProof i side
trace[duper.rule.datatypeAcyclicity] "datatypeAcyclicity applied to {c.lits} to yield {yC.1}"
return some #[yC]
else -- `lit` is a tautology so the clause `c` can simply be removed
Expand Down
53 changes: 2 additions & 51 deletions Duper/Saturate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ open SimpResult

def forwardSimpRules : ProverM (Array SimpRule) := do
let subsumptionTrie ← getSubsumptionTrie
if (← getIncludeExpensiveRulesM) && (← getIncludeDatatypeRulesM) && (← getIncludeUnsafeAcyclicityM) then
if (← getIncludeExpensiveRulesM) && (← getIncludeDatatypeRulesM) then
return #[
betaEtaReduction.toSimpRule,
clausificationStep.toSimpRule,
Expand All @@ -164,31 +164,6 @@ def forwardSimpRules : ProverM (Array SimpRule) := do
(forwardNegativeSimplifyReflect subsumptionTrie).toSimpRule,
identBoolHoist.toSimpRule -- Higher order rule
]
else if (← getIncludeExpensiveRulesM) && (← getIncludeDatatypeRulesM) && !(← getIncludeUnsafeAcyclicityM) then
return #[
betaEtaReduction.toSimpRule,
clausificationStep.toSimpRule,
syntacticTautologyDeletion1.toSimpRule,
syntacticTautologyDeletion2.toSimpRule,
boolSimp.toSimpRule,
syntacticTautologyDeletion3.toSimpRule,
elimDupLit.toSimpRule,
elimResolvedLit.toSimpRule,
destructiveEqualityResolution.toSimpRule,
identPropFalseElim.toSimpRule,
identBoolFalseElim.toSimpRule,
datatypeDistinctness.toSimpRule, -- Inductive datatype rule
datatypeInjectivity.toSimpRule, -- Inductive datatype rule
-- datatypeAcyclicity.toSimpRule, -- Inductive datatype rule
decElim.toSimpRule,
(forwardDemodulation (← getDemodSidePremiseIdx)).toSimpRule,
(forwardClauseSubsumption subsumptionTrie).toSimpRule,
(forwardEqualitySubsumption subsumptionTrie).toSimpRule,
(forwardContextualLiteralCutting subsumptionTrie).toSimpRule,
(forwardPositiveSimplifyReflect subsumptionTrie).toSimpRule,
(forwardNegativeSimplifyReflect subsumptionTrie).toSimpRule,
identBoolHoist.toSimpRule -- Higher order rule
]
else if (← getIncludeExpensiveRulesM) && !(← getIncludeDatatypeRulesM) then
return #[
betaEtaReduction.toSimpRule,
Expand All @@ -211,7 +186,7 @@ def forwardSimpRules : ProverM (Array SimpRule) := do
(forwardNegativeSimplifyReflect subsumptionTrie).toSimpRule,
identBoolHoist.toSimpRule -- Higher order rule
]
else if !(← getIncludeExpensiveRulesM) && (← getIncludeDatatypeRulesM) && (← getIncludeUnsafeAcyclicityM) then
else if !(← getIncludeExpensiveRulesM) && (← getIncludeDatatypeRulesM) then
return #[
betaEtaReduction.toSimpRule,
clausificationStep.toSimpRule,
Expand All @@ -235,30 +210,6 @@ def forwardSimpRules : ProverM (Array SimpRule) := do
(forwardNegativeSimplifyReflect subsumptionTrie).toSimpRule,
identBoolHoist.toSimpRule -- Higher order rule
]
else if !(← getIncludeExpensiveRulesM) && (← getIncludeDatatypeRulesM) && !(← getIncludeUnsafeAcyclicityM) then
return #[
betaEtaReduction.toSimpRule,
clausificationStep.toSimpRule,
syntacticTautologyDeletion1.toSimpRule,
syntacticTautologyDeletion2.toSimpRule,
boolSimp.toSimpRule,
syntacticTautologyDeletion3.toSimpRule,
elimDupLit.toSimpRule,
elimResolvedLit.toSimpRule,
destructiveEqualityResolution.toSimpRule,
identPropFalseElim.toSimpRule,
identBoolFalseElim.toSimpRule,
datatypeDistinctness.toSimpRule, -- Inductive datatype rule
datatypeInjectivity.toSimpRule, -- Inductive datatype rule
-- datatypeAcyclicity.toSimpRule, -- Inductive datatype rule
(forwardDemodulation (← getDemodSidePremiseIdx)).toSimpRule,
(forwardClauseSubsumption subsumptionTrie).toSimpRule,
(forwardEqualitySubsumption subsumptionTrie).toSimpRule,
(forwardContextualLiteralCutting subsumptionTrie).toSimpRule,
(forwardPositiveSimplifyReflect subsumptionTrie).toSimpRule,
(forwardNegativeSimplifyReflect subsumptionTrie).toSimpRule,
identBoolHoist.toSimpRule -- Higher order rule
]
else -- !(← getIncludeExpensiveRulesM) && !(← getIncludeDatatypeRulesM)
return #[
betaEtaReduction.toSimpRule,
Expand Down
16 changes: 16 additions & 0 deletions Duper/Tests/test_regression.lean
Original file line number Diff line number Diff line change
Expand Up @@ -784,3 +784,19 @@ example (t1 : Type 1) (t2 : Type 2) (x : myType4 t1 t2) :
set_option duper.collectDatatypes true in
example (P : α × β → Prop) (h : ∀ x : α, ∀ y : β, P (x, y)) : ∀ z : α × β, P z := by
duper [*] {portfolioInstance := 7}

example (x : List Nat) : x ≠ 0 :: x := by
duper [*] {portfolioInstance := 7}

example (x : myType3) : const6 x ≠ x := by
duper

inductive TriTree (t : Type _) where
| node : TriTree t → TriTree t → TriTree t → TriTree t
| leaf : t → TriTree t

example (x y z : TriTree Nat) : x ≠ TriTree.node x y z := by
duper [*] {portfolioInstance := 7}

example (x y z : TriTree Nat) : TriTree.node x y z ≠ z := by
duper [*]