Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions extensions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,14 @@ cc_test(
":math_ext_decls",
":math_ext_macros",
"//checker:standard_library",
"//checker:type_check_issue",
"//checker:validation_result",
"//common:decl",
"//common:function_descriptor",
"//common:type",
"//compiler",
"//compiler:compiler_factory",
"//compiler:standard_library",
"//eval/public:activation",
"//eval/public:builtin_func_registrar",
"//eval/public:cel_expr_builder_factory",
Expand All @@ -162,6 +166,7 @@ cc_test(
"//runtime:activation",
"//runtime:runtime_options",
"//runtime:standard_runtime_builder_factory",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/strings",
Expand Down
58 changes: 44 additions & 14 deletions extensions/math_ext_decls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,32 +128,42 @@ absl::Status AddMinMaxDecls(TypeCheckerBuilder& builder) {
absl::Status AddSignednessDecls(TypeCheckerBuilder& builder) {
const Type kNumerics[] = {IntType(), DoubleType(), UintType()};

FunctionDecl sqrt_decl;
sqrt_decl.set_name("math.sqrt");

FunctionDecl sign_decl;
sign_decl.set_name("math.sign");

FunctionDecl abs_decl;
abs_decl.set_name("math.abs");

for (const Type& type : kNumerics) {
CEL_RETURN_IF_ERROR(sqrt_decl.AddOverload(
MakeOverloadDecl(absl::StrCat("math_sqrt_", OverloadTypeName(type)),
DoubleType(), type)));
CEL_RETURN_IF_ERROR(sign_decl.AddOverload(MakeOverloadDecl(
absl::StrCat("math_sign_", OverloadTypeName(type)), type, type)));
CEL_RETURN_IF_ERROR(abs_decl.AddOverload(MakeOverloadDecl(
absl::StrCat("math_abs_", OverloadTypeName(type)), type, type)));
}

CEL_RETURN_IF_ERROR(builder.AddFunction(sqrt_decl));
CEL_RETURN_IF_ERROR(builder.AddFunction(sign_decl));
CEL_RETURN_IF_ERROR(builder.AddFunction(abs_decl));

return absl::OkStatus();
}

absl::Status AddSqrtDecls(TypeCheckerBuilder& builder) {
const Type kNumerics[] = {IntType(), DoubleType(), UintType()};

FunctionDecl sqrt_decl;
sqrt_decl.set_name("math.sqrt");

for (const Type& type : kNumerics) {
CEL_RETURN_IF_ERROR(sqrt_decl.AddOverload(
MakeOverloadDecl(absl::StrCat("math_sqrt_", OverloadTypeName(type)),
DoubleType(), type)));
}

CEL_RETURN_IF_ERROR(builder.AddFunction(sqrt_decl));

return absl::OkStatus();
}

absl::Status AddFloatingPointDecls(TypeCheckerBuilder& builder) {
// Rounding
CEL_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -270,17 +280,28 @@ absl::Status AddBitwiseDecls(TypeCheckerBuilder& builder) {
return absl::OkStatus();
}

absl::Status AddMathExtensionDeclarations(TypeCheckerBuilder& builder) {
absl::Status AddMathExtensionDeclarations(TypeCheckerBuilder& builder,
int version) {
CEL_RETURN_IF_ERROR(AddMinMaxDecls(builder));
if (version == 0) {
return absl::OkStatus();
}

CEL_RETURN_IF_ERROR(AddSignednessDecls(builder));
CEL_RETURN_IF_ERROR(AddFloatingPointDecls(builder));
CEL_RETURN_IF_ERROR(AddBitwiseDecls(builder));
if (version == 1) {
return absl::OkStatus();
}
CEL_RETURN_IF_ERROR(AddSqrtDecls(builder));

return absl::OkStatus();
}

absl::Status AddMathExtensionMacros(ParserBuilder& builder) {
absl::Status AddMathExtensionMacros(ParserBuilder& builder, int version) {
for (const auto& m : math_macros()) {
// At the moment, all macros are supported in all versions. When we add a
// new macro, we must add a version check here.
CEL_RETURN_IF_ERROR(builder.AddMacro(m));
}
return absl::OkStatus();
Expand All @@ -289,16 +310,25 @@ absl::Status AddMathExtensionMacros(ParserBuilder& builder) {
} // namespace

// Configuration for cel::Compiler to enable the math extension declarations.
CompilerLibrary MathCompilerLibrary() {
return CompilerLibrary(kMathExtensionName, &AddMathExtensionMacros,
&AddMathExtensionDeclarations);
CompilerLibrary MathCompilerLibrary(int version) {
return CompilerLibrary(
kMathExtensionName,
[version](ParserBuilder& builder) {
return AddMathExtensionMacros(builder, version);
},
[version](TypeCheckerBuilder& builder) {
return AddMathExtensionDeclarations(builder, version);
});
}

// Configuration for cel::TypeChecker to enable the math extension declarations.
CheckerLibrary MathCheckerLibrary() {
CheckerLibrary MathCheckerLibrary(int version) {
return {
.id = kMathExtensionName,
.configure = &AddMathExtensionDeclarations,
.configure =
[version](TypeCheckerBuilder& builder) {
return AddMathExtensionDeclarations(builder, version);
},
};
}

Expand Down
6 changes: 4 additions & 2 deletions extensions/math_ext_decls.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@

namespace cel::extensions {

constexpr int kMathExtensionLatestVersion = 2;

// Configuration for cel::Compiler to enable the math extension declarations.
CompilerLibrary MathCompilerLibrary();
CompilerLibrary MathCompilerLibrary(int version = kMathExtensionLatestVersion);

// Configuration for cel::TypeChecker to enable the math extension declarations.
CheckerLibrary MathCheckerLibrary();
CheckerLibrary MathCheckerLibrary(int version = kMathExtensionLatestVersion);

} // namespace cel::extensions

Expand Down
126 changes: 126 additions & 0 deletions extensions/math_ext_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,25 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "cel/expr/syntax.pb.h"
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/status_matchers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "checker/standard_library.h"
#include "checker/type_check_issue.h"
#include "checker/validation_result.h"
#include "common/decl.h"
#include "common/function_descriptor.h"
#include "common/type.h"
#include "compiler/compiler.h"
#include "compiler/compiler_factory.h"
#include "compiler/standard_library.h"
#include "eval/public/activation.h"
#include "eval/public/builtin_func_registrar.h"
#include "eval/public/cel_expr_builder_factory.h"
Expand Down Expand Up @@ -70,6 +76,8 @@ using ::google::api::expr::runtime::RegisterBuiltinFunctions;
using ::google::api::expr::runtime::test::EqualsCelValue;
using ::google::protobuf::Arena;
using ::testing::HasSubstr;
using ::testing::IsEmpty;
using ::testing::ValuesIn;

constexpr absl::string_view kMathMin = "math.@min";
constexpr absl::string_view kMathMax = "math.@max";
Expand Down Expand Up @@ -573,5 +581,123 @@ INSTANTIATE_TEST_SUITE_P(
{"math.bitShiftRight(4, 1) == 2"},
{"math.bitShiftRight(4u, 1) == 2u"}}));

struct MathExtensionVersionTestCase {
std::string expr;
std::vector<int> expected_supported_versions;
};

class MathExtensionVersionTest
: public ::testing::TestWithParam<MathExtensionVersionTestCase> {};

TEST_P(MathExtensionVersionTest, MathExtensionVersions) {
const MathExtensionVersionTestCase& test_case = GetParam();
for (int version = 0; version <= cel::extensions::kMathExtensionLatestVersion;
++version) {
CompilerLibrary compiler_library = MathCompilerLibrary(version);

ASSERT_OK_AND_ASSIGN(
std::unique_ptr<CompilerBuilder> builder,
cel::NewCompilerBuilder(internal::GetTestingDescriptorPool(),
CompilerOptions()));
ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk());
ASSERT_THAT(builder->AddLibrary(std::move(compiler_library)), IsOk());

ASSERT_OK_AND_ASSIGN(std::unique_ptr<Compiler> compiler, builder->Build());
ASSERT_OK_AND_ASSIGN(ValidationResult result,
compiler->Compile(test_case.expr));
if (absl::c_contains(test_case.expected_supported_versions, version)) {
EXPECT_THAT(result.GetIssues(), IsEmpty())
<< "Expected no issues for expr: " << test_case.expr
<< " at version: " << version << " but got: " << result.FormatError();
} else {
EXPECT_THAT(result.GetIssues(),
Contains(Property(&TypeCheckIssue::message,
HasSubstr("undeclared reference"))))
<< "Expected undeclared reference for expr: " << test_case.expr
<< " at version: " << version;
}
}
};

std::vector<MathExtensionVersionTestCase> CreateMathExtensionVersionParams() {
return {
MathExtensionVersionTestCase{
.expr = "math.least([0,1,2,3])",
.expected_supported_versions = {0, 1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.greatest([0,1,2,3])",
.expected_supported_versions = {0, 1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.ceil(1.5)",
.expected_supported_versions = {1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.floor(1.5)",
.expected_supported_versions = {1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.round(1.5)",
.expected_supported_versions = {1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.trunc(1.5)",
.expected_supported_versions = {1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.isInf(1.5)",
.expected_supported_versions = {1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.isNaN(1.5)",
.expected_supported_versions = {1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.isFinite(1.5)",
.expected_supported_versions = {1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.abs(1.5)",
.expected_supported_versions = {1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.sign(1.5)",
.expected_supported_versions = {1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.bitAnd(1, 1)",
.expected_supported_versions = {1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.bitOr(1, 1)",
.expected_supported_versions = {1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.bitXor(1, 1)",
.expected_supported_versions = {1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.bitNot(1)",
.expected_supported_versions = {1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.bitShiftLeft(1, 1)",
.expected_supported_versions = {1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.bitShiftRight(1, 1)",
.expected_supported_versions = {1, 2},
},
MathExtensionVersionTestCase{
.expr = "math.sqrt(1.5)",
.expected_supported_versions = {2},
},
};
}

INSTANTIATE_TEST_SUITE_P(MathExtensionVersionTest, MathExtensionVersionTest,
ValuesIn(CreateMathExtensionVersionParams()));

} // namespace
} // namespace cel::extensions