Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions extensions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ cc_test(
srcs = ["lists_functions_test.cc"],
deps = [
":lists_functions",
"//checker:type_check_issue",
"//checker:validation_result",
"//common:source",
"//common:value",
Expand All @@ -476,6 +477,7 @@ cc_test(
"//runtime:runtime_builder",
"//runtime:runtime_options",
"//runtime:standard_runtime_builder_factory",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/strings:string_view",
Expand Down
69 changes: 52 additions & 17 deletions extensions/lists_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,8 @@ const Type& ListTypeParamType() {
return *kInstance;
}

absl::Status RegisterListsCheckerDecls(TypeCheckerBuilder& builder) {
absl::Status RegisterListsCheckerDecls(TypeCheckerBuilder& builder,
int version) {
CEL_ASSIGN_OR_RETURN(
FunctionDecl distinct_decl,
MakeFunctionDecl("distinct", MakeMemberOverloadDecl(
Expand Down Expand Up @@ -615,22 +616,40 @@ absl::Status RegisterListsCheckerDecls(TypeCheckerBuilder& builder) {
ListTypeParamType(), ListTypeParamType(), list_type)));
}

CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(slice_decl)));
if (version == 0) {
return absl::OkStatus();
}

CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(flatten_decl)));
if (version == 1) {
return absl::OkStatus();
}

CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(sort_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(sort_by_key_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(distinct_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(flatten_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(range_decl)));
// MergeFunction is used to combine with the reverse function
// defined in strings extension.
CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(reverse_decl)));
CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(slice_decl)));
return absl::OkStatus();
}

std::vector<Macro> lists_macros() { return {ListSortByMacro()}; }
std::vector<Macro> lists_macros(int version) {
switch (version) {
case 0:
return {};
case 1:
return {};
case 2:
default:
return {ListSortByMacro()};
};
}

