diff --git a/liquidjava-verifier/pom.xml b/liquidjava-verifier/pom.xml index 2461306e..af57236d 100644 --- a/liquidjava-verifier/pom.xml +++ b/liquidjava-verifier/pom.xml @@ -11,7 +11,7 @@ io.github.liquid-java liquidjava-verifier - 0.0.4 + 0.0.8 liquidjava-verifier LiquidJava Verifier https://github.com/liquid-java/liquidjava diff --git a/liquidjava-verifier/src/main/java/liquidjava/diagnostics/errors/RefinementError.java b/liquidjava-verifier/src/main/java/liquidjava/diagnostics/errors/RefinementError.java index a2c13c91..93a8a447 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/diagnostics/errors/RefinementError.java +++ b/liquidjava-verifier/src/main/java/liquidjava/diagnostics/errors/RefinementError.java @@ -12,19 +12,18 @@ */ public class RefinementError extends LJError { - private final String expected; + private final ValDerivationNode expected; private final ValDerivationNode found; - public RefinementError(SourcePosition position, Expression expected, ValDerivationNode found, + public RefinementError(SourcePosition position, ValDerivationNode expected, ValDerivationNode found, TranslationTable translationTable) { - super("Refinement Error", - String.format("%s is not a subtype of %s", found.getValue(), expected.toSimplifiedString()), position, - translationTable); - this.expected = expected.toSimplifiedString(); + super("Refinement Error", String.format("%s is not a subtype of %s", found.getValue(), expected.getValue()), + position, translationTable); + this.expected = expected; this.found = found; } - public String getExpected() { + public ValDerivationNode getExpected() { return expected; } diff --git a/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java b/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java index 3df794ac..0c2fba5e 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java +++ b/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java @@ -52,8 +52,8 @@ public void processSubtyping(Predicate expectedType, List list, CtEl } boolean isSubtype = smtChecks(expected, premises, element.getPosition()); if (!isSubtype) - throw new RefinementError(element.getPosition(), expectedType.getExpression(), - premisesBeforeChange.simplify(), map); + throw new RefinementError(element.getPosition(), expectedType.simplify(), premisesBeforeChange.simplify(), + map); } /** @@ -263,7 +263,7 @@ protected void throwRefinementError(SourcePosition position, Predicate expected, gatherVariables(found, lrv, mainVars); TranslationTable map = new TranslationTable(); Predicate premises = joinPredicates(expected, mainVars, lrv, map).toConjunctions(); - throw new RefinementError(position, expected.getExpression(), premises.simplify(), map); + throw new RefinementError(position, expected.simplify(), premises.simplify(), map); } protected void throwStateRefinementError(SourcePosition position, Predicate found, Predicate expected) diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/Expression.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/Expression.java index fac053e6..5449949b 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/Expression.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/Expression.java @@ -271,7 +271,7 @@ public void validateGhostInvocations(Context ctx, Factory f) throws LJError { if (this instanceof FunctionInvocation fi) { // get all ghosts with the matching name List candidates = ctx.getGhosts().stream().filter(g -> g.matches(fi.name)).toList(); - if (candidates.isEmpty()) + if (candidates.isEmpty()) return; // not found error is thrown elsewhere // find matching overload diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantPropagation.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantPropagation.java index 5c74897f..5cc0562e 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantPropagation.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantPropagation.java @@ -10,6 +10,7 @@ import liquidjava.rj_language.opt.derivation_node.ValDerivationNode; import liquidjava.rj_language.opt.derivation_node.VarDerivationNode; +import java.util.HashMap; import java.util.Map; public class ConstantPropagation { @@ -19,23 +20,37 @@ public class ConstantPropagation { * VariableResolver to extract variable equalities from the expression first. Returns a derivation node representing * the propagation steps taken. */ - public static ValDerivationNode propagate(Expression exp) { + public static ValDerivationNode propagate(Expression exp, ValDerivationNode previousOrigin) { Map substitutions = VariableResolver.resolve(exp); - return propagateRecursive(exp, substitutions); + + // map of variable origins from the previous derivation tree + Map varOrigins = new HashMap<>(); + if (previousOrigin != null) { + extractVarOrigins(previousOrigin, varOrigins); + } + return propagateRecursive(exp, substitutions, varOrigins); } /** * Recursively performs constant propagation on an expression (e.g. x + y && x == 1 && y == 2 => 1 + 2) */ - private static ValDerivationNode propagateRecursive(Expression exp, Map subs) { + private static ValDerivationNode propagateRecursive(Expression exp, Map subs, + Map varOrigins) { // substitute variable if (exp instanceof Var var) { String name = var.getName(); Expression value = subs.get(name); // substitution - if (value != null) - return new ValDerivationNode(value.clone(), new VarDerivationNode(name)); + if (value != null) { + // check if this variable has an origin from a previous pass + DerivationNode previousOrigin = varOrigins.get(name); + + // preserve origin if value came from previous derivation + DerivationNode origin = previousOrigin != null ? new VarDerivationNode(name, previousOrigin) + : new VarDerivationNode(name); + return new ValDerivationNode(value.clone(), origin); + } // no substitution return new ValDerivationNode(var, null); @@ -43,31 +58,33 @@ private static ValDerivationNode propagateRecursive(Expression exp, Map varOrigins) { + if (node == null) + return; + + Expression value = node.getValue(); + DerivationNode origin = node.getOrigin(); + + // check for equality expressions + if (value instanceof BinaryExpression binExp && "==".equals(binExp.getOperator()) + && origin instanceof BinaryDerivationNode binOrigin) { + Expression left = binExp.getFirstOperand(); + Expression right = binExp.getSecondOperand(); + + // extract variable name and value derivation from either side + String varName = null; + ValDerivationNode valueDerivation = null; + + if (left instanceof Var var && right.isLiteral()) { + varName = var.getName(); + valueDerivation = binOrigin.getRight(); + } else if (right instanceof Var var && left.isLiteral()) { + varName = var.getName(); + valueDerivation = binOrigin.getLeft(); + } + if (varName != null && valueDerivation != null && valueDerivation.getOrigin() != null) { + varOrigins.put(varName, valueDerivation.getOrigin()); + } + } + + // recursively process the origin tree + if (origin instanceof BinaryDerivationNode binOrigin) { + extractVarOrigins(binOrigin.getLeft(), varOrigins); + extractVarOrigins(binOrigin.getRight(), varOrigins); + } else if (origin instanceof UnaryDerivationNode unaryOrigin) { + extractVarOrigins(unaryOrigin.getOperand(), varOrigins); + } else if (origin instanceof ValDerivationNode valOrigin) { + extractVarOrigins(valOrigin, varOrigins); + } + } } \ No newline at end of file diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java index 4bb4050a..2e43e326 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java @@ -14,41 +14,88 @@ public class ExpressionSimplifier { * Returns a derivation node representing the tree of simplifications applied */ public static ValDerivationNode simplify(Expression exp) { - ValDerivationNode prop = ConstantPropagation.propagate(exp); + ValDerivationNode fixedPoint = simplifyToFixedPoint(null, exp); + return simplifyValDerivationNode(fixedPoint); + } + + /** + * Recursively applies propagation and folding until the expression stops changing (fixed point) Stops early if the + * expression simplifies to 'true', which means we've simplified too much + */ + private static ValDerivationNode simplifyToFixedPoint(ValDerivationNode current, Expression prevExp) { + // apply propagation and folding + ValDerivationNode prop = ConstantPropagation.propagate(prevExp, current); ValDerivationNode fold = ConstantFolding.fold(prop); - return simplifyDerivationTree(fold); + ValDerivationNode simplified = simplifyValDerivationNode(fold); + Expression currExp = simplified.getValue(); + + // fixed point reached + if (current != null && currExp.equals(current.getValue())) { + return current; + } + + // continue simplifying + return simplifyToFixedPoint(simplified, simplified.getValue()); } /** * Recursively simplifies the derivation tree by removing redundant conjuncts */ - private static ValDerivationNode simplifyDerivationTree(ValDerivationNode node) { + private static ValDerivationNode simplifyValDerivationNode(ValDerivationNode node) { Expression value = node.getValue(); DerivationNode origin = node.getOrigin(); // binary expression with && - if (value instanceof BinaryExpression binExp) { - if ("&&".equals(binExp.getOperator()) && origin instanceof BinaryDerivationNode binOrigin) { - // recursively simplify children - ValDerivationNode leftSimplified = simplifyDerivationTree(binOrigin.getLeft()); - ValDerivationNode rightSimplified = simplifyDerivationTree(binOrigin.getRight()); - - // check if either side is redundant - if (isRedundant(leftSimplified.getValue())) - return rightSimplified; - if (isRedundant(rightSimplified.getValue())) - return leftSimplified; - - // return the conjunction with simplified children - Expression newValue = new BinaryExpression(leftSimplified.getValue(), "&&", rightSimplified.getValue()); - DerivationNode newOrigin = new BinaryDerivationNode(leftSimplified, rightSimplified, "&&"); - return new ValDerivationNode(newValue, newOrigin); + if (value instanceof BinaryExpression binExp && "&&".equals(binExp.getOperator())) { + ValDerivationNode leftSimplified; + ValDerivationNode rightSimplified; + + if (origin instanceof BinaryDerivationNode binOrigin) { + leftSimplified = simplifyValDerivationNode(binOrigin.getLeft()); + rightSimplified = simplifyValDerivationNode(binOrigin.getRight()); + } else { + leftSimplified = simplifyValDerivationNode(new ValDerivationNode(binExp.getFirstOperand(), null)); + rightSimplified = simplifyValDerivationNode(new ValDerivationNode(binExp.getSecondOperand(), null)); } + + // check if either side is redundant + if (isRedundant(leftSimplified.getValue())) + return rightSimplified; + if (isRedundant(rightSimplified.getValue())) + return leftSimplified; + + // collapse identical sides (x && x => x) + if (leftSimplified.getValue().equals(rightSimplified.getValue())) { + return leftSimplified; + } + + // collapse symmetric equalities (e.g. x == y && y == x => x == y) + if (isSymmetricEquality(leftSimplified.getValue(), rightSimplified.getValue())) { + return leftSimplified; + } + + // return the conjunction with simplified children + Expression newValue = new BinaryExpression(leftSimplified.getValue(), "&&", rightSimplified.getValue()); + DerivationNode newOrigin = new BinaryDerivationNode(leftSimplified, rightSimplified, "&&"); + return new ValDerivationNode(newValue, newOrigin); } // no simplification return node; } + private static boolean isSymmetricEquality(Expression left, Expression right) { + if (left instanceof BinaryExpression b1 && "==".equals(b1.getOperator()) && right instanceof BinaryExpression b2 + && "==".equals(b2.getOperator())) { + + Expression l1 = b1.getFirstOperand(); + Expression r1 = b1.getSecondOperand(); + Expression l2 = b2.getFirstOperand(); + Expression r2 = b2.getSecondOperand(); + return l1.equals(r2) && r1.equals(l2); + } + return false; + } + /** * Checks if an expression is redundant (e.g. true or x == x) */ diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java index b93c78db..9d5850e6 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java @@ -12,25 +12,30 @@ public class VariableResolver { /** - * Extracts variables with constant values from an expression Returns a map from variable names to their values + * Extracts variables with constant values from an expression + * + * @param exp + * + * @returns map from variable names to their values */ public static Map resolve(Expression exp) { - // if the expression is just a single equality (not a conjunction) don't extract it - // this avoids creating tautologies like "1 == 1" after substitution, which are then simplified to "true" - if (exp instanceof BinaryExpression be) { - if ("==".equals(be.getOperator())) { - return new HashMap<>(); - } - } - Map map = new HashMap<>(); + + // extract variable equalities recursively resolveRecursive(exp, map); + + // remove variables that were not used in the expression + map.entrySet().removeIf(entry -> !hasUsage(exp, entry.getKey())); + + // transitively resolve variables return resolveTransitive(map); } /** * Recursively extracts variable equalities from an expression (e.g. ... && x == 1 && y == 2 => map: x -> 1, y -> 2) - * Modifies the given map in place + * + * @param exp + * @param map */ private static void resolveRecursive(Expression exp, Map map) { if (!(exp instanceof BinaryExpression be)) @@ -43,16 +48,20 @@ private static void resolveRecursive(Expression exp, Map map } else if ("==".equals(op)) { Expression left = be.getFirstOperand(); Expression right = be.getSecondOperand(); - if (left instanceof Var && (right.isLiteral() || right instanceof Var)) { - map.put(((Var) left).getName(), right.clone()); - } else if (right instanceof Var && left.isLiteral()) { - map.put(((Var) right).getName(), left.clone()); + if (left instanceof Var var && right.isLiteral()) { + map.put(var.getName(), right.clone()); + } else if (right instanceof Var var && left.isLiteral()) { + map.put(var.getName(), left.clone()); } } } /** * Handles transitive variable equalities in the map (e.g. map: x -> y, y -> 1 => map: x -> 1, y -> 1) + * + * @param map + * + * @return new map with resolved values */ private static Map resolveTransitive(Map map) { Map result = new HashMap<>(); @@ -65,6 +74,12 @@ private static Map resolveTransitive(Map /** * Returns the value of a variable by looking up in the map recursively Uses the seen set to avoid circular * references (e.g. x -> y, y -> x) which would cause infinite recursion + * + * @param exp + * @param map + * @param seen + * + * @return resolved expression */ private static Expression lookup(Expression exp, Map map, Set seen) { if (!(exp instanceof Var)) @@ -81,4 +96,39 @@ private static Expression lookup(Expression exp, Map map, Se seen.add(name); return lookup(value, map, seen); } + + /** + * Checks if a variable is used in the expression (excluding its own definitions) + * + * @param exp + * @param name + * + * @return true if used, false otherwise + */ + private static boolean hasUsage(Expression exp, String name) { + // exclude own definitions + if (exp instanceof BinaryExpression binary && "==".equals(binary.getOperator())) { + Expression left = binary.getFirstOperand(); + Expression right = binary.getSecondOperand(); + if (left instanceof Var v && v.getName().equals(name) && right.isLiteral()) + return false; + if (right instanceof Var v && v.getName().equals(name) && left.isLiteral()) + return false; + } + + // usage found + if (exp instanceof Var var && var.getName().equals(name)) { + return true; + } + + // recurse children + if (exp.hasChildren()) { + for (Expression child : exp.getChildren()) + if (hasUsage(child, name)) + return true; + } + + // usage not found + return false; + } } \ No newline at end of file diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/VarDerivationNode.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/VarDerivationNode.java index 1c044f52..c134a44e 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/VarDerivationNode.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/VarDerivationNode.java @@ -3,12 +3,23 @@ public class VarDerivationNode extends DerivationNode { private final String var; + private final DerivationNode origin; public VarDerivationNode(String var) { this.var = var; + this.origin = null; + } + + public VarDerivationNode(String var, DerivationNode origin) { + this.var = var; + this.origin = origin; } public String getVar() { return var; } + + public DerivationNode getOrigin() { + return origin; + } } diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java index ff034f93..b49ce805 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java @@ -110,7 +110,7 @@ void testSimpleComparison() { // Then assertNotNull(result, "Result should not be null"); - assertTrue(result.getValue() instanceof LiteralBoolean, "Result should be a boolean"); + assertInstanceOf(LiteralBoolean.class, result.getValue(), "Result should be a boolean"); assertFalse(((LiteralBoolean) result.getValue()).isBooleanTrue(), "Expected result to befalse"); // (y || true) && y == false => false || true = true @@ -246,8 +246,8 @@ void testComplexArithmeticWithMultipleOperations() { // Then assertNotNull(result, "Result should not be null"); assertNotNull(result.getValue(), "Result value should not be null"); - assertTrue(result.getValue() instanceof LiteralBoolean, "Result should be a boolean literal"); - assertTrue(((LiteralBoolean) result.getValue()).isBooleanTrue(), "Expected result to be true"); + assertInstanceOf(LiteralBoolean.class, result.getValue(), "Result should be a boolean literal"); + assertTrue(result.getValue().isBooleanTrue(), "Expected result to be true"); // 5 * 2 + 7 - 3 ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a")); @@ -305,7 +305,64 @@ void testComplexArithmeticWithMultipleOperations() { } @Test - void testSingleEqualityNotSimplifiedToTrue() { + void testFixedPointSimplification() { + // Given: x == -y && y == a / b && a == 6 && b == 3 + // Expected: x == -2 + + Expression varX = new Var("x"); + Expression varY = new Var("y"); + Expression varA = new Var("a"); + Expression varB = new Var("b"); + + Expression aDivB = new BinaryExpression(varA, "/", varB); + Expression yEqualsADivB = new BinaryExpression(varY, "==", aDivB); + Expression negY = new UnaryExpression("-", varY); + Expression xEqualsNegY = new BinaryExpression(varX, "==", negY); + Expression six = new LiteralInt(6); + Expression aEquals6 = new BinaryExpression(varA, "==", six); + Expression three = new LiteralInt(3); + Expression bEquals3 = new BinaryExpression(varB, "==", three); + Expression firstAnd = new BinaryExpression(xEqualsNegY, "&&", yEqualsADivB); + Expression secondAnd = new BinaryExpression(aEquals6, "&&", bEquals3); + Expression fullExpression = new BinaryExpression(firstAnd, "&&", secondAnd); + + // When + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + + // Then + assertNotNull(result, "Result should not be null"); + assertEquals("x == -2", result.getValue().toString(), "Expected result to be x == -2"); + + // Compare derivation tree structure + + // Build the derivation chain for the right side: + // 6 came from a, 3 came from b + ValDerivationNode val6FromA = new ValDerivationNode(new LiteralInt(6), new VarDerivationNode("a")); + ValDerivationNode val3FromB = new ValDerivationNode(new LiteralInt(3), new VarDerivationNode("b")); + + // 6 / 3 -> 2 + BinaryDerivationNode divOrigin = new BinaryDerivationNode(val6FromA, val3FromB, "/"); + + // 2 came from y, and y's value came from 6 / 2 + VarDerivationNode yChainedOrigin = new VarDerivationNode("y", divOrigin); + ValDerivationNode val2FromY = new ValDerivationNode(new LiteralInt(2), yChainedOrigin); + + // -2 + UnaryDerivationNode negOrigin = new UnaryDerivationNode(val2FromY, "-"); + ValDerivationNode rightNode = new ValDerivationNode(new LiteralInt(-2), negOrigin); + + // Left node x has no origin + ValDerivationNode leftNode = new ValDerivationNode(new Var("x"), null); + + // Root equality + BinaryDerivationNode rootOrigin = new BinaryDerivationNode(leftNode, rightNode, "=="); + ValDerivationNode expected = new ValDerivationNode(result.getValue(), rootOrigin); + + assertDerivationEquals(expected, result, "Derivation tree structure"); + } + + @Test + void testSingleEqualityShouldNotSimplify() { // Given: x == 1 // Expected: x == 1 (should not be simplified to "true") @@ -322,13 +379,177 @@ void testSingleEqualityNotSimplifiedToTrue() { "Single equality should not be simplified to a boolean literal"); // The result should be the original expression unchanged - assertTrue(result.getValue() instanceof BinaryExpression, "Result should still be a binary expression"); + assertInstanceOf(BinaryExpression.class, result.getValue(), "Result should still be a binary expression"); BinaryExpression resultExpr = (BinaryExpression) result.getValue(); assertEquals("==", resultExpr.getOperator(), "Operator should still be =="); assertEquals("x", resultExpr.getFirstOperand().toString(), "Left operand should be x"); assertEquals("1", resultExpr.getSecondOperand().toString(), "Right operand should be 1"); } + @Test + void testTwoEqualitiesShouldNotSimplify() { + // Given: x == 1 && y == 2 + // Expected: x == 1 && y == 2 (should not be simplified to "true") + + Expression varX = new Var("x"); + Expression one = new LiteralInt(1); + Expression xEquals1 = new BinaryExpression(varX, "==", one); + + Expression varY = new Var("y"); + Expression two = new LiteralInt(2); + Expression yEquals2 = new BinaryExpression(varY, "==", two); + Expression fullExpression = new BinaryExpression(xEquals1, "&&", yEquals2); + + // When + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + + // Then + assertNotNull(result, "Result should not be null"); + assertEquals("x == 1 && y == 2", result.getValue().toString(), + "Two equalities should not be simplified to a boolean literal"); + + // The result should be the original expression unchanged + assertInstanceOf(BinaryExpression.class, result.getValue(), "Result should still be a binary expression"); + BinaryExpression resultExpr = (BinaryExpression) result.getValue(); + assertEquals("&&", resultExpr.getOperator(), "Operator should still be &&"); + assertEquals("x == 1", resultExpr.getFirstOperand().toString(), "Left operand should be x == 1"); + assertEquals("y == 2", resultExpr.getSecondOperand().toString(), "Right operand should be y == 2"); + } + + @Test + void testSameVarTwiceShouldSimplifyToSingle() { + // Given: x && x + // Expected: x + + Expression varX = new Var("x"); + Expression fullExpression = new BinaryExpression(varX, "&&", varX); + // When + + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + // Then + + assertNotNull(result, "Result should not be null"); + assertEquals("x", result.getValue().toString(), + "Same variable twice should be simplified to a single variable"); + } + + @Test + void testSameEqualityTwiceShouldSimplifyToSingle() { + // Given: x == 1 && x == 1 + // Expected: x == 1 + + Expression varX = new Var("x"); + Expression one = new LiteralInt(1); + Expression xEquals1First = new BinaryExpression(varX, "==", one); + Expression xEquals1Second = new BinaryExpression(varX, "==", one); + Expression fullExpression = new BinaryExpression(xEquals1First, "&&", xEquals1Second); + + // When + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + + // Then + assertNotNull(result, "Result should not be null"); + assertEquals("x == 1", result.getValue().toString(), + "Same equality twice should be simplified to a single equality"); + } + + @Test + void testSameExpressionTwiceShouldSimplifyToSingle() { + // Given: a + b == 1 && a + b == 1 + // Expected: a + b == 1 + + Expression varA = new Var("a"); + Expression varB = new Var("b"); + Expression sum = new BinaryExpression(varA, "+", varB); + Expression one = new LiteralInt(1); + Expression sumEquals3First = new BinaryExpression(sum, "==", one); + Expression sumEquals3Second = new BinaryExpression(sum, "==", one); + Expression fullExpression = new BinaryExpression(sumEquals3First, "&&", sumEquals3Second); + + // When + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + + // Then + assertNotNull(result, "Result should not be null"); + assertEquals("a + b == 1", result.getValue().toString(), + "Same expression twice should be simplified to a single equality"); + } + + @Test + void testSymmetricEqualityShouldSimplify() { + // Given: x == y && y == x + // Expected: x == y + + Expression varX = new Var("x"); + Expression varY = new Var("y"); + Expression xEqualsY = new BinaryExpression(varX, "==", varY); + Expression yEqualsX = new BinaryExpression(varY, "==", varX); + Expression fullExpression = new BinaryExpression(xEqualsY, "&&", yEqualsX); + + // When + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + + // Then + assertNotNull(result, "Result should not be null"); + assertEquals("x == y", result.getValue().toString(), + "Symmetric equality should be simplified to a single equality"); + } + + @Test + void testRealExpression() { + // Given: #a_5 == -#fresh_4 && #fresh_4 == #x_2 / #y_3 && #x_2 == #x_0 && #x_0 == 6 && #y_3 == #y_1 && #y_1 == 3 + // Expected: #a_5 == -2 + + Expression varA5 = new Var("#a_5"); + Expression varFresh4 = new Var("#fresh_4"); + Expression varX2 = new Var("#x_2"); + Expression varY3 = new Var("#y_3"); + Expression varX0 = new Var("#x_0"); + Expression varY1 = new Var("#y_1"); + Expression six = new LiteralInt(6); + Expression three = new LiteralInt(3); + Expression fresh4EqualsX2DivY3 = new BinaryExpression(varFresh4, "==", new BinaryExpression(varX2, "/", varY3)); + Expression x2EqualsX0 = new BinaryExpression(varX2, "==", varX0); + Expression x0Equals6 = new BinaryExpression(varX0, "==", six); + Expression y3EqualsY1 = new BinaryExpression(varY3, "==", varY1); + Expression y1Equals3 = new BinaryExpression(varY1, "==", three); + Expression negFresh4 = new UnaryExpression("-", varFresh4); + Expression a5EqualsNegFresh4 = new BinaryExpression(varA5, "==", negFresh4); + Expression firstAnd = new BinaryExpression(a5EqualsNegFresh4, "&&", fresh4EqualsX2DivY3); + Expression secondAnd = new BinaryExpression(x2EqualsX0, "&&", x0Equals6); + Expression thirdAnd = new BinaryExpression(y3EqualsY1, "&&", y1Equals3); + Expression firstBigAnd = new BinaryExpression(firstAnd, "&&", secondAnd); + Expression fullExpression = new BinaryExpression(firstBigAnd, "&&", thirdAnd); + + // When + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + + // Then + assertNotNull(result, "Result should not be null"); + assertEquals("#a_5 == -2", result.getValue().toString(), "Expected result to be #a_5 == -2"); + + } + + @Test + void testTransitive() { + // Given: a == b && b == 1 + // Expected: a == 1 + + Expression varA = new Var("a"); + Expression varB = new Var("b"); + Expression one = new LiteralInt(1); + Expression aEqualsB = new BinaryExpression(varA, "==", varB); + Expression bEquals1 = new BinaryExpression(varB, "==", one); + Expression fullExpression = new BinaryExpression(aEqualsB, "&&", bEquals1); + + // When + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + + // Then + assertNotNull(result, "Result should not be null"); + assertEquals("a == 1", result.getValue().toString(), "Expected result to be a == 1"); + } + /** * Helper method to compare two derivation nodes recursively */ @@ -336,25 +557,22 @@ private void assertDerivationEquals(DerivationNode expected, DerivationNode actu if (expected == null && actual == null) return; + assertNotNull(expected); assertEquals(expected.getClass(), actual.getClass(), message + ": node types should match"); - if (expected instanceof ValDerivationNode) { - ValDerivationNode expectedVal = (ValDerivationNode) expected; + if (expected instanceof ValDerivationNode expectedVal) { ValDerivationNode actualVal = (ValDerivationNode) actual; assertEquals(expectedVal.getValue().toString(), actualVal.getValue().toString(), message + ": values should match"); assertDerivationEquals(expectedVal.getOrigin(), actualVal.getOrigin(), message + " > origin"); - } else if (expected instanceof BinaryDerivationNode) { - BinaryDerivationNode expectedBin = (BinaryDerivationNode) expected; + } else if (expected instanceof BinaryDerivationNode expectedBin) { BinaryDerivationNode actualBin = (BinaryDerivationNode) actual; assertEquals(expectedBin.getOp(), actualBin.getOp(), message + ": operators should match"); assertDerivationEquals(expectedBin.getLeft(), actualBin.getLeft(), message + " > left"); assertDerivationEquals(expectedBin.getRight(), actualBin.getRight(), message + " > right"); - } else if (expected instanceof VarDerivationNode) { - VarDerivationNode expectedVar = (VarDerivationNode) expected; + } else if (expected instanceof VarDerivationNode expectedVar) { VarDerivationNode actualVar = (VarDerivationNode) actual; assertEquals(expectedVar.getVar(), actualVar.getVar(), message + ": variables should match"); - } else if (expected instanceof UnaryDerivationNode) { - UnaryDerivationNode expectedUnary = (UnaryDerivationNode) expected; + } else if (expected instanceof UnaryDerivationNode expectedUnary) { UnaryDerivationNode actualUnary = (UnaryDerivationNode) actual; assertEquals(expectedUnary.getOp(), actualUnary.getOp(), message + ": operators should match"); assertDerivationEquals(expectedUnary.getOperand(), actualUnary.getOperand(), message + " > operand"); diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VariableResolverTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VariableResolverTest.java index 0d08e9a2..6f799aa2 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VariableResolverTest.java +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VariableResolverTest.java @@ -107,4 +107,39 @@ void testGroupedEquality() { Map result = VariableResolver.resolve(grouped); assertTrue(result.isEmpty(), "Grouped single equality should not extract variable mapping"); } + + @Test + void testCircularDependency() { + // x == y && y == x should not extract anything due to circular dependency + Expression varX = new Var("x"); + Expression varY = new Var("y"); + + Expression xEqualsY = new BinaryExpression(varX, "==", varY); + Expression yEqualsX = new BinaryExpression(varY, "==", varX); + Expression conjunction = new BinaryExpression(xEqualsY, "&&", yEqualsX); + + Map result = VariableResolver.resolve(conjunction); + assertTrue(result.isEmpty(), "Circular dependency should not extract variable mappings"); + } + + @Test + void testUnusedEqualitiesShouldBeIgnored() { + // z > 0 && x == 1 && y == 2 && z == 3 + Expression varX = new Var("x"); + Expression varY = new Var("y"); + Expression varZ = new Var("z"); + Expression one = new LiteralInt(1); + Expression two = new LiteralInt(2); + Expression three = new LiteralInt(3); + Expression zero = new LiteralInt(0); + Expression zGreaterZero = new BinaryExpression(varZ, ">", zero); + Expression xEquals1 = new BinaryExpression(varX, "==", one); + Expression yEquals2 = new BinaryExpression(varY, "==", two); + Expression zEquals3 = new BinaryExpression(varZ, "==", three); + Expression conditions = new BinaryExpression(xEquals1, "&&", new BinaryExpression(yEquals2, "&&", zEquals3)); + Expression fullExpr = new BinaryExpression(zGreaterZero, "&&", conditions); + Map result = VariableResolver.resolve(fullExpr); + assertEquals(1, result.size(), "Should only extract used variable z"); + assertEquals("3", result.get("z").toString()); + } }