diff --git a/extensions/BUILD b/extensions/BUILD index 55325186c..35eea53b9 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -143,10 +143,14 @@ cc_test( ":math_ext_decls", ":math_ext_macros", "//checker:standard_library", + "//checker:type_check_issue", "//checker:validation_result", "//common:decl", "//common:function_descriptor", + "//common:type", + "//compiler", "//compiler:compiler_factory", + "//compiler:standard_library", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", @@ -162,6 +166,7 @@ cc_test( "//runtime:activation", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", diff --git a/extensions/math_ext_decls.cc b/extensions/math_ext_decls.cc index ca0487408..a7091cef6 100644 --- a/extensions/math_ext_decls.cc +++ b/extensions/math_ext_decls.cc @@ -128,9 +128,6 @@ absl::Status AddMinMaxDecls(TypeCheckerBuilder& builder) { absl::Status AddSignednessDecls(TypeCheckerBuilder& builder) { const Type kNumerics[] = {IntType(), DoubleType(), UintType()}; - FunctionDecl sqrt_decl; - sqrt_decl.set_name("math.sqrt"); - FunctionDecl sign_decl; sign_decl.set_name("math.sign"); @@ -138,22 +135,35 @@ absl::Status AddSignednessDecls(TypeCheckerBuilder& builder) { abs_decl.set_name("math.abs"); for (const Type& type : kNumerics) { - CEL_RETURN_IF_ERROR(sqrt_decl.AddOverload( - MakeOverloadDecl(absl::StrCat("math_sqrt_", OverloadTypeName(type)), - DoubleType(), type))); CEL_RETURN_IF_ERROR(sign_decl.AddOverload(MakeOverloadDecl( absl::StrCat("math_sign_", OverloadTypeName(type)), type, type))); CEL_RETURN_IF_ERROR(abs_decl.AddOverload(MakeOverloadDecl( absl::StrCat("math_abs_", OverloadTypeName(type)), type, type))); } - CEL_RETURN_IF_ERROR(builder.AddFunction(sqrt_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(sign_decl)); CEL_RETURN_IF_ERROR(builder.AddFunction(abs_decl)); return absl::OkStatus(); } +absl::Status AddSqrtDecls(TypeCheckerBuilder& builder) { + const Type kNumerics[] = {IntType(), DoubleType(), UintType()}; + + FunctionDecl sqrt_decl; + sqrt_decl.set_name("math.sqrt"); + + for (const Type& type : kNumerics) { + CEL_RETURN_IF_ERROR(sqrt_decl.AddOverload( + MakeOverloadDecl(absl::StrCat("math_sqrt_", OverloadTypeName(type)), + DoubleType(), type))); + } + + CEL_RETURN_IF_ERROR(builder.AddFunction(sqrt_decl)); + + return absl::OkStatus(); +} + absl::Status AddFloatingPointDecls(TypeCheckerBuilder& builder) { // Rounding CEL_ASSIGN_OR_RETURN( @@ -270,17 +280,28 @@ absl::Status AddBitwiseDecls(TypeCheckerBuilder& builder) { return absl::OkStatus(); } -absl::Status AddMathExtensionDeclarations(TypeCheckerBuilder& builder) { +absl::Status AddMathExtensionDeclarations(TypeCheckerBuilder& builder, + int version) { CEL_RETURN_IF_ERROR(AddMinMaxDecls(builder)); + if (version == 0) { + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR(AddSignednessDecls(builder)); CEL_RETURN_IF_ERROR(AddFloatingPointDecls(builder)); CEL_RETURN_IF_ERROR(AddBitwiseDecls(builder)); + if (version == 1) { + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR(AddSqrtDecls(builder)); return absl::OkStatus(); } -absl::Status AddMathExtensionMacros(ParserBuilder& builder) { +absl::Status AddMathExtensionMacros(ParserBuilder& builder, int version) { for (const auto& m : math_macros()) { + // At the moment, all macros are supported in all versions. When we add a + // new macro, we must add a version check here. CEL_RETURN_IF_ERROR(builder.AddMacro(m)); } return absl::OkStatus(); @@ -289,16 +310,25 @@ absl::Status AddMathExtensionMacros(ParserBuilder& builder) { } // namespace // Configuration for cel::Compiler to enable the math extension declarations. -CompilerLibrary MathCompilerLibrary() { - return CompilerLibrary(kMathExtensionName, &AddMathExtensionMacros, - &AddMathExtensionDeclarations); +CompilerLibrary MathCompilerLibrary(int version) { + return CompilerLibrary( + kMathExtensionName, + [version](ParserBuilder& builder) { + return AddMathExtensionMacros(builder, version); + }, + [version](TypeCheckerBuilder& builder) { + return AddMathExtensionDeclarations(builder, version); + }); } // Configuration for cel::TypeChecker to enable the math extension declarations. -CheckerLibrary MathCheckerLibrary() { +CheckerLibrary MathCheckerLibrary(int version) { return { .id = kMathExtensionName, - .configure = &AddMathExtensionDeclarations, + .configure = + [version](TypeCheckerBuilder& builder) { + return AddMathExtensionDeclarations(builder, version); + }, }; } diff --git a/extensions/math_ext_decls.h b/extensions/math_ext_decls.h index 31758f77b..624649a39 100644 --- a/extensions/math_ext_decls.h +++ b/extensions/math_ext_decls.h @@ -20,11 +20,13 @@ namespace cel::extensions { +constexpr int kMathExtensionLatestVersion = 2; + // Configuration for cel::Compiler to enable the math extension declarations. -CompilerLibrary MathCompilerLibrary(); +CompilerLibrary MathCompilerLibrary(int version = kMathExtensionLatestVersion); // Configuration for cel::TypeChecker to enable the math extension declarations. -CheckerLibrary MathCheckerLibrary(); +CheckerLibrary MathCheckerLibrary(int version = kMathExtensionLatestVersion); } // namespace cel::extensions diff --git a/extensions/math_ext_test.cc b/extensions/math_ext_test.cc index b5d0f60b0..3088e6fa8 100644 --- a/extensions/math_ext_test.cc +++ b/extensions/math_ext_test.cc @@ -17,8 +17,10 @@ #include #include #include +#include #include "cel/expr/syntax.pb.h" +#include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/str_cat.h" @@ -26,10 +28,14 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "checker/standard_library.h" +#include "checker/type_check_issue.h" #include "checker/validation_result.h" #include "common/decl.h" #include "common/function_descriptor.h" +#include "common/type.h" +#include "compiler/compiler.h" #include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" @@ -70,6 +76,8 @@ using ::google::api::expr::runtime::RegisterBuiltinFunctions; using ::google::api::expr::runtime::test::EqualsCelValue; using ::google::protobuf::Arena; using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::ValuesIn; constexpr absl::string_view kMathMin = "math.@min"; constexpr absl::string_view kMathMax = "math.@max"; @@ -573,5 +581,123 @@ INSTANTIATE_TEST_SUITE_P( {"math.bitShiftRight(4, 1) == 2"}, {"math.bitShiftRight(4u, 1) == 2u"}})); +struct MathExtensionVersionTestCase { + std::string expr; + std::vector expected_supported_versions; +}; + +class MathExtensionVersionTest + : public ::testing::TestWithParam {}; + +TEST_P(MathExtensionVersionTest, MathExtensionVersions) { + const MathExtensionVersionTestCase& test_case = GetParam(); + for (int version = 0; version <= cel::extensions::kMathExtensionLatestVersion; + ++version) { + CompilerLibrary compiler_library = MathCompilerLibrary(version); + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + cel::NewCompilerBuilder(internal::GetTestingDescriptorPool(), + CompilerOptions())); + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(std::move(compiler_library)), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, builder->Build()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile(test_case.expr)); + if (absl::c_contains(test_case.expected_supported_versions, version)) { + EXPECT_THAT(result.GetIssues(), IsEmpty()) + << "Expected no issues for expr: " << test_case.expr + << " at version: " << version << " but got: " << result.FormatError(); + } else { + EXPECT_THAT(result.GetIssues(), + Contains(Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference")))) + << "Expected undeclared reference for expr: " << test_case.expr + << " at version: " << version; + } + } +}; + +std::vector CreateMathExtensionVersionParams() { + return { + MathExtensionVersionTestCase{ + .expr = "math.least([0,1,2,3])", + .expected_supported_versions = {0, 1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.greatest([0,1,2,3])", + .expected_supported_versions = {0, 1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.ceil(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.floor(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.round(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.trunc(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.isInf(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.isNaN(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.isFinite(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.abs(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.sign(1.5)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.bitAnd(1, 1)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.bitOr(1, 1)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.bitXor(1, 1)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.bitNot(1)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.bitShiftLeft(1, 1)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.bitShiftRight(1, 1)", + .expected_supported_versions = {1, 2}, + }, + MathExtensionVersionTestCase{ + .expr = "math.sqrt(1.5)", + .expected_supported_versions = {2}, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(MathExtensionVersionTest, MathExtensionVersionTest, + ValuesIn(CreateMathExtensionVersionParams())); + } // namespace } // namespace cel::extensions