From f90aaf84ba164737a0b9ca0d779a9eddcb1229f1 Mon Sep 17 00:00:00 2001 From: He-Pin Date: Sun, 20 Apr 2025 20:42:38 +0800 Subject: [PATCH] chore: add some static optimizations --- sjsonnet/src/sjsonnet/Error.scala | 2 - sjsonnet/src/sjsonnet/Evaluator.scala | 3 +- sjsonnet/src/sjsonnet/StaticOptimizer.scala | 71 ++++++++++++++++++++- sjsonnet/src/sjsonnet/Val.scala | 2 + 4 files changed, 73 insertions(+), 5 deletions(-) diff --git a/sjsonnet/src/sjsonnet/Error.scala b/sjsonnet/src/sjsonnet/Error.scala index 49e30f32..e8e63ea8 100644 --- a/sjsonnet/src/sjsonnet/Error.scala +++ b/sjsonnet/src/sjsonnet/Error.scala @@ -1,7 +1,5 @@ package sjsonnet -import fastparse.IndexedParserInput - import scala.util.control.NonFatal /** diff --git a/sjsonnet/src/sjsonnet/Evaluator.scala b/sjsonnet/src/sjsonnet/Evaluator.scala index 99de24b8..26b839b2 100644 --- a/sjsonnet/src/sjsonnet/Evaluator.scala +++ b/sjsonnet/src/sjsonnet/Evaluator.scala @@ -1,11 +1,10 @@ package sjsonnet -import Expr.{Error => _, _} import sjsonnet.Expr.Member.Visibility +import sjsonnet.Expr.{Error => _, _} import ujson.Value import scala.annotation.tailrec -import scala.collection.mutable /** * Recursively walks the [[Expr]] trees to convert them into into [[Val]] diff --git a/sjsonnet/src/sjsonnet/StaticOptimizer.scala b/sjsonnet/src/sjsonnet/StaticOptimizer.scala index aa2ecc99..de298caa 100644 --- a/sjsonnet/src/sjsonnet/StaticOptimizer.scala +++ b/sjsonnet/src/sjsonnet/StaticOptimizer.scala @@ -34,11 +34,80 @@ class StaticOptimizer( case Lookup(pos, ValidSuper(_, selfIdx), index) => LookupSuper(pos, selfIdx, index) - case b2 @ BinaryOp(pos, lhs, BinaryOp.OP_in, ValidSuper(_, selfIdx)) => + case BinaryOp(pos, lhs, BinaryOp.OP_in, ValidSuper(_, selfIdx)) => InSuper(pos, lhs, selfIdx) + case BinaryOp(pos, Val.Str(_, key), BinaryOp.OP_in, obj: Val.Obj) if obj.staticSafe => + Val.bool(pos, obj.containsKey(key)) case b2 @ BinaryOp(pos, lhs: Val.Str, BinaryOp.OP_%, rhs) => try ApplyBuiltin1(pos, new Format.PartialApplyFmt(lhs.value), rhs, tailstrict = false) catch { case _: Exception => b2 } + //optimize booleans + case Or(_, Val.False(_), rhs) => transform(rhs) + case Or(pos, Val.True(_), _) => Val.True(pos) + case Or(pos, _, Val.True(_)) => Val.True(pos) + case And(_, Val.True(_), rhs) => transform(rhs) + case And(pos, Val.False(_), _) => Val.False(pos) + case And(pos, _, Val.False(_)) => Val.False(pos) + case UnaryOp(pos, UnaryOp.OP_!, Val.True(_)) => Val.False(pos) + case UnaryOp(pos, UnaryOp.OP_!, Val.False(_)) => Val.True(pos) + case UnaryOp(_, UnaryOp.OP_!, UnaryOp(_, UnaryOp.OP_!, expr)) => expr + //optimize for numbers + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_+, Val.Num(_, r)) => Val.Num(pos, l + r) + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_-, Val.Num(_, r)) => Val.Num(pos, l - r) + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_*, Val.Num(_, r)) => Val.Num(pos, l * r) + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_/, Val.Num(_, r)) if r != 0 => Val.Num(pos, l / r) + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_%, Val.Num(_, r)) if r != 0 => Val.Num(pos, l % r) + case UnaryOp(pos, UnaryOp.OP_!, Val.False(_)) => Val.True(pos) + case UnaryOp(pos, UnaryOp.OP_!, Val.True(_)) => Val.False(pos) + case UnaryOp(pos, UnaryOp.OP_+, Val.Num(_, v)) => Val.Num(pos, v) + case UnaryOp(pos, UnaryOp.OP_-, Val.Num(_, v)) => Val.Num(pos, -v) + case UnaryOp(pos, UnaryOp.OP_~, Val.Num(_, v)) => Val.Num(pos, ~v.toLong) + //optimize for bitwise + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_&, Val.Num(_, r)) => + Val.Num(pos, l.toLong & r.toLong) + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_^, Val.Num(_, r)) => + Val.Num(pos, l.toLong ^ r.toLong) + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_|, Val.Num(_, r)) => + Val.Num(pos, l.toLong | r.toLong) + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_<<, Val.Num(_, r)) if r.isValidInt => + Val.Num(pos, l.toLong << r.toInt) + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_>>, Val.Num(_, r)) if r.isValidInt => + Val.Num(pos, l.toLong >> r.toInt) + //optimize for strings + case BinaryOp(pos, Val.Str(_, l), BinaryOp.OP_+, Val.Str(_, r)) => Val.Str(pos, l + r) + //optimize for comparing + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_==, Val.Num(_, r)) => Val.bool(pos, l == r) + case BinaryOp(pos, Val.Str(_, l), BinaryOp.OP_==, Val.Str(_, r)) => Val.bool(pos, l == r) + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_!=, Val.Num(_, r)) => Val.bool(pos, l != r) + case BinaryOp(pos, Val.Str(_, l), BinaryOp.OP_!=, Val.Str(_, r)) => Val.bool(pos, l != r) + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_<, Val.Num(_, r)) => Val.bool(pos, l < r) + case BinaryOp(pos, Val.Str(_, l), BinaryOp.OP_<, Val.Str(_, r)) => Val.bool(pos, l < r) + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_>, Val.Num(_, r)) => Val.bool(pos, l > r) + case BinaryOp(pos, Val.Str(_, l), BinaryOp.OP_>, Val.Str(_, r)) => Val.bool(pos, l > r) + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_<=, Val.Num(_, r)) => Val.bool(pos, l <= r) + case BinaryOp(pos, Val.Str(_, l), BinaryOp.OP_<=, Val.Str(_, r)) => Val.bool(pos, l <= r) + case BinaryOp(pos, Val.Num(_, l), BinaryOp.OP_>=, Val.Num(_, r)) => Val.bool(pos, l >= r) + case BinaryOp(pos, Val.Str(_, l), BinaryOp.OP_>=, Val.Str(_, r)) => Val.bool(pos, l >= r) + //optimize for if else + case IfElse(_, Val.True(_), thenExpr, _) => transform(thenExpr) + case IfElse(pos, Val.False(_), _, elseExpr) => `elseExpr` match { + case null => Val.Null(pos) + case _ => transform(elseExpr) + } + //optimize for obj + case b3@BinaryOp(_, lhs: Val.Obj, BinaryOp.OP_+, rhs: Val.Obj) if lhs.staticSafe && rhs.staticSafe => + if (lhs.allKeyNames.isEmpty) { + rhs + } else if (rhs.allKeyNames.isEmpty) { + lhs + } else b3 + //optimize for arr + case b4@BinaryOp(pos, lhs: Val.Arr, BinaryOp.OP_+, rhs: Val.Arr) => + if (lhs.length == 0) { + new Val.Arr(pos, rhs.asLazyArray) + } else if (rhs.length == 0) { + new Val.Arr(pos, lhs.asLazyArray) + } else b4 case e @ Id(pos, name) => scope.get(name) match { diff --git a/sjsonnet/src/sjsonnet/Val.scala b/sjsonnet/src/sjsonnet/Val.scala index 14a7e67e..b90414a5 100644 --- a/sjsonnet/src/sjsonnet/Val.scala +++ b/sjsonnet/src/sjsonnet/Val.scala @@ -407,6 +407,8 @@ object Val{ f(k, v) } } + + def staticSafe:Boolean = static } final class StaticObjectFieldSet(protected val keys: Array[String]) {