diff --git a/liquidjava-verifier/pom.xml b/liquidjava-verifier/pom.xml
index 21464fa3..46a4972d 100644
--- a/liquidjava-verifier/pom.xml
+++ b/liquidjava-verifier/pom.xml
@@ -188,6 +188,11 @@
antlr4-runtime
4.7.1
+
+ com.google.code.gson
+ gson
+ 2.10.1
+
diff --git a/liquidjava-verifier/src/main/java/liquidjava/errors/ErrorHandler.java b/liquidjava-verifier/src/main/java/liquidjava/errors/ErrorHandler.java
index b17f51b4..91ab4987 100644
--- a/liquidjava-verifier/src/main/java/liquidjava/errors/ErrorHandler.java
+++ b/liquidjava-verifier/src/main/java/liquidjava/errors/ErrorHandler.java
@@ -26,7 +26,7 @@ public static void printError(CtElement var, Predicate expectedType, Predica
}
public static void printError(CtElement var, String moreInfo, Predicate expectedType, Predicate cSMT,
- HashMap map, ErrorEmitter errorl) {
+ HashMap map, ErrorEmitter ee) {
String resumeMessage = "Type expected:" + expectedType.toString(); // + "; " +"Refinement found:" +
// cSMT.toString();
@@ -41,16 +41,16 @@ public static void printError(CtElement var, String moreInfo, Predicate expe
// all message
sb.append(sbtitle.toString() + "\n\n");
sb.append("Type expected:" + expectedType.toString() + "\n");
- sb.append("Refinement found:" + cSMT.toString() + "\n");
+ sb.append("Refinement found: " + cSMT.simplify().getValue() + "\n");
sb.append(printMap(map));
sb.append("Location: " + var.getPosition() + "\n");
sb.append("______________________________________________________\n");
- errorl.addError(resumeMessage, sb.toString(), var.getPosition(), 1, map);
+ ee.addError(resumeMessage, sb.toString(), var.getPosition(), 1, map);
}
public static void printStateMismatch(CtElement element, String method, VCImplication constraintForErrorMsg,
- String states, HashMap map, ErrorEmitter errorl) {
+ String states, HashMap map, ErrorEmitter ee) {
String resumeMessage = "Failed to check state transitions. " + "Expected possible states:" + states; // + ";
// Found
@@ -75,11 +75,11 @@ public static void printStateMismatch(CtElement element, String method, VCImplic
sb.append("Location: " + element.getPosition() + "\n");
sb.append("______________________________________________________\n");
- errorl.addError(resumeMessage, sb.toString(), element.getPosition(), 1, map);
+ ee.addError(resumeMessage, sb.toString(), element.getPosition(), 1, map);
}
public static void printErrorUnknownVariable(CtElement var, String et, String correctRefinement,
- HashMap map, ErrorEmitter errorl) {
+ HashMap map, ErrorEmitter ee) {
String resumeMessage = "Encountered unknown variable";
@@ -94,11 +94,11 @@ public static void printErrorUnknownVariable(CtElement var, String et, Strin
sb.append("Location: " + var.getPosition() + "\n");
sb.append("______________________________________________________\n");
- errorl.addError(resumeMessage, sb.toString(), var.getPosition(), 2, map);
+ ee.addError(resumeMessage, sb.toString(), var.getPosition(), 2, map);
}
public static void printNotFound(CtElement var, Predicate constraint, Predicate constraint2, String msg,
- HashMap map, ErrorEmitter errorl) {
+ HashMap map, ErrorEmitter ee) {
StringBuilder sb = new StringBuilder();
sb.append("______________________________________________________\n");
@@ -111,11 +111,11 @@ public static void printNotFound(CtElement var, Predicate constraint, Predic
sb.append("Location: " + var.getPosition() + "\n");
sb.append("______________________________________________________\n");
- errorl.addError(msg, sb.toString(), var.getPosition(), 2, map);
+ ee.addError(msg, sb.toString(), var.getPosition(), 2, map);
}
public static void printErrorArgs(CtElement var, Predicate expectedType, String msg,
- HashMap map, ErrorEmitter errorl) {
+ HashMap map, ErrorEmitter ee) {
StringBuilder sb = new StringBuilder();
sb.append("______________________________________________________\n");
String title = "Error in ghost invocation: " + msg + "\n";
@@ -125,11 +125,11 @@ public static void printErrorArgs(CtElement var, Predicate expectedType, Str
sb.append("Location: " + var.getPosition() + "\n");
sb.append("______________________________________________________\n");
- errorl.addError(title, sb.toString(), var.getPosition(), 2, map);
+ ee.addError(title, sb.toString(), var.getPosition(), 2, map);
}
public static void printErrorTypeMismatch(CtElement element, Predicate expectedType, String message,
- HashMap map, ErrorEmitter errorl) {
+ HashMap map, ErrorEmitter ee) {
StringBuilder sb = new StringBuilder();
sb.append("______________________________________________________\n");
sb.append(message + "\n\n");
@@ -138,11 +138,11 @@ public static void printErrorTypeMismatch(CtElement element, Predicate expectedT
sb.append("Location: " + element.getPosition() + "\n");
sb.append("______________________________________________________\n");
- errorl.addError(message, sb.toString(), element.getPosition(), 2, map);
+ ee.addError(message, sb.toString(), element.getPosition(), 2, map);
}
public static void printSameStateSetError(CtElement element, Predicate p, String name,
- HashMap map, ErrorEmitter errorl) {
+ HashMap map, ErrorEmitter ee) {
String resume = "Error found multiple disjoint states from a State Set in a refinement";
StringBuilder sb = new StringBuilder();
@@ -157,10 +157,10 @@ public static void printSameStateSetError(CtElement element, Predicate p, String
sb.append("Location: " + element.getPosition() + "\n");
sb.append("______________________________________________________\n");
- errorl.addError(resume, sb.toString(), element.getPosition(), 1, map);
+ ee.addError(resume, sb.toString(), element.getPosition(), 1, map);
}
- public static void printErrorConstructorFromState(CtElement element, CtLiteral from, ErrorEmitter errorl) {
+ public static void printErrorConstructorFromState(CtElement element, CtLiteral from, ErrorEmitter ee) {
StringBuilder sb = new StringBuilder();
sb.append("______________________________________________________\n");
String s = " Error found constructor with FROM state (Constructor's should only have a TO state)\n\n";
@@ -170,10 +170,10 @@ public static void printErrorConstructorFromState(CtElement element, CtLiteral e, int set, CtElement element) {
CtLiteral s = (CtLiteral) ce;
String f = s.getValue();
if (Character.isUpperCase(f.charAt(0))) {
- ErrorHandler.printCostumeError(s, "State name must start with lowercase in '" + f + "'",
+ ErrorHandler.printCustomError(s, "State name must start with lowercase in '" + f + "'",
errorEmitter);
}
}
@@ -161,11 +161,11 @@ private void createStateGhost(String string, CtAnnotation extends Annotation>
try {
gd = RefinementsParser.getGhostDeclaration(string);
} catch (ParsingException e) {
- ErrorHandler.printCostumeError(ann, "Could not parse the Ghost Function" + e.getMessage(), errorEmitter);
+ ErrorHandler.printCustomError(ann, "Could not parse the Ghost Function" + e.getMessage(), errorEmitter);
return;
}
if (gd.getParam_types().size() > 0) {
- ErrorHandler.printCostumeError(ann, "Ghost States have the class as parameter "
+ ErrorHandler.printCustomError(ann, "Ghost States have the class as parameter "
+ "by default, no other parameters are allowed in '" + string + "'", errorEmitter);
return;
}
@@ -224,8 +224,7 @@ protected void getGhostFunction(String value, CtElement element) {
context.addGhostFunction(gh);
}
} catch (ParsingException e) {
- ErrorHandler.printCostumeError(element, "Could not parse the Ghost Function" + e.getMessage(),
- errorEmitter);
+ ErrorHandler.printCustomError(element, "Could not parse the Ghost Function" + e.getMessage(), errorEmitter);
// e.printStackTrace();
return;
}
@@ -252,7 +251,7 @@ protected void handleAlias(String value, CtElement element) {
}
}
} catch (ParsingException e) {
- ErrorHandler.printCostumeError(element, e.getMessage(), errorEmitter);
+ ErrorHandler.printCustomError(element, e.getMessage(), errorEmitter);
return;
// e.printStackTrace();
}
diff --git a/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java b/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java
index 82cbf15d..458ae4da 100644
--- a/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java
+++ b/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/VCChecker.java
@@ -364,7 +364,7 @@ private void printError(Exception e, Predicate premisesBeforeChange, Predicate e
} else if (e instanceof NotFoundError) {
ErrorHandler.printNotFound(element, cSMTMessageReady, etMessageReady, e.getMessage(), map, errorEmitter);
} else {
- ErrorHandler.printCostumeError(element, e.getMessage(), errorEmitter);
+ ErrorHandler.printCustomError(element, e.getMessage(), errorEmitter);
// System.err.println("Unknown error:"+e.getMessage());
// e.printStackTrace();
// System.exit(7);
diff --git a/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/object_checkers/AuxStateHandler.java b/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/object_checkers/AuxStateHandler.java
index 192ff2de..72ada2df 100644
--- a/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/object_checkers/AuxStateHandler.java
+++ b/liquidjava-verifier/src/main/java/liquidjava/processor/refinement_checker/object_checkers/AuxStateHandler.java
@@ -369,7 +369,7 @@ public static void updateGhostField(CtFieldWrite> fw, TypeChecker tc) {
stateChange.setTo(toPredicate);
} catch (ParsingException e) {
ErrorHandler
- .printCostumeError(fw,
+ .printCustomError(fw,
"ParsingException while constructing assignment update for `" + fw + "` in class `"
+ fw.getVariable().getDeclaringType() + "` : " + e.getMessage(),
tc.getErrorEmitter());
diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java
index eca80729..7fb7fa94 100644
--- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java
+++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/Predicate.java
@@ -21,6 +21,9 @@
import liquidjava.rj_language.ast.LiteralReal;
import liquidjava.rj_language.ast.UnaryExpression;
import liquidjava.rj_language.ast.Var;
+import liquidjava.rj_language.opt.derivation_node.DerivationNode;
+import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
+import liquidjava.rj_language.opt.ExpressionSimplifier;
import liquidjava.rj_language.parsing.ParsingException;
import liquidjava.rj_language.parsing.RefinementsParser;
import liquidjava.utils.Utils;
@@ -212,6 +215,10 @@ public Expression getExpression() {
return exp;
}
+ public ValDerivationNode simplify() {
+ return ExpressionSimplifier.simplify(exp.clone());
+ }
+
public static Predicate createConjunction(Predicate c1, Predicate c2) {
return new Predicate(new BinaryExpression(c1.getExpression(), Utils.AND, c2.getExpression()));
}
diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/Expression.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/Expression.java
index da75ec71..008b651c 100644
--- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/Expression.java
+++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/Expression.java
@@ -1,9 +1,11 @@
package liquidjava.rj_language.ast;
-import com.microsoft.z3.Expr;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
+
+import com.microsoft.z3.Expr;
+
import liquidjava.processor.context.Context;
import liquidjava.processor.facade.AliasDTO;
import liquidjava.rj_language.ast.typing.TypeInfer;
@@ -47,6 +49,10 @@ public void setChild(int index, Expression element) {
children.set(index, element);
}
+ public boolean isLiteral() {
+ return this instanceof LiteralInt || this instanceof LiteralReal || this instanceof LiteralBoolean;
+ }
+
/**
* Substitutes the expression first given expression by the second
*
diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralInt.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralInt.java
index 6a4e97f1..4e98fa57 100644
--- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralInt.java
+++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralInt.java
@@ -25,6 +25,10 @@ public String toString() {
return Integer.toString(value);
}
+ public int getValue() {
+ return value;
+ }
+
@Override
public void getVariableNames(List toAdd) {
// end leaf
diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralReal.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralReal.java
index 2ddc8430..e6c4a810 100644
--- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralReal.java
+++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/ast/LiteralReal.java
@@ -25,6 +25,10 @@ public String toString() {
return Double.toString(value);
}
+ public double getValue() {
+ return value;
+ }
+
@Override
public void getVariableNames(List toAdd) {
// end leaf
diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantFolding.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantFolding.java
new file mode 100644
index 00000000..0d5fe242
--- /dev/null
+++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantFolding.java
@@ -0,0 +1,197 @@
+package liquidjava.rj_language.opt;
+
+import liquidjava.rj_language.ast.BinaryExpression;
+import liquidjava.rj_language.ast.Expression;
+import liquidjava.rj_language.ast.GroupExpression;
+import liquidjava.rj_language.ast.LiteralBoolean;
+import liquidjava.rj_language.ast.LiteralInt;
+import liquidjava.rj_language.ast.LiteralReal;
+import liquidjava.rj_language.ast.UnaryExpression;
+import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
+import liquidjava.rj_language.opt.derivation_node.DerivationNode;
+import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
+import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
+
+public class ConstantFolding {
+
+ /**
+ * Performs constant folding on a derivation node by evaluating nodes with constant values. Returns a new derivation
+ * node representing the folding steps taken
+ */
+ public static ValDerivationNode fold(ValDerivationNode node) {
+ Expression exp = node.getValue();
+ if (exp instanceof BinaryExpression)
+ return foldBinary(node);
+
+ if (exp instanceof UnaryExpression)
+ return foldUnary(node);
+
+ if (exp instanceof GroupExpression) {
+ GroupExpression group = (GroupExpression) exp;
+ if (group.getChildren().size() == 1) {
+ return fold(new ValDerivationNode(group.getChildren().get(0), node.getOrigin()));
+ }
+ }
+ return node;
+ }
+
+ /**
+ * Folds a binary expression node if both children are constant values (e.g. 1 + 2 => 3)
+ */
+ private static ValDerivationNode foldBinary(ValDerivationNode node) {
+ BinaryExpression binExp = (BinaryExpression) node.getValue();
+ DerivationNode parent = node.getOrigin();
+
+ // fold child nodes
+ ValDerivationNode leftNode;
+ ValDerivationNode rightNode;
+ if (parent instanceof BinaryDerivationNode) {
+ // has origin (from constant propagation)
+ BinaryDerivationNode binaryOrigin = (BinaryDerivationNode) parent;
+ leftNode = fold(binaryOrigin.getLeft());
+ rightNode = fold(binaryOrigin.getRight());
+ } else {
+ // no origin
+ leftNode = fold(new ValDerivationNode(binExp.getFirstOperand(), null));
+ rightNode = fold(new ValDerivationNode(binExp.getSecondOperand(), null));
+ }
+
+ Expression left = leftNode.getValue();
+ Expression right = rightNode.getValue();
+ String op = binExp.getOperator();
+ binExp.setChild(0, left);
+ binExp.setChild(1, right);
+
+ // int and int
+ if (left instanceof LiteralInt && right instanceof LiteralInt) {
+ int l = ((LiteralInt) left).getValue();
+ int r = ((LiteralInt) right).getValue();
+ Expression res = switch (op) {
+ case "+" -> new LiteralInt(l + r);
+ case "-" -> new LiteralInt(l - r);
+ case "*" -> new LiteralInt(l * r);
+ case "/" -> r != 0 ? new LiteralInt(l / r) : null;
+ case "%" -> r != 0 ? new LiteralInt(l % r) : null;
+ case "<" -> new LiteralBoolean(l < r);
+ case "<=" -> new LiteralBoolean(l <= r);
+ case ">" -> new LiteralBoolean(l > r);
+ case ">=" -> new LiteralBoolean(l >= r);
+ case "==" -> new LiteralBoolean(l == r);
+ case "!=" -> new LiteralBoolean(l != r);
+ default -> null;
+ };
+ if (res != null)
+ return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op));
+ }
+ // real and real
+ else if (left instanceof LiteralReal && right instanceof LiteralReal) {
+ double l = ((LiteralReal) left).getValue();
+ double r = ((LiteralReal) right).getValue();
+ Expression res = switch (op) {
+ case "+" -> new LiteralReal(l + r);
+ case "-" -> new LiteralReal(l - r);
+ case "*" -> new LiteralReal(l * r);
+ case "/" -> r != 0.0 ? new LiteralReal(l / r) : null;
+ case "%" -> r != 0.0 ? new LiteralReal(l % r) : null;
+ case "<" -> new LiteralBoolean(l < r);
+ case "<=" -> new LiteralBoolean(l <= r);
+ case ">" -> new LiteralBoolean(l > r);
+ case ">=" -> new LiteralBoolean(l >= r);
+ case "==" -> new LiteralBoolean(l == r);
+ case "!=" -> new LiteralBoolean(l != r);
+ default -> null;
+ };
+ if (res != null)
+ return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op));
+ }
+
+ // mixed int and real
+ else if ((left instanceof LiteralInt && right instanceof LiteralReal)
+ || (left instanceof LiteralReal && right instanceof LiteralInt)) {
+ double l = left instanceof LiteralInt ? ((LiteralInt) left).getValue() : ((LiteralReal) left).getValue();
+ double r = right instanceof LiteralInt ? ((LiteralInt) right).getValue() : ((LiteralReal) right).getValue();
+ Expression res = switch (op) {
+ case "+" -> new LiteralReal(l + r);
+ case "-" -> new LiteralReal(l - r);
+ case "*" -> new LiteralReal(l * r);
+ case "/" -> r != 0.0 ? new LiteralReal(l / r) : null;
+ case "%" -> r != 0.0 ? new LiteralReal(l % r) : null;
+ case "<" -> new LiteralBoolean(l < r);
+ case "<=" -> new LiteralBoolean(l <= r);
+ case ">" -> new LiteralBoolean(l > r);
+ case ">=" -> new LiteralBoolean(l >= r);
+ case "==" -> new LiteralBoolean(l == r);
+ case "!=" -> new LiteralBoolean(l != r);
+ default -> null;
+ };
+ if (res != null)
+ return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op));
+ }
+ // bool and bool
+ else if (left instanceof LiteralBoolean && right instanceof LiteralBoolean) {
+ boolean l = ((LiteralBoolean) left).isBooleanTrue();
+ boolean r = ((LiteralBoolean) right).isBooleanTrue();
+ Expression res = switch (op) {
+ case "&&" -> new LiteralBoolean(l && r);
+ case "||" -> new LiteralBoolean(l || r);
+ case "-->" -> new LiteralBoolean(!l || r);
+ case "==" -> new LiteralBoolean(l == r);
+ case "!=" -> new LiteralBoolean(l != r);
+ default -> null;
+ };
+ if (res != null)
+ return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op));
+ }
+
+ // no folding
+ DerivationNode origin = (leftNode.getOrigin() != null || rightNode.getOrigin() != null)
+ ? new BinaryDerivationNode(leftNode, rightNode, op) : null;
+ return new ValDerivationNode(binExp, origin);
+ }
+
+ /**
+ * Folds a unary expression node if the child (operand) is a constant value (e.g. !true => false)
+ */
+ private static ValDerivationNode foldUnary(ValDerivationNode node) {
+ UnaryExpression unaryExp = (UnaryExpression) node.getValue();
+ DerivationNode parent = node.getOrigin();
+
+ // fold child node
+ ValDerivationNode operandNode;
+ if (parent instanceof UnaryDerivationNode) {
+ // has origin (from constant propagation)
+ UnaryDerivationNode unaryOrigin = (UnaryDerivationNode) parent;
+ operandNode = fold(unaryOrigin.getOperand());
+ } else {
+ // no origin
+ operandNode = fold(new ValDerivationNode(unaryExp.getChildren().get(0), null));
+ }
+ Expression operand = operandNode.getValue();
+ String operator = unaryExp.getOp();
+ unaryExp.setChild(0, operand);
+
+ // unary not
+ if ("!".equals(operator) && operand instanceof LiteralBoolean) {
+ // !true => false, !false => true
+ boolean value = ((LiteralBoolean) operand).isBooleanTrue();
+ Expression res = new LiteralBoolean(!value);
+ return new ValDerivationNode(res, new UnaryDerivationNode(operandNode, operator));
+ }
+ // unary minus
+ if ("-".equals(operator)) {
+ // -(x) => -x
+ if (operand instanceof LiteralInt) {
+ Expression res = new LiteralInt(-((LiteralInt) operand).getValue());
+ return new ValDerivationNode(res, new UnaryDerivationNode(operandNode, operator));
+ }
+ if (operand instanceof LiteralReal) {
+ Expression res = new LiteralReal(-((LiteralReal) operand).getValue());
+ return new ValDerivationNode(res, new UnaryDerivationNode(operandNode, operator));
+ }
+ }
+
+ // no folding
+ DerivationNode origin = operandNode.getOrigin() != null ? new UnaryDerivationNode(operandNode, operator) : null;
+ return new ValDerivationNode(unaryExp, origin);
+ }
+}
\ No newline at end of file
diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantPropagation.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantPropagation.java
new file mode 100644
index 00000000..a72a9b33
--- /dev/null
+++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ConstantPropagation.java
@@ -0,0 +1,82 @@
+package liquidjava.rj_language.opt;
+
+import liquidjava.rj_language.ast.BinaryExpression;
+import liquidjava.rj_language.ast.Expression;
+import liquidjava.rj_language.ast.UnaryExpression;
+import liquidjava.rj_language.ast.Var;
+import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
+import liquidjava.rj_language.opt.derivation_node.DerivationNode;
+import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
+import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
+import liquidjava.rj_language.opt.derivation_node.VarDerivationNode;
+
+import java.util.Map;
+
+public class ConstantPropagation {
+
+ /**
+ * Performs constant propagation on an expression, by substituting variables with their constant values. Uses the
+ * VariableResolver to extract variable equalities from the expression first. Returns a derivation node representing
+ * the propagation steps taken.
+ */
+ public static ValDerivationNode propagate(Expression exp) {
+ Map substitutions = VariableResolver.resolve(exp);
+ return propagateRecursive(exp, substitutions);
+ }
+
+ /**
+ * Recursively performs constant propagation on an expression (e.g. x + y && x == 1 && y == 2 => 1 + 2)
+ */
+ private static ValDerivationNode propagateRecursive(Expression exp, Map subs) {
+
+ // substitute variable
+ if (exp instanceof Var) {
+ Var var = (Var) exp;
+ String name = var.getName();
+ Expression value = subs.get(name);
+ // substitution
+ if (value != null)
+ return new ValDerivationNode(value.clone(), new VarDerivationNode(name));
+
+ // no substitution
+ return new ValDerivationNode(var, null);
+ }
+
+ // lift unary origin
+ if (exp instanceof UnaryExpression) {
+ UnaryExpression unary = (UnaryExpression) exp;
+ ValDerivationNode operand = propagateRecursive(unary.getChildren().get(0), subs);
+ unary.setChild(0, operand.getValue());
+
+ DerivationNode origin = operand.getOrigin() != null ? new UnaryDerivationNode(operand, unary.getOp())
+ : null;
+ return new ValDerivationNode(unary, origin);
+ }
+
+ // lift binary origin
+ if (exp instanceof BinaryExpression) {
+ BinaryExpression binary = (BinaryExpression) exp;
+ ValDerivationNode left = propagateRecursive(binary.getFirstOperand(), subs);
+ ValDerivationNode right = propagateRecursive(binary.getSecondOperand(), subs);
+ binary.setChild(0, left.getValue());
+ binary.setChild(1, right.getValue());
+
+ DerivationNode origin = (left.getOrigin() != null || right.getOrigin() != null)
+ ? new BinaryDerivationNode(left, right, binary.getOperator()) : null;
+ return new ValDerivationNode(binary, origin);
+ }
+
+ // recursively propagate children
+ if (exp.hasChildren()) {
+ Expression propagated = exp.clone();
+ for (int i = 0; i < exp.getChildren().size(); i++) {
+ ValDerivationNode child = propagateRecursive(exp.getChildren().get(i), subs);
+ propagated.setChild(i, child.getValue());
+ }
+ return new ValDerivationNode(propagated, null);
+ }
+
+ // no propagation
+ return new ValDerivationNode(exp, null);
+ }
+}
\ No newline at end of file
diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java
new file mode 100644
index 00000000..2a022b81
--- /dev/null
+++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionSimplifier.java
@@ -0,0 +1,74 @@
+package liquidjava.rj_language.opt;
+
+import liquidjava.rj_language.ast.BinaryExpression;
+import liquidjava.rj_language.ast.Expression;
+import liquidjava.rj_language.ast.LiteralBoolean;
+import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
+import liquidjava.rj_language.opt.derivation_node.DerivationNode;
+import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
+
+public class ExpressionSimplifier {
+
+ /**
+ * Simplifies an expression by applying constant propagation, constant folding and removing redundant conjuncts
+ * Returns a derivation node representing the tree of simplifications applied
+ */
+ public static ValDerivationNode simplify(Expression exp) {
+ ValDerivationNode prop = ConstantPropagation.propagate(exp);
+ ValDerivationNode fold = ConstantFolding.fold(prop);
+ return simplifyDerivationTree(fold);
+ }
+
+ /**
+ * Recursively simplifies the derivation tree by removing redundant conjuncts
+ */
+ private static ValDerivationNode simplifyDerivationTree(ValDerivationNode node) {
+ Expression value = node.getValue();
+ DerivationNode origin = node.getOrigin();
+
+ // binary expression with &&
+ if (value instanceof BinaryExpression) {
+ BinaryExpression binExp = (BinaryExpression) value;
+ if ("&&".equals(binExp.getOperator()) && origin instanceof BinaryDerivationNode) {
+ BinaryDerivationNode binOrigin = (BinaryDerivationNode) origin;
+
+ // recursively simplify children
+ ValDerivationNode leftSimplified = simplifyDerivationTree(binOrigin.getLeft());
+ ValDerivationNode rightSimplified = simplifyDerivationTree(binOrigin.getRight());
+
+ // check if either side is redundant
+ if (isRedundant(leftSimplified.getValue()))
+ return rightSimplified;
+ if (isRedundant(rightSimplified.getValue()))
+ return leftSimplified;
+
+ // return the conjunction with simplified children
+ Expression newValue = new BinaryExpression(leftSimplified.getValue(), "&&", rightSimplified.getValue());
+ DerivationNode newOrigin = new BinaryDerivationNode(leftSimplified, rightSimplified, "&&");
+ return new ValDerivationNode(newValue, newOrigin);
+ }
+ }
+ // no simplification
+ return node;
+ }
+
+ /**
+ * Checks if an expression is redundant (e.g. true or x == x)
+ */
+ private static boolean isRedundant(Expression exp) {
+ // true
+ if (exp instanceof LiteralBoolean && ((LiteralBoolean) exp).isBooleanTrue()) {
+ return true;
+ }
+ // x == x
+ if (exp instanceof BinaryExpression) {
+ BinaryExpression binExp = (BinaryExpression) exp;
+ if ("==".equals(binExp.getOperator())) {
+ Expression left = binExp.getFirstOperand();
+ Expression right = binExp.getSecondOperand();
+ return left.equals(right);
+ }
+ }
+ return false;
+ }
+}
diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java
new file mode 100644
index 00000000..2ac6d210
--- /dev/null
+++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java
@@ -0,0 +1,77 @@
+package liquidjava.rj_language.opt;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+import liquidjava.rj_language.ast.BinaryExpression;
+import liquidjava.rj_language.ast.Expression;
+import liquidjava.rj_language.ast.Var;
+
+public class VariableResolver {
+
+ /**
+ * Extracts variables with constant values from an expression Returns a map from variable names to their values
+ */
+ public static Map resolve(Expression exp) {
+ Map map = new HashMap<>();
+ resolveRecursive(exp, map);
+ return resolveTransitive(map);
+ }
+
+ /**
+ * Recursively extracts variable equalities from an expression (e.g. ... && x == 1 && y == 2 => map: x -> 1, y -> 2)
+ * Modifies the given map in place
+ */
+ private static void resolveRecursive(Expression exp, Map map) {
+ if (!(exp instanceof BinaryExpression))
+ return;
+
+ BinaryExpression be = (BinaryExpression) exp;
+ String op = be.getOperator();
+ if ("&&".equals(op)) {
+ resolveRecursive(be.getFirstOperand(), map);
+ resolveRecursive(be.getSecondOperand(), map);
+ } else if ("==".equals(op)) {
+ Expression left = be.getFirstOperand();
+ Expression right = be.getSecondOperand();
+ if (left instanceof Var && (right.isLiteral() || right instanceof Var)) {
+ map.put(((Var) left).getName(), right.clone());
+ } else if (right instanceof Var && left.isLiteral()) {
+ map.put(((Var) right).getName(), left.clone());
+ }
+ }
+ }
+
+ /**
+ * Handles transitive variable equalities in the map (e.g. map: x -> y, y -> 1 => map: x -> 1, y -> 1)
+ */
+ private static Map resolveTransitive(Map map) {
+ Map result = new HashMap<>();
+ for (Map.Entry entry : map.entrySet()) {
+ result.put(entry.getKey(), lookup(entry.getValue(), map, new HashSet<>()));
+ }
+ return result;
+ }
+
+ /**
+ * Returns the value of a variable by looking up in the map recursively Uses the seen set to avoid circular
+ * references (e.g. x -> y, y -> x) which would cause infinite recursion
+ */
+ private static Expression lookup(Expression exp, Map map, Set seen) {
+ if (!(exp instanceof Var))
+ return exp;
+
+ String name = exp.toString();
+ if (seen.contains(name))
+ return exp; // circular reference
+
+ Expression value = map.get(name);
+ if (value == null)
+ return exp;
+
+ seen.add(name);
+ return lookup(value, map, seen);
+ }
+}
\ No newline at end of file
diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/BinaryDerivationNode.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/BinaryDerivationNode.java
new file mode 100644
index 00000000..8c30a802
--- /dev/null
+++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/BinaryDerivationNode.java
@@ -0,0 +1,26 @@
+package liquidjava.rj_language.opt.derivation_node;
+
+public class BinaryDerivationNode extends DerivationNode {
+
+ private final String op;
+ private final ValDerivationNode left;
+ private final ValDerivationNode right;
+
+ public BinaryDerivationNode(ValDerivationNode left, ValDerivationNode right, String op) {
+ this.left = left;
+ this.right = right;
+ this.op = op;
+ }
+
+ public ValDerivationNode getLeft() {
+ return left;
+ }
+
+ public ValDerivationNode getRight() {
+ return right;
+ }
+
+ public String getOp() {
+ return op;
+ }
+}
diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/DerivationNode.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/DerivationNode.java
new file mode 100644
index 00000000..a6b08e54
--- /dev/null
+++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/DerivationNode.java
@@ -0,0 +1,15 @@
+package liquidjava.rj_language.opt.derivation_node;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+
+public abstract class DerivationNode {
+
+ // disable html escaping to avoid escaping characters like &, >, <, =, etc.
+ private static final Gson gson = new GsonBuilder().setPrettyPrinting().disableHtmlEscaping().create();
+
+ @Override
+ public String toString() {
+ return gson.toJson(this);
+ }
+}
\ No newline at end of file
diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/UnaryDerivationNode.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/UnaryDerivationNode.java
new file mode 100644
index 00000000..f0693dc3
--- /dev/null
+++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/UnaryDerivationNode.java
@@ -0,0 +1,20 @@
+package liquidjava.rj_language.opt.derivation_node;
+
+public class UnaryDerivationNode extends DerivationNode {
+
+ private final String op;
+ private final ValDerivationNode operand;
+
+ public UnaryDerivationNode(ValDerivationNode operand, String op) {
+ this.operand = operand;
+ this.op = op;
+ }
+
+ public ValDerivationNode getOperand() {
+ return operand;
+ }
+
+ public String getOp() {
+ return op;
+ }
+}
diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/ValDerivationNode.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/ValDerivationNode.java
new file mode 100644
index 00000000..eeb6f21d
--- /dev/null
+++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/ValDerivationNode.java
@@ -0,0 +1,54 @@
+package liquidjava.rj_language.opt.derivation_node;
+
+import java.lang.reflect.Type;
+
+import com.google.gson.JsonElement;
+import com.google.gson.JsonNull;
+import com.google.gson.JsonPrimitive;
+import com.google.gson.JsonSerializationContext;
+import com.google.gson.JsonSerializer;
+import com.google.gson.annotations.JsonAdapter;
+
+import liquidjava.rj_language.ast.Expression;
+import liquidjava.rj_language.ast.LiteralBoolean;
+import liquidjava.rj_language.ast.LiteralInt;
+import liquidjava.rj_language.ast.LiteralReal;
+import liquidjava.rj_language.ast.Var;
+
+public class ValDerivationNode extends DerivationNode {
+
+ @JsonAdapter(ExpressionSerializer.class)
+ private final Expression value;
+ private final DerivationNode origin;
+
+ public ValDerivationNode(Expression exp, DerivationNode origin) {
+ this.value = exp;
+ this.origin = origin;
+ }
+
+ public Expression getValue() {
+ return value;
+ }
+
+ public DerivationNode getOrigin() {
+ return origin;
+ }
+
+ // Custom serializer to handle Expression subclasses properly
+ private static class ExpressionSerializer implements JsonSerializer {
+ @Override
+ public JsonElement serialize(Expression exp, Type typeOfSrc, JsonSerializationContext context) {
+ if (exp == null)
+ return JsonNull.INSTANCE;
+ if (exp instanceof LiteralInt)
+ return new JsonPrimitive(((LiteralInt) exp).getValue());
+ if (exp instanceof LiteralReal)
+ return new JsonPrimitive(((LiteralReal) exp).getValue());
+ if (exp instanceof LiteralBoolean)
+ return new JsonPrimitive(((LiteralBoolean) exp).isBooleanTrue());
+ if (exp instanceof Var)
+ return new JsonPrimitive(((Var) exp).getName());
+ return new JsonPrimitive(exp.toString());
+ }
+ }
+}
diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/VarDerivationNode.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/VarDerivationNode.java
new file mode 100644
index 00000000..1c044f52
--- /dev/null
+++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/derivation_node/VarDerivationNode.java
@@ -0,0 +1,14 @@
+package liquidjava.rj_language.opt.derivation_node;
+
+public class VarDerivationNode extends DerivationNode {
+
+ private final String var;
+
+ public VarDerivationNode(String var) {
+ this.var = var;
+ }
+
+ public String getVar() {
+ return var;
+ }
+}
diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java
new file mode 100644
index 00000000..d1629682
--- /dev/null
+++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java
@@ -0,0 +1,338 @@
+package liquidjava.rj_language.opt;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+import liquidjava.rj_language.ast.BinaryExpression;
+import liquidjava.rj_language.ast.Expression;
+import liquidjava.rj_language.ast.LiteralBoolean;
+import liquidjava.rj_language.ast.LiteralInt;
+import liquidjava.rj_language.ast.UnaryExpression;
+import liquidjava.rj_language.ast.Var;
+import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
+import liquidjava.rj_language.opt.derivation_node.DerivationNode;
+import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
+import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
+import liquidjava.rj_language.opt.derivation_node.VarDerivationNode;
+import org.junit.jupiter.api.Test;
+
+/**
+ * Test suite for expression simplification using constant propagation and folding
+ */
+class ExpressionSimplifierTest {
+
+ @Test
+ void testNegation() {
+ // Given: -a && a == 7
+ // Expected: -7
+
+ Expression varA = new Var("a");
+ Expression negA = new UnaryExpression("-", varA);
+ Expression seven = new LiteralInt(7);
+ Expression aEquals7 = new BinaryExpression(varA, "==", seven);
+ Expression fullExpression = new BinaryExpression(negA, "&&", aEquals7);
+
+ // When
+ ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
+
+ // Then
+ assertNotNull(result, "Result should not be null");
+ assertEquals("-7", result.getValue().toString(), "Expected result to be -7");
+
+ // 7 from variable a
+ ValDerivationNode val7 = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("a"));
+
+ // -7
+ UnaryDerivationNode negation = new UnaryDerivationNode(val7, "-");
+ ValDerivationNode expected = new ValDerivationNode(new LiteralInt(-7), negation);
+
+ // Compare the derivation trees
+ assertDerivationEquals(expected, result, "");
+ }
+
+ @Test
+ void testSimpleAddition() {
+ // Given: a + b && a == 3 && b == 5
+ // Expected: 8 (3 + 5)
+
+ Expression varA = new Var("a");
+ Expression varB = new Var("b");
+ Expression addition = new BinaryExpression(varA, "+", varB);
+
+ Expression three = new LiteralInt(3);
+ Expression aEquals3 = new BinaryExpression(varA, "==", three);
+
+ Expression five = new LiteralInt(5);
+ Expression bEquals5 = new BinaryExpression(varB, "==", five);
+
+ Expression conditions = new BinaryExpression(aEquals3, "&&", bEquals5);
+ Expression fullExpression = new BinaryExpression(addition, "&&", conditions);
+
+ // When
+ ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
+
+ // Then
+ assertNotNull(result, "Result should not be null");
+ assertEquals("8", result.getValue().toString(), "Expected result to be 8");
+
+ // 3 from variable a
+ ValDerivationNode val3 = new ValDerivationNode(new LiteralInt(3), new VarDerivationNode("a"));
+
+ // 5 from variable b
+ ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("b"));
+
+ // 3 + 5
+ BinaryDerivationNode add3Plus5 = new BinaryDerivationNode(val3, val5, "+");
+ ValDerivationNode expected = new ValDerivationNode(new LiteralInt(8), add3Plus5);
+
+ // Compare the derivation trees
+ assertDerivationEquals(expected, result, "");
+ }
+
+ @Test
+ void testSimpleComparison() {
+ // Given: (y || true) && !true && y == false
+ // Expected: false (true && false)
+
+ Expression varY = new Var("y");
+ Expression trueExp = new LiteralBoolean(true);
+ Expression yOrTrue = new BinaryExpression(varY, "||", trueExp);
+
+ Expression notTrue = new UnaryExpression("!", trueExp);
+
+ Expression falseExp = new LiteralBoolean(false);
+ Expression yEqualsFalse = new BinaryExpression(varY, "==", falseExp);
+
+ Expression firstAnd = new BinaryExpression(yOrTrue, "&&", notTrue);
+ Expression fullExpression = new BinaryExpression(firstAnd, "&&", yEqualsFalse);
+
+ // When
+ ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
+
+ // Then
+ assertNotNull(result, "Result should not be null");
+ assertTrue(result.getValue() instanceof LiteralBoolean, "Result should be a boolean");
+ assertFalse(((LiteralBoolean) result.getValue()).isBooleanTrue(), "Expected result to befalse");
+
+ // (y || true) && y == false => false || true = true
+ ValDerivationNode valFalseForY = new ValDerivationNode(new LiteralBoolean(false), new VarDerivationNode("y"));
+ ValDerivationNode valTrue1 = new ValDerivationNode(new LiteralBoolean(true), null);
+ BinaryDerivationNode orFalseTrue = new BinaryDerivationNode(valFalseForY, valTrue1, "||");
+ ValDerivationNode trueFromOr = new ValDerivationNode(new LiteralBoolean(true), orFalseTrue);
+
+ // !true = false
+ ValDerivationNode valTrue2 = new ValDerivationNode(new LiteralBoolean(true), null);
+ UnaryDerivationNode notOp = new UnaryDerivationNode(valTrue2, "!");
+ ValDerivationNode falseFromNot = new ValDerivationNode(new LiteralBoolean(false), notOp);
+
+ // true && false = false
+ BinaryDerivationNode andTrueFalse = new BinaryDerivationNode(trueFromOr, falseFromNot, "&&");
+ ValDerivationNode falseFromFirstAnd = new ValDerivationNode(new LiteralBoolean(false), andTrueFalse);
+
+ // y == false
+ ValDerivationNode valFalseForY2 = new ValDerivationNode(new LiteralBoolean(false), new VarDerivationNode("y"));
+ ValDerivationNode valFalse2 = new ValDerivationNode(new LiteralBoolean(false), null);
+ BinaryDerivationNode compareFalseFalse = new BinaryDerivationNode(valFalseForY2, valFalse2, "==");
+ ValDerivationNode trueFromCompare = new ValDerivationNode(new LiteralBoolean(true), compareFalseFalse);
+
+ // false && true = false
+ BinaryDerivationNode finalAnd = new BinaryDerivationNode(falseFromFirstAnd, trueFromCompare, "&&");
+ ValDerivationNode expected = new ValDerivationNode(new LiteralBoolean(false), finalAnd);
+
+ // Compare the derivation trees
+ assertDerivationEquals(expected, result, "");
+ }
+
+ @Test
+ void testArithmeticWithConstants() {
+ // Given: (a / b + (-5)) + x && a == 6 && b == 2
+ // Expected: -2 + x (6 / 2 = 3, 3 + (-5) = -2)
+
+ Expression varA = new Var("a");
+ Expression varB = new Var("b");
+ Expression division = new BinaryExpression(varA, "/", varB);
+
+ Expression five = new LiteralInt(5);
+ Expression negFive = new UnaryExpression("-", five);
+
+ Expression firstSum = new BinaryExpression(division, "+", negFive);
+ Expression varX = new Var("x");
+ Expression fullArithmetic = new BinaryExpression(firstSum, "+", varX);
+
+ Expression six = new LiteralInt(6);
+ Expression aEquals6 = new BinaryExpression(varA, "==", six);
+
+ Expression two = new LiteralInt(2);
+ Expression bEquals2 = new BinaryExpression(varB, "==", two);
+
+ Expression allConditions = new BinaryExpression(aEquals6, "&&", bEquals2);
+ Expression fullExpression = new BinaryExpression(fullArithmetic, "&&", allConditions);
+
+ // When
+ ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
+
+ // Then
+ assertNotNull(result, "Result should not be null");
+ assertNotNull(result.getValue(), "Result value should not be null");
+
+ String resultStr = result.getValue().toString();
+ assertEquals("-2 + x", resultStr, "Expected result to be -2 + x");
+
+ // 6 from variable a
+ ValDerivationNode val6 = new ValDerivationNode(new LiteralInt(6), new VarDerivationNode("a"));
+
+ // 2 from variable b
+ ValDerivationNode val2 = new ValDerivationNode(new LiteralInt(2), new VarDerivationNode("b"));
+
+ // 6 / 2 = 3
+ BinaryDerivationNode div6By2 = new BinaryDerivationNode(val6, val2, "/");
+ ValDerivationNode val3 = new ValDerivationNode(new LiteralInt(3), div6By2);
+
+ // -5 from unary negation of 5
+ ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), null);
+ UnaryDerivationNode unaryNeg5 = new UnaryDerivationNode(val5, "-");
+ ValDerivationNode valNeg5 = new ValDerivationNode(new LiteralInt(-5), unaryNeg5);
+
+ // 3 + (-5) = -2
+ BinaryDerivationNode add3AndNeg5 = new BinaryDerivationNode(val3, valNeg5, "+");
+ ValDerivationNode valNeg2 = new ValDerivationNode(new LiteralInt(-2), add3AndNeg5);
+
+ // x (variable with null origin)
+ ValDerivationNode valX = new ValDerivationNode(new Var("x"), null);
+
+ // -2 + x
+ BinaryDerivationNode addNeg2AndX = new BinaryDerivationNode(valNeg2, valX, "+");
+ Expression expectedResultExpr = new BinaryExpression(new LiteralInt(-2), "+", new Var("x"));
+ ValDerivationNode expected = new ValDerivationNode(expectedResultExpr, addNeg2AndX);
+
+ // Compare the derivation trees
+ assertDerivationEquals(expected, result, "");
+ }
+
+ @Test
+ void testComplexArithmeticWithMultipleOperations() {
+ // Given: (a * 2 + b - 3) == c && a == 5 && b == 7 && c == 14
+ // Expected: (5 * 2 + 7 - 3) == 14 => 14 == 14 => true
+
+ Expression varA = new Var("a");
+ Expression varB = new Var("b");
+ Expression varC = new Var("c");
+
+ Expression two = new LiteralInt(2);
+ Expression aTimes2 = new BinaryExpression(varA, "*", two);
+
+ Expression sum = new BinaryExpression(aTimes2, "+", varB);
+
+ Expression three = new LiteralInt(3);
+ Expression arithmetic = new BinaryExpression(sum, "-", three);
+
+ Expression comparison = new BinaryExpression(arithmetic, "==", varC);
+
+ Expression five = new LiteralInt(5);
+ Expression aEquals5 = new BinaryExpression(varA, "==", five);
+
+ Expression seven = new LiteralInt(7);
+ Expression bEquals7 = new BinaryExpression(varB, "==", seven);
+
+ Expression fourteen = new LiteralInt(14);
+ Expression cEquals14 = new BinaryExpression(varC, "==", fourteen);
+
+ Expression conj1 = new BinaryExpression(aEquals5, "&&", bEquals7);
+ Expression allConditions = new BinaryExpression(conj1, "&&", cEquals14);
+ Expression fullExpression = new BinaryExpression(comparison, "&&", allConditions);
+
+ // When
+ ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
+
+ // Then
+ assertNotNull(result, "Result should not be null");
+ assertNotNull(result.getValue(), "Result value should not be null");
+ assertTrue(result.getValue() instanceof LiteralBoolean, "Result should be a boolean literal");
+ assertTrue(((LiteralBoolean) result.getValue()).isBooleanTrue(), "Expected result to be true");
+
+ // 5 * 2 + 7 - 3
+ ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a"));
+ ValDerivationNode val2 = new ValDerivationNode(new LiteralInt(2), null);
+ BinaryDerivationNode mult5Times2 = new BinaryDerivationNode(val5, val2, "*");
+ ValDerivationNode val10 = new ValDerivationNode(new LiteralInt(10), mult5Times2);
+
+ ValDerivationNode val7 = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("b"));
+ BinaryDerivationNode add10Plus7 = new BinaryDerivationNode(val10, val7, "+");
+ ValDerivationNode val17 = new ValDerivationNode(new LiteralInt(17), add10Plus7);
+
+ ValDerivationNode val3 = new ValDerivationNode(new LiteralInt(3), null);
+ BinaryDerivationNode sub17Minus3 = new BinaryDerivationNode(val17, val3, "-");
+ ValDerivationNode val14Left = new ValDerivationNode(new LiteralInt(14), sub17Minus3);
+
+ // 14 from variable c
+ ValDerivationNode val14Right = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c"));
+
+ // 14 == 14
+ BinaryDerivationNode compare14 = new BinaryDerivationNode(val14Left, val14Right, "==");
+ ValDerivationNode trueFromComparison = new ValDerivationNode(new LiteralBoolean(true), compare14);
+
+ // a == 5 => true
+ ValDerivationNode val5ForCompA = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a"));
+ ValDerivationNode val5Literal = new ValDerivationNode(new LiteralInt(5), null);
+ BinaryDerivationNode compareA5 = new BinaryDerivationNode(val5ForCompA, val5Literal, "==");
+ ValDerivationNode trueFromA = new ValDerivationNode(new LiteralBoolean(true), compareA5);
+
+ // b == 7 => true
+ ValDerivationNode val7ForCompB = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("b"));
+ ValDerivationNode val7Literal = new ValDerivationNode(new LiteralInt(7), null);
+ BinaryDerivationNode compareB7 = new BinaryDerivationNode(val7ForCompB, val7Literal, "==");
+ ValDerivationNode trueFromB = new ValDerivationNode(new LiteralBoolean(true), compareB7);
+
+ // (a == 5) && (b == 7) => true
+ BinaryDerivationNode andAB = new BinaryDerivationNode(trueFromA, trueFromB, "&&");
+ ValDerivationNode trueFromAB = new ValDerivationNode(new LiteralBoolean(true), andAB);
+
+ // c == 14 => true
+ ValDerivationNode val14ForCompC = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c"));
+ ValDerivationNode val14Literal = new ValDerivationNode(new LiteralInt(14), null);
+ BinaryDerivationNode compareC14 = new BinaryDerivationNode(val14ForCompC, val14Literal, "==");
+ ValDerivationNode trueFromC = new ValDerivationNode(new LiteralBoolean(true), compareC14);
+
+ // ((a == 5) && (b == 7)) && (c == 14) => true
+ BinaryDerivationNode andABC = new BinaryDerivationNode(trueFromAB, trueFromC, "&&");
+ ValDerivationNode trueFromAllConditions = new ValDerivationNode(new LiteralBoolean(true), andABC);
+
+ // 14 == 14 => true
+ BinaryDerivationNode finalAnd = new BinaryDerivationNode(trueFromComparison, trueFromAllConditions, "&&");
+ ValDerivationNode expected = new ValDerivationNode(new LiteralBoolean(true), finalAnd);
+
+ // Compare the derivation trees
+ assertDerivationEquals(expected, result, "");
+ }
+
+ /**
+ * Helper method to compare two derivation nodes recursively
+ */
+ private void assertDerivationEquals(DerivationNode expected, DerivationNode actual, String message) {
+ if (expected == null && actual == null)
+ return;
+
+ assertEquals(expected.getClass(), actual.getClass(), message + ": node types should match");
+ if (expected instanceof ValDerivationNode) {
+ ValDerivationNode expectedVal = (ValDerivationNode) expected;
+ ValDerivationNode actualVal = (ValDerivationNode) actual;
+ assertEquals(expectedVal.getValue().toString(), actualVal.getValue().toString(),
+ message + ": values should match");
+ assertDerivationEquals(expectedVal.getOrigin(), actualVal.getOrigin(), message + " > origin");
+ } else if (expected instanceof BinaryDerivationNode) {
+ BinaryDerivationNode expectedBin = (BinaryDerivationNode) expected;
+ BinaryDerivationNode actualBin = (BinaryDerivationNode) actual;
+ assertEquals(expectedBin.getOp(), actualBin.getOp(), message + ": operators should match");
+ assertDerivationEquals(expectedBin.getLeft(), actualBin.getLeft(), message + " > left");
+ assertDerivationEquals(expectedBin.getRight(), actualBin.getRight(), message + " > right");
+ } else if (expected instanceof VarDerivationNode) {
+ VarDerivationNode expectedVar = (VarDerivationNode) expected;
+ VarDerivationNode actualVar = (VarDerivationNode) actual;
+ assertEquals(expectedVar.getVar(), actualVar.getVar(), message + ": variables should match");
+ } else if (expected instanceof UnaryDerivationNode) {
+ UnaryDerivationNode expectedUnary = (UnaryDerivationNode) expected;
+ UnaryDerivationNode actualUnary = (UnaryDerivationNode) actual;
+ assertEquals(expectedUnary.getOp(), actualUnary.getOp(), message + ": operators should match");
+ assertDerivationEquals(expectedUnary.getOperand(), actualUnary.getOperand(), message + " > operand");
+ }
+ }
+}