absl::Status ConfigureParser(ParserBuilder& builder) {
for (const Macro& macro : lists_macros()) {
absl::Status ConfigureParser(ParserBuilder& builder, int version) {
for (const Macro& macro : lists_macros(version)) {
CEL_RETURN_IF_ERROR(builder.AddMacro(macro));
}
return absl::OkStatus();
Expand All @@ -639,28 +658,44 @@ absl::Status ConfigureParser(ParserBuilder& builder) {
} // namespace

absl::Status RegisterListsFunctions(FunctionRegistry& registry,
const RuntimeOptions& options) {
CEL_RETURN_IF_ERROR(RegisterListDistinctFunction(registry));
const RuntimeOptions& options,
int version) {
CEL_RETURN_IF_ERROR(RegisterListSliceFunction(registry));
if (version == 0) {
return absl::OkStatus();
}

// Since version 1
CEL_RETURN_IF_ERROR(RegisterListFlattenFunction(registry));
if (version == 1) {
return absl::OkStatus();
}

// Since version 2
CEL_RETURN_IF_ERROR(RegisterListDistinctFunction(registry));
CEL_RETURN_IF_ERROR(RegisterListRangeFunction(registry));
CEL_RETURN_IF_ERROR(RegisterListReverseFunction(registry));
CEL_RETURN_IF_ERROR(RegisterListSliceFunction(registry));
CEL_RETURN_IF_ERROR(RegisterListSortFunction(registry));
return absl::OkStatus();
}

absl::Status RegisterListsMacros(MacroRegistry& registry,
const ParserOptions&) {
return registry.RegisterMacros(lists_macros());
absl::Status RegisterListsMacros(MacroRegistry& registry, const ParserOptions&,
int version) {
return registry.RegisterMacros(lists_macros(version));
}

CheckerLibrary ListsCheckerLibrary() {
return {.id = "cel.lib.ext.lists", .configure = RegisterListsCheckerDecls};
CheckerLibrary ListsCheckerLibrary(int version) {
return {.id = "cel.lib.ext.lists",
.configure = [version](TypeCheckerBuilder& builder) {
return RegisterListsCheckerDecls(builder, version);
}};
}

CompilerLibrary ListsCompilerLibrary() {
auto lib = CompilerLibrary::FromCheckerLibrary(ListsCheckerLibrary());
lib.configure_parser = ConfigureParser;
CompilerLibrary ListsCompilerLibrary(int version) {
auto lib = CompilerLibrary::FromCheckerLibrary(ListsCheckerLibrary(version));
lib.configure_parser = [version](ParserBuilder& builder) {
return ConfigureParser(builder, version);
};
return lib;
}

Expand Down
57 changes: 35 additions & 22 deletions extensions/lists_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,65 +25,78 @@

namespace cel::extensions {

constexpr int kListsExtensionLatestVersion = 2;

// Register implementations for list extension functions.
//
// lists.range(n: int) -> list(int)
//
// <list(T)>.distinct() -> list(T)
// === Since version 0 ===
// <list(T)>.slice(start: int, end: int) -> list(T)
//
// === Since version 1 ===
// <list(dyn)>.flatten() -> list(dyn)
// <list(dyn)>.flatten(limit: int) -> list(dyn)
//
// === Since version 2 ===
// lists.range(n: int) -> list(int)
//
// <list(T)>.distinct() -> list(T)
//
// <list(T)>.reverse() -> list(T)
//
// <list(T)>.sort() -> list(T)
//
// <list(T)>.slice(start: int, end: int) -> list(T)
absl::Status RegisterListsFunctions(FunctionRegistry& registry,
const RuntimeOptions& options);
const RuntimeOptions& options,
int version = kListsExtensionLatestVersion);

// Register list macros.
//
// === Since version 2 ===
//
// <list(T)>.sortBy(<element name>, <element key expression>)
absl::Status RegisterListsMacros(MacroRegistry& registry,
const ParserOptions& options);
const ParserOptions& options,
int version = kListsExtensionLatestVersion);

// Type check declarations for the lists extension library.
// Provides decls for the following functions:
//
// lists.range(n: int) -> list(int)
//
// <list(T)>.distinct() -> list(T)
// === Since version 0 ===
// <list(T)>.slice(start: int, end: int) -> list(T)
//
// === Since version 1 ===
// <list(dyn)>.flatten() -> list(dyn)
// <list(dyn)>.flatten(limit: int) -> list(dyn)
//
// === Since version 2 ===
// lists.range(n: int) -> list(int)
//
// <list(T)>.distinct() -> list(T)
//
// <list(T)>.reverse() -> list(T)
//
// <list(T_)>.sort() -> list(T_) where T_ is partially orderable
//
// <list(T)>.slice(start: int, end: int) -> list(T)
CheckerLibrary ListsCheckerLibrary();
CheckerLibrary ListsCheckerLibrary(int version = kListsExtensionLatestVersion);

// Provides decls for the following functions:
//
// lists.range(n: int) -> list(int)
//
// <list(T)>.distinct() -> list(T)
// === Since version 0 ===
// <list(T)>.slice(start: int, end: int) -> list(T)
//
// === Since version 1 ===
// <list(dyn)>.flatten() -> list(dyn)
// <list(dyn)>.flatten(limit: int) -> list(dyn)
//
// <list(T)>.reverse() -> list(T)
//
// <list(T_)>.sort() -> list(T_) where T_ is partially orderable
// === Since version 2 ===
// lists.range(n: int) -> list(int)
//
// <list(T)>.slice(start: int, end: int) -> list(T)
// <list(T)>.distinct() -> list(T)
//
// and the following macros:
// <list(T)>.reverse() -> list(T)
//
// <list(T)>.sortBy(<element name>, <element key expression>)
CompilerLibrary ListsCompilerLibrary();
// <list(T_)>.sort() -> list(T_) where T_ is partially orderable
CompilerLibrary ListsCompilerLibrary(
int version = kListsExtensionLatestVersion);

} // namespace cel::extensions

Expand Down
80 changes: 80 additions & 0 deletions extensions/lists_functions_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
#include <vector>

#include "cel/expr/syntax.pb.h"
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/status_matchers.h"
#include "absl/strings/string_view.h"
#include "checker/type_check_issue.h"
#include "checker/validation_result.h"
#include "common/source.h"
#include "common/value.h"
Expand Down Expand Up @@ -54,7 +56,9 @@ using ::cel::test::ErrorValueIs;
using ::cel::expr::Expr;
using ::cel::expr::ParsedExpr;
using ::cel::expr::SourceInfo;
using ::testing::Contains;
using ::testing::HasSubstr;
using ::testing::IsEmpty;
using ::testing::ValuesIn;

struct TestInfo {
Expand Down Expand Up @@ -377,5 +381,81 @@ std::vector<ListCheckerTestCase> createListsCheckerParams() {
INSTANTIATE_TEST_SUITE_P(ListsCheckerLibraryTest, ListsCheckerLibraryTest,
ValuesIn(createListsCheckerParams()));

struct ListsExtensionVersionTestCase {
std::string expr;
std::vector<int> expected_supported_versions;
};

class ListsExtensionVersionTest
: public ::testing::TestWithParam<ListsExtensionVersionTestCase> {};

TEST_P(ListsExtensionVersionTest, ListsExtensionVersions) {
const ListsExtensionVersionTestCase& test_case = GetParam();
for (int version = 0;
version <= cel::extensions::kListsExtensionLatestVersion; ++version) {
CompilerLibrary compiler_library = ListsCompilerLibrary(version);

ASSERT_OK_AND_ASSIGN(
std::unique_ptr<CompilerBuilder> builder,
cel::NewCompilerBuilder(internal::GetTestingDescriptorPool(),
CompilerOptions()));
ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk());
ASSERT_THAT(builder->AddLibrary(std::move(compiler_library)), IsOk());

ASSERT_OK_AND_ASSIGN(std::unique_ptr<Compiler> compiler, builder->Build());
ASSERT_OK_AND_ASSIGN(ValidationResult result,
compiler->Compile(test_case.expr));
if (absl::c_contains(test_case.expected_supported_versions, version)) {
EXPECT_THAT(result.GetIssues(), IsEmpty())
<< "Expected no issues for expr: " << test_case.expr
<< " at version: " << version << " but got: " << result.FormatError();
} else {
EXPECT_THAT(result.GetIssues(),
Contains(Property(&TypeCheckIssue::message,
HasSubstr("undeclared reference"))));
}
}
};

std::vector<ListsExtensionVersionTestCase> CreateListsExtensionVersionParams() {
return {
ListsExtensionVersionTestCase{
.expr = "[0,1,2,3].slice(0, 2)",
.expected_supported_versions = {0, 1, 2},
},
ListsExtensionVersionTestCase{
.expr = "[[0]].flatten()",
.expected_supported_versions = {1, 2},
},
ListsExtensionVersionTestCase{
.expr = "[[0]].flatten(1)",
.expected_supported_versions = {1, 2},
},
ListsExtensionVersionTestCase{
.expr = "[1,2,3,4].sort()",
.expected_supported_versions = {2},
},
ListsExtensionVersionTestCase{
.expr = "[1,2,3,4].sortBy(x, x)",
.expected_supported_versions = {2},
},
ListsExtensionVersionTestCase{
.expr = "[1,2,3,4].distinct()",
.expected_supported_versions = {2},
},
ListsExtensionVersionTestCase{
.expr = "lists.range(4)",
.expected_supported_versions = {2},
},
ListsExtensionVersionTestCase{
.expr = "[1,2,3,4].reverse()",
.expected_supported_versions = {2},
},
};
}

INSTANTIATE_TEST_SUITE_P(ListsExtensionVersionTest, ListsExtensionVersionTest,
ValuesIn(CreateListsExtensionVersionParams()));

} // namespace
} // namespace cel::extensions