diff --git a/liquidjava-verifier/pom.xml b/liquidjava-verifier/pom.xml index 21464fa3..46a4972d 100644 --- a/liquidjava-verifier/pom.xml +++ b/liquidjava-verifier/pom.xml @@ -188,6 +188,11 @@ antlr4-runtime 4.7.1 + + com.google.code.gson + gson + 2.10.1 + diff --git a/liquidjava-verifier/src/main/java/liquidjava/errors/ErrorHandler.java b/liquidjava-verifier/src/main/java/liquidjava/errors/ErrorHandler.java index b17f51b4..91ab4987 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/errors/ErrorHandler.java +++ b/liquidjava-verifier/src/main/java/liquidjava/errors/ErrorHandler.java @@ -26,7 +26,7 @@ public static void printError(CtElement var, Predicate expectedType, Predica } public static void printError(CtElement var, String moreInfo, Predicate expectedType, Predicate cSMT, - HashMap map, ErrorEmitter errorl) { + HashMap map, ErrorEmitter ee) { String resumeMessage = "Type expected:" + expectedType.toString(); // + "; " +"Refinement found:" + // cSMT.toString(); @@ -41,16 +41,16 @@ public static void printError(CtElement var, String moreInfo, Predicate expe // all message sb.append(sbtitle.toString() + "\n\n"); sb.append("Type expected:" + expectedType.toString() + "\n"); - sb.append("Refinement found:" + cSMT.toString() + "\n"); + sb.append("Refinement found: " + cSMT.simplify().getValue() + "\n"); sb.append(printMap(map)); sb.append("Location: " + var.getPosition() + "\n"); sb.append("______________________________________________________\n"); - errorl.addError(resumeMessage, sb.toString(), var.getPosition(), 1, map); + ee.addError(resumeMessage, sb.toString(), var.getPosition(), 1, map); } public static void printStateMismatch(CtElement element, String method, VCImplication constraintForErrorMsg, - String states, HashMap map, ErrorEmitter errorl) { + String states, HashMap map, ErrorEmitter ee) { String resumeMessage = "Failed to check state transitions. " + "Expected possible states:" + states; // + "; // Found @@ -75,11 +75,11 @@ public static void printStateMismatch(CtElement element, String method, VCImplic sb.append("Location: " + element.getPosition() + "\n"); sb.append("______________________________________________________\n"); - errorl.addError(resumeMessage, sb.toString(), element.getPosition(), 1, map); + ee.addError(resumeMessage, sb.toString(), element.getPosition(), 1, map); } public static void printErrorUnknownVariable(CtElement var, String et, String correctRefinement, - HashMap map, ErrorEmitter errorl) { + HashMap map, ErrorEmitter ee) { String resumeMessage = "Encountered unknown variable"; @@ -94,11 +94,11 @@ public static void printErrorUnknownVariable(CtElement var, String et, Strin sb.append("Location: " + var.getPosition() + "\n"); sb.append("______________________________________________________\n"); - errorl.addError(resumeMessage, sb.toString(), var.getPosition(), 2, map); + ee.addError(resumeMessage, sb.toString(), var.getPosition(), 2, map); } public static void printNotFound(CtElement var, Predicate constraint, Predicate constraint2, String msg, - HashMap map, ErrorEmitter errorl) { + HashMap map, ErrorEmitter ee) { StringBuilder sb = new StringBuilder(); sb.append("______________________________________________________\n"); @@ -111,11 +111,11 @@ public static void printNotFound(CtElement var, Predicate constraint, Predic sb.append("Location: " + var.getPosition() + "\n"); sb.append("______________________________________________________\n"); - errorl.addError(msg, sb.toString(), var.getPosition(), 2, map); + ee.addError(msg, sb.toString(), var.getPosition(), 2, map); } public static void printErrorArgs(CtElement var, Predicate expectedType, String msg, - HashMap map, ErrorEmitter errorl) { + HashMap map, ErrorEmitter ee) { StringBuilder sb = new StringBuilder(); sb.append("______________________________________________________\n"); String title = "Error in ghost invocation: " + msg + "\n"; @@ -125,11 +125,11 @@ public static void printErrorArgs(CtElement var, Predicate expectedType, Str sb.append("Location: " + var.getPosition() + "\n"); sb.append("______________________________________________________\n"); - errorl.addError(title, sb.toString(), var.getPosition(), 2, map); + ee.addError(title, sb.toString(), var.getPosition(), 2, map); } public static void printErrorTypeMismatch(CtElement element, Predicate expectedType, String message, - HashMap map, ErrorEmitter errorl) { + HashMap map, ErrorEmitter ee) { StringBuilder sb = new StringBuilder(); sb.append("______________________________________________________\n"); sb.append(message + "\n\n"); @@ -138,11 +138,11 @@ public static void printErrorTypeMismatch(CtElement element, Predicate expectedT sb.append("Location: " + element.getPosition() + "\n"); sb.append("______________________________________________________\n"); - errorl.addError(message, sb.toString(), element.getPosition(), 2, map); + ee.addError(message, sb.toString(), element.getPosition(), 2, map); } public static void printSameStateSetError(CtElement element, Predicate p, String name, - HashMap map, ErrorEmitter errorl) { + HashMap map, ErrorEmitter ee) { String resume = "Error found multiple disjoint states from a State Set in a refinement"; StringBuilder sb = new StringBuilder(); @@ -157,10 +157,10 @@ public static void printSameStateSetError(CtElement element, Predicate p, String sb.append("Location: " + element.getPosition() + "\n"); sb.append("______________________________________________________\n"); - errorl.addError(resume, sb.toString(), element.getPosition(), 1, map); + ee.addError(resume, sb.toString(), element.getPosition(), 1, map); } - public static void printErrorConstructorFromState(CtElement element, CtLiteral from, ErrorEmitter errorl) { + public static void printErrorConstructorFromState(CtElement element, CtLiteral from, ErrorEmitter ee) { StringBuilder sb = new StringBuilder(); sb.append("______________________________________________________\n"); String s = " Error found constructor with FROM state (Constructor's should only have a TO state)\n\n"; @@ -170,10 +170,10 @@ public static void printErrorConstructorFromState(CtElement element, CtLiteral e, int set, CtElement element) { CtLiteral s = (CtLiteral) ce; String f = s.getValue(); if (Character.isUpperCase(f.charAt(0))) { - ErrorHandler.printCostumeError(s, "State name must start with lowercase in '" + f + "'", + ErrorHandler.printCustomError(s, "State name must start with lowercase in '" + f + "'", errorEmitter); } } @@ -161,11 +161,11 @@ private void createStateGhost(String string, CtAnnotation try { gd = RefinementsParser.getGhostDeclaration(string); } catch (ParsingException e) { - ErrorHandler.printCostumeError(ann, "Could not parse the Ghost Function" + e.getMessage(), errorEmitter); + ErrorHandler.printCustomError(ann, "Could not parse the Ghost Function" + e.getMessage(), errorEmitter); return; } if (gd.getParam_types().size() > 0) { - ErrorHandler.printCostumeError(ann, "Ghost States have the class as parameter " + ErrorHandler.printCustomError(ann, "Ghost States have the class as parameter " + "by default, no other parameters are allowed in '" + string + "'", errorEmitter); return; } @@ -224,8 +224,7 @@ protected void getGhostFunction(String value, CtElement element) { context.addGhostFunction(gh); } } catch (ParsingException e) { - ErrorHandler.printCostumeError(element, "Could not parse the Ghost Function" + e.getMessage(), - errorEmitter); + ErrorHandler.printCustomError(element, "Could not parse the Ghost Function" + e.getMessage(), errorEmitter); // e.printStackTrace(); return; } @@ -252,7 +251,7 @@ protected void handleAlias(String value, CtElement element) { } } } catch (ParsingException e) { - ErrorHandler.printCostumeError(element, e.getMessage(), errorEmitter); + ErrorHandler.printCustomError(element, e.getMessage(), errorEmitter); return; // e.printStackTrace(); } diff --git a/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java b/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java index 82cbf15d..458ae4da 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java +++ b/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java @@ -364,7 +364,7 @@ private void printError(Exception e, Predicate premisesBeforeChange, Predicate e } else if (e instanceof NotFoundError) { ErrorHandler.printNotFound(element, cSMTMessageReady, etMessageReady, e.getMessage(), map, errorEmitter); } else { - ErrorHandler.printCostumeError(element, e.getMessage(), errorEmitter); + ErrorHandler.printCustomError(element, e.getMessage(), errorEmitter); // System.err.println("Unknown error:"+e.getMessage()); // e.printStackTrace(); // System.exit(7); diff --git a/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/object_checkers/AuxStateHandler.java b/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/object_checkers/AuxStateHandler.java index 192ff2de..72ada2df 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/object_checkers/AuxStateHandler.java +++ b/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/object_checkers/AuxStateHandler.java @@ -369,7 +369,7 @@ public static void updateGhostField(CtFieldWrite fw, TypeChecker tc) { stateChange.setTo(toPredicate); } catch (ParsingException e) { ErrorHandler - .printCostumeError(fw, + .printCustomError(fw, "ParsingException while constructing assignment update for `" + fw + "` in class `" + fw.getVariable().getDeclaringType() + "` : " + e.getMessage(), tc.getErrorEmitter()); diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java index eca80729..7fb7fa94 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java @@ -21,6 +21,9 @@ import liquidjava.rj_language.ast.LiteralReal; import liquidjava.rj_language.ast.UnaryExpression; import liquidjava.rj_language.ast.Var; +import liquidjava.rj_language.opt.derivation_node.DerivationNode; +import liquidjava.rj_language.opt.derivation_node.ValDerivationNode; +import liquidjava.rj_language.opt.ExpressionSimplifier; import liquidjava.rj_language.parsing.ParsingException; import liquidjava.rj_language.parsing.RefinementsParser; import liquidjava.utils.Utils; @@ -212,6 +215,10 @@ public Expression getExpression() { return exp; } + public ValDerivationNode simplify() { + return ExpressionSimplifier.simplify(exp.clone()); + } + public static Predicate createConjunction(Predicate c1, Predicate c2) { return new Predicate(new BinaryExpression(c1.getExpression(), Utils.AND, c2.getExpression())); } diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/Expression.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/Expression.java index da75ec71..008b651c 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/Expression.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/Expression.java @@ -1,9 +1,11 @@ package liquidjava.rj_language.ast; -import com.microsoft.z3.Expr; import java.util.ArrayList; import java.util.List; import java.util.Map; + +import com.microsoft.z3.Expr; + import liquidjava.processor.context.Context; import liquidjava.processor.facade.AliasDTO; import liquidjava.rj_language.ast.typing.TypeInfer; @@ -47,6 +49,10 @@ public void setChild(int index, Expression element) { children.set(index, element); } + public boolean isLiteral() { + return this instanceof LiteralInt || this instanceof LiteralReal || this instanceof LiteralBoolean; + } + /** * Substitutes the expression first given expression by the second * diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralInt.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralInt.java index 6a4e97f1..4e98fa57 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralInt.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralInt.java @@ -25,6 +25,10 @@ public String toString() { return Integer.toString(value); } + public int getValue() { + return value; + } + @Override public void getVariableNames(List toAdd) { // end leaf diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralReal.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralReal.java index 2ddc8430..e6c4a810 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralReal.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralReal.java @@ -25,6 +25,10 @@ public String toString() { return Double.toString(value); } + public double getValue() { + return value; + } + @Override public void getVariableNames(List toAdd) { // end leaf diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantFolding.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantFolding.java new file mode 100644 index 00000000..0d5fe242 --- /dev/null +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantFolding.java @@ -0,0 +1,197 @@ +package liquidjava.rj_language.opt; + +import liquidjava.rj_language.ast.BinaryExpression; +import liquidjava.rj_language.ast.Expression; +import liquidjava.rj_language.ast.GroupExpression; +import liquidjava.rj_language.ast.LiteralBoolean; +import liquidjava.rj_language.ast.LiteralInt; +import liquidjava.rj_language.ast.LiteralReal; +import liquidjava.rj_language.ast.UnaryExpression; +import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode; +import liquidjava.rj_language.opt.derivation_node.DerivationNode; +import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode; +import liquidjava.rj_language.opt.derivation_node.ValDerivationNode; + +public class ConstantFolding { + + /** + * Performs constant folding on a derivation node by evaluating nodes with constant values. Returns a new derivation + * node representing the folding steps taken + */ + public static ValDerivationNode fold(ValDerivationNode node) { + Expression exp = node.getValue(); + if (exp instanceof BinaryExpression) + return foldBinary(node); + + if (exp instanceof UnaryExpression) + return foldUnary(node); + + if (exp instanceof GroupExpression) { + GroupExpression group = (GroupExpression) exp; + if (group.getChildren().size() == 1) { + return fold(new ValDerivationNode(group.getChildren().get(0), node.getOrigin())); + } + } + return node; + } + + /** + * Folds a binary expression node if both children are constant values (e.g. 1 + 2 => 3) + */ + private static ValDerivationNode foldBinary(ValDerivationNode node) { + BinaryExpression binExp = (BinaryExpression) node.getValue(); + DerivationNode parent = node.getOrigin(); + + // fold child nodes + ValDerivationNode leftNode; + ValDerivationNode rightNode; + if (parent instanceof BinaryDerivationNode) { + // has origin (from constant propagation) + BinaryDerivationNode binaryOrigin = (BinaryDerivationNode) parent; + leftNode = fold(binaryOrigin.getLeft()); + rightNode = fold(binaryOrigin.getRight()); + } else { + // no origin + leftNode = fold(new ValDerivationNode(binExp.getFirstOperand(), null)); + rightNode = fold(new ValDerivationNode(binExp.getSecondOperand(), null)); + } + + Expression left = leftNode.getValue(); + Expression right = rightNode.getValue(); + String op = binExp.getOperator(); + binExp.setChild(0, left); + binExp.setChild(1, right); + + // int and int + if (left instanceof LiteralInt && right instanceof LiteralInt) { + int l = ((LiteralInt) left).getValue(); + int r = ((LiteralInt) right).getValue(); + Expression res = switch (op) { + case "+" -> new LiteralInt(l + r); + case "-" -> new LiteralInt(l - r); + case "*" -> new LiteralInt(l * r); + case "/" -> r != 0 ? new LiteralInt(l / r) : null; + case "%" -> r != 0 ? new LiteralInt(l % r) : null; + case "<" -> new LiteralBoolean(l < r); + case "<=" -> new LiteralBoolean(l <= r); + case ">" -> new LiteralBoolean(l > r); + case ">=" -> new LiteralBoolean(l >= r); + case "==" -> new LiteralBoolean(l == r); + case "!=" -> new LiteralBoolean(l != r); + default -> null; + }; + if (res != null) + return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op)); + } + // real and real + else if (left instanceof LiteralReal && right instanceof LiteralReal) { + double l = ((LiteralReal) left).getValue(); + double r = ((LiteralReal) right).getValue(); + Expression res = switch (op) { + case "+" -> new LiteralReal(l + r); + case "-" -> new LiteralReal(l - r); + case "*" -> new LiteralReal(l * r); + case "/" -> r != 0.0 ? new LiteralReal(l / r) : null; + case "%" -> r != 0.0 ? new LiteralReal(l % r) : null; + case "<" -> new LiteralBoolean(l < r); + case "<=" -> new LiteralBoolean(l <= r); + case ">" -> new LiteralBoolean(l > r); + case ">=" -> new LiteralBoolean(l >= r); + case "==" -> new LiteralBoolean(l == r); + case "!=" -> new LiteralBoolean(l != r); + default -> null; + }; + if (res != null) + return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op)); + } + + // mixed int and real + else if ((left instanceof LiteralInt && right instanceof LiteralReal) + || (left instanceof LiteralReal && right instanceof LiteralInt)) { + double l = left instanceof LiteralInt ? ((LiteralInt) left).getValue() : ((LiteralReal) left).getValue(); + double r = right instanceof LiteralInt ? ((LiteralInt) right).getValue() : ((LiteralReal) right).getValue(); + Expression res = switch (op) { + case "+" -> new LiteralReal(l + r); + case "-" -> new LiteralReal(l - r); + case "*" -> new LiteralReal(l * r); + case "/" -> r != 0.0 ? new LiteralReal(l / r) : null; + case "%" -> r != 0.0 ? new LiteralReal(l % r) : null; + case "<" -> new LiteralBoolean(l < r); + case "<=" -> new LiteralBoolean(l <= r); + case ">" -> new LiteralBoolean(l > r); + case ">=" -> new LiteralBoolean(l >= r); + case "==" -> new LiteralBoolean(l == r); + case "!=" -> new LiteralBoolean(l != r); + default -> null; + }; + if (res != null) + return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op)); + } + // bool and bool + else if (left instanceof LiteralBoolean && right instanceof LiteralBoolean) { + boolean l = ((LiteralBoolean) left).isBooleanTrue(); + boolean r = ((LiteralBoolean) right).isBooleanTrue(); + Expression res = switch (op) { + case "&&" -> new LiteralBoolean(l && r); + case "||" -> new LiteralBoolean(l || r); + case "-->" -> new LiteralBoolean(!l || r); + case "==" -> new LiteralBoolean(l == r); + case "!=" -> new LiteralBoolean(l != r); + default -> null; + }; + if (res != null) + return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op)); + } + + // no folding + DerivationNode origin = (leftNode.getOrigin() != null || rightNode.getOrigin() != null) + ? new BinaryDerivationNode(leftNode, rightNode, op) : null; + return new ValDerivationNode(binExp, origin); + } + + /** + * Folds a unary expression node if the child (operand) is a constant value (e.g. !true => false) + */ + private static ValDerivationNode foldUnary(ValDerivationNode node) { + UnaryExpression unaryExp = (UnaryExpression) node.getValue(); + DerivationNode parent = node.getOrigin(); + + // fold child node + ValDerivationNode operandNode; + if (parent instanceof UnaryDerivationNode) { + // has origin (from constant propagation) + UnaryDerivationNode unaryOrigin = (UnaryDerivationNode) parent; + operandNode = fold(unaryOrigin.getOperand()); + } else { + // no origin + operandNode = fold(new ValDerivationNode(unaryExp.getChildren().get(0), null)); + } + Expression operand = operandNode.getValue(); + String operator = unaryExp.getOp(); + unaryExp.setChild(0, operand); + + // unary not + if ("!".equals(operator) && operand instanceof LiteralBoolean) { + // !true => false, !false => true + boolean value = ((LiteralBoolean) operand).isBooleanTrue(); + Expression res = new LiteralBoolean(!value); + return new ValDerivationNode(res, new UnaryDerivationNode(operandNode, operator)); + } + // unary minus + if ("-".equals(operator)) { + // -(x) => -x + if (operand instanceof LiteralInt) { + Expression res = new LiteralInt(-((LiteralInt) operand).getValue()); + return new ValDerivationNode(res, new UnaryDerivationNode(operandNode, operator)); + } + if (operand instanceof LiteralReal) { + Expression res = new LiteralReal(-((LiteralReal) operand).getValue()); + return new ValDerivationNode(res, new UnaryDerivationNode(operandNode, operator)); + } + } + + // no folding + DerivationNode origin = operandNode.getOrigin() != null ? new UnaryDerivationNode(operandNode, operator) : null; + return new ValDerivationNode(unaryExp, origin); + } +} \ No newline at end of file diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantPropagation.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantPropagation.java new file mode 100644 index 00000000..a72a9b33 --- /dev/null +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantPropagation.java @@ -0,0 +1,82 @@ +package liquidjava.rj_language.opt; + +import liquidjava.rj_language.ast.BinaryExpression; +import liquidjava.rj_language.ast.Expression; +import liquidjava.rj_language.ast.UnaryExpression; +import liquidjava.rj_language.ast.Var; +import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode; +import liquidjava.rj_language.opt.derivation_node.DerivationNode; +import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode; +import liquidjava.rj_language.opt.derivation_node.ValDerivationNode; +import liquidjava.rj_language.opt.derivation_node.VarDerivationNode; + +import java.util.Map; + +public class ConstantPropagation { + + /** + * Performs constant propagation on an expression, by substituting variables with their constant values. Uses the + * VariableResolver to extract variable equalities from the expression first. Returns a derivation node representing + * the propagation steps taken. + */ + public static ValDerivationNode propagate(Expression exp) { + Map substitutions = VariableResolver.resolve(exp); + return propagateRecursive(exp, substitutions); + } + + /** + * Recursively performs constant propagation on an expression (e.g. x + y && x == 1 && y == 2 => 1 + 2) + */ + private static ValDerivationNode propagateRecursive(Expression exp, Map subs) { + + // substitute variable + if (exp instanceof Var) { + Var var = (Var) exp; + String name = var.getName(); + Expression value = subs.get(name); + // substitution + if (value != null) + return new ValDerivationNode(value.clone(), new VarDerivationNode(name)); + + // no substitution + return new ValDerivationNode(var, null); + } + + // lift unary origin + if (exp instanceof UnaryExpression) { + UnaryExpression unary = (UnaryExpression) exp; + ValDerivationNode operand = propagateRecursive(unary.getChildren().get(0), subs); + unary.setChild(0, operand.getValue()); + + DerivationNode origin = operand.getOrigin() != null ? new UnaryDerivationNode(operand, unary.getOp()) + : null; + return new ValDerivationNode(unary, origin); + } + + // lift binary origin + if (exp instanceof BinaryExpression) { + BinaryExpression binary = (BinaryExpression) exp; + ValDerivationNode left = propagateRecursive(binary.getFirstOperand(), subs); + ValDerivationNode right = propagateRecursive(binary.getSecondOperand(), subs); + binary.setChild(0, left.getValue()); + binary.setChild(1, right.getValue()); + + DerivationNode origin = (left.getOrigin() != null || right.getOrigin() != null) + ? new BinaryDerivationNode(left, right, binary.getOperator()) : null; + return new ValDerivationNode(binary, origin); + } + + // recursively propagate children + if (exp.hasChildren()) { + Expression propagated = exp.clone(); + for (int i = 0; i < exp.getChildren().size(); i++) { + ValDerivationNode child = propagateRecursive(exp.getChildren().get(i), subs); + propagated.setChild(i, child.getValue()); + } + return new ValDerivationNode(propagated, null); + } + + // no propagation + return new ValDerivationNode(exp, null); + } +} \ No newline at end of file diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java new file mode 100644 index 00000000..2a022b81 --- /dev/null +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java @@ -0,0 +1,74 @@ +package liquidjava.rj_language.opt; + +import liquidjava.rj_language.ast.BinaryExpression; +import liquidjava.rj_language.ast.Expression; +import liquidjava.rj_language.ast.LiteralBoolean; +import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode; +import liquidjava.rj_language.opt.derivation_node.DerivationNode; +import liquidjava.rj_language.opt.derivation_node.ValDerivationNode; + +public class ExpressionSimplifier { + + /** + * Simplifies an expression by applying constant propagation, constant folding and removing redundant conjuncts + * Returns a derivation node representing the tree of simplifications applied + */ + public static ValDerivationNode simplify(Expression exp) { + ValDerivationNode prop = ConstantPropagation.propagate(exp); + ValDerivationNode fold = ConstantFolding.fold(prop); + return simplifyDerivationTree(fold); + } + + /** + * Recursively simplifies the derivation tree by removing redundant conjuncts + */ + private static ValDerivationNode simplifyDerivationTree(ValDerivationNode node) { + Expression value = node.getValue(); + DerivationNode origin = node.getOrigin(); + + // binary expression with && + if (value instanceof BinaryExpression) { + BinaryExpression binExp = (BinaryExpression) value; + if ("&&".equals(binExp.getOperator()) && origin instanceof BinaryDerivationNode) { + BinaryDerivationNode binOrigin = (BinaryDerivationNode) origin; + + // recursively simplify children + ValDerivationNode leftSimplified = simplifyDerivationTree(binOrigin.getLeft()); + ValDerivationNode rightSimplified = simplifyDerivationTree(binOrigin.getRight()); + + // check if either side is redundant + if (isRedundant(leftSimplified.getValue())) + return rightSimplified; + if (isRedundant(rightSimplified.getValue())) + return leftSimplified; + + // return the conjunction with simplified children + Expression newValue = new BinaryExpression(leftSimplified.getValue(), "&&", rightSimplified.getValue()); + DerivationNode newOrigin = new BinaryDerivationNode(leftSimplified, rightSimplified, "&&"); + return new ValDerivationNode(newValue, newOrigin); + } + } + // no simplification + return node; + } + + /** + * Checks if an expression is redundant (e.g. true or x == x) + */ + private static boolean isRedundant(Expression exp) { + // true + if (exp instanceof LiteralBoolean && ((LiteralBoolean) exp).isBooleanTrue()) { + return true; + } + // x == x + if (exp instanceof BinaryExpression) { + BinaryExpression binExp = (BinaryExpression) exp; + if ("==".equals(binExp.getOperator())) { + Expression left = binExp.getFirstOperand(); + Expression right = binExp.getSecondOperand(); + return left.equals(right); + } + } + return false; + } +} diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java new file mode 100644 index 00000000..2ac6d210 --- /dev/null +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java @@ -0,0 +1,77 @@ +package liquidjava.rj_language.opt; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import liquidjava.rj_language.ast.BinaryExpression; +import liquidjava.rj_language.ast.Expression; +import liquidjava.rj_language.ast.Var; + +public class VariableResolver { + + /** + * Extracts variables with constant values from an expression Returns a map from variable names to their values + */ + public static Map resolve(Expression exp) { + Map map = new HashMap<>(); + resolveRecursive(exp, map); + return resolveTransitive(map); + } + + /** + * Recursively extracts variable equalities from an expression (e.g. ... && x == 1 && y == 2 => map: x -> 1, y -> 2) + * Modifies the given map in place + */ + private static void resolveRecursive(Expression exp, Map map) { + if (!(exp instanceof BinaryExpression)) + return; + + BinaryExpression be = (BinaryExpression) exp; + String op = be.getOperator(); + if ("&&".equals(op)) { + resolveRecursive(be.getFirstOperand(), map); + resolveRecursive(be.getSecondOperand(), map); + } else if ("==".equals(op)) { + Expression left = be.getFirstOperand(); + Expression right = be.getSecondOperand(); + if (left instanceof Var && (right.isLiteral() || right instanceof Var)) { + map.put(((Var) left).getName(), right.clone()); + } else if (right instanceof Var && left.isLiteral()) { + map.put(((Var) right).getName(), left.clone()); + } + } + } + + /** + * Handles transitive variable equalities in the map (e.g. map: x -> y, y -> 1 => map: x -> 1, y -> 1) + */ + private static Map resolveTransitive(Map map) { + Map result = new HashMap<>(); + for (Map.Entry entry : map.entrySet()) { + result.put(entry.getKey(), lookup(entry.getValue(), map, new HashSet<>())); + } + return result; + } + + /** + * Returns the value of a variable by looking up in the map recursively Uses the seen set to avoid circular + * references (e.g. x -> y, y -> x) which would cause infinite recursion + */ + private static Expression lookup(Expression exp, Map map, Set seen) { + if (!(exp instanceof Var)) + return exp; + + String name = exp.toString(); + if (seen.contains(name)) + return exp; // circular reference + + Expression value = map.get(name); + if (value == null) + return exp; + + seen.add(name); + return lookup(value, map, seen); + } +} \ No newline at end of file diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/BinaryDerivationNode.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/BinaryDerivationNode.java new file mode 100644 index 00000000..8c30a802 --- /dev/null +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/BinaryDerivationNode.java @@ -0,0 +1,26 @@ +package liquidjava.rj_language.opt.derivation_node; + +public class BinaryDerivationNode extends DerivationNode { + + private final String op; + private final ValDerivationNode left; + private final ValDerivationNode right; + + public BinaryDerivationNode(ValDerivationNode left, ValDerivationNode right, String op) { + this.left = left; + this.right = right; + this.op = op; + } + + public ValDerivationNode getLeft() { + return left; + } + + public ValDerivationNode getRight() { + return right; + } + + public String getOp() { + return op; + } +} diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/DerivationNode.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/DerivationNode.java new file mode 100644 index 00000000..a6b08e54 --- /dev/null +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/DerivationNode.java @@ -0,0 +1,15 @@ +package liquidjava.rj_language.opt.derivation_node; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; + +public abstract class DerivationNode { + + // disable html escaping to avoid escaping characters like &, >, <, =, etc. + private static final Gson gson = new GsonBuilder().setPrettyPrinting().disableHtmlEscaping().create(); + + @Override + public String toString() { + return gson.toJson(this); + } +} \ No newline at end of file diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/UnaryDerivationNode.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/UnaryDerivationNode.java new file mode 100644 index 00000000..f0693dc3 --- /dev/null +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/UnaryDerivationNode.java @@ -0,0 +1,20 @@ +package liquidjava.rj_language.opt.derivation_node; + +public class UnaryDerivationNode extends DerivationNode { + + private final String op; + private final ValDerivationNode operand; + + public UnaryDerivationNode(ValDerivationNode operand, String op) { + this.operand = operand; + this.op = op; + } + + public ValDerivationNode getOperand() { + return operand; + } + + public String getOp() { + return op; + } +} diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/ValDerivationNode.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/ValDerivationNode.java new file mode 100644 index 00000000..eeb6f21d --- /dev/null +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/ValDerivationNode.java @@ -0,0 +1,54 @@ +package liquidjava.rj_language.opt.derivation_node; + +import java.lang.reflect.Type; + +import com.google.gson.JsonElement; +import com.google.gson.JsonNull; +import com.google.gson.JsonPrimitive; +import com.google.gson.JsonSerializationContext; +import com.google.gson.JsonSerializer; +import com.google.gson.annotations.JsonAdapter; + +import liquidjava.rj_language.ast.Expression; +import liquidjava.rj_language.ast.LiteralBoolean; +import liquidjava.rj_language.ast.LiteralInt; +import liquidjava.rj_language.ast.LiteralReal; +import liquidjava.rj_language.ast.Var; + +public class ValDerivationNode extends DerivationNode { + + @JsonAdapter(ExpressionSerializer.class) + private final Expression value; + private final DerivationNode origin; + + public ValDerivationNode(Expression exp, DerivationNode origin) { + this.value = exp; + this.origin = origin; + } + + public Expression getValue() { + return value; + } + + public DerivationNode getOrigin() { + return origin; + } + + // Custom serializer to handle Expression subclasses properly + private static class ExpressionSerializer implements JsonSerializer { + @Override + public JsonElement serialize(Expression exp, Type typeOfSrc, JsonSerializationContext context) { + if (exp == null) + return JsonNull.INSTANCE; + if (exp instanceof LiteralInt) + return new JsonPrimitive(((LiteralInt) exp).getValue()); + if (exp instanceof LiteralReal) + return new JsonPrimitive(((LiteralReal) exp).getValue()); + if (exp instanceof LiteralBoolean) + return new JsonPrimitive(((LiteralBoolean) exp).isBooleanTrue()); + if (exp instanceof Var) + return new JsonPrimitive(((Var) exp).getName()); + return new JsonPrimitive(exp.toString()); + } + } +} diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/VarDerivationNode.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/VarDerivationNode.java new file mode 100644 index 00000000..1c044f52 --- /dev/null +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/VarDerivationNode.java @@ -0,0 +1,14 @@ +package liquidjava.rj_language.opt.derivation_node; + +public class VarDerivationNode extends DerivationNode { + + private final String var; + + public VarDerivationNode(String var) { + this.var = var; + } + + public String getVar() { + return var; + } +} diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java new file mode 100644 index 00000000..d1629682 --- /dev/null +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java @@ -0,0 +1,338 @@ +package liquidjava.rj_language.opt; + +import static org.junit.jupiter.api.Assertions.*; + +import liquidjava.rj_language.ast.BinaryExpression; +import liquidjava.rj_language.ast.Expression; +import liquidjava.rj_language.ast.LiteralBoolean; +import liquidjava.rj_language.ast.LiteralInt; +import liquidjava.rj_language.ast.UnaryExpression; +import liquidjava.rj_language.ast.Var; +import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode; +import liquidjava.rj_language.opt.derivation_node.DerivationNode; +import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode; +import liquidjava.rj_language.opt.derivation_node.ValDerivationNode; +import liquidjava.rj_language.opt.derivation_node.VarDerivationNode; +import org.junit.jupiter.api.Test; + +/** + * Test suite for expression simplification using constant propagation and folding + */ +class ExpressionSimplifierTest { + + @Test + void testNegation() { + // Given: -a && a == 7 + // Expected: -7 + + Expression varA = new Var("a"); + Expression negA = new UnaryExpression("-", varA); + Expression seven = new LiteralInt(7); + Expression aEquals7 = new BinaryExpression(varA, "==", seven); + Expression fullExpression = new BinaryExpression(negA, "&&", aEquals7); + + // When + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + + // Then + assertNotNull(result, "Result should not be null"); + assertEquals("-7", result.getValue().toString(), "Expected result to be -7"); + + // 7 from variable a + ValDerivationNode val7 = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("a")); + + // -7 + UnaryDerivationNode negation = new UnaryDerivationNode(val7, "-"); + ValDerivationNode expected = new ValDerivationNode(new LiteralInt(-7), negation); + + // Compare the derivation trees + assertDerivationEquals(expected, result, ""); + } + + @Test + void testSimpleAddition() { + // Given: a + b && a == 3 && b == 5 + // Expected: 8 (3 + 5) + + Expression varA = new Var("a"); + Expression varB = new Var("b"); + Expression addition = new BinaryExpression(varA, "+", varB); + + Expression three = new LiteralInt(3); + Expression aEquals3 = new BinaryExpression(varA, "==", three); + + Expression five = new LiteralInt(5); + Expression bEquals5 = new BinaryExpression(varB, "==", five); + + Expression conditions = new BinaryExpression(aEquals3, "&&", bEquals5); + Expression fullExpression = new BinaryExpression(addition, "&&", conditions); + + // When + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + + // Then + assertNotNull(result, "Result should not be null"); + assertEquals("8", result.getValue().toString(), "Expected result to be 8"); + + // 3 from variable a + ValDerivationNode val3 = new ValDerivationNode(new LiteralInt(3), new VarDerivationNode("a")); + + // 5 from variable b + ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("b")); + + // 3 + 5 + BinaryDerivationNode add3Plus5 = new BinaryDerivationNode(val3, val5, "+"); + ValDerivationNode expected = new ValDerivationNode(new LiteralInt(8), add3Plus5); + + // Compare the derivation trees + assertDerivationEquals(expected, result, ""); + } + + @Test + void testSimpleComparison() { + // Given: (y || true) && !true && y == false + // Expected: false (true && false) + + Expression varY = new Var("y"); + Expression trueExp = new LiteralBoolean(true); + Expression yOrTrue = new BinaryExpression(varY, "||", trueExp); + + Expression notTrue = new UnaryExpression("!", trueExp); + + Expression falseExp = new LiteralBoolean(false); + Expression yEqualsFalse = new BinaryExpression(varY, "==", falseExp); + + Expression firstAnd = new BinaryExpression(yOrTrue, "&&", notTrue); + Expression fullExpression = new BinaryExpression(firstAnd, "&&", yEqualsFalse); + + // When + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + + // Then + assertNotNull(result, "Result should not be null"); + assertTrue(result.getValue() instanceof LiteralBoolean, "Result should be a boolean"); + assertFalse(((LiteralBoolean) result.getValue()).isBooleanTrue(), "Expected result to befalse"); + + // (y || true) && y == false => false || true = true + ValDerivationNode valFalseForY = new ValDerivationNode(new LiteralBoolean(false), new VarDerivationNode("y")); + ValDerivationNode valTrue1 = new ValDerivationNode(new LiteralBoolean(true), null); + BinaryDerivationNode orFalseTrue = new BinaryDerivationNode(valFalseForY, valTrue1, "||"); + ValDerivationNode trueFromOr = new ValDerivationNode(new LiteralBoolean(true), orFalseTrue); + + // !true = false + ValDerivationNode valTrue2 = new ValDerivationNode(new LiteralBoolean(true), null); + UnaryDerivationNode notOp = new UnaryDerivationNode(valTrue2, "!"); + ValDerivationNode falseFromNot = new ValDerivationNode(new LiteralBoolean(false), notOp); + + // true && false = false + BinaryDerivationNode andTrueFalse = new BinaryDerivationNode(trueFromOr, falseFromNot, "&&"); + ValDerivationNode falseFromFirstAnd = new ValDerivationNode(new LiteralBoolean(false), andTrueFalse); + + // y == false + ValDerivationNode valFalseForY2 = new ValDerivationNode(new LiteralBoolean(false), new VarDerivationNode("y")); + ValDerivationNode valFalse2 = new ValDerivationNode(new LiteralBoolean(false), null); + BinaryDerivationNode compareFalseFalse = new BinaryDerivationNode(valFalseForY2, valFalse2, "=="); + ValDerivationNode trueFromCompare = new ValDerivationNode(new LiteralBoolean(true), compareFalseFalse); + + // false && true = false + BinaryDerivationNode finalAnd = new BinaryDerivationNode(falseFromFirstAnd, trueFromCompare, "&&"); + ValDerivationNode expected = new ValDerivationNode(new LiteralBoolean(false), finalAnd); + + // Compare the derivation trees + assertDerivationEquals(expected, result, ""); + } + + @Test + void testArithmeticWithConstants() { + // Given: (a / b + (-5)) + x && a == 6 && b == 2 + // Expected: -2 + x (6 / 2 = 3, 3 + (-5) = -2) + + Expression varA = new Var("a"); + Expression varB = new Var("b"); + Expression division = new BinaryExpression(varA, "/", varB); + + Expression five = new LiteralInt(5); + Expression negFive = new UnaryExpression("-", five); + + Expression firstSum = new BinaryExpression(division, "+", negFive); + Expression varX = new Var("x"); + Expression fullArithmetic = new BinaryExpression(firstSum, "+", varX); + + Expression six = new LiteralInt(6); + Expression aEquals6 = new BinaryExpression(varA, "==", six); + + Expression two = new LiteralInt(2); + Expression bEquals2 = new BinaryExpression(varB, "==", two); + + Expression allConditions = new BinaryExpression(aEquals6, "&&", bEquals2); + Expression fullExpression = new BinaryExpression(fullArithmetic, "&&", allConditions); + + // When + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + + // Then + assertNotNull(result, "Result should not be null"); + assertNotNull(result.getValue(), "Result value should not be null"); + + String resultStr = result.getValue().toString(); + assertEquals("-2 + x", resultStr, "Expected result to be -2 + x"); + + // 6 from variable a + ValDerivationNode val6 = new ValDerivationNode(new LiteralInt(6), new VarDerivationNode("a")); + + // 2 from variable b + ValDerivationNode val2 = new ValDerivationNode(new LiteralInt(2), new VarDerivationNode("b")); + + // 6 / 2 = 3 + BinaryDerivationNode div6By2 = new BinaryDerivationNode(val6, val2, "/"); + ValDerivationNode val3 = new ValDerivationNode(new LiteralInt(3), div6By2); + + // -5 from unary negation of 5 + ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), null); + UnaryDerivationNode unaryNeg5 = new UnaryDerivationNode(val5, "-"); + ValDerivationNode valNeg5 = new ValDerivationNode(new LiteralInt(-5), unaryNeg5); + + // 3 + (-5) = -2 + BinaryDerivationNode add3AndNeg5 = new BinaryDerivationNode(val3, valNeg5, "+"); + ValDerivationNode valNeg2 = new ValDerivationNode(new LiteralInt(-2), add3AndNeg5); + + // x (variable with null origin) + ValDerivationNode valX = new ValDerivationNode(new Var("x"), null); + + // -2 + x + BinaryDerivationNode addNeg2AndX = new BinaryDerivationNode(valNeg2, valX, "+"); + Expression expectedResultExpr = new BinaryExpression(new LiteralInt(-2), "+", new Var("x")); + ValDerivationNode expected = new ValDerivationNode(expectedResultExpr, addNeg2AndX); + + // Compare the derivation trees + assertDerivationEquals(expected, result, ""); + } + + @Test + void testComplexArithmeticWithMultipleOperations() { + // Given: (a * 2 + b - 3) == c && a == 5 && b == 7 && c == 14 + // Expected: (5 * 2 + 7 - 3) == 14 => 14 == 14 => true + + Expression varA = new Var("a"); + Expression varB = new Var("b"); + Expression varC = new Var("c"); + + Expression two = new LiteralInt(2); + Expression aTimes2 = new BinaryExpression(varA, "*", two); + + Expression sum = new BinaryExpression(aTimes2, "+", varB); + + Expression three = new LiteralInt(3); + Expression arithmetic = new BinaryExpression(sum, "-", three); + + Expression comparison = new BinaryExpression(arithmetic, "==", varC); + + Expression five = new LiteralInt(5); + Expression aEquals5 = new BinaryExpression(varA, "==", five); + + Expression seven = new LiteralInt(7); + Expression bEquals7 = new BinaryExpression(varB, "==", seven); + + Expression fourteen = new LiteralInt(14); + Expression cEquals14 = new BinaryExpression(varC, "==", fourteen); + + Expression conj1 = new BinaryExpression(aEquals5, "&&", bEquals7); + Expression allConditions = new BinaryExpression(conj1, "&&", cEquals14); + Expression fullExpression = new BinaryExpression(comparison, "&&", allConditions); + + // When + ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + + // Then + assertNotNull(result, "Result should not be null"); + assertNotNull(result.getValue(), "Result value should not be null"); + assertTrue(result.getValue() instanceof LiteralBoolean, "Result should be a boolean literal"); + assertTrue(((LiteralBoolean) result.getValue()).isBooleanTrue(), "Expected result to be true"); + + // 5 * 2 + 7 - 3 + ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a")); + ValDerivationNode val2 = new ValDerivationNode(new LiteralInt(2), null); + BinaryDerivationNode mult5Times2 = new BinaryDerivationNode(val5, val2, "*"); + ValDerivationNode val10 = new ValDerivationNode(new LiteralInt(10), mult5Times2); + + ValDerivationNode val7 = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("b")); + BinaryDerivationNode add10Plus7 = new BinaryDerivationNode(val10, val7, "+"); + ValDerivationNode val17 = new ValDerivationNode(new LiteralInt(17), add10Plus7); + + ValDerivationNode val3 = new ValDerivationNode(new LiteralInt(3), null); + BinaryDerivationNode sub17Minus3 = new BinaryDerivationNode(val17, val3, "-"); + ValDerivationNode val14Left = new ValDerivationNode(new LiteralInt(14), sub17Minus3); + + // 14 from variable c + ValDerivationNode val14Right = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c")); + + // 14 == 14 + BinaryDerivationNode compare14 = new BinaryDerivationNode(val14Left, val14Right, "=="); + ValDerivationNode trueFromComparison = new ValDerivationNode(new LiteralBoolean(true), compare14); + + // a == 5 => true + ValDerivationNode val5ForCompA = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a")); + ValDerivationNode val5Literal = new ValDerivationNode(new LiteralInt(5), null); + BinaryDerivationNode compareA5 = new BinaryDerivationNode(val5ForCompA, val5Literal, "=="); + ValDerivationNode trueFromA = new ValDerivationNode(new LiteralBoolean(true), compareA5); + + // b == 7 => true + ValDerivationNode val7ForCompB = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("b")); + ValDerivationNode val7Literal = new ValDerivationNode(new LiteralInt(7), null); + BinaryDerivationNode compareB7 = new BinaryDerivationNode(val7ForCompB, val7Literal, "=="); + ValDerivationNode trueFromB = new ValDerivationNode(new LiteralBoolean(true), compareB7); + + // (a == 5) && (b == 7) => true + BinaryDerivationNode andAB = new BinaryDerivationNode(trueFromA, trueFromB, "&&"); + ValDerivationNode trueFromAB = new ValDerivationNode(new LiteralBoolean(true), andAB); + + // c == 14 => true + ValDerivationNode val14ForCompC = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c")); + ValDerivationNode val14Literal = new ValDerivationNode(new LiteralInt(14), null); + BinaryDerivationNode compareC14 = new BinaryDerivationNode(val14ForCompC, val14Literal, "=="); + ValDerivationNode trueFromC = new ValDerivationNode(new LiteralBoolean(true), compareC14); + + // ((a == 5) && (b == 7)) && (c == 14) => true + BinaryDerivationNode andABC = new BinaryDerivationNode(trueFromAB, trueFromC, "&&"); + ValDerivationNode trueFromAllConditions = new ValDerivationNode(new LiteralBoolean(true), andABC); + + // 14 == 14 => true + BinaryDerivationNode finalAnd = new BinaryDerivationNode(trueFromComparison, trueFromAllConditions, "&&"); + ValDerivationNode expected = new ValDerivationNode(new LiteralBoolean(true), finalAnd); + + // Compare the derivation trees + assertDerivationEquals(expected, result, ""); + } + + /** + * Helper method to compare two derivation nodes recursively + */ + private void assertDerivationEquals(DerivationNode expected, DerivationNode actual, String message) { + if (expected == null && actual == null) + return; + + assertEquals(expected.getClass(), actual.getClass(), message + ": node types should match"); + if (expected instanceof ValDerivationNode) { + ValDerivationNode expectedVal = (ValDerivationNode) expected; + ValDerivationNode actualVal = (ValDerivationNode) actual; + assertEquals(expectedVal.getValue().toString(), actualVal.getValue().toString(), + message + ": values should match"); + assertDerivationEquals(expectedVal.getOrigin(), actualVal.getOrigin(), message + " > origin"); + } else if (expected instanceof BinaryDerivationNode) { + BinaryDerivationNode expectedBin = (BinaryDerivationNode) expected; + BinaryDerivationNode actualBin = (BinaryDerivationNode) actual; + assertEquals(expectedBin.getOp(), actualBin.getOp(), message + ": operators should match"); + assertDerivationEquals(expectedBin.getLeft(), actualBin.getLeft(), message + " > left"); + assertDerivationEquals(expectedBin.getRight(), actualBin.getRight(), message + " > right"); + } else if (expected instanceof VarDerivationNode) { + VarDerivationNode expectedVar = (VarDerivationNode) expected; + VarDerivationNode actualVar = (VarDerivationNode) actual; + assertEquals(expectedVar.getVar(), actualVar.getVar(), message + ": variables should match"); + } else if (expected instanceof UnaryDerivationNode) { + UnaryDerivationNode expectedUnary = (UnaryDerivationNode) expected; + UnaryDerivationNode actualUnary = (UnaryDerivationNode) actual; + assertEquals(expectedUnary.getOp(), actualUnary.getOp(), message + ": operators should match"); + assertDerivationEquals(expectedUnary.getOperand(), actualUnary.getOperand(), message + " > operand"); + } + } +}