From 19b41218dd4626a93946df99839b50bc63d4c8d5 Mon Sep 17 00:00:00 2001 From: raphasampaio Date: Mon, 11 May 2026 19:30:22 -0300 Subject: [PATCH 01/13] Update --- include/quiver/c/expression/expression.h | 12 ++ include/quiver/expression/expression.h | 12 ++ include/quiver/expression/expression_node.h | 5 +- src/c/expression/expression.cpp | 33 +++++ src/expression/expression.cpp | 16 +++ src/expression/expression_node.cpp | 33 ++++- tests/test_expression.cpp | 151 ++++++++++++++++++++ 7 files changed, 257 insertions(+), 5 deletions(-) diff --git a/include/quiver/c/expression/expression.h b/include/quiver/c/expression/expression.h index 9b491fd7..68922682 100644 --- a/include/quiver/c/expression/expression.h +++ b/include/quiver/c/expression/expression.h @@ -20,6 +20,15 @@ typedef enum { QUIVER_EXPRESSION_OPERATION_DIVIDE = 3, } quiver_expression_operation_t; +// Unary operation kind +typedef enum { + QUIVER_EXPRESSION_UNARY_OPERATION_NEGATE = 0, + QUIVER_EXPRESSION_UNARY_OPERATION_ABS = 1, + QUIVER_EXPRESSION_UNARY_OPERATION_SQRT = 2, + QUIVER_EXPRESSION_UNARY_OPERATION_LOG = 3, + QUIVER_EXPRESSION_UNARY_OPERATION_EXP = 4, +} quiver_expression_unary_operation_t; + // Construction QUIVER_C_API quiver_error_t quiver_expression_from_file(quiver_binary_file_t* file, quiver_expression_t** out); @@ -39,6 +48,9 @@ QUIVER_C_API quiver_error_t quiver_expression_apply_scalar_left(quiver_expressio double lhs, quiver_expression_t* rhs, quiver_expression_t** out); +QUIVER_C_API quiver_error_t quiver_expression_apply_unary(quiver_expression_unary_operation_t operation, + quiver_expression_t* operand, + quiver_expression_t** out); // Materialization QUIVER_C_API quiver_error_t quiver_expression_save(quiver_expression_t* expression, const char* path); diff --git a/include/quiver/expression/expression.h b/include/quiver/expression/expression.h index de92b866..3ef765dd 100644 --- a/include/quiver/expression/expression.h +++ b/include/quiver/expression/expression.h @@ -42,6 +42,12 @@ class QUIVER_API Expression { friend Expression operator/(const Expression&, double); friend Expression operator/(double, const Expression&); + friend Expression operator-(const Expression&); + friend Expression abs(const Expression&); + friend Expression sqrt(const Expression&); + friend Expression log(const Expression&); + friend Expression exp(const Expression&); + std::shared_ptr node_; }; @@ -61,6 +67,12 @@ QUIVER_API Expression operator/(const Expression& lhs, const Expression& rhs); QUIVER_API Expression operator/(const Expression& lhs, double rhs); QUIVER_API Expression operator/(double lhs, const Expression& rhs); +QUIVER_API Expression operator-(const Expression& operand); +QUIVER_API Expression abs(const Expression& operand); +QUIVER_API Expression sqrt(const Expression& operand); +QUIVER_API Expression log(const Expression& operand); +QUIVER_API Expression exp(const Expression& operand); + } // namespace quiver #endif // QUIVER_EXPRESSION_H diff --git a/include/quiver/expression/expression_node.h b/include/quiver/expression/expression_node.h index 3c84f78f..24738d4f 100644 --- a/include/quiver/expression/expression_node.h +++ b/include/quiver/expression/expression_node.h @@ -89,7 +89,7 @@ class QUIVER_API ExpressionBinary final : public ExpressionNode { class QUIVER_API ExpressionUnary final : public ExpressionNode { public: - enum class Operation { Unspecified }; + enum class Operation { Negate, Abs, Sqrt, Log, Exp }; ExpressionUnary(Operation operation, std::shared_ptr operand); @@ -98,8 +98,11 @@ class QUIVER_API ExpressionUnary final : public ExpressionNode { void collect_input_files(std::vector& out) const override; private: + static double apply(Operation operation, double x); + Operation operation_; std::shared_ptr operand_; + mutable std::vector operand_row_buf_; }; class QUIVER_API ExpressionTernary final : public ExpressionNode { diff --git a/src/c/expression/expression.cpp b/src/c/expression/expression.cpp index 4db3afd5..261934bd 100644 --- a/src/c/expression/expression.cpp +++ b/src/c/expression/expression.cpp @@ -23,6 +23,22 @@ quiver::Expression dispatch(quiver_expression_operation_t operation, const Lhs& throw std::runtime_error("Cannot apply: unknown expression operation"); } +quiver::Expression dispatch_unary(quiver_expression_unary_operation_t operation, const quiver::Expression& operand) { + switch (operation) { + case QUIVER_EXPRESSION_UNARY_OPERATION_NEGATE: + return -operand; + case QUIVER_EXPRESSION_UNARY_OPERATION_ABS: + return quiver::abs(operand); + case QUIVER_EXPRESSION_UNARY_OPERATION_SQRT: + return quiver::sqrt(operand); + case QUIVER_EXPRESSION_UNARY_OPERATION_LOG: + return quiver::log(operand); + case QUIVER_EXPRESSION_UNARY_OPERATION_EXP: + return quiver::exp(operand); + } + throw std::runtime_error("Cannot apply: unknown expression unary operation"); +} + } // namespace extern "C" { @@ -107,6 +123,23 @@ QUIVER_C_API quiver_error_t quiver_expression_apply_scalar_left(quiver_expressio } } +QUIVER_C_API quiver_error_t quiver_expression_apply_unary(quiver_expression_unary_operation_t operation, + quiver_expression_t* operand, + quiver_expression_t** out) { + QUIVER_REQUIRE(operand, out); + + try { + *out = new quiver_expression(dispatch_unary(operation, operand->expression)); + return QUIVER_OK; + } catch (const std::bad_alloc&) { + quiver_set_last_error("Memory allocation failed"); + return QUIVER_ERROR; + } catch (const std::exception& e) { + quiver_set_last_error(e.what()); + return QUIVER_ERROR; + } +} + // Materialization QUIVER_C_API quiver_error_t quiver_expression_save(quiver_expression_t* expression, const char* path) { diff --git a/src/expression/expression.cpp b/src/expression/expression.cpp index e85e672c..984734c1 100644 --- a/src/expression/expression.cpp +++ b/src/expression/expression.cpp @@ -132,4 +132,20 @@ Expression operator/(double lhs, const Expression& rhs) { return Expression(std::make_shared(ExpressionBinary::Operation::Divide, scalar, rhs.node_)); } +Expression operator-(const Expression& operand) { + return Expression(std::make_shared(ExpressionUnary::Operation::Negate, operand.node_)); +} +Expression abs(const Expression& operand) { + return Expression(std::make_shared(ExpressionUnary::Operation::Abs, operand.node_)); +} +Expression sqrt(const Expression& operand) { + return Expression(std::make_shared(ExpressionUnary::Operation::Sqrt, operand.node_)); +} +Expression log(const Expression& operand) { + return Expression(std::make_shared(ExpressionUnary::Operation::Log, operand.node_)); +} +Expression exp(const Expression& operand) { + return Expression(std::make_shared(ExpressionUnary::Operation::Exp, operand.node_)); +} + } // namespace quiver diff --git a/src/expression/expression_node.cpp b/src/expression/expression_node.cpp index 16a1c128..8bb94ae8 100644 --- a/src/expression/expression_node.cpp +++ b/src/expression/expression_node.cpp @@ -373,15 +373,40 @@ void ExpressionBinary::collect_input_files(std::vector& out) const rhs_->collect_input_files(out); } +double ExpressionUnary::apply(Operation operation, double x) { + switch (operation) { + case Operation::Negate: + return -x; + case Operation::Abs: + return std::abs(x); + case Operation::Sqrt: + return std::sqrt(x); + case Operation::Log: + return std::log(x); + case Operation::Exp: + return std::exp(x); + } + throw std::runtime_error("Cannot apply: unhandled ExpressionUnary::Operation variant"); +} + ExpressionUnary::ExpressionUnary(Operation operation, std::shared_ptr operand) - : operation_(operation), operand_(std::move(operand)) {} + : operation_(operation), operand_(std::move(operand)) { + operand_row_buf_.resize(operand_->metadata().labels.size()); +} const BinaryMetadata& ExpressionUnary::metadata() const { - throw std::runtime_error("Cannot get_metadata: ExpressionUnary is not yet implemented"); + return operand_->metadata(); } -void ExpressionUnary::compute_row(const std::vector& /*dims*/, std::vector& /*out*/) const { - throw std::runtime_error("Cannot compute_row: ExpressionUnary is not yet implemented"); +void ExpressionUnary::compute_row(const std::vector& dims, std::vector& out) const { + const auto n = operand_row_buf_.size(); + if (out.size() != n) { + out.resize(n); + } + operand_->compute_row(dims, operand_row_buf_); + for (size_t k = 0; k < n; ++k) { + out[k] = apply(operation_, operand_row_buf_[k]); + } } void ExpressionUnary::collect_input_files(std::vector& out) const { diff --git a/tests/test_expression.cpp b/tests/test_expression.cpp index 7b355d5d..1fbde3b0 100644 --- a/tests/test_expression.cpp +++ b/tests/test_expression.cpp @@ -1566,3 +1566,154 @@ TEST_F(ExpressionFixture, SaveReleasesInternalHandlesOnDestruction) { auto reopened_writer = BinaryFile::open_file(path_a, 'w', md); // must not throw EXPECT_TRUE(reopened_writer.is_open()); } + +TEST_F(ExpressionFixture, UnaryNegate) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector& dims, size_t k) { + return static_cast(dims[0] * 100 + dims[1] * 10 + static_cast(k)); + }); + auto a = BinaryFile::open_file(path_a, 'r'); + Expression e = -Expression(a); + e.save(path_out); + + auto va = read_all_cells(path_a); + auto vo = read_all_cells(path_out); + ASSERT_EQ(va.size(), vo.size()); + for (size_t i = 0; i < va.size(); ++i) + EXPECT_DOUBLE_EQ(vo[i], -va[i]); +} + +TEST_F(ExpressionFixture, UnaryAbs) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector& dims, size_t k) { + // Alternate sign so abs has work to do. + const double v = static_cast(dims[0] * 10 + dims[1] + static_cast(k)); + return (dims[0] % 2 == 0) ? -v : v; + }); + auto a = BinaryFile::open_file(path_a, 'r'); + Expression e = abs(Expression(a)); + e.save(path_out); + + auto va = read_all_cells(path_a); + auto vo = read_all_cells(path_out); + ASSERT_EQ(va.size(), vo.size()); + for (size_t i = 0; i < va.size(); ++i) + EXPECT_DOUBLE_EQ(vo[i], std::abs(va[i])); +} + +TEST_F(ExpressionFixture, UnarySqrt) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector& dims, size_t k) { + return static_cast(dims[0] * 100 + dims[1] * 10 + static_cast(k) + 1); // > 0 + }); + auto a = BinaryFile::open_file(path_a, 'r'); + Expression e = sqrt(Expression(a)); + e.save(path_out); + + auto va = read_all_cells(path_a); + auto vo = read_all_cells(path_out); + ASSERT_EQ(va.size(), vo.size()); + for (size_t i = 0; i < va.size(); ++i) + EXPECT_DOUBLE_EQ(vo[i], std::sqrt(va[i])); +} + +TEST_F(ExpressionFixture, UnarySqrtPropagatesNaNOnNegative) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector&, size_t) { return -1.0; }); + auto a = BinaryFile::open_file(path_a, 'r'); + Expression e = sqrt(Expression(a)); + e.save(path_out); + + auto vo = read_all_cells(path_out); + for (size_t i = 0; i < vo.size(); ++i) + EXPECT_TRUE(std::isnan(vo[i])); +} + +TEST_F(ExpressionFixture, UnaryLog) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector& dims, size_t k) { + return static_cast(dims[0] * 10 + dims[1] + static_cast(k) + 1); // > 0 + }); + auto a = BinaryFile::open_file(path_a, 'r'); + Expression e = log(Expression(a)); + e.save(path_out); + + auto va = read_all_cells(path_a); + auto vo = read_all_cells(path_out); + ASSERT_EQ(va.size(), vo.size()); + for (size_t i = 0; i < va.size(); ++i) + EXPECT_DOUBLE_EQ(vo[i], std::log(va[i])); +} + +TEST_F(ExpressionFixture, UnaryExp) { + auto md = make_simple_metadata(); + // Small magnitudes keep exp() in a comfortable numerical range. + write_qvr(path_a, md, [](const std::vector& dims, size_t k) { + return static_cast(dims[0] + dims[1] + static_cast(k)) * 0.1; + }); + auto a = BinaryFile::open_file(path_a, 'r'); + Expression e = exp(Expression(a)); + e.save(path_out); + + auto va = read_all_cells(path_a); + auto vo = read_all_cells(path_out); + ASSERT_EQ(va.size(), vo.size()); + for (size_t i = 0; i < va.size(); ++i) + EXPECT_DOUBLE_EQ(vo[i], std::exp(va[i])); +} + +TEST_F(ExpressionFixture, UnaryMetadataPreserved) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); + auto a = BinaryFile::open_file(path_a, 'r'); + Expression neg = -Expression(a); + + const auto& m = neg.metadata(); + EXPECT_EQ(m.unit, "MW"); + ASSERT_EQ(m.dimensions.size(), 2u); + EXPECT_EQ(m.dimensions[0].name, "row"); + EXPECT_EQ(m.dimensions[0].size, 3); + EXPECT_EQ(m.dimensions[1].name, "col"); + EXPECT_EQ(m.dimensions[1].size, 2); + ASSERT_EQ(m.labels.size(), 2u); + EXPECT_EQ(m.labels[0], "val1"); + EXPECT_EQ(m.labels[1], "val2"); +} + +TEST_F(ExpressionFixture, UnaryComposes) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector& dims, size_t k) { + return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); + }); + auto a = BinaryFile::open_file(path_a, 'r'); + Expression e = abs(-Expression(a)); + e.save(path_out); + + auto va = read_all_cells(path_a); + auto vo = read_all_cells(path_out); + ASSERT_EQ(va.size(), vo.size()); + for (size_t i = 0; i < va.size(); ++i) + EXPECT_DOUBLE_EQ(vo[i], std::abs(va[i])); +} + +TEST_F(ExpressionFixture, UnaryComposesWithBinary) { + // Spot-check that -(a + b) parses cleanly and produces the expected result. + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector& dims, size_t k) { + return static_cast(dims[0] + dims[1] + static_cast(k)); + }); + write_qvr(path_b, md, [](const std::vector& dims, size_t k) { + return static_cast(dims[0] * 2 + dims[1] + static_cast(k)); + }); + auto a = BinaryFile::open_file(path_a, 'r'); + auto b = BinaryFile::open_file(path_b, 'r'); + Expression e = -(Expression(a) + Expression(b)); + e.save(path_out); + + auto va = read_all_cells(path_a); + auto vb = read_all_cells(path_b); + auto vo = read_all_cells(path_out); + ASSERT_EQ(va.size(), vo.size()); + for (size_t i = 0; i < va.size(); ++i) + EXPECT_DOUBLE_EQ(vo[i], -(va[i] + vb[i])); +} From 9c3de76bcc03b7e992fce91c7706235b539ede9d Mon Sep 17 00:00:00 2001 From: raphasampaio Date: Mon, 11 May 2026 19:55:53 -0300 Subject: [PATCH 02/13] Update --- CLAUDE.md | 18 ++- bindings/julia/src/c_api.jl | 12 ++ bindings/julia/src/expression.jl | 18 +++ bindings/julia/test/test_expression.jl | 197 +++++++++++++++++++++++++ tests/test_c_api_expression.cpp | 128 ++++++++++++++++ 5 files changed, 365 insertions(+), 8 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index ed8e9152..5a5b0069 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -172,7 +172,7 @@ struct Database::Impl { Binary subsystem: `BinaryFile` and `CSVConverter` use Pimpl (hide file I/O dependencies). `BinaryMetadata`, `Dimension`, `TimeProperties` are plain value types. -Expression subsystem: `Expression` is a plain value type wrapping `shared_ptr` — no Pimpl. `ExpressionNode` is an abstract base with virtual `metadata()` / `compute_row()`; concrete subclasses `ExpressionFile`, `ExpressionScalar`, `ExpressionBinary`, `ExpressionAggregate`, `ExpressionAggregateAgents` are exposed via `QUIVER_API` and use Rule of Zero. Polymorphism is justified by the recursive tree shape (`ExpressionBinary` and the aggregation nodes own child `shared_ptr` operands). Scaffold subclasses `ExpressionUnary`, `ExpressionTernary` exist with the same shape but throw `"not yet implemented"` from their virtuals — operations land in follow-up work. +Expression subsystem: `Expression` is a plain value type wrapping `shared_ptr` — no Pimpl. `ExpressionNode` is an abstract base with virtual `metadata()` / `compute_row()`; concrete subclasses `ExpressionFile`, `ExpressionScalar`, `ExpressionBinary`, `ExpressionUnary`, `ExpressionAggregate`, `ExpressionAggregateAgents` are exposed via `QUIVER_API` and use Rule of Zero. Polymorphism is justified by the recursive tree shape (`ExpressionBinary`, `ExpressionUnary`, and the aggregation nodes own child `shared_ptr` operands). Scaffold subclass `ExpressionTernary` exists with the same shape but throws `"not yet implemented"` from its virtuals — operations land in follow-up work. Classes with no private dependencies (`Element`, `Row`, `Migration`, `Migrations`, `GroupMetadata`, `ScalarMetadata`, `CSVOptions`, `BinaryMetadata`, `Dimension`, `TimeProperties`, `Expression`) are plain value types — direct members, no Pimpl, Rule of Zero (compiler-generated copy/move/destructor). @@ -482,31 +482,33 @@ Profiled with 480×500×31 dimensions (~7.3M read/write calls). Main hot-path co (Note: `FileNode::compute_row` was renamed to `ExpressionFile::compute_row`; the cache-amortization point still applies.) ### Expression Subsystem -Lazy expressions over `.qvr` binary files. Build a DAG using `+ - * /` operator overloads, materialize via `save()`. +Lazy expressions over `.qvr` binary files. Build a DAG using `+ - * /` operator overloads (binary and unary minus) and unary math free functions, materialize via `save()`. ```cpp auto a = BinaryFile::open_file("a", 'r'); auto b = BinaryFile::open_file("b", 'r'); -Expression result = (a + b) * 2.0; +Expression result = abs((a + b) * 2.0 - sqrt(Expression(a))); result.save("output"); // writes output.qvr + output.toml ``` - `Expression` value type (header `quiver/expression/expression.h`): - Constructors: `Expression(const BinaryFile&)` (implicit, enables `bf_a + bf_b`), `Expression(shared_ptr)` - - Accessors: `metadata()`, `node()` + - Accessors: `metadata()` - Materialize: `save(path)` — iterates via `first_dimensions`/`next_dimensions`, calls `compute_row()` per cell, writes to a new `.qvr`. Throws if `path` collides (after `weakly_canonical`) with any input file in the DAG. - Aggregation: `aggregate(dimension, op, [parameter])` collapses a dimension; `aggregate_agents(op, [parameter])` collapses the label axis. `op` is one of `"sum" | "mean" | "min" | "max" | "percentile"` (string tags, validated in C++). `percentile` requires a `parameter` fraction in `[0, 1]`; nullary ops reject `parameter`. -- Operator overloads (12 total): `+ - * /` × {expr+expr, expr+double, double+expr}. +- Operator overloads (12 binary + 1 unary): `+ - * /` × {expr+expr, expr+double, double+expr}, plus unary `-expr`. +- Free functions in `quiver::` for unary math: `abs(expr)`, `sqrt(expr)`, `log(expr)`, `exp(expr)`. - `ExpressionNode` hierarchy (header `quiver/expression/expression_node.h`): - `ExpressionNode` (abstract): `metadata()`, `compute_row(dims, out)` - `ExpressionFile`: lazy reads from a `.qvr`. Caches an open `BinaryFile` and a reusable `unordered_map` across calls (mutable members; not thread-safe per instance). - `ExpressionScalar`: broadcasts a constant across the operand's label space. - `ExpressionBinary`: combines two operands with `ExpressionBinary::Operation::{Add,Subtract,Multiply,Divide}` (nested enum). Constructor pre-computes broadcast metadata (`build_broadcast_metadata`), reusable input/output buffers, and `lhs_to_out_`/`rhs_to_out_` index translation tables. The `apply(Operation, double, double)` operation-dispatch is a private static member. + - `ExpressionUnary`: applies a single-operand math function with `ExpressionUnary::Operation::{Negate,Abs,Sqrt,Log,Exp}` (nested enum). `metadata()` returns the operand's metadata unchanged (no dimensional analysis — `sqrt(MW)` stays as `MW`). Constructor pre-allocates a reusable `operand_row_buf_`. Lets IEEE-754 NaN/inf propagate naturally (`sqrt(-1) → NaN`, `log(0) → -inf`); no NaN special-casing. The `apply(Operation, double)` operation-dispatch is a private static member. - `ExpressionAggregate`: collapses a named dimension. `Operation::{Sum,Mean,Min,Max,Percentile}` (nested enum). Constructor eagerly removes the dim from output metadata, rewires child time-dim `parent_dimension_index` transitively (a time dim whose parent was removed re-points to the removed dim's grandparent, or `-1`), and pre-allocates index translation + reusable buffers. Skips NaN inputs during accumulation; all-NaN range yields NaN. - `ExpressionAggregateAgents`: collapses the label axis to a single entry named after the operation (e.g., `"sum"`, `"mean"`, `"percentile"`). Dimensions, `initial_datetime`, `unit` unchanged. Same NaN policy as `ExpressionAggregate`. - - `ExpressionUnary`, `ExpressionTernary` (scaffolds): same shape, but their `metadata()` and `compute_row()` throw `Cannot {operation}: ExpressionXxx is not yet implemented`. Each carries a placeholder nested `enum class Operation { Unspecified }` until concrete operations are designed. -- Validation is **eager** at construction for `ExpressionBinary`, `ExpressionAggregate`, `ExpressionAggregateAgents` (units/dim sizes/time-dim properties/label sizes/initial datetimes for binary; dim existence + op/parameter consistency + output metadata validity for aggregations). Computation is **lazy**: no I/O until `save()`. -- The binary-operation enum is nested as `ExpressionBinary::Operation`. Aggregation operations are parsed via static `parse_operation(string)` methods on each aggregation class. The C API surface keeps its own stable enum `quiver_expression_operation_t`. + - `ExpressionTernary` (scaffold): same shape as the real nodes, but its `metadata()` and `compute_row()` throw `Cannot {operation}: ExpressionTernary is not yet implemented`. Carries a placeholder nested `enum class Operation { Unspecified }` until concrete operations are designed. +- Validation is **eager** at construction for `ExpressionBinary`, `ExpressionAggregate`, `ExpressionAggregateAgents` (units/dim sizes/time-dim properties/label sizes/initial datetimes for binary; dim existence + op/parameter consistency + output metadata validity for aggregations). `ExpressionUnary` has no inputs to cross-validate so its constructor just sizes the row buffer. Computation is **lazy**: no I/O until `save()`. +- The binary-operation enum is nested as `ExpressionBinary::Operation`; the unary-operation enum is nested as `ExpressionUnary::Operation`. Aggregation operations are parsed via static `parse_operation(string)` methods on each aggregation class. The C API surface keeps its own stable enums `quiver_expression_operation_t` and `quiver_expression_unary_operation_t`. ### LuaRunner Class Executes Lua scripts with database access: diff --git a/bindings/julia/src/c_api.jl b/bindings/julia/src/c_api.jl index adc9be6b..1a4a3c8c 100644 --- a/bindings/julia/src/c_api.jl +++ b/bindings/julia/src/c_api.jl @@ -646,6 +646,14 @@ const quiver_expression_t = quiver_expression QUIVER_EXPRESSION_OPERATION_DIVIDE = 3 end +@cenum quiver_expression_unary_operation_t::UInt32 begin + QUIVER_EXPRESSION_UNARY_OPERATION_NEGATE = 0 + QUIVER_EXPRESSION_UNARY_OPERATION_ABS = 1 + QUIVER_EXPRESSION_UNARY_OPERATION_SQRT = 2 + QUIVER_EXPRESSION_UNARY_OPERATION_LOG = 3 + QUIVER_EXPRESSION_UNARY_OPERATION_EXP = 4 +end + function quiver_expression_from_file(file, out) @ccall libquiver_c.quiver_expression_from_file(file::Ptr{quiver_binary_file_t}, out::Ptr{Ptr{quiver_expression_t}})::quiver_error_t end @@ -666,6 +674,10 @@ function quiver_expression_apply_scalar_left(operation, lhs, rhs, out) @ccall libquiver_c.quiver_expression_apply_scalar_left(operation::quiver_expression_operation_t, lhs::Cdouble, rhs::Ptr{quiver_expression_t}, out::Ptr{Ptr{quiver_expression_t}})::quiver_error_t end +function quiver_expression_apply_unary(operation, operand, out) + @ccall libquiver_c.quiver_expression_apply_unary(operation::quiver_expression_unary_operation_t, operand::Ptr{quiver_expression_t}, out::Ptr{Ptr{quiver_expression_t}})::quiver_error_t +end + function quiver_expression_save(expression, path) @ccall libquiver_c.quiver_expression_save(expression::Ptr{quiver_expression_t}, path::Ptr{Cchar})::quiver_error_t end diff --git a/bindings/julia/src/expression.jl b/bindings/julia/src/expression.jl index 74b1f56b..c3a0b924 100644 --- a/bindings/julia/src/expression.jl +++ b/bindings/julia/src/expression.jl @@ -40,6 +40,18 @@ function _binop(operation, lhs::Real, rhs::Expression) return Expression(out[]) end +function _unop(operation, e::Expression) + out = Ref{Ptr{C.quiver_expression}}(C_NULL) + check(C.quiver_expression_apply_unary(operation, e.ptr, out)) + return Expression(out[]) +end + +Base.:-(a::Expression) = _unop(C.QUIVER_EXPRESSION_UNARY_OPERATION_NEGATE, a) +Base.abs(a::Expression) = _unop(C.QUIVER_EXPRESSION_UNARY_OPERATION_ABS, a) +Base.sqrt(a::Expression) = _unop(C.QUIVER_EXPRESSION_UNARY_OPERATION_SQRT, a) +Base.log(a::Expression) = _unop(C.QUIVER_EXPRESSION_UNARY_OPERATION_LOG, a) +Base.exp(a::Expression) = _unop(C.QUIVER_EXPRESSION_UNARY_OPERATION_EXP, a) + Base.:+(a::Expression, b::Expression) = _binop(C.QUIVER_EXPRESSION_OPERATION_ADD, a, b) Base.:+(a::Expression, b::Real) = _binop(C.QUIVER_EXPRESSION_OPERATION_ADD, a, b) Base.:+(a::Real, b::Expression) = _binop(C.QUIVER_EXPRESSION_OPERATION_ADD, a, b) @@ -80,6 +92,12 @@ Base.:/(a::Real, b::Binary.File) = a / Expression(b) Base.:/(a::Binary.File, b::Expression) = Expression(a) / b Base.:/(a::Expression, b::Binary.File) = a / Expression(b) +Base.:-(a::Binary.File) = -Expression(a) +Base.abs(a::Binary.File) = abs(Expression(a)) +Base.sqrt(a::Binary.File) = sqrt(Expression(a)) +Base.log(a::Binary.File) = log(Expression(a)) +Base.exp(a::Binary.File) = exp(Expression(a)) + function save(e::Expression, path::String) check(C.quiver_expression_save(e.ptr, path)) return nothing diff --git a/bindings/julia/test/test_expression.jl b/bindings/julia/test/test_expression.jl index 20b48d87..8dfc5529 100644 --- a/bindings/julia/test/test_expression.jl +++ b/bindings/julia/test/test_expression.jl @@ -1305,6 +1305,203 @@ end cleanup(path_a, path_out) end end + + # ========================================================================== + # Unary operators (Negate, Abs, Sqrt, Log, Exp) + # ========================================================================== + + @testset "Unary negate on Expression" begin + path_a, path_out = make_path("a"), make_path("out") + try + write_fixture(path_a, (r, c, k) -> r * 100 + c * 10 + k) + with_expr(path_a) do a + result = -a + Quiver.save(result, path_out) + return Quiver.close!(result) + end + @test read_all_cells(path_out) == .-read_all_cells(path_a) + finally + cleanup(path_a, path_out) + end + end + + @testset "Unary abs on Expression" begin + path_a, path_out = make_path("a"), make_path("out") + try + # Alternate sign so abs has work to do. + write_fixture(path_a, (r, c, k) -> (r % 2 == 0 ? -1 : 1) * (r * 10 + c + k)) + with_expr(path_a) do a + result = abs(a) + Quiver.save(result, path_out) + return Quiver.close!(result) + end + @test read_all_cells(path_out) == abs.(read_all_cells(path_a)) + finally + cleanup(path_a, path_out) + end + end + + @testset "Unary sqrt on Expression" begin + path_a, path_out = make_path("a"), make_path("out") + try + write_fixture(path_a, (r, c, k) -> r * 100 + c * 10 + k + 1) # > 0 + with_expr(path_a) do a + result = sqrt(a) + Quiver.save(result, path_out) + return Quiver.close!(result) + end + @test read_all_cells(path_out) ≈ sqrt.(read_all_cells(path_a)) + finally + cleanup(path_a, path_out) + end + end + + @testset "Unary sqrt propagates NaN on negative" begin + path_a, path_out = make_path("a"), make_path("out") + try + write_fixture(path_a, (_, _, _) -> -1.0) + with_expr(path_a) do a + result = sqrt(a) + Quiver.save(result, path_out) + return Quiver.close!(result) + end + # Read with allow_nulls = true since sqrt(-1) produces NaN cells. + file = Quiver.Binary.open_file(path_out; mode = 'r') + try + for r in 1:3, c in 1:2 + cell = Quiver.Binary.read(file; allow_nulls = true, row = r, col = c) + @test all(isnan, cell) + end + finally + Quiver.Binary.close!(file) + end + finally + cleanup(path_a, path_out) + end + end + + @testset "Unary log on Expression" begin + path_a, path_out = make_path("a"), make_path("out") + try + write_fixture(path_a, (r, c, k) -> r * 10 + c + k + 1) # > 0 + with_expr(path_a) do a + result = log(a) + Quiver.save(result, path_out) + return Quiver.close!(result) + end + @test read_all_cells(path_out) ≈ log.(read_all_cells(path_a)) + finally + cleanup(path_a, path_out) + end + end + + @testset "Unary exp on Expression" begin + path_a, path_out = make_path("a"), make_path("out") + try + # Small magnitudes to keep exp() comfortable numerically. + write_fixture(path_a, (r, c, k) -> (r + c + k) * 0.1) + with_expr(path_a) do a + result = exp(a) + Quiver.save(result, path_out) + return Quiver.close!(result) + end + @test read_all_cells(path_out) ≈ exp.(read_all_cells(path_a)) + finally + cleanup(path_a, path_out) + end + end + + @testset "Unary metadata preserved" begin + path_a = make_path("a") + try + write_fixture(path_a, (_, _, _) -> 1.0) + with_expr(path_a) do a + neg = -a + try + md = Quiver.get_metadata(neg) + @test Quiver.Binary.get_unit(md) == "MW" + @test Quiver.Binary.get_labels(md) == ["val1", "val2"] + dims = Quiver.Binary.get_dimensions(md) + @test length(dims) == 2 + @test dims[1].name == "row" + @test dims[2].name == "col" + finally + Quiver.close!(neg) + end + end + finally + cleanup(path_a) + end + end + + @testset "Unary composes (abs of negate)" begin + path_a, path_out = make_path("a"), make_path("out") + try + write_fixture(path_a, (r, c, k) -> r * 10 + c + k) + with_expr(path_a) do a + result = abs(-a) + Quiver.save(result, path_out) + return Quiver.close!(result) + end + @test read_all_cells(path_out) == abs.(read_all_cells(path_a)) + finally + cleanup(path_a, path_out) + end + end + + @testset "Unary composes with binary: -(a + b)" begin + path_a, path_b, path_out = make_path("a"), make_path("b"), make_path("out") + try + write_fixture(path_a, (r, c, k) -> r + c + k) + write_fixture(path_b, (r, c, k) -> r * 2 + c + k) + with_expr(path_a) do a + with_expr(path_b) do b + result = -(a + b) + Quiver.save(result, path_out) + return Quiver.close!(result) + end + end + @test read_all_cells(path_out) == .-(read_all_cells(path_a) .+ read_all_cells(path_b)) + finally + cleanup(path_a, path_b, path_out) + end + end + + @testset "Unary on Binary.File shortcut" begin + path_a, path_out = make_path("a"), make_path("out") + try + write_fixture(path_a, (r, c, k) -> r * 100 + c * 10 + k + 1) + file = Quiver.Binary.open_file(path_a; mode = 'r') + try + result = sqrt(file) + Quiver.save(result, path_out) + Quiver.close!(result) + finally + Quiver.Binary.close!(file) + end + @test read_all_cells(path_out) ≈ sqrt.(read_all_cells(path_a)) + finally + cleanup(path_a, path_out) + end + end + + @testset "Unary negate on Binary.File shortcut" begin + path_a, path_out = make_path("a"), make_path("out") + try + write_fixture(path_a, (r, c, k) -> r * 100 + c * 10 + k) + file = Quiver.Binary.open_file(path_a; mode = 'r') + try + result = -file + Quiver.save(result, path_out) + Quiver.close!(result) + finally + Quiver.Binary.close!(file) + end + @test read_all_cells(path_out) == .-read_all_cells(path_a) + finally + cleanup(path_a, path_out) + end + end end end diff --git a/tests/test_c_api_expression.cpp b/tests/test_c_api_expression.cpp index 4fa9f6e9..783b2c55 100644 --- a/tests/test_c_api_expression.cpp +++ b/tests/test_c_api_expression.cpp @@ -1201,3 +1201,131 @@ TEST_F(ExpressionCApiFixture, FromUnopenedBinaryFile) { for (size_t i = 0; i < orig.size(); ++i) EXPECT_DOUBLE_EQ(orig[i], copy[i]); } + +TEST_F(ExpressionCApiFixture, UnaryNegate) { + write_fixture(path_a, [](int r, int c, int k) { return static_cast(r * 100 + c * 10 + k); }); + + auto* a = expr_from_file(path_a); + quiver_expression_t* neg = nullptr; + ASSERT_EQ(quiver_expression_apply_unary(QUIVER_EXPRESSION_UNARY_OPERATION_NEGATE, a, &neg), QUIVER_OK); + ASSERT_EQ(quiver_expression_save(neg, path_out.c_str()), QUIVER_OK); + quiver_expression_close(a); + quiver_expression_close(neg); + + auto va = read_all_cells(path_a); + auto vo = read_all_cells(path_out); + ASSERT_EQ(va.size(), vo.size()); + for (size_t i = 0; i < va.size(); ++i) + EXPECT_DOUBLE_EQ(vo[i], -va[i]); +} + +TEST_F(ExpressionCApiFixture, UnaryAbs) { + write_fixture(path_a, [](int r, int c, int k) { + const double v = static_cast(r * 10 + c + k); + return (r % 2 == 0) ? -v : v; + }); + + auto* a = expr_from_file(path_a); + quiver_expression_t* abs_e = nullptr; + ASSERT_EQ(quiver_expression_apply_unary(QUIVER_EXPRESSION_UNARY_OPERATION_ABS, a, &abs_e), QUIVER_OK); + ASSERT_EQ(quiver_expression_save(abs_e, path_out.c_str()), QUIVER_OK); + quiver_expression_close(a); + quiver_expression_close(abs_e); + + auto va = read_all_cells(path_a); + auto vo = read_all_cells(path_out); + ASSERT_EQ(va.size(), vo.size()); + for (size_t i = 0; i < va.size(); ++i) + EXPECT_DOUBLE_EQ(vo[i], std::abs(va[i])); +} + +TEST_F(ExpressionCApiFixture, UnarySqrt) { + write_fixture(path_a, [](int r, int c, int k) { return static_cast(r * 100 + c * 10 + k + 1); }); + + auto* a = expr_from_file(path_a); + quiver_expression_t* sqrt_e = nullptr; + ASSERT_EQ(quiver_expression_apply_unary(QUIVER_EXPRESSION_UNARY_OPERATION_SQRT, a, &sqrt_e), QUIVER_OK); + ASSERT_EQ(quiver_expression_save(sqrt_e, path_out.c_str()), QUIVER_OK); + quiver_expression_close(a); + quiver_expression_close(sqrt_e); + + auto va = read_all_cells(path_a); + auto vo = read_all_cells(path_out); + ASSERT_EQ(va.size(), vo.size()); + for (size_t i = 0; i < va.size(); ++i) + EXPECT_DOUBLE_EQ(vo[i], std::sqrt(va[i])); +} + +TEST_F(ExpressionCApiFixture, UnarySqrtPropagatesNaNOnNegative) { + write_fixture(path_a, [](int, int, int) { return -1.0; }); + + auto* a = expr_from_file(path_a); + quiver_expression_t* sqrt_e = nullptr; + ASSERT_EQ(quiver_expression_apply_unary(QUIVER_EXPRESSION_UNARY_OPERATION_SQRT, a, &sqrt_e), QUIVER_OK); + ASSERT_EQ(quiver_expression_save(sqrt_e, path_out.c_str()), QUIVER_OK); + quiver_expression_close(a); + quiver_expression_close(sqrt_e); + + // Read with allow_nulls=1 since sqrt(-1) produces NaN cells. + quiver_binary_file_t* f = nullptr; + ASSERT_EQ(quiver_binary_file_open_file(path_out.c_str(), 'r', nullptr, &f), QUIVER_OK); + const char* dim_names[] = {"row", "col"}; + for (int64_t r = 1; r <= 3; ++r) { + for (int64_t c = 1; c <= 2; ++c) { + int64_t dim_values[] = {r, c}; + double* data = nullptr; + size_t count = 0; + ASSERT_EQ(quiver_binary_file_read(f, dim_names, dim_values, 2, /*allow_nulls=*/1, &data, &count), + QUIVER_OK); + for (size_t i = 0; i < count; ++i) + EXPECT_TRUE(std::isnan(data[i])); + quiver_binary_file_free_float_array(data); + } + } + quiver_binary_file_close(f); +} + +TEST_F(ExpressionCApiFixture, UnaryLog) { + write_fixture(path_a, [](int r, int c, int k) { return static_cast(r * 10 + c + k + 1); }); + + auto* a = expr_from_file(path_a); + quiver_expression_t* log_e = nullptr; + ASSERT_EQ(quiver_expression_apply_unary(QUIVER_EXPRESSION_UNARY_OPERATION_LOG, a, &log_e), QUIVER_OK); + ASSERT_EQ(quiver_expression_save(log_e, path_out.c_str()), QUIVER_OK); + quiver_expression_close(a); + quiver_expression_close(log_e); + + auto va = read_all_cells(path_a); + auto vo = read_all_cells(path_out); + ASSERT_EQ(va.size(), vo.size()); + for (size_t i = 0; i < va.size(); ++i) + EXPECT_DOUBLE_EQ(vo[i], std::log(va[i])); +} + +TEST_F(ExpressionCApiFixture, UnaryExp) { + write_fixture(path_a, [](int r, int c, int k) { return static_cast(r + c + k) * 0.1; }); + + auto* a = expr_from_file(path_a); + quiver_expression_t* exp_e = nullptr; + ASSERT_EQ(quiver_expression_apply_unary(QUIVER_EXPRESSION_UNARY_OPERATION_EXP, a, &exp_e), QUIVER_OK); + ASSERT_EQ(quiver_expression_save(exp_e, path_out.c_str()), QUIVER_OK); + quiver_expression_close(a); + quiver_expression_close(exp_e); + + auto va = read_all_cells(path_a); + auto vo = read_all_cells(path_out); + ASSERT_EQ(va.size(), vo.size()); + for (size_t i = 0; i < va.size(); ++i) + EXPECT_DOUBLE_EQ(vo[i], std::exp(va[i])); +} + +TEST_F(ExpressionCApiFixture, UnaryNullArguments) { + write_fixture(path_a, [](int, int, int) { return 1.0; }); + auto* a = expr_from_file(path_a); + + quiver_expression_t* out = nullptr; + EXPECT_EQ(quiver_expression_apply_unary(QUIVER_EXPRESSION_UNARY_OPERATION_NEGATE, nullptr, &out), QUIVER_ERROR); + EXPECT_EQ(quiver_expression_apply_unary(QUIVER_EXPRESSION_UNARY_OPERATION_NEGATE, a, nullptr), QUIVER_ERROR); + + quiver_expression_close(a); +} From 9e6e4bf8a4ed4c8be44ad57caabc72fead984025 Mon Sep 17 00:00:00 2001 From: raphasampaio Date: Mon, 11 May 2026 21:48:44 -0300 Subject: [PATCH 03/13] Update --- include/quiver/c/expression/expression.h | 10 + include/quiver/expression/expression.h | 5 + include/quiver/expression/expression_node.h | 36 +++- src/c/expression/expression.cpp | 31 +++ src/expression/expression.cpp | 7 + src/expression/expression_node.cpp | 227 +++++++++++++++++++- 6 files changed, 298 insertions(+), 18 deletions(-) diff --git a/include/quiver/c/expression/expression.h b/include/quiver/c/expression/expression.h index 68922682..814cd3fe 100644 --- a/include/quiver/c/expression/expression.h +++ b/include/quiver/c/expression/expression.h @@ -29,6 +29,11 @@ typedef enum { QUIVER_EXPRESSION_UNARY_OPERATION_EXP = 4, } quiver_expression_unary_operation_t; +// Ternary operation kind +typedef enum { + QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE = 0, +} quiver_expression_ternary_operation_t; + // Construction QUIVER_C_API quiver_error_t quiver_expression_from_file(quiver_binary_file_t* file, quiver_expression_t** out); @@ -51,6 +56,11 @@ QUIVER_C_API quiver_error_t quiver_expression_apply_scalar_left(quiver_expressio QUIVER_C_API quiver_error_t quiver_expression_apply_unary(quiver_expression_unary_operation_t operation, quiver_expression_t* operand, quiver_expression_t** out); +QUIVER_C_API quiver_error_t quiver_expression_apply_ternary(quiver_expression_ternary_operation_t operation, + quiver_expression_t* condition, + quiver_expression_t* then_value, + quiver_expression_t* else_value, + quiver_expression_t** out); // Materialization QUIVER_C_API quiver_error_t quiver_expression_save(quiver_expression_t* expression, const char* path); diff --git a/include/quiver/expression/expression.h b/include/quiver/expression/expression.h index 3ef765dd..f27f28c7 100644 --- a/include/quiver/expression/expression.h +++ b/include/quiver/expression/expression.h @@ -48,6 +48,8 @@ class QUIVER_API Expression { friend Expression log(const Expression&); friend Expression exp(const Expression&); + friend Expression ifelse(const Expression&, const Expression&, const Expression&); + std::shared_ptr node_; }; @@ -73,6 +75,9 @@ QUIVER_API Expression sqrt(const Expression& operand); QUIVER_API Expression log(const Expression& operand); QUIVER_API Expression exp(const Expression& operand); +QUIVER_API Expression +ifelse(const Expression& condition, const Expression& then_value, const Expression& else_value); + } // namespace quiver #endif // QUIVER_EXPRESSION_H diff --git a/include/quiver/expression/expression_node.h b/include/quiver/expression/expression_node.h index 24738d4f..0a78ada8 100644 --- a/include/quiver/expression/expression_node.h +++ b/include/quiver/expression/expression_node.h @@ -107,22 +107,44 @@ class QUIVER_API ExpressionUnary final : public ExpressionNode { class QUIVER_API ExpressionTernary final : public ExpressionNode { public: - enum class Operation { Unspecified }; + enum class Operation { IfElse }; ExpressionTernary(Operation operation, - std::shared_ptr first, - std::shared_ptr second, - std::shared_ptr third); + std::shared_ptr condition, + std::shared_ptr then_value, + std::shared_ptr else_value); const BinaryMetadata& metadata() const override; void compute_row(const std::vector& dims, std::vector& out) const override; void collect_input_files(std::vector& out) const override; private: + static double apply(Operation operation, double condition, double then_value, double else_value); + Operation operation_; - std::shared_ptr first_; - std::shared_ptr second_; - std::shared_ptr third_; + std::shared_ptr condition_; + std::shared_ptr then_value_; + std::shared_ptr else_value_; + BinaryMetadata broadcast_meta_; + + std::vector condition_dim_sizes_; + std::vector then_dim_sizes_; + std::vector else_dim_sizes_; + + std::vector condition_to_out_; + std::vector then_to_out_; + std::vector else_to_out_; + + size_t condition_label_count_ = 0; + size_t then_label_count_ = 0; + size_t else_label_count_ = 0; + + mutable std::vector condition_dims_buf_; + mutable std::vector then_dims_buf_; + mutable std::vector else_dims_buf_; + mutable std::vector condition_buf_; + mutable std::vector then_buf_; + mutable std::vector else_buf_; }; class QUIVER_API ExpressionAggregate final : public ExpressionNode { diff --git a/src/c/expression/expression.cpp b/src/c/expression/expression.cpp index 261934bd..98b01b17 100644 --- a/src/c/expression/expression.cpp +++ b/src/c/expression/expression.cpp @@ -39,6 +39,17 @@ quiver::Expression dispatch_unary(quiver_expression_unary_operation_t operation, throw std::runtime_error("Cannot apply: unknown expression unary operation"); } +quiver::Expression dispatch_ternary(quiver_expression_ternary_operation_t operation, + const quiver::Expression& condition, + const quiver::Expression& then_value, + const quiver::Expression& else_value) { + switch (operation) { + case QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE: + return quiver::ifelse(condition, then_value, else_value); + } + throw std::runtime_error("Cannot apply: unknown expression ternary operation"); +} + } // namespace extern "C" { @@ -140,6 +151,26 @@ QUIVER_C_API quiver_error_t quiver_expression_apply_unary(quiver_expression_unar } } +QUIVER_C_API quiver_error_t quiver_expression_apply_ternary(quiver_expression_ternary_operation_t operation, + quiver_expression_t* condition, + quiver_expression_t* then_value, + quiver_expression_t* else_value, + quiver_expression_t** out) { + QUIVER_REQUIRE(condition, then_value, else_value, out); + + try { + *out = new quiver_expression( + dispatch_ternary(operation, condition->expression, then_value->expression, else_value->expression)); + return QUIVER_OK; + } catch (const std::bad_alloc&) { + quiver_set_last_error("Memory allocation failed"); + return QUIVER_ERROR; + } catch (const std::exception& e) { + quiver_set_last_error(e.what()); + return QUIVER_ERROR; + } +} + // Materialization QUIVER_C_API quiver_error_t quiver_expression_save(quiver_expression_t* expression, const char* path) { diff --git a/src/expression/expression.cpp b/src/expression/expression.cpp index 984734c1..b064991d 100644 --- a/src/expression/expression.cpp +++ b/src/expression/expression.cpp @@ -148,4 +148,11 @@ Expression exp(const Expression& operand) { return Expression(std::make_shared(ExpressionUnary::Operation::Exp, operand.node_)); } +Expression ifelse(const Expression& condition, const Expression& then_value, const Expression& else_value) { + return Expression(std::make_shared(ExpressionTernary::Operation::IfElse, + condition.node_, + then_value.node_, + else_value.node_)); +} + } // namespace quiver diff --git a/src/expression/expression_node.cpp b/src/expression/expression_node.cpp index 8bb94ae8..8eead0e0 100644 --- a/src/expression/expression_node.cpp +++ b/src/expression/expression_node.cpp @@ -71,11 +71,13 @@ std::string parent_name_of(int64_t parent_idx, const BinaryMetadata& m) { return (parent_idx >= 0) ? m.dimensions[parent_idx].name : std::string{}; } -void validate_compatibility(const BinaryMetadata& lhs, const BinaryMetadata& rhs) { +void validate_unit_match(const BinaryMetadata& lhs, const BinaryMetadata& rhs) { if (lhs.unit != rhs.unit) { throw std::runtime_error("Cannot apply: units differ ('" + lhs.unit + "' vs '" + rhs.unit + "')"); } +} +void validate_shape_compatibility(const BinaryMetadata& lhs, const BinaryMetadata& rhs) { for (const auto& l_dim : lhs.dimensions) { auto r_idx = find_dim_index(rhs.dimensions, l_dim.name); if (r_idx < 0) { @@ -135,6 +137,11 @@ void validate_compatibility(const BinaryMetadata& lhs, const BinaryMetadata& rhs } } +void validate_compatibility(const BinaryMetadata& lhs, const BinaryMetadata& rhs) { + validate_unit_match(lhs, rhs); + validate_shape_compatibility(lhs, rhs); +} + std::vector compute_output_labels(const std::vector& l_labels, const std::vector& r_labels) { const auto ll = l_labels.size(); @@ -207,6 +214,105 @@ build_broadcast_metadata(const BinaryMetadata& lhs, const BinaryMetadata& rhs, s return out; } +std::vector compute_ternary_output_labels(const std::vector& c_labels, + const std::vector& t_labels, + const std::vector& e_labels) { + const std::vector*> non_singleton = [&] { + std::vector*> v; + if (c_labels.size() > 1) + v.push_back(&c_labels); + if (t_labels.size() > 1) + v.push_back(&t_labels); + if (e_labels.size() > 1) + v.push_back(&e_labels); + return v; + }(); + + if (non_singleton.empty()) { + return t_labels; + } + + for (size_t i = 1; i < non_singleton.size(); ++i) { + if (*non_singleton[i] != *non_singleton[0]) { + throw std::runtime_error("Cannot apply: labels are incompatible across operands " + "(non-singleton label sets must match)"); + } + } + return *non_singleton[0]; +} + +BinaryMetadata build_ternary_broadcast_metadata(const BinaryMetadata& cond, + const BinaryMetadata& then_meta, + const BinaryMetadata& else_meta, + std::vector output_labels) { + BinaryMetadata out; + out.version = then_meta.version; + out.unit = then_meta.unit; // validated == else_meta.unit; cond.unit is ignored + out.labels = std::move(output_labels); + + const auto then_has_time = any_time_dim(then_meta.dimensions); + const auto else_has_time = any_time_dim(else_meta.dimensions); + const auto cond_has_time = any_time_dim(cond.dimensions); + if (then_has_time) { + out.initial_datetime = then_meta.initial_datetime; + } else if (else_has_time) { + out.initial_datetime = else_meta.initial_datetime; + } else if (cond_has_time) { + out.initial_datetime = cond.initial_datetime; + } else { + out.initial_datetime = then_meta.initial_datetime; + } + + const std::vector sources = {&cond, &then_meta, &else_meta}; + std::unordered_map output_index_by_name; + for (const auto* src : sources) { + for (const auto& dim : src->dimensions) { + if (output_index_by_name.count(dim.name)) + continue; + int64_t out_size = dim.size; + for (const auto* other : sources) { + if (other == src) + continue; + auto idx = find_dim_index(other->dimensions, dim.name); + if (idx >= 0) { + out_size = std::max(out_size, other->dimensions[idx].size); + } + } + Dimension d{dim.name, out_size, dim.time}; + out.dimensions.push_back(std::move(d)); + output_index_by_name[dim.name] = static_cast(out.dimensions.size()) - 1; + } + } + + for (auto& out_d : out.dimensions) { + if (!out_d.is_time_dimension()) + continue; + const BinaryMetadata* src_meta = nullptr; + int src_idx = -1; + for (const auto* s : sources) { + src_idx = find_dim_index(s->dimensions, out_d.name); + if (src_idx >= 0) { + src_meta = s; + break; + } + } + int64_t src_parent_idx = src_meta->dimensions[src_idx].time->parent_dimension_index; + if (src_parent_idx < 0) { + out_d.time->parent_dimension_index = -1; + continue; + } + const std::string& parent_name = src_meta->dimensions[src_parent_idx].name; + out_d.time->parent_dimension_index = output_index_by_name.find(parent_name)->second; + } + + for (const auto& d : out.dimensions) { + if (d.is_time_dimension()) { + ++out.number_of_time_dimensions; + } + } + return out; +} + template Op parse_aggregation_operation_name(const std::string& name, const std::string& fn_label) { if (name == "sum") @@ -413,24 +519,123 @@ void ExpressionUnary::collect_input_files(std::vector& out) const { operand_->collect_input_files(out); } +double ExpressionTernary::apply(Operation operation, double condition, double then_value, double else_value) { + switch (operation) { + case Operation::IfElse: + if (std::isnan(condition)) { + return std::numeric_limits::quiet_NaN(); + } + return (condition != 0.0) ? then_value : else_value; + } + throw std::runtime_error("Cannot apply: unhandled ExpressionTernary::Operation variant"); +} + ExpressionTernary::ExpressionTernary(Operation operation, - std::shared_ptr first, - std::shared_ptr second, - std::shared_ptr third) - : operation_(operation), first_(std::move(first)), second_(std::move(second)), third_(std::move(third)) {} + std::shared_ptr condition, + std::shared_ptr then_value, + std::shared_ptr else_value) + : operation_(operation), condition_(std::move(condition)), then_value_(std::move(then_value)), + else_value_(std::move(else_value)) { + const auto& condition_meta = condition_->metadata(); + const auto& then_meta = then_value_->metadata(); + const auto& else_meta = else_value_->metadata(); + + validate_unit_match(then_meta, else_meta); + validate_shape_compatibility(condition_meta, then_meta); + validate_shape_compatibility(then_meta, else_meta); + validate_shape_compatibility(condition_meta, else_meta); + + auto output_labels = compute_ternary_output_labels(condition_meta.labels, then_meta.labels, else_meta.labels); + broadcast_meta_ = + build_ternary_broadcast_metadata(condition_meta, then_meta, else_meta, std::move(output_labels)); + broadcast_meta_.validate(); + + const auto& out_dims = broadcast_meta_.dimensions; + condition_dim_sizes_.assign(out_dims.size(), 0); + then_dim_sizes_.assign(out_dims.size(), 0); + else_dim_sizes_.assign(out_dims.size(), 0); + condition_to_out_.assign(condition_meta.dimensions.size(), -1); + then_to_out_.assign(then_meta.dimensions.size(), -1); + else_to_out_.assign(else_meta.dimensions.size(), -1); + + for (size_t out_i = 0; out_i < out_dims.size(); ++out_i) { + const auto ci = find_dim_index(condition_meta.dimensions, out_dims[out_i].name); + const auto ti = find_dim_index(then_meta.dimensions, out_dims[out_i].name); + const auto ei = find_dim_index(else_meta.dimensions, out_dims[out_i].name); + condition_dim_sizes_[out_i] = (ci >= 0) ? condition_meta.dimensions[ci].size : 0; + then_dim_sizes_[out_i] = (ti >= 0) ? then_meta.dimensions[ti].size : 0; + else_dim_sizes_[out_i] = (ei >= 0) ? else_meta.dimensions[ei].size : 0; + if (ci >= 0) + condition_to_out_[ci] = static_cast(out_i); + if (ti >= 0) + then_to_out_[ti] = static_cast(out_i); + if (ei >= 0) + else_to_out_[ei] = static_cast(out_i); + } + + condition_label_count_ = condition_meta.labels.size(); + then_label_count_ = then_meta.labels.size(); + else_label_count_ = else_meta.labels.size(); + + condition_dims_buf_.resize(condition_meta.dimensions.size()); + then_dims_buf_.resize(then_meta.dimensions.size()); + else_dims_buf_.resize(else_meta.dimensions.size()); + condition_buf_.resize(condition_label_count_); + then_buf_.resize(then_label_count_); + else_buf_.resize(else_label_count_); +} const BinaryMetadata& ExpressionTernary::metadata() const { - throw std::runtime_error("Cannot get_metadata: ExpressionTernary is not yet implemented"); + return broadcast_meta_; } -void ExpressionTernary::compute_row(const std::vector& /*dims*/, std::vector& /*out*/) const { - throw std::runtime_error("Cannot compute_row: ExpressionTernary is not yet implemented"); +void ExpressionTernary::compute_row(const std::vector& dims, std::vector& out) const { + const auto out_label_count = broadcast_meta_.labels.size(); + if (out.size() != out_label_count) { + out.resize(out_label_count); + } + + for (size_t ci = 0; ci < condition_dims_buf_.size(); ++ci) { + const auto out_i = condition_to_out_[ci]; + auto coord = dims[out_i]; + if (condition_dim_sizes_[out_i] == 1) { + coord = 1; + } + condition_dims_buf_[ci] = coord; + } + for (size_t ti = 0; ti < then_dims_buf_.size(); ++ti) { + const auto out_i = then_to_out_[ti]; + auto coord = dims[out_i]; + if (then_dim_sizes_[out_i] == 1) { + coord = 1; + } + then_dims_buf_[ti] = coord; + } + for (size_t ei = 0; ei < else_dims_buf_.size(); ++ei) { + const auto out_i = else_to_out_[ei]; + auto coord = dims[out_i]; + if (else_dim_sizes_[out_i] == 1) { + coord = 1; + } + else_dims_buf_[ei] = coord; + } + + condition_->compute_row(condition_dims_buf_, condition_buf_); + then_value_->compute_row(then_dims_buf_, then_buf_); + else_value_->compute_row(else_dims_buf_, else_buf_); + + for (size_t k = 0; k < out_label_count; ++k) { + const size_t ck = (condition_label_count_ == 1) ? 0 : k; + const size_t tk = (then_label_count_ == 1) ? 0 : k; + const size_t ek = (else_label_count_ == 1) ? 0 : k; + out[k] = apply(operation_, condition_buf_[ck], then_buf_[tk], else_buf_[ek]); + } } void ExpressionTernary::collect_input_files(std::vector& out) const { - first_->collect_input_files(out); - second_->collect_input_files(out); - third_->collect_input_files(out); + condition_->collect_input_files(out); + then_value_->collect_input_files(out); + else_value_->collect_input_files(out); } ExpressionAggregate::Operation ExpressionAggregate::parse_operation(const std::string& name) { From 0a5570f5f7f80cf1c0693b6f86fbb6fae71c139c Mon Sep 17 00:00:00 2001 From: raphasampaio Date: Mon, 11 May 2026 22:16:40 -0300 Subject: [PATCH 04/13] Update --- CLAUDE.md | 3 +- bindings/julia/src/c_api.jl | 8 + bindings/julia/src/expression.jl | 22 +++ bindings/julia/test/test_expression.jl | 199 +++++++++++++++++++++ tests/test_c_api_expression.cpp | 155 ++++++++++++++++ tests/test_expression.cpp | 236 +++++++++++++++++++++++++ 6 files changed, 622 insertions(+), 1 deletion(-) diff --git a/CLAUDE.md b/CLAUDE.md index 5a5b0069..3d1ec777 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -172,7 +172,7 @@ struct Database::Impl { Binary subsystem: `BinaryFile` and `CSVConverter` use Pimpl (hide file I/O dependencies). `BinaryMetadata`, `Dimension`, `TimeProperties` are plain value types. -Expression subsystem: `Expression` is a plain value type wrapping `shared_ptr` — no Pimpl. `ExpressionNode` is an abstract base with virtual `metadata()` / `compute_row()`; concrete subclasses `ExpressionFile`, `ExpressionScalar`, `ExpressionBinary`, `ExpressionUnary`, `ExpressionAggregate`, `ExpressionAggregateAgents` are exposed via `QUIVER_API` and use Rule of Zero. Polymorphism is justified by the recursive tree shape (`ExpressionBinary`, `ExpressionUnary`, and the aggregation nodes own child `shared_ptr` operands). Scaffold subclass `ExpressionTernary` exists with the same shape but throws `"not yet implemented"` from its virtuals — operations land in follow-up work. +Expression subsystem: `Expression` is a plain value type wrapping `shared_ptr` — no Pimpl. `ExpressionNode` is an abstract base with virtual `metadata()` / `compute_row()`; concrete subclasses `ExpressionFile`, `ExpressionScalar`, `ExpressionBinary`, `ExpressionUnary`, `ExpressionTernary`, `ExpressionAggregate`, `ExpressionAggregateAgents` are exposed via `QUIVER_API` and use Rule of Zero. Polymorphism is justified by the recursive tree shape (`ExpressionBinary`, `ExpressionUnary`, `ExpressionTernary`, and the aggregation nodes own child `shared_ptr` operands). Classes with no private dependencies (`Element`, `Row`, `Migration`, `Migrations`, `GroupMetadata`, `ScalarMetadata`, `CSVOptions`, `BinaryMetadata`, `Dimension`, `TimeProperties`, `Expression`) are plain value types — direct members, no Pimpl, Rule of Zero (compiler-generated copy/move/destructor). @@ -498,6 +498,7 @@ result.save("output"); // writes output.qvr + output.toml - Aggregation: `aggregate(dimension, op, [parameter])` collapses a dimension; `aggregate_agents(op, [parameter])` collapses the label axis. `op` is one of `"sum" | "mean" | "min" | "max" | "percentile"` (string tags, validated in C++). `percentile` requires a `parameter` fraction in `[0, 1]`; nullary ops reject `parameter`. - Operator overloads (12 binary + 1 unary): `+ - * /` × {expr+expr, expr+double, double+expr}, plus unary `-expr`. - Free functions in `quiver::` for unary math: `abs(expr)`, `sqrt(expr)`, `log(expr)`, `exp(expr)`. +- Free function `ifelse(cond, then_value, else_value)` selects per-element: NaN cond → NaN; `cond != 0` → `then_value`; else → `else_value`. `then` and `else` units must match; `cond`'s unit is ignored. - `ExpressionNode` hierarchy (header `quiver/expression/expression_node.h`): - `ExpressionNode` (abstract): `metadata()`, `compute_row(dims, out)` - `ExpressionFile`: lazy reads from a `.qvr`. Caches an open `BinaryFile` and a reusable `unordered_map` across calls (mutable members; not thread-safe per instance). diff --git a/bindings/julia/src/c_api.jl b/bindings/julia/src/c_api.jl index 1a4a3c8c..3758601a 100644 --- a/bindings/julia/src/c_api.jl +++ b/bindings/julia/src/c_api.jl @@ -654,6 +654,10 @@ end QUIVER_EXPRESSION_UNARY_OPERATION_EXP = 4 end +@cenum quiver_expression_ternary_operation_t::UInt32 begin + QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE = 0 +end + function quiver_expression_from_file(file, out) @ccall libquiver_c.quiver_expression_from_file(file::Ptr{quiver_binary_file_t}, out::Ptr{Ptr{quiver_expression_t}})::quiver_error_t end @@ -678,6 +682,10 @@ function quiver_expression_apply_unary(operation, operand, out) @ccall libquiver_c.quiver_expression_apply_unary(operation::quiver_expression_unary_operation_t, operand::Ptr{quiver_expression_t}, out::Ptr{Ptr{quiver_expression_t}})::quiver_error_t end +function quiver_expression_apply_ternary(operation, condition, then_value, else_value, out) + @ccall libquiver_c.quiver_expression_apply_ternary(operation::quiver_expression_ternary_operation_t, condition::Ptr{quiver_expression_t}, then_value::Ptr{quiver_expression_t}, else_value::Ptr{quiver_expression_t}, out::Ptr{Ptr{quiver_expression_t}})::quiver_error_t +end + function quiver_expression_save(expression, path) @ccall libquiver_c.quiver_expression_save(expression::Ptr{quiver_expression_t}, path::Ptr{Cchar})::quiver_error_t end diff --git a/bindings/julia/src/expression.jl b/bindings/julia/src/expression.jl index c3a0b924..6369e423 100644 --- a/bindings/julia/src/expression.jl +++ b/bindings/julia/src/expression.jl @@ -98,6 +98,28 @@ Base.sqrt(a::Binary.File) = sqrt(Expression(a)) Base.log(a::Binary.File) = log(Expression(a)) Base.exp(a::Binary.File) = exp(Expression(a)) +function Base.ifelse(condition::Expression, then_value::Expression, else_value::Expression) + out = Ref{Ptr{C.quiver_expression}}(C_NULL) + check(C.quiver_expression_apply_ternary(C.QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, + condition.ptr, then_value.ptr, else_value.ptr, out)) + return Expression(out[]) +end + +Base.ifelse(condition::Binary.File, then_value::Binary.File, else_value::Binary.File) = + ifelse(Expression(condition), Expression(then_value), Expression(else_value)) +Base.ifelse(condition::Binary.File, then_value::Expression, else_value::Expression) = + ifelse(Expression(condition), then_value, else_value) +Base.ifelse(condition::Expression, then_value::Binary.File, else_value::Expression) = + ifelse(condition, Expression(then_value), else_value) +Base.ifelse(condition::Expression, then_value::Expression, else_value::Binary.File) = + ifelse(condition, then_value, Expression(else_value)) +Base.ifelse(condition::Binary.File, then_value::Binary.File, else_value::Expression) = + ifelse(Expression(condition), Expression(then_value), else_value) +Base.ifelse(condition::Binary.File, then_value::Expression, else_value::Binary.File) = + ifelse(Expression(condition), then_value, Expression(else_value)) +Base.ifelse(condition::Expression, then_value::Binary.File, else_value::Binary.File) = + ifelse(condition, Expression(then_value), Expression(else_value)) + function save(e::Expression, path::String) check(C.quiver_expression_save(e.ptr, path)) return nothing diff --git a/bindings/julia/test/test_expression.jl b/bindings/julia/test/test_expression.jl index 8dfc5529..2d01936b 100644 --- a/bindings/julia/test/test_expression.jl +++ b/bindings/julia/test/test_expression.jl @@ -1502,6 +1502,205 @@ end cleanup(path_a, path_out) end end + + @testset "ifelse selects by condition" begin + path_cond, path_then, path_else, path_out = + make_path("cond"), make_path("then"), make_path("else"), make_path("out") + try + write_fixture(path_cond, (r, _c, _k) -> r == 1 ? 1.0 : 0.0) + write_fixture(path_then, (_r, _c, _k) -> 10.0) + write_fixture(path_else, (_r, _c, _k) -> 20.0) + cond = Quiver.Binary.open_file(path_cond; mode = 'r') + then_v = Quiver.Binary.open_file(path_then; mode = 'r') + else_v = Quiver.Binary.open_file(path_else; mode = 'r') + try + result = ifelse(Quiver.Expression(cond), Quiver.Expression(then_v), Quiver.Expression(else_v)) + Quiver.save(result, path_out) + Quiver.close!(result) + finally + Quiver.Binary.close!(cond) + Quiver.Binary.close!(then_v) + Quiver.Binary.close!(else_v) + end + vc = read_all_cells(path_cond) + vo = read_all_cells(path_out) + for i in eachindex(vo) + @test vo[i] == (vc[i] != 0.0 ? 10.0 : 20.0) + end + finally + cleanup(path_cond, path_then, path_else, path_out) + end + end + + @testset "ifelse propagates NaN in condition" begin + path_cond, path_then, path_else, path_out = + make_path("cond"), make_path("then"), make_path("else"), make_path("out") + try + write_fixture(path_cond, (r, c, _k) -> (r == 1 && c == 1) ? NaN : 1.0) + write_fixture(path_then, (_r, _c, _k) -> 7.0) + write_fixture(path_else, (_r, _c, _k) -> -7.0) + cond = Quiver.Binary.open_file(path_cond; mode = 'r') + then_v = Quiver.Binary.open_file(path_then; mode = 'r') + else_v = Quiver.Binary.open_file(path_else; mode = 'r') + try + result = ifelse(Quiver.Expression(cond), Quiver.Expression(then_v), Quiver.Expression(else_v)) + Quiver.save(result, path_out) + Quiver.close!(result) + finally + Quiver.Binary.close!(cond) + Quiver.Binary.close!(then_v) + Quiver.Binary.close!(else_v) + end + file = Quiver.Binary.open_file(path_out; mode = 'r') + cell_11 = Quiver.Binary.read(file; allow_nulls = true, row = 1, col = 1) + cell_22 = Quiver.Binary.read(file; row = 2, col = 2) + Quiver.Binary.close!(file) + @test all(isnan, cell_11) + @test cell_22 == [7.0, 7.0] + finally + cleanup(path_cond, path_then, path_else, path_out) + end + end + + @testset "ifelse unselected-branch NaN does not propagate" begin + path_cond, path_then, path_else, path_out = + make_path("cond"), make_path("then"), make_path("else"), make_path("out") + try + write_fixture(path_cond, (_r, _c, _k) -> 1.0) + write_fixture(path_then, (_r, _c, _k) -> 42.0) + write_fixture(path_else, (_r, _c, _k) -> NaN) + cond = Quiver.Binary.open_file(path_cond; mode = 'r') + then_v = Quiver.Binary.open_file(path_then; mode = 'r') + else_v = Quiver.Binary.open_file(path_else; mode = 'r') + try + result = ifelse(Quiver.Expression(cond), Quiver.Expression(then_v), Quiver.Expression(else_v)) + Quiver.save(result, path_out) + Quiver.close!(result) + finally + Quiver.Binary.close!(cond) + Quiver.Binary.close!(then_v) + Quiver.Binary.close!(else_v) + end + @test all(v -> v == 42.0, read_all_cells(path_out)) + finally + cleanup(path_cond, path_then, path_else, path_out) + end + end + + @testset "ifelse condition unit ignored" begin + path_cond, path_then, path_else, path_out = + make_path("cond"), make_path("then"), make_path("else"), make_path("out") + try + md_cond = make_metadata(; unit = "flag") # cond.unit = flag + md_branch = make_metadata(; unit = "MW") # then/else MW + write_fixture_with_metadata(path_cond, md_cond, (_r, _c, _k) -> 1.0) + write_fixture_with_metadata(path_then, md_branch, (_r, _c, _k) -> 10.0) + write_fixture_with_metadata(path_else, md_branch, (_r, _c, _k) -> 20.0) + cond = Quiver.Binary.open_file(path_cond; mode = 'r') + then_v = Quiver.Binary.open_file(path_then; mode = 'r') + else_v = Quiver.Binary.open_file(path_else; mode = 'r') + try + result = ifelse(Quiver.Expression(cond), Quiver.Expression(then_v), Quiver.Expression(else_v)) + meta = Quiver.get_metadata(result) + @test Quiver.Binary.get_unit(meta) == "MW" + Quiver.save(result, path_out) + Quiver.close!(result) + finally + Quiver.Binary.close!(cond) + Quiver.Binary.close!(then_v) + Quiver.Binary.close!(else_v) + end + @test all(v -> v == 10.0, read_all_cells(path_out)) + finally + cleanup(path_cond, path_then, path_else, path_out) + end + end + + @testset "ifelse unit mismatch between then and else throws" begin + path_cond, path_then, path_else = make_path("cond"), make_path("then"), make_path("else") + try + md = make_metadata(; unit = "MW") + md_other = make_metadata(; unit = "kWh") + write_fixture_with_metadata(path_cond, md, (_r, _c, _k) -> 1.0) + write_fixture_with_metadata(path_then, md, (_r, _c, _k) -> 1.0) + write_fixture_with_metadata(path_else, md_other, (_r, _c, _k) -> 1.0) + cond = Quiver.Binary.open_file(path_cond; mode = 'r') + then_v = Quiver.Binary.open_file(path_then; mode = 'r') + else_v = Quiver.Binary.open_file(path_else; mode = 'r') + try + @test_throws Quiver.DatabaseException ifelse( + Quiver.Expression(cond), Quiver.Expression(then_v), Quiver.Expression(else_v), + ) + finally + Quiver.Binary.close!(cond) + Quiver.Binary.close!(then_v) + Quiver.Binary.close!(else_v) + end + finally + cleanup(path_cond, path_then, path_else) + end + end + + @testset "ifelse on Binary.File shortcuts" begin + path_cond, path_then, path_else, path_out = + make_path("cond"), make_path("then"), make_path("else"), make_path("out") + try + write_fixture(path_cond, (r, _c, _k) -> r == 1 ? 1.0 : 0.0) + write_fixture(path_then, (_r, _c, _k) -> 100.0) + write_fixture(path_else, (_r, _c, _k) -> 200.0) + cond = Quiver.Binary.open_file(path_cond; mode = 'r') + then_v = Quiver.Binary.open_file(path_then; mode = 'r') + else_v = Quiver.Binary.open_file(path_else; mode = 'r') + try + # All three as Binary.File + result = ifelse(cond, then_v, else_v) + Quiver.save(result, path_out) + Quiver.close!(result) + finally + Quiver.Binary.close!(cond) + Quiver.Binary.close!(then_v) + Quiver.Binary.close!(else_v) + end + vc = read_all_cells(path_cond) + vo = read_all_cells(path_out) + for i in eachindex(vo) + @test vo[i] == (vc[i] != 0.0 ? 100.0 : 200.0) + end + finally + cleanup(path_cond, path_then, path_else, path_out) + end + end + + @testset "ifelse chains with binary ops" begin + path_cond, path_then, path_else, path_out = + make_path("cond"), make_path("then"), make_path("else"), make_path("out") + try + write_fixture(path_cond, (r, _c, _k) -> r == 1 ? 1.0 : 0.0) + write_fixture(path_then, (_r, _c, _k) -> 10.0) + write_fixture(path_else, (_r, _c, _k) -> 20.0) + cond = Quiver.Binary.open_file(path_cond; mode = 'r') + then_v = Quiver.Binary.open_file(path_then; mode = 'r') + else_v = Quiver.Binary.open_file(path_else; mode = 'r') + try + # 2 * ifelse(cond, then, else) + 1 + result = 2.0 * ifelse(Quiver.Expression(cond), Quiver.Expression(then_v), Quiver.Expression(else_v)) + 1.0 + Quiver.save(result, path_out) + Quiver.close!(result) + finally + Quiver.Binary.close!(cond) + Quiver.Binary.close!(then_v) + Quiver.Binary.close!(else_v) + end + vc = read_all_cells(path_cond) + vo = read_all_cells(path_out) + for i in eachindex(vo) + base = vc[i] != 0.0 ? 10.0 : 20.0 + @test vo[i] == 2.0 * base + 1.0 + end + finally + cleanup(path_cond, path_then, path_else, path_out) + end + end end end diff --git a/tests/test_c_api_expression.cpp b/tests/test_c_api_expression.cpp index 783b2c55..82810200 100644 --- a/tests/test_c_api_expression.cpp +++ b/tests/test_c_api_expression.cpp @@ -1329,3 +1329,158 @@ TEST_F(ExpressionCApiFixture, UnaryNullArguments) { quiver_expression_close(a); } + +// ============================================================================ +// Ternary operations (ifelse) +// ============================================================================ + +TEST_F(ExpressionCApiFixture, ApplyTernaryIfElse) { + write_fixture(path_a, [](int r, int /*c*/, int /*k*/) { return (r == 1) ? 1.0 : 0.0; }); + write_fixture(path_b, [](int /*r*/, int /*c*/, int /*k*/) { return 10.0; }); + write_fixture(path_c, [](int /*r*/, int /*c*/, int /*k*/) { return 20.0; }); + + auto* cond = expr_from_file(path_a); + auto* then_v = expr_from_file(path_b); + auto* else_v = expr_from_file(path_c); + quiver_expression_t* result = nullptr; + ASSERT_EQ(quiver_expression_apply_ternary( + QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, else_v, &result), + QUIVER_OK); + ASSERT_EQ(quiver_expression_save(result, path_out.c_str()), QUIVER_OK); + quiver_expression_close(cond); + quiver_expression_close(then_v); + quiver_expression_close(else_v); + quiver_expression_close(result); + + auto vc = read_all_cells(path_a); + auto vo = read_all_cells(path_out); + ASSERT_EQ(vc.size(), vo.size()); + for (size_t i = 0; i < vo.size(); ++i) { + const double expected = (vc[i] != 0.0) ? 10.0 : 20.0; + EXPECT_DOUBLE_EQ(vo[i], expected); + } +} + +TEST_F(ExpressionCApiFixture, ApplyTernaryIfElsePropagatesNaN) { + const double nan_v = std::numeric_limits::quiet_NaN(); + write_fixture(path_a, [&](int r, int c, int /*k*/) { return (r == 1 && c == 1) ? nan_v : 1.0; }); + write_fixture(path_b, [](int, int, int) { return 7.0; }); + write_fixture(path_c, [](int, int, int) { return -7.0; }); + + auto* cond = expr_from_file(path_a); + auto* then_v = expr_from_file(path_b); + auto* else_v = expr_from_file(path_c); + quiver_expression_t* result = nullptr; + ASSERT_EQ(quiver_expression_apply_ternary( + QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, else_v, &result), + QUIVER_OK); + ASSERT_EQ(quiver_expression_save(result, path_out.c_str()), QUIVER_OK); + quiver_expression_close(cond); + quiver_expression_close(then_v); + quiver_expression_close(else_v); + quiver_expression_close(result); + + // Re-read output allowing nulls so NaN cells don't trip the reader. + quiver_binary_file_t* f = nullptr; + ASSERT_EQ(quiver_binary_file_open_file(path_out.c_str(), 'r', nullptr, &f), QUIVER_OK); + const char* dim_names[] = {"row", "col"}; + int64_t cell_11[] = {1, 1}; + int64_t cell_22[] = {2, 2}; + double* data = nullptr; + size_t count = 0; + ASSERT_EQ(quiver_binary_file_read(f, dim_names, cell_11, 2, /*allow_nulls=*/1, &data, &count), QUIVER_OK); + EXPECT_TRUE(std::isnan(data[0])); + EXPECT_TRUE(std::isnan(data[1])); + quiver_binary_file_free_float_array(data); + + ASSERT_EQ(quiver_binary_file_read(f, dim_names, cell_22, 2, /*allow_nulls=*/0, &data, &count), QUIVER_OK); + EXPECT_DOUBLE_EQ(data[0], 7.0); + EXPECT_DOUBLE_EQ(data[1], 7.0); + quiver_binary_file_free_float_array(data); + quiver_binary_file_close(f); +} + +TEST_F(ExpressionCApiFixture, ApplyTernaryNullArguments) { + write_fixture(path_a, [](int, int, int) { return 1.0; }); + write_fixture(path_b, [](int, int, int) { return 2.0; }); + write_fixture(path_c, [](int, int, int) { return 3.0; }); + auto* cond = expr_from_file(path_a); + auto* then_v = expr_from_file(path_b); + auto* else_v = expr_from_file(path_c); + + quiver_expression_t* out = nullptr; + EXPECT_EQ(quiver_expression_apply_ternary( + QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, nullptr, then_v, else_v, &out), + QUIVER_ERROR); + EXPECT_EQ(quiver_expression_apply_ternary( + QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, nullptr, else_v, &out), + QUIVER_ERROR); + EXPECT_EQ(quiver_expression_apply_ternary( + QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, nullptr, &out), + QUIVER_ERROR); + EXPECT_EQ( + quiver_expression_apply_ternary(QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, else_v, nullptr), + QUIVER_ERROR); + + quiver_expression_close(cond); + quiver_expression_close(then_v); + quiver_expression_close(else_v); +} + +TEST_F(ExpressionCApiFixture, ApplyTernaryUnitMismatch) { + auto* md_mw = make_metadata(3, 2, "MW", {"val1", "val2"}); + auto* md_kwh = make_metadata(3, 2, "kWh", {"val1", "val2"}); + write_fixture_with_metadata(path_a, md_mw, [](int, int, int) { return 1.0; }); + write_fixture_with_metadata(path_b, md_mw, [](int, int, int) { return 2.0; }); + write_fixture_with_metadata(path_c, md_kwh, [](int, int, int) { return 3.0; }); + quiver_binary_metadata_free(md_mw); + quiver_binary_metadata_free(md_kwh); + + auto* cond = expr_from_file(path_a); + auto* then_v = expr_from_file(path_b); + auto* else_v = expr_from_file(path_c); + quiver_expression_t* out = nullptr; + EXPECT_EQ( + quiver_expression_apply_ternary(QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, else_v, &out), + QUIVER_ERROR); + const char* msg = quiver_get_last_error(); + ASSERT_NE(msg, nullptr); + EXPECT_NE(std::string(msg).find("units differ"), std::string::npos); + + quiver_expression_close(cond); + quiver_expression_close(then_v); + quiver_expression_close(else_v); +} + +TEST_F(ExpressionCApiFixture, ApplyTernaryShapeMismatch) { + auto* md_3x2 = make_metadata(3, 2, "MW", {"val1", "val2"}); + auto* md_4x2 = make_metadata(4, 2, "MW", {"val1", "val2"}); + write_fixture_with_metadata(path_a, md_3x2, [](int, int, int) { return 1.0; }); + write_fixture_with_metadata(path_b, md_3x2, [](int, int, int) { return 2.0; }); + // path_c uses 4x2 (incompatible with 3x2) + quiver_binary_file_t* f = nullptr; + ASSERT_EQ(quiver_binary_file_open_file(path_c.c_str(), 'w', md_4x2, &f), QUIVER_OK); + const char* dim_names[] = {"row", "col"}; + for (int64_t r = 1; r <= 4; ++r) { + for (int64_t c = 1; c <= 2; ++c) { + int64_t dvs[] = {r, c}; + double data[] = {3.0, 3.0}; + ASSERT_EQ(quiver_binary_file_write(f, dim_names, dvs, 2, data, 2), QUIVER_OK); + } + } + ASSERT_EQ(quiver_binary_file_close(f), QUIVER_OK); + quiver_binary_metadata_free(md_3x2); + quiver_binary_metadata_free(md_4x2); + + auto* cond = expr_from_file(path_a); + auto* then_v = expr_from_file(path_b); + auto* else_v = expr_from_file(path_c); + quiver_expression_t* out = nullptr; + EXPECT_EQ( + quiver_expression_apply_ternary(QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, else_v, &out), + QUIVER_ERROR); + + quiver_expression_close(cond); + quiver_expression_close(then_v); + quiver_expression_close(else_v); +} diff --git a/tests/test_expression.cpp b/tests/test_expression.cpp index 1fbde3b0..8ca69161 100644 --- a/tests/test_expression.cpp +++ b/tests/test_expression.cpp @@ -1717,3 +1717,239 @@ TEST_F(ExpressionFixture, UnaryComposesWithBinary) { for (size_t i = 0; i < va.size(); ++i) EXPECT_DOUBLE_EQ(vo[i], -(va[i] + vb[i])); } + +TEST_F(ExpressionFixture, IfElseSelectsByCondition) { + // cond: nonzero at row 1, zero at rows 2-3 + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector& dims, size_t /*k*/) { + return (dims[0] == 1) ? 1.0 : 0.0; // condition + }); + write_qvr(path_b, md, [](const std::vector& dims, size_t k) { + return static_cast(dims[0] * 100 + dims[1] * 10 + static_cast(k)); // then + }); + write_qvr(path_c, md, [](const std::vector& dims, size_t k) { + return -static_cast(dims[0] * 100 + dims[1] * 10 + static_cast(k)); // else + }); + + auto cond = BinaryFile::open_file(path_a, 'r'); + auto then_v = BinaryFile::open_file(path_b, 'r'); + auto else_v = BinaryFile::open_file(path_c, 'r'); + Expression e = ifelse(Expression(cond), Expression(then_v), Expression(else_v)); + e.save(path_out); + + auto vc = read_all_cells(path_a); + auto vt = read_all_cells(path_b); + auto ve = read_all_cells(path_c); + auto vo = read_all_cells(path_out); + ASSERT_EQ(vc.size(), vo.size()); + for (size_t i = 0; i < vo.size(); ++i) { + const double expected = (vc[i] != 0.0) ? vt[i] : ve[i]; + EXPECT_DOUBLE_EQ(vo[i], expected) << " at index " << i; + } +} + +TEST_F(ExpressionFixture, IfElsePropagatesNaNInCondition) { + auto md = make_simple_metadata(); + const double nan_v = std::numeric_limits::quiet_NaN(); + // cond: NaN at (1,1), 1 elsewhere + write_qvr(path_a, md, [&](const std::vector& dims, size_t /*k*/) { + return (dims[0] == 1 && dims[1] == 1) ? nan_v : 1.0; + }); + write_qvr(path_b, md, [](const std::vector&, size_t) { return 7.0; }); + write_qvr(path_c, md, [](const std::vector&, size_t) { return -7.0; }); + + auto cond = BinaryFile::open_file(path_a, 'r'); + auto then_v = BinaryFile::open_file(path_b, 'r'); + auto else_v = BinaryFile::open_file(path_c, 'r'); + Expression e = ifelse(Expression(cond), Expression(then_v), Expression(else_v)); + e.save(path_out); + + auto reopened = BinaryFile::open_file(path_out, 'r'); + auto cell_11 = reopened.read({{"row", 1}, {"col", 1}}, true); + auto cell_22 = reopened.read({{"row", 2}, {"col", 2}}, true); + EXPECT_TRUE(std::isnan(cell_11[0])); + EXPECT_TRUE(std::isnan(cell_11[1])); + EXPECT_DOUBLE_EQ(cell_22[0], 7.0); + EXPECT_DOUBLE_EQ(cell_22[1], 7.0); +} + +TEST_F(ExpressionFixture, IfElseUnselectedBranchNaNDoesNotPropagate) { + auto md = make_simple_metadata(); + const double nan_v = std::numeric_limits::quiet_NaN(); + // cond is always 1 (true) -> only `then` matters; else NaN is irrelevant. + write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); + write_qvr(path_b, md, [](const std::vector&, size_t) { return 42.0; }); + write_qvr(path_c, md, [&](const std::vector&, size_t) { return nan_v; }); + + auto cond = BinaryFile::open_file(path_a, 'r'); + auto then_v = BinaryFile::open_file(path_b, 'r'); + auto else_v = BinaryFile::open_file(path_c, 'r'); + Expression e = ifelse(Expression(cond), Expression(then_v), Expression(else_v)); + e.save(path_out); + + auto vo = read_all_cells(path_out); + for (double v : vo) { + EXPECT_DOUBLE_EQ(v, 42.0); + } +} + +TEST_F(ExpressionFixture, IfElseBroadcastsConditionSizeOneDim) { + auto md_cond = BinaryMetadata::from_element(Element() + .set("version", "1") + .set("initial_datetime", "2025-01-01T00:00:00") + .set("unit", "flag") + .set("dimensions", {"row", "col"}) + .set("dimension_sizes", {1, 2}) // broadcast row + .set("labels", {"val1", "val2"})); + auto md_full = make_simple_metadata(); + // cond at row=1 is [1, 0] for col=1, col=2 respectively (per-column mask) + write_qvr(path_a, md_cond, [](const std::vector& dims, size_t /*k*/) { + return (dims[1] == 1) ? 1.0 : 0.0; + }); + write_qvr(path_b, md_full, [](const std::vector&, size_t) { return 100.0; }); + write_qvr(path_c, md_full, [](const std::vector&, size_t) { return -100.0; }); + + auto cond = BinaryFile::open_file(path_a, 'r'); + auto then_v = BinaryFile::open_file(path_b, 'r'); + auto else_v = BinaryFile::open_file(path_c, 'r'); + Expression e = ifelse(Expression(cond), Expression(then_v), Expression(else_v)); + e.save(path_out); + + // Output row=3 (from full), col=2 (matches). For all rows: col=1 selects then (100), + // col=2 selects else (-100). Condition's size-1 row dim broadcasts. + auto reopened = BinaryFile::open_file(path_out, 'r'); + EXPECT_EQ(reopened.get_metadata().dimensions[0].size, 3); + for (int64_t r = 1; r <= 3; ++r) { + auto cell_r1 = reopened.read({{"row", r}, {"col", 1}}, true); + auto cell_r2 = reopened.read({{"row", r}, {"col", 2}}, true); + EXPECT_DOUBLE_EQ(cell_r1[0], 100.0); + EXPECT_DOUBLE_EQ(cell_r2[0], -100.0); + } +} + +TEST_F(ExpressionFixture, IfElseBroadcastsLabels) { + auto md_single = BinaryMetadata::from_element(Element() + .set("version", "1") + .set("initial_datetime", "2025-01-01T00:00:00") + .set("unit", "flag") + .set("dimensions", {"row", "col"}) + .set("dimension_sizes", {3, 2}) + .set("labels", {"only"})); // 1 label + auto md_full = make_simple_metadata(); // 2 labels + write_qvr(path_a, md_single, [](const std::vector& dims, size_t /*k*/) { + return (dims[0] == 1) ? 1.0 : 0.0; + }); + write_qvr(path_b, md_full, [](const std::vector& dims, size_t k) { + return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); + }); + write_qvr(path_c, md_full, [](const std::vector& dims, size_t k) { + return -static_cast(dims[0] * 10 + dims[1] + static_cast(k)); + }); + + auto cond = BinaryFile::open_file(path_a, 'r'); + auto then_v = BinaryFile::open_file(path_b, 'r'); + auto else_v = BinaryFile::open_file(path_c, 'r'); + Expression e = ifelse(Expression(cond), Expression(then_v), Expression(else_v)); + e.save(path_out); + + auto reopened = BinaryFile::open_file(path_out, 'r'); + ASSERT_EQ(reopened.get_metadata().labels.size(), 2u); + auto cell_11 = reopened.read({{"row", 1}, {"col", 1}}, true); + auto cell_22 = reopened.read({{"row", 2}, {"col", 2}}, true); + // row=1: cond=1 -> then (positive); row=2: cond=0 -> else (negative). + EXPECT_DOUBLE_EQ(cell_11[0], 1.0 * 10 + 1.0 + 0); + EXPECT_DOUBLE_EQ(cell_11[1], 1.0 * 10 + 1.0 + 1); + EXPECT_DOUBLE_EQ(cell_22[0], -(2.0 * 10 + 2.0 + 0)); + EXPECT_DOUBLE_EQ(cell_22[1], -(2.0 * 10 + 2.0 + 1)); +} + +TEST_F(ExpressionFixture, IfElseUnitMismatchThenElseThrows) { + auto md_t = make_simple_metadata(); // unit "MW" + auto md_f = BinaryMetadata::from_element(Element() + .set("version", "1") + .set("initial_datetime", "2025-01-01T00:00:00") + .set("unit", "kWh") + .set("dimensions", {"row", "col"}) + .set("dimension_sizes", {3, 2}) + .set("labels", {"val1", "val2"})); + write_qvr(path_a, md_t, [](const std::vector&, size_t) { return 1.0; }); + write_qvr(path_b, md_t, [](const std::vector&, size_t) { return 2.0; }); + write_qvr(path_c, md_f, [](const std::vector&, size_t) { return 3.0; }); + + auto cond = BinaryFile::open_file(path_a, 'r'); + auto then_v = BinaryFile::open_file(path_b, 'r'); + auto else_v = BinaryFile::open_file(path_c, 'r'); + EXPECT_THROW({ auto e = ifelse(Expression(cond), Expression(then_v), Expression(else_v)); }, + std::runtime_error); +} + +TEST_F(ExpressionFixture, IfElseConditionUnitIgnored) { + // cond has a different unit than then/else; should succeed. + auto md_cond = BinaryMetadata::from_element(Element() + .set("version", "1") + .set("initial_datetime", "2025-01-01T00:00:00") + .set("unit", "flag") + .set("dimensions", {"row", "col"}) + .set("dimension_sizes", {3, 2}) + .set("labels", {"val1", "val2"})); + auto md_branch = make_simple_metadata(); // "MW" + write_qvr(path_a, md_cond, [](const std::vector&, size_t) { return 1.0; }); + write_qvr(path_b, md_branch, [](const std::vector&, size_t) { return 10.0; }); + write_qvr(path_c, md_branch, [](const std::vector&, size_t) { return 20.0; }); + + auto cond = BinaryFile::open_file(path_a, 'r'); + auto then_v = BinaryFile::open_file(path_b, 'r'); + auto else_v = BinaryFile::open_file(path_c, 'r'); + Expression e = ifelse(Expression(cond), Expression(then_v), Expression(else_v)); + e.save(path_out); + + auto reopened = BinaryFile::open_file(path_out, 'r'); + EXPECT_EQ(reopened.get_metadata().unit, "MW"); // output unit = then.unit (== else.unit) + auto vo = read_all_cells(path_out); + for (double v : vo) { + EXPECT_DOUBLE_EQ(v, 10.0); // all cells select then (cond=1) + } +} + +TEST_F(ExpressionFixture, IfElseShapeMismatchThrows) { + auto md_t = make_simple_metadata(); // 3x2 + auto md_f = BinaryMetadata::from_element(Element() + .set("version", "1") + .set("initial_datetime", "2025-01-01T00:00:00") + .set("unit", "MW") + .set("dimensions", {"row", "col"}) + .set("dimension_sizes", {4, 2}) // size 4 vs 3 + .set("labels", {"val1", "val2"})); + write_qvr(path_a, md_t, [](const std::vector&, size_t) { return 1.0; }); + write_qvr(path_b, md_t, [](const std::vector&, size_t) { return 2.0; }); + write_qvr(path_c, md_f, [](const std::vector&, size_t) { return 3.0; }); + + auto cond = BinaryFile::open_file(path_a, 'r'); + auto then_v = BinaryFile::open_file(path_b, 'r'); + auto else_v = BinaryFile::open_file(path_c, 'r'); + EXPECT_THROW({ auto e = ifelse(Expression(cond), Expression(then_v), Expression(else_v)); }, + std::runtime_error); +} + +TEST_F(ExpressionFixture, IfElseChainsWithBinary) { + // Verify ifelse composes with binary ops: 2 * ifelse(cond, a, b) + 1 + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector& dims, size_t /*k*/) { + return (dims[0] == 1) ? 1.0 : 0.0; + }); + write_qvr(path_b, md, [](const std::vector&, size_t) { return 10.0; }); + write_qvr(path_c, md, [](const std::vector&, size_t) { return 20.0; }); + + auto cond = BinaryFile::open_file(path_a, 'r'); + auto then_v = BinaryFile::open_file(path_b, 'r'); + auto else_v = BinaryFile::open_file(path_c, 'r'); + Expression e = 2.0 * ifelse(Expression(cond), Expression(then_v), Expression(else_v)) + 1.0; + e.save(path_out); + + auto vc = read_all_cells(path_a); + auto vo = read_all_cells(path_out); + for (size_t i = 0; i < vo.size(); ++i) { + const double base = (vc[i] != 0.0) ? 10.0 : 20.0; + EXPECT_DOUBLE_EQ(vo[i], 2.0 * base + 1.0); + } +} From 24deb4d9f300236e70d0dfd3641e93c79479fec4 Mon Sep 17 00:00:00 2001 From: raphasampaio Date: Mon, 11 May 2026 22:19:12 -0300 Subject: [PATCH 05/13] Update --- CLAUDE.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 3d1ec777..e5f3861e 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -505,11 +505,11 @@ result.save("output"); // writes output.qvr + output.toml - `ExpressionScalar`: broadcasts a constant across the operand's label space. - `ExpressionBinary`: combines two operands with `ExpressionBinary::Operation::{Add,Subtract,Multiply,Divide}` (nested enum). Constructor pre-computes broadcast metadata (`build_broadcast_metadata`), reusable input/output buffers, and `lhs_to_out_`/`rhs_to_out_` index translation tables. The `apply(Operation, double, double)` operation-dispatch is a private static member. - `ExpressionUnary`: applies a single-operand math function with `ExpressionUnary::Operation::{Negate,Abs,Sqrt,Log,Exp}` (nested enum). `metadata()` returns the operand's metadata unchanged (no dimensional analysis — `sqrt(MW)` stays as `MW`). Constructor pre-allocates a reusable `operand_row_buf_`. Lets IEEE-754 NaN/inf propagate naturally (`sqrt(-1) → NaN`, `log(0) → -inf`); no NaN special-casing. The `apply(Operation, double)` operation-dispatch is a private static member. + - `ExpressionTernary`: selects per-element across three operands. `Operation::{IfElse}` (nested enum). For `IfElse`: NaN in `condition` → NaN; `condition != 0` → `then_value`; else `else_value`. Constructor eagerly validates (`then` and `else` units must match; `condition`'s unit is ignored; shapes broadcast across all three pairs), pre-builds broadcast metadata via `build_ternary_broadcast_metadata`, pre-allocates per-operand dim/label translation tables and reusable buffers. The `apply(Operation, double, double, double)` operation-dispatch is a private static member. - `ExpressionAggregate`: collapses a named dimension. `Operation::{Sum,Mean,Min,Max,Percentile}` (nested enum). Constructor eagerly removes the dim from output metadata, rewires child time-dim `parent_dimension_index` transitively (a time dim whose parent was removed re-points to the removed dim's grandparent, or `-1`), and pre-allocates index translation + reusable buffers. Skips NaN inputs during accumulation; all-NaN range yields NaN. - `ExpressionAggregateAgents`: collapses the label axis to a single entry named after the operation (e.g., `"sum"`, `"mean"`, `"percentile"`). Dimensions, `initial_datetime`, `unit` unchanged. Same NaN policy as `ExpressionAggregate`. - - `ExpressionTernary` (scaffold): same shape as the real nodes, but its `metadata()` and `compute_row()` throw `Cannot {operation}: ExpressionTernary is not yet implemented`. Carries a placeholder nested `enum class Operation { Unspecified }` until concrete operations are designed. -- Validation is **eager** at construction for `ExpressionBinary`, `ExpressionAggregate`, `ExpressionAggregateAgents` (units/dim sizes/time-dim properties/label sizes/initial datetimes for binary; dim existence + op/parameter consistency + output metadata validity for aggregations). `ExpressionUnary` has no inputs to cross-validate so its constructor just sizes the row buffer. Computation is **lazy**: no I/O until `save()`. -- The binary-operation enum is nested as `ExpressionBinary::Operation`; the unary-operation enum is nested as `ExpressionUnary::Operation`. Aggregation operations are parsed via static `parse_operation(string)` methods on each aggregation class. The C API surface keeps its own stable enums `quiver_expression_operation_t` and `quiver_expression_unary_operation_t`. +- Validation is **eager** at construction for `ExpressionBinary`, `ExpressionTernary`, `ExpressionAggregate`, `ExpressionAggregateAgents` (units/dim sizes/time-dim properties/label sizes/initial datetimes for binary and ternary; dim existence + op/parameter consistency + output metadata validity for aggregations). `ExpressionUnary` has no inputs to cross-validate so its constructor just sizes the row buffer. Computation is **lazy**: no I/O until `save()`. +- The binary-operation enum is nested as `ExpressionBinary::Operation`; the unary-operation enum is nested as `ExpressionUnary::Operation`; the ternary-operation enum is nested as `ExpressionTernary::Operation`. Aggregation operations are parsed via static `parse_operation(string)` methods on each aggregation class. The C API surface keeps its own stable enums `quiver_expression_operation_t`, `quiver_expression_unary_operation_t`, and `quiver_expression_ternary_operation_t`. ### LuaRunner Class Executes Lua scripts with database access: From dbb8e9dbe4243b68da43848478216cc1e7b6c9c5 Mon Sep 17 00:00:00 2001 From: raphasampaio Date: Mon, 11 May 2026 22:59:26 -0300 Subject: [PATCH 06/13] Update --- include/quiver/expression/expression.h | 3 +-- src/expression/expression.cpp | 6 ++--- src/expression/expression_node.cpp | 3 +-- tests/test_c_api_expression.cpp | 34 ++++++++++++-------------- tests/test_expression.cpp | 22 ++++++----------- 5 files changed, 27 insertions(+), 41 deletions(-) diff --git a/include/quiver/expression/expression.h b/include/quiver/expression/expression.h index f27f28c7..b5571a99 100644 --- a/include/quiver/expression/expression.h +++ b/include/quiver/expression/expression.h @@ -75,8 +75,7 @@ QUIVER_API Expression sqrt(const Expression& operand); QUIVER_API Expression log(const Expression& operand); QUIVER_API Expression exp(const Expression& operand); -QUIVER_API Expression -ifelse(const Expression& condition, const Expression& then_value, const Expression& else_value); +QUIVER_API Expression ifelse(const Expression& condition, const Expression& then_value, const Expression& else_value); } // namespace quiver diff --git a/src/expression/expression.cpp b/src/expression/expression.cpp index b064991d..ac0e2fa0 100644 --- a/src/expression/expression.cpp +++ b/src/expression/expression.cpp @@ -149,10 +149,8 @@ Expression exp(const Expression& operand) { } Expression ifelse(const Expression& condition, const Expression& then_value, const Expression& else_value) { - return Expression(std::make_shared(ExpressionTernary::Operation::IfElse, - condition.node_, - then_value.node_, - else_value.node_)); + return Expression(std::make_shared( + ExpressionTernary::Operation::IfElse, condition.node_, then_value.node_, else_value.node_)); } } // namespace quiver diff --git a/src/expression/expression_node.cpp b/src/expression/expression_node.cpp index 8eead0e0..bd9a1349 100644 --- a/src/expression/expression_node.cpp +++ b/src/expression/expression_node.cpp @@ -546,8 +546,7 @@ ExpressionTernary::ExpressionTernary(Operation operation, validate_shape_compatibility(condition_meta, else_meta); auto output_labels = compute_ternary_output_labels(condition_meta.labels, then_meta.labels, else_meta.labels); - broadcast_meta_ = - build_ternary_broadcast_metadata(condition_meta, then_meta, else_meta, std::move(output_labels)); + broadcast_meta_ = build_ternary_broadcast_metadata(condition_meta, then_meta, else_meta, std::move(output_labels)); broadcast_meta_.validate(); const auto& out_dims = broadcast_meta_.dimensions; diff --git a/tests/test_c_api_expression.cpp b/tests/test_c_api_expression.cpp index 82810200..95104655 100644 --- a/tests/test_c_api_expression.cpp +++ b/tests/test_c_api_expression.cpp @@ -1343,9 +1343,9 @@ TEST_F(ExpressionCApiFixture, ApplyTernaryIfElse) { auto* then_v = expr_from_file(path_b); auto* else_v = expr_from_file(path_c); quiver_expression_t* result = nullptr; - ASSERT_EQ(quiver_expression_apply_ternary( - QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, else_v, &result), - QUIVER_OK); + ASSERT_EQ( + quiver_expression_apply_ternary(QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, else_v, &result), + QUIVER_OK); ASSERT_EQ(quiver_expression_save(result, path_out.c_str()), QUIVER_OK); quiver_expression_close(cond); quiver_expression_close(then_v); @@ -1371,9 +1371,9 @@ TEST_F(ExpressionCApiFixture, ApplyTernaryIfElsePropagatesNaN) { auto* then_v = expr_from_file(path_b); auto* else_v = expr_from_file(path_c); quiver_expression_t* result = nullptr; - ASSERT_EQ(quiver_expression_apply_ternary( - QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, else_v, &result), - QUIVER_OK); + ASSERT_EQ( + quiver_expression_apply_ternary(QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, else_v, &result), + QUIVER_OK); ASSERT_EQ(quiver_expression_save(result, path_out.c_str()), QUIVER_OK); quiver_expression_close(cond); quiver_expression_close(then_v); @@ -1409,14 +1409,12 @@ TEST_F(ExpressionCApiFixture, ApplyTernaryNullArguments) { auto* else_v = expr_from_file(path_c); quiver_expression_t* out = nullptr; - EXPECT_EQ(quiver_expression_apply_ternary( - QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, nullptr, then_v, else_v, &out), - QUIVER_ERROR); - EXPECT_EQ(quiver_expression_apply_ternary( - QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, nullptr, else_v, &out), + EXPECT_EQ( + quiver_expression_apply_ternary(QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, nullptr, then_v, else_v, &out), + QUIVER_ERROR); + EXPECT_EQ(quiver_expression_apply_ternary(QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, nullptr, else_v, &out), QUIVER_ERROR); - EXPECT_EQ(quiver_expression_apply_ternary( - QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, nullptr, &out), + EXPECT_EQ(quiver_expression_apply_ternary(QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, nullptr, &out), QUIVER_ERROR); EXPECT_EQ( quiver_expression_apply_ternary(QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, else_v, nullptr), @@ -1440,9 +1438,8 @@ TEST_F(ExpressionCApiFixture, ApplyTernaryUnitMismatch) { auto* then_v = expr_from_file(path_b); auto* else_v = expr_from_file(path_c); quiver_expression_t* out = nullptr; - EXPECT_EQ( - quiver_expression_apply_ternary(QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, else_v, &out), - QUIVER_ERROR); + EXPECT_EQ(quiver_expression_apply_ternary(QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, else_v, &out), + QUIVER_ERROR); const char* msg = quiver_get_last_error(); ASSERT_NE(msg, nullptr); EXPECT_NE(std::string(msg).find("units differ"), std::string::npos); @@ -1476,9 +1473,8 @@ TEST_F(ExpressionCApiFixture, ApplyTernaryShapeMismatch) { auto* then_v = expr_from_file(path_b); auto* else_v = expr_from_file(path_c); quiver_expression_t* out = nullptr; - EXPECT_EQ( - quiver_expression_apply_ternary(QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, else_v, &out), - QUIVER_ERROR); + EXPECT_EQ(quiver_expression_apply_ternary(QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, cond, then_v, else_v, &out), + QUIVER_ERROR); quiver_expression_close(cond); quiver_expression_close(then_v); diff --git a/tests/test_expression.cpp b/tests/test_expression.cpp index 8ca69161..6d76abc3 100644 --- a/tests/test_expression.cpp +++ b/tests/test_expression.cpp @@ -1803,9 +1803,8 @@ TEST_F(ExpressionFixture, IfElseBroadcastsConditionSizeOneDim) { .set("labels", {"val1", "val2"})); auto md_full = make_simple_metadata(); // cond at row=1 is [1, 0] for col=1, col=2 respectively (per-column mask) - write_qvr(path_a, md_cond, [](const std::vector& dims, size_t /*k*/) { - return (dims[1] == 1) ? 1.0 : 0.0; - }); + write_qvr( + path_a, md_cond, [](const std::vector& dims, size_t /*k*/) { return (dims[1] == 1) ? 1.0 : 0.0; }); write_qvr(path_b, md_full, [](const std::vector&, size_t) { return 100.0; }); write_qvr(path_c, md_full, [](const std::vector&, size_t) { return -100.0; }); @@ -1835,10 +1834,9 @@ TEST_F(ExpressionFixture, IfElseBroadcastsLabels) { .set("dimensions", {"row", "col"}) .set("dimension_sizes", {3, 2}) .set("labels", {"only"})); // 1 label - auto md_full = make_simple_metadata(); // 2 labels - write_qvr(path_a, md_single, [](const std::vector& dims, size_t /*k*/) { - return (dims[0] == 1) ? 1.0 : 0.0; - }); + auto md_full = make_simple_metadata(); // 2 labels + write_qvr( + path_a, md_single, [](const std::vector& dims, size_t /*k*/) { return (dims[0] == 1) ? 1.0 : 0.0; }); write_qvr(path_b, md_full, [](const std::vector& dims, size_t k) { return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); }); @@ -1879,8 +1877,7 @@ TEST_F(ExpressionFixture, IfElseUnitMismatchThenElseThrows) { auto cond = BinaryFile::open_file(path_a, 'r'); auto then_v = BinaryFile::open_file(path_b, 'r'); auto else_v = BinaryFile::open_file(path_c, 'r'); - EXPECT_THROW({ auto e = ifelse(Expression(cond), Expression(then_v), Expression(else_v)); }, - std::runtime_error); + EXPECT_THROW({ auto e = ifelse(Expression(cond), Expression(then_v), Expression(else_v)); }, std::runtime_error); } TEST_F(ExpressionFixture, IfElseConditionUnitIgnored) { @@ -1927,16 +1924,13 @@ TEST_F(ExpressionFixture, IfElseShapeMismatchThrows) { auto cond = BinaryFile::open_file(path_a, 'r'); auto then_v = BinaryFile::open_file(path_b, 'r'); auto else_v = BinaryFile::open_file(path_c, 'r'); - EXPECT_THROW({ auto e = ifelse(Expression(cond), Expression(then_v), Expression(else_v)); }, - std::runtime_error); + EXPECT_THROW({ auto e = ifelse(Expression(cond), Expression(then_v), Expression(else_v)); }, std::runtime_error); } TEST_F(ExpressionFixture, IfElseChainsWithBinary) { // Verify ifelse composes with binary ops: 2 * ifelse(cond, a, b) + 1 auto md = make_simple_metadata(); - write_qvr(path_a, md, [](const std::vector& dims, size_t /*k*/) { - return (dims[0] == 1) ? 1.0 : 0.0; - }); + write_qvr(path_a, md, [](const std::vector& dims, size_t /*k*/) { return (dims[0] == 1) ? 1.0 : 0.0; }); write_qvr(path_b, md, [](const std::vector&, size_t) { return 10.0; }); write_qvr(path_c, md, [](const std::vector&, size_t) { return 20.0; }); From 82d6425160cb31c10bb0710a9c0798a7ac687e3b Mon Sep 17 00:00:00 2001 From: raphasampaio Date: Mon, 11 May 2026 23:00:02 -0300 Subject: [PATCH 07/13] Update --- bindings/julia/src/expression.jl | 6 ++++-- bindings/julia/test/test_expression.jl | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/bindings/julia/src/expression.jl b/bindings/julia/src/expression.jl index 6369e423..e7da3ca6 100644 --- a/bindings/julia/src/expression.jl +++ b/bindings/julia/src/expression.jl @@ -100,8 +100,10 @@ Base.exp(a::Binary.File) = exp(Expression(a)) function Base.ifelse(condition::Expression, then_value::Expression, else_value::Expression) out = Ref{Ptr{C.quiver_expression}}(C_NULL) - check(C.quiver_expression_apply_ternary(C.QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, - condition.ptr, then_value.ptr, else_value.ptr, out)) + check( + C.quiver_expression_apply_ternary(C.QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE, + condition.ptr, then_value.ptr, else_value.ptr, out), + ) return Expression(out[]) end diff --git a/bindings/julia/test/test_expression.jl b/bindings/julia/test/test_expression.jl index 2d01936b..e643b391 100644 --- a/bindings/julia/test/test_expression.jl +++ b/bindings/julia/test/test_expression.jl @@ -1683,7 +1683,8 @@ end else_v = Quiver.Binary.open_file(path_else; mode = 'r') try # 2 * ifelse(cond, then, else) + 1 - result = 2.0 * ifelse(Quiver.Expression(cond), Quiver.Expression(then_v), Quiver.Expression(else_v)) + 1.0 + result = + 2.0 * ifelse(Quiver.Expression(cond), Quiver.Expression(then_v), Quiver.Expression(else_v)) + 1.0 Quiver.save(result, path_out) Quiver.close!(result) finally From e4bf964045d2b810bc1153a1a78f1836e8e72262 Mon Sep 17 00:00:00 2001 From: raphasampaio Date: Wed, 13 May 2026 23:41:54 -0300 Subject: [PATCH 08/13] Update --- CLAUDE.md | 9 +- bindings/julia/src/c_api.jl | 8 + bindings/julia/src/expression.jl | 27 ++++ bindings/julia/test/test_expression.jl | 157 ++++++++++++++++++ include/quiver/c/expression/expression.h | 13 ++ include/quiver/expression/expression.h | 6 + include/quiver/expression/expression_node.h | 30 ++++ src/c/expression/expression.cpp | 50 ++++++ src/expression/expression.cpp | 8 + src/expression/expression_node.cpp | 90 +++++++++++ tests/test_c_api_expression.cpp | 144 +++++++++++++++++ tests/test_expression.cpp | 169 ++++++++++++++++++++ 12 files changed, 708 insertions(+), 3 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index e5f3861e..24cc0ad6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -172,7 +172,7 @@ struct Database::Impl { Binary subsystem: `BinaryFile` and `CSVConverter` use Pimpl (hide file I/O dependencies). `BinaryMetadata`, `Dimension`, `TimeProperties` are plain value types. -Expression subsystem: `Expression` is a plain value type wrapping `shared_ptr` — no Pimpl. `ExpressionNode` is an abstract base with virtual `metadata()` / `compute_row()`; concrete subclasses `ExpressionFile`, `ExpressionScalar`, `ExpressionBinary`, `ExpressionUnary`, `ExpressionTernary`, `ExpressionAggregate`, `ExpressionAggregateAgents` are exposed via `QUIVER_API` and use Rule of Zero. Polymorphism is justified by the recursive tree shape (`ExpressionBinary`, `ExpressionUnary`, `ExpressionTernary`, and the aggregation nodes own child `shared_ptr` operands). +Expression subsystem: `Expression` is a plain value type wrapping `shared_ptr` — no Pimpl. `ExpressionNode` is an abstract base with virtual `metadata()` / `compute_row()`; concrete subclasses `ExpressionFile`, `ExpressionScalar`, `ExpressionBinary`, `ExpressionUnary`, `ExpressionTernary`, `ExpressionAggregate`, `ExpressionAggregateAgents`, `ExpressionSelectAgents`, `ExpressionRenameAgents` are exposed via `QUIVER_API` and use Rule of Zero. Polymorphism is justified by the recursive tree shape (`ExpressionBinary`, `ExpressionUnary`, `ExpressionTernary`, and the aggregation / label-projection nodes own child `shared_ptr` operands). Classes with no private dependencies (`Element`, `Row`, `Migration`, `Migrations`, `GroupMetadata`, `ScalarMetadata`, `CSVOptions`, `BinaryMetadata`, `Dimension`, `TimeProperties`, `Expression`) are plain value types — direct members, no Pimpl, Rule of Zero (compiler-generated copy/move/destructor). @@ -496,6 +496,7 @@ result.save("output"); // writes output.qvr + output.toml - Accessors: `metadata()` - Materialize: `save(path)` — iterates via `first_dimensions`/`next_dimensions`, calls `compute_row()` per cell, writes to a new `.qvr`. Throws if `path` collides (after `weakly_canonical`) with any input file in the DAG. - Aggregation: `aggregate(dimension, op, [parameter])` collapses a dimension; `aggregate_agents(op, [parameter])` collapses the label axis. `op` is one of `"sum" | "mean" | "min" | "max" | "percentile"` (string tags, validated in C++). `percentile` requires a `parameter` fraction in `[0, 1]`; nullary ops reject `parameter`. + - Label-axis projection: `select_agents(labels)` keeps (and may reorder) a chosen subset of operand labels; `rename_agents(mapping)` rewrites labels in place via a partial `{old: new}` map. Both validate eagerly: `select_agents` throws if any requested label is absent; `rename_agents` throws on duplicate keys or unknown keys, and `BinaryMetadata::validate()` rejects renames that produce duplicate output labels. - Operator overloads (12 binary + 1 unary): `+ - * /` × {expr+expr, expr+double, double+expr}, plus unary `-expr`. - Free functions in `quiver::` for unary math: `abs(expr)`, `sqrt(expr)`, `log(expr)`, `exp(expr)`. - Free function `ifelse(cond, then_value, else_value)` selects per-element: NaN cond → NaN; `cond != 0` → `then_value`; else → `else_value`. `then` and `else` units must match; `cond`'s unit is ignored. @@ -508,8 +509,10 @@ result.save("output"); // writes output.qvr + output.toml - `ExpressionTernary`: selects per-element across three operands. `Operation::{IfElse}` (nested enum). For `IfElse`: NaN in `condition` → NaN; `condition != 0` → `then_value`; else `else_value`. Constructor eagerly validates (`then` and `else` units must match; `condition`'s unit is ignored; shapes broadcast across all three pairs), pre-builds broadcast metadata via `build_ternary_broadcast_metadata`, pre-allocates per-operand dim/label translation tables and reusable buffers. The `apply(Operation, double, double, double)` operation-dispatch is a private static member. - `ExpressionAggregate`: collapses a named dimension. `Operation::{Sum,Mean,Min,Max,Percentile}` (nested enum). Constructor eagerly removes the dim from output metadata, rewires child time-dim `parent_dimension_index` transitively (a time dim whose parent was removed re-points to the removed dim's grandparent, or `-1`), and pre-allocates index translation + reusable buffers. Skips NaN inputs during accumulation; all-NaN range yields NaN. - `ExpressionAggregateAgents`: collapses the label axis to a single entry named after the operation (e.g., `"sum"`, `"mean"`, `"percentile"`). Dimensions, `initial_datetime`, `unit` unchanged. Same NaN policy as `ExpressionAggregate`. -- Validation is **eager** at construction for `ExpressionBinary`, `ExpressionTernary`, `ExpressionAggregate`, `ExpressionAggregateAgents` (units/dim sizes/time-dim properties/label sizes/initial datetimes for binary and ternary; dim existence + op/parameter consistency + output metadata validity for aggregations). `ExpressionUnary` has no inputs to cross-validate so its constructor just sizes the row buffer. Computation is **lazy**: no I/O until `save()`. -- The binary-operation enum is nested as `ExpressionBinary::Operation`; the unary-operation enum is nested as `ExpressionUnary::Operation`; the ternary-operation enum is nested as `ExpressionTernary::Operation`. Aggregation operations are parsed via static `parse_operation(string)` methods on each aggregation class. The C API surface keeps its own stable enums `quiver_expression_operation_t`, `quiver_expression_unary_operation_t`, and `quiver_expression_ternary_operation_t`. + - `ExpressionSelectAgents`: projects the operand onto a caller-supplied label list. Constructor pre-computes a `selected_indices_` table from operand-label → output-position, copies operand metadata with `labels` replaced, and calls `output_meta_.validate()` (which rejects duplicate output labels). Missing labels throw `"Cannot select_agents: label not found: ''"`. `compute_row` reads the operand row into a reusable buffer and gathers selected columns into `out`. + - `ExpressionRenameAgents`: rewrites operand labels via a partial `{old: new}` mapping. Constructor builds a rename map (duplicate keys throw), walks operand labels swapping matched names, verifies every key was used (unmatched keys throw), and calls `output_meta_.validate()` (rejects collisions like `val1→val2` when `val2` already exists). `compute_row` forwards directly to the operand — count and order are unchanged so no per-row reshuffle is needed. +- Validation is **eager** at construction for `ExpressionBinary`, `ExpressionTernary`, `ExpressionAggregate`, `ExpressionAggregateAgents`, `ExpressionSelectAgents`, `ExpressionRenameAgents` (units/dim sizes/time-dim properties/label sizes/initial datetimes for binary and ternary; dim existence + op/parameter consistency + output metadata validity for aggregations; label existence + uniqueness for label-axis projections). `ExpressionUnary` has no inputs to cross-validate so its constructor just sizes the row buffer. Computation is **lazy**: no I/O until `save()`. +- The binary-operation enum is nested as `ExpressionBinary::Operation`; the unary-operation enum is nested as `ExpressionUnary::Operation`; the ternary-operation enum is nested as `ExpressionTernary::Operation`. Aggregation operations are parsed via static `parse_operation(string)` methods on each aggregation class. Label-axis projection nodes (`ExpressionSelectAgents`, `ExpressionRenameAgents`) have no operation enum — their behavior is fully specified by the label list / rename map. The C API surface keeps its own stable enums `quiver_expression_operation_t`, `quiver_expression_unary_operation_t`, and `quiver_expression_ternary_operation_t`. ### LuaRunner Class Executes Lua scripts with database access: diff --git a/bindings/julia/src/c_api.jl b/bindings/julia/src/c_api.jl index 3758601a..d641beb0 100644 --- a/bindings/julia/src/c_api.jl +++ b/bindings/julia/src/c_api.jl @@ -702,6 +702,14 @@ function quiver_expression_aggregate_agents(expression, operation, parameter, ou @ccall libquiver_c.quiver_expression_aggregate_agents(expression::Ptr{quiver_expression_t}, operation::Ptr{Cchar}, parameter::Ptr{Cdouble}, out::Ptr{Ptr{quiver_expression_t}})::quiver_error_t end +function quiver_expression_select_agents(expression, labels, label_count, out) + @ccall libquiver_c.quiver_expression_select_agents(expression::Ptr{quiver_expression_t}, labels::Ptr{Ptr{Cchar}}, label_count::Csize_t, out::Ptr{Ptr{quiver_expression_t}})::quiver_error_t +end + +function quiver_expression_rename_agents(expression, old_labels, new_labels, mapping_count, out) + @ccall libquiver_c.quiver_expression_rename_agents(expression::Ptr{quiver_expression_t}, old_labels::Ptr{Ptr{Cchar}}, new_labels::Ptr{Ptr{Cchar}}, mapping_count::Csize_t, out::Ptr{Ptr{quiver_expression_t}})::quiver_error_t +end + #! format: on diff --git a/bindings/julia/src/expression.jl b/bindings/julia/src/expression.jl index e7da3ca6..feb1ddf9 100644 --- a/bindings/julia/src/expression.jl +++ b/bindings/julia/src/expression.jl @@ -166,3 +166,30 @@ end function aggregate_agents(f::Binary.File, operation::String, parameter::Optional{Real} = nothing) return aggregate_agents(Expression(f), operation, parameter) end + +function select_agents(e::Expression, labels::Vector{<:AbstractString}) + cstrings = [Base.cconvert(Cstring, s) for s in labels] + ptrs = [Base.unsafe_convert(Cstring, cs) for cs in cstrings] + out = Ref{Ptr{C.quiver_expression}}(C_NULL) + GC.@preserve cstrings begin + check(C.quiver_expression_select_agents(e.ptr, ptrs, length(labels), out)) + end + return Expression(out[]) +end + +function rename_agents(e::Expression, mapping::AbstractDict{<:AbstractString, <:AbstractString}) + old_labels = String[String(k) for k in keys(mapping)] + new_labels = String[String(mapping[k]) for k in old_labels] + old_cstrings = [Base.cconvert(Cstring, s) for s in old_labels] + new_cstrings = [Base.cconvert(Cstring, s) for s in new_labels] + old_ptrs = [Base.unsafe_convert(Cstring, cs) for cs in old_cstrings] + new_ptrs = [Base.unsafe_convert(Cstring, cs) for cs in new_cstrings] + out = Ref{Ptr{C.quiver_expression}}(C_NULL) + GC.@preserve old_cstrings new_cstrings begin + check(C.quiver_expression_rename_agents(e.ptr, old_ptrs, new_ptrs, length(old_labels), out)) + end + return Expression(out[]) +end + +select_agents(f::Binary.File, labels::Vector{<:AbstractString}) = select_agents(Expression(f), labels) +rename_agents(f::Binary.File, mapping::AbstractDict) = rename_agents(Expression(f), mapping) diff --git a/bindings/julia/test/test_expression.jl b/bindings/julia/test/test_expression.jl index e643b391..b8b7f2ec 100644 --- a/bindings/julia/test/test_expression.jl +++ b/bindings/julia/test/test_expression.jl @@ -1702,6 +1702,163 @@ end cleanup(path_cond, path_then, path_else, path_out) end end + + # ========================================================================= + # select_agents — label-axis projection + # ========================================================================= + + @testset "select_agents subset" begin + path_a, path_out = make_path("a"), make_path("out") + try + write_fixture(path_a, (r, c, k) -> r * 10 + c + k) + with_expr(path_a) do e + out = Quiver.select_agents(e, ["val2"]) + md = Quiver.get_metadata(out) + @test Quiver.Binary.get_labels(md) == ["val2"] + Quiver.save(out, path_out) + return Quiver.close!(out) + end + # val2 is k=2 → cell = 10r + c + 2. + @test read_one_cell(path_out; row = 1, col = 1) == [13.0] + @test read_one_cell(path_out; row = 3, col = 2) == [34.0] + finally + cleanup(path_a, path_out) + end + end + + @testset "select_agents reorder" begin + path_a, path_out = make_path("a"), make_path("out") + try + write_fixture(path_a, (r, c, k) -> r * 10 + c + k) + with_expr(path_a) do e + out = Quiver.select_agents(e, ["val2", "val1"]) + md = Quiver.get_metadata(out) + @test Quiver.Binary.get_labels(md) == ["val2", "val1"] + Quiver.save(out, path_out) + return Quiver.close!(out) + end + # First entry is val2 (k=2 → 10r+c+2), second is val1 (k=1 → 10r+c+1). + @test read_one_cell(path_out; row = 1, col = 1) == [13.0, 12.0] + finally + cleanup(path_a, path_out) + end + end + + @testset "select_agents missing label throws" begin + path_a = make_path("a") + try + write_fixture(path_a, (r, c, k) -> 1.0) + with_expr(path_a) do e + @test_throws Quiver.DatabaseException Quiver.select_agents(e, ["nope"]) + end + finally + cleanup(path_a) + end + end + + @testset "select_agents on Binary.File shortcut" begin + path_a, path_out = make_path("a"), make_path("out") + try + write_fixture(path_a, (r, c, k) -> r * 10 + c + k) + file = Quiver.Binary.open_file(path_a; mode = 'r') + try + out = Quiver.select_agents(file, ["val1"]) + Quiver.save(out, path_out) + Quiver.close!(out) + finally + Quiver.Binary.close!(file) + end + @test read_one_cell(path_out; row = 1, col = 1) == [12.0] + finally + cleanup(path_a, path_out) + end + end + + # ========================================================================= + # rename_agents — label-axis rename + # ========================================================================= + + @testset "rename_agents partial mapping" begin + path_a, path_out = make_path("a"), make_path("out") + try + write_fixture(path_a, (r, c, k) -> r * 10 + c + k) + with_expr(path_a) do e + out = Quiver.rename_agents(e, Dict("val1" => "alpha")) + md = Quiver.get_metadata(out) + @test Quiver.Binary.get_labels(md) == ["alpha", "val2"] + Quiver.save(out, path_out) + return Quiver.close!(out) + end + # Values unchanged; both labels still present in (val1->alpha, val2) order. + @test read_one_cell(path_out; row = 1, col = 1) == [12.0, 13.0] + finally + cleanup(path_a, path_out) + end + end + + @testset "rename_agents all labels" begin + path_a = make_path("a") + try + write_fixture(path_a, (r, c, k) -> 1.0) + with_expr(path_a) do e + out = Quiver.rename_agents(e, Dict("val1" => "alpha", "val2" => "beta")) + md = Quiver.get_metadata(out) + labels = Quiver.Binary.get_labels(md) + @test sort(labels) == ["alpha", "beta"] + return Quiver.close!(out) + end + finally + cleanup(path_a) + end + end + + @testset "rename_agents missing key throws" begin + path_a = make_path("a") + try + write_fixture(path_a, (r, c, k) -> 1.0) + with_expr(path_a) do e + @test_throws Quiver.DatabaseException Quiver.rename_agents(e, Dict("nope" => "x")) + end + finally + cleanup(path_a) + end + end + + @testset "rename_agents collision throws" begin + path_a = make_path("a") + try + write_fixture(path_a, (r, c, k) -> 1.0) + with_expr(path_a) do e + @test_throws Quiver.DatabaseException Quiver.rename_agents(e, Dict("val1" => "val2")) + end + finally + cleanup(path_a) + end + end + + @testset "rename_agents on Binary.File shortcut" begin + path_a, path_out = make_path("a"), make_path("out") + try + write_fixture(path_a, (r, c, k) -> 1.0) + file = Quiver.Binary.open_file(path_a; mode = 'r') + try + out = Quiver.rename_agents(file, Dict("val1" => "alpha")) + Quiver.save(out, path_out) + Quiver.close!(out) + finally + Quiver.Binary.close!(file) + end + reopened = Quiver.Binary.open_file(path_out; mode = 'r') + try + md = Quiver.Binary.get_metadata(reopened) + @test Quiver.Binary.get_labels(md) == ["alpha", "val2"] + finally + Quiver.Binary.close!(reopened) + end + finally + cleanup(path_a, path_out) + end + end end end diff --git a/include/quiver/c/expression/expression.h b/include/quiver/c/expression/expression.h index 814cd3fe..4c40a9b9 100644 --- a/include/quiver/c/expression/expression.h +++ b/include/quiver/c/expression/expression.h @@ -82,6 +82,19 @@ QUIVER_C_API quiver_error_t quiver_expression_aggregate_agents(quiver_expression const double* parameter, quiver_expression_t** out); +// Label-axis projection. Selects (and may reorder / duplicate) labels from the operand. +QUIVER_C_API quiver_error_t quiver_expression_select_agents(quiver_expression_t* expression, + const char* const* labels, + size_t label_count, + quiver_expression_t** out); + +// Label-axis rename via partial mapping (parallel arrays). Unmapped labels keep their original name. +QUIVER_C_API quiver_error_t quiver_expression_rename_agents(quiver_expression_t* expression, + const char* const* old_labels, + const char* const* new_labels, + size_t mapping_count, + quiver_expression_t** out); + #ifdef __cplusplus } #endif diff --git a/include/quiver/expression/expression.h b/include/quiver/expression/expression.h index b5571a99..bf45ac54 100644 --- a/include/quiver/expression/expression.h +++ b/include/quiver/expression/expression.h @@ -9,6 +9,8 @@ #include #include #include +#include +#include namespace quiver { @@ -28,6 +30,10 @@ class QUIVER_API Expression { Expression aggregate_agents(const std::string& operation, std::optional parameter = std::nullopt) const; + Expression select_agents(const std::vector& labels) const; + + Expression rename_agents(const std::vector>& mapping) const; + private: friend Expression operator+(const Expression&, const Expression&); friend Expression operator+(const Expression&, double); diff --git a/include/quiver/expression/expression_node.h b/include/quiver/expression/expression_node.h index 0a78ada8..3b71ac69 100644 --- a/include/quiver/expression/expression_node.h +++ b/include/quiver/expression/expression_node.h @@ -10,6 +10,7 @@ #include #include #include +#include #include namespace quiver { @@ -199,6 +200,35 @@ class QUIVER_API ExpressionAggregateAgents final : public ExpressionNode { mutable std::vector operand_row_buf_; }; +class QUIVER_API ExpressionSelectAgents final : public ExpressionNode { +public: + ExpressionSelectAgents(std::shared_ptr operand, std::vector labels); + + const BinaryMetadata& metadata() const override; + void compute_row(const std::vector& dims, std::vector& out) const override; + void collect_input_files(std::vector& out) const override; + +private: + std::shared_ptr operand_; + BinaryMetadata output_meta_; + std::vector selected_indices_; + mutable std::vector operand_row_buf_; +}; + +class QUIVER_API ExpressionRenameAgents final : public ExpressionNode { +public: + ExpressionRenameAgents(std::shared_ptr operand, + std::vector> mapping); + + const BinaryMetadata& metadata() const override; + void compute_row(const std::vector& dims, std::vector& out) const override; + void collect_input_files(std::vector& out) const override; + +private: + std::shared_ptr operand_; + BinaryMetadata output_meta_; +}; + } // namespace quiver #endif // QUIVER_EXPRESSION_NODE_H diff --git a/src/c/expression/expression.cpp b/src/c/expression/expression.cpp index 98b01b17..70d52b33 100644 --- a/src/c/expression/expression.cpp +++ b/src/c/expression/expression.cpp @@ -5,6 +5,9 @@ #include #include #include +#include +#include +#include namespace { @@ -241,4 +244,51 @@ QUIVER_C_API quiver_error_t quiver_expression_aggregate_agents(quiver_expression } } +QUIVER_C_API quiver_error_t quiver_expression_select_agents(quiver_expression_t* expression, + const char* const* labels, + size_t label_count, + quiver_expression_t** out) { + QUIVER_REQUIRE(expression, labels, out); + + try { + std::vector selected; + selected.reserve(label_count); + for (size_t i = 0; i < label_count; ++i) { + selected.emplace_back(labels[i]); + } + *out = new quiver_expression(expression->expression.select_agents(selected)); + return QUIVER_OK; + } catch (const std::bad_alloc&) { + quiver_set_last_error("Memory allocation failed"); + return QUIVER_ERROR; + } catch (const std::exception& e) { + quiver_set_last_error(e.what()); + return QUIVER_ERROR; + } +} + +QUIVER_C_API quiver_error_t quiver_expression_rename_agents(quiver_expression_t* expression, + const char* const* old_labels, + const char* const* new_labels, + size_t mapping_count, + quiver_expression_t** out) { + QUIVER_REQUIRE(expression, old_labels, new_labels, out); + + try { + std::vector> mapping; + mapping.reserve(mapping_count); + for (size_t i = 0; i < mapping_count; ++i) { + mapping.emplace_back(old_labels[i], new_labels[i]); + } + *out = new quiver_expression(expression->expression.rename_agents(mapping)); + return QUIVER_OK; + } catch (const std::bad_alloc&) { + quiver_set_last_error("Memory allocation failed"); + return QUIVER_ERROR; + } catch (const std::exception& e) { + quiver_set_last_error(e.what()); + return QUIVER_ERROR; + } +} + } // extern "C" diff --git a/src/expression/expression.cpp b/src/expression/expression.cpp index ac0e2fa0..d3cd3c0d 100644 --- a/src/expression/expression.cpp +++ b/src/expression/expression.cpp @@ -35,6 +35,14 @@ Expression Expression::aggregate_agents(const std::string& operation, std::optio return Expression(std::make_shared(op, node_, parameter)); } +Expression Expression::select_agents(const std::vector& labels) const { + return Expression(std::make_shared(node_, labels)); +} + +Expression Expression::rename_agents(const std::vector>& mapping) const { + return Expression(std::make_shared(node_, mapping)); +} + void Expression::save(const std::string& path) const { std::vector input_files; node_->collect_input_files(input_files); diff --git a/src/expression/expression_node.cpp b/src/expression/expression_node.cpp index bd9a1349..14703b95 100644 --- a/src/expression/expression_node.cpp +++ b/src/expression/expression_node.cpp @@ -920,4 +920,94 @@ void ExpressionAggregateAgents::collect_input_files(std::vector& ou operand_->collect_input_files(out); } +ExpressionSelectAgents::ExpressionSelectAgents(std::shared_ptr operand, std::vector labels) + : operand_(std::move(operand)) { + const auto& operand_meta = operand_->metadata(); + + std::unordered_map label_to_index; + label_to_index.reserve(operand_meta.labels.size()); + for (size_t i = 0; i < operand_meta.labels.size(); ++i) { + label_to_index.emplace(operand_meta.labels[i], i); + } + + selected_indices_.reserve(labels.size()); + for (const auto& label : labels) { + auto it = label_to_index.find(label); + if (it == label_to_index.end()) { + throw std::runtime_error("Cannot select_agents: label not found: '" + label + "'"); + } + selected_indices_.push_back(it->second); + } + + output_meta_ = operand_meta; + output_meta_.labels = std::move(labels); + output_meta_.validate(); + + operand_row_buf_.resize(operand_meta.labels.size()); +} + +const BinaryMetadata& ExpressionSelectAgents::metadata() const { + return output_meta_; +} + +void ExpressionSelectAgents::compute_row(const std::vector& dims, std::vector& out) const { + operand_->compute_row(dims, operand_row_buf_); + if (out.size() != selected_indices_.size()) { + out.resize(selected_indices_.size()); + } + for (size_t i = 0; i < selected_indices_.size(); ++i) { + out[i] = operand_row_buf_[selected_indices_[i]]; + } +} + +void ExpressionSelectAgents::collect_input_files(std::vector& out) const { + operand_->collect_input_files(out); +} + +ExpressionRenameAgents::ExpressionRenameAgents(std::shared_ptr operand, + std::vector> mapping) + : operand_(std::move(operand)) { + const auto& operand_meta = operand_->metadata(); + + std::unordered_map rename_map; + std::unordered_map used; + rename_map.reserve(mapping.size()); + used.reserve(mapping.size()); + for (auto& entry : mapping) { + if (!rename_map.emplace(entry.first, std::move(entry.second)).second) { + throw std::runtime_error("Cannot rename_agents: duplicate key '" + entry.first + "'"); + } + used.emplace(entry.first, false); + } + + output_meta_ = operand_meta; + for (auto& label : output_meta_.labels) { + auto it = rename_map.find(label); + if (it != rename_map.end()) { + label = it->second; + used[it->first] = true; + } + } + + for (const auto& entry : used) { + if (!entry.second) { + throw std::runtime_error("Cannot rename_agents: label not found: '" + entry.first + "'"); + } + } + + output_meta_.validate(); +} + +const BinaryMetadata& ExpressionRenameAgents::metadata() const { + return output_meta_; +} + +void ExpressionRenameAgents::compute_row(const std::vector& dims, std::vector& out) const { + operand_->compute_row(dims, out); +} + +void ExpressionRenameAgents::collect_input_files(std::vector& out) const { + operand_->collect_input_files(out); +} + } // namespace quiver diff --git a/tests/test_c_api_expression.cpp b/tests/test_c_api_expression.cpp index 95104655..80faa835 100644 --- a/tests/test_c_api_expression.cpp +++ b/tests/test_c_api_expression.cpp @@ -1480,3 +1480,147 @@ TEST_F(ExpressionCApiFixture, ApplyTernaryShapeMismatch) { quiver_expression_close(then_v); quiver_expression_close(else_v); } + +// ============================================================================ +// quiver_expression_select_agents +// ============================================================================ + +TEST_F(ExpressionCApiFixture, SelectAgentsSubset) { + write_fixture(path_a, [](int r, int c, int k) { return static_cast(r * 10 + c + k); }); + + auto* a = expr_from_file(path_a); + const char* labels[] = {"val2"}; + quiver_expression_t* sel = nullptr; + ASSERT_EQ(quiver_expression_select_agents(a, labels, 1, &sel), QUIVER_OK); + + quiver_binary_metadata_t* out_md = nullptr; + ASSERT_EQ(quiver_expression_get_metadata(sel, &out_md), QUIVER_OK); + char** out_labels = nullptr; + size_t out_count = 0; + ASSERT_EQ(quiver_binary_metadata_get_labels(out_md, &out_labels, &out_count), QUIVER_OK); + ASSERT_EQ(out_count, 1u); + EXPECT_STREQ(out_labels[0], "val2"); + quiver_binary_metadata_free_string_array(out_labels, out_count); + quiver_binary_metadata_free(out_md); + + ASSERT_EQ(quiver_expression_save(sel, path_out.c_str()), QUIVER_OK); + quiver_expression_close(a); + quiver_expression_close(sel); + + // val2 is k=1, so cell value = 10r + c + 1. + auto cell = read_one_cell(path_out, {"row", "col"}, {1, 1}); + ASSERT_EQ(cell.size(), 1u); + EXPECT_DOUBLE_EQ(cell[0], 12.0); + auto cell2 = read_one_cell(path_out, {"row", "col"}, {3, 2}); + EXPECT_DOUBLE_EQ(cell2[0], 33.0); +} + +TEST_F(ExpressionCApiFixture, SelectAgentsReorder) { + write_fixture(path_a, [](int r, int c, int k) { return static_cast(r * 10 + c + k); }); + + auto* a = expr_from_file(path_a); + const char* labels[] = {"val2", "val1"}; + quiver_expression_t* sel = nullptr; + ASSERT_EQ(quiver_expression_select_agents(a, labels, 2, &sel), QUIVER_OK); + ASSERT_EQ(quiver_expression_save(sel, path_out.c_str()), QUIVER_OK); + quiver_expression_close(a); + quiver_expression_close(sel); + + auto cell = read_one_cell(path_out, {"row", "col"}, {1, 1}); + ASSERT_EQ(cell.size(), 2u); + EXPECT_DOUBLE_EQ(cell[0], 12.0); // val2 + EXPECT_DOUBLE_EQ(cell[1], 11.0); // val1 +} + +TEST_F(ExpressionCApiFixture, SelectAgentsMissingReturnsError) { + write_fixture(path_a, [](int, int, int) { return 1.0; }); + auto* a = expr_from_file(path_a); + const char* labels[] = {"nope"}; + quiver_expression_t* sel = nullptr; + EXPECT_EQ(quiver_expression_select_agents(a, labels, 1, &sel), QUIVER_ERROR); + EXPECT_EQ(sel, nullptr); + quiver_expression_close(a); +} + +TEST_F(ExpressionCApiFixture, SelectAgentsNullArguments) { + write_fixture(path_a, [](int, int, int) { return 1.0; }); + auto* a = expr_from_file(path_a); + const char* labels[] = {"val1"}; + quiver_expression_t* sel = nullptr; + + EXPECT_EQ(quiver_expression_select_agents(nullptr, labels, 1, &sel), QUIVER_ERROR); + EXPECT_EQ(quiver_expression_select_agents(a, nullptr, 1, &sel), QUIVER_ERROR); + EXPECT_EQ(quiver_expression_select_agents(a, labels, 1, nullptr), QUIVER_ERROR); + + quiver_expression_close(a); +} + +// ============================================================================ +// quiver_expression_rename_agents +// ============================================================================ + +TEST_F(ExpressionCApiFixture, RenameAgentsPartial) { + write_fixture(path_a, [](int r, int c, int k) { return static_cast(r * 10 + c + k); }); + + auto* a = expr_from_file(path_a); + const char* old_labels[] = {"val1"}; + const char* new_labels[] = {"alpha"}; + quiver_expression_t* ren = nullptr; + ASSERT_EQ(quiver_expression_rename_agents(a, old_labels, new_labels, 1, &ren), QUIVER_OK); + + quiver_binary_metadata_t* out_md = nullptr; + ASSERT_EQ(quiver_expression_get_metadata(ren, &out_md), QUIVER_OK); + char** out_labels = nullptr; + size_t out_count = 0; + ASSERT_EQ(quiver_binary_metadata_get_labels(out_md, &out_labels, &out_count), QUIVER_OK); + ASSERT_EQ(out_count, 2u); + EXPECT_STREQ(out_labels[0], "alpha"); + EXPECT_STREQ(out_labels[1], "val2"); + quiver_binary_metadata_free_string_array(out_labels, out_count); + quiver_binary_metadata_free(out_md); + + ASSERT_EQ(quiver_expression_save(ren, path_out.c_str()), QUIVER_OK); + quiver_expression_close(a); + quiver_expression_close(ren); + + auto cell = read_one_cell(path_out, {"row", "col"}, {1, 1}); + ASSERT_EQ(cell.size(), 2u); + EXPECT_DOUBLE_EQ(cell[0], 11.0); // unchanged (was val1) + EXPECT_DOUBLE_EQ(cell[1], 12.0); // unchanged (val2) +} + +TEST_F(ExpressionCApiFixture, RenameAgentsMissingReturnsError) { + write_fixture(path_a, [](int, int, int) { return 1.0; }); + auto* a = expr_from_file(path_a); + const char* old_labels[] = {"nope"}; + const char* new_labels[] = {"x"}; + quiver_expression_t* ren = nullptr; + EXPECT_EQ(quiver_expression_rename_agents(a, old_labels, new_labels, 1, &ren), QUIVER_ERROR); + EXPECT_EQ(ren, nullptr); + quiver_expression_close(a); +} + +TEST_F(ExpressionCApiFixture, RenameAgentsCollisionReturnsError) { + write_fixture(path_a, [](int, int, int) { return 1.0; }); + auto* a = expr_from_file(path_a); + const char* old_labels[] = {"val1"}; + const char* new_labels[] = {"val2"}; + quiver_expression_t* ren = nullptr; + EXPECT_EQ(quiver_expression_rename_agents(a, old_labels, new_labels, 1, &ren), QUIVER_ERROR); + quiver_expression_close(a); +} + +TEST_F(ExpressionCApiFixture, RenameAgentsNullArguments) { + write_fixture(path_a, [](int, int, int) { return 1.0; }); + auto* a = expr_from_file(path_a); + const char* old_labels[] = {"val1"}; + const char* new_labels[] = {"alpha"}; + quiver_expression_t* ren = nullptr; + + EXPECT_EQ(quiver_expression_rename_agents(nullptr, old_labels, new_labels, 1, &ren), QUIVER_ERROR); + EXPECT_EQ(quiver_expression_rename_agents(a, nullptr, new_labels, 1, &ren), QUIVER_ERROR); + EXPECT_EQ(quiver_expression_rename_agents(a, old_labels, nullptr, 1, &ren), QUIVER_ERROR); + EXPECT_EQ(quiver_expression_rename_agents(a, old_labels, new_labels, 1, nullptr), QUIVER_ERROR); + + quiver_expression_close(a); +} diff --git a/tests/test_expression.cpp b/tests/test_expression.cpp index 6d76abc3..945de306 100644 --- a/tests/test_expression.cpp +++ b/tests/test_expression.cpp @@ -1947,3 +1947,172 @@ TEST_F(ExpressionFixture, IfElseChainsWithBinary) { EXPECT_DOUBLE_EQ(vo[i], 2.0 * base + 1.0); } } + +// ============================================================================= +// ExpressionSelectAgents — label-axis projection +// ============================================================================= + +TEST_F(ExpressionFixture, SelectAgentsSubset) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector& dims, size_t k) { + return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); + }); + auto a = BinaryFile::open_file(path_a, 'r'); + auto out = Expression(a).select_agents({"val2"}); + out.save(path_out); + + const auto& m = out.metadata(); + ASSERT_EQ(m.labels.size(), 1u); + EXPECT_EQ(m.labels[0], "val2"); + + auto vo = read_all_cells(path_out); + ASSERT_EQ(vo.size(), 6u); + // val2 is k=1: cells = 10r + c + 1 + EXPECT_DOUBLE_EQ(vo[0], 12.0); // r=1, c=1 + EXPECT_DOUBLE_EQ(vo[1], 13.0); // r=1, c=2 + EXPECT_DOUBLE_EQ(vo[2], 22.0); // r=2, c=1 + EXPECT_DOUBLE_EQ(vo[3], 23.0); + EXPECT_DOUBLE_EQ(vo[4], 32.0); + EXPECT_DOUBLE_EQ(vo[5], 33.0); +} + +TEST_F(ExpressionFixture, SelectAgentsReorder) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector& dims, size_t k) { + return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); + }); + auto a = BinaryFile::open_file(path_a, 'r'); + auto out = Expression(a).select_agents({"val2", "val1"}); + out.save(path_out); + + const auto& m = out.metadata(); + ASSERT_EQ(m.labels.size(), 2u); + EXPECT_EQ(m.labels[0], "val2"); + EXPECT_EQ(m.labels[1], "val1"); + + auto vo = read_all_cells(path_out); + // Per cell pair: [val2 first (= 10r+c+1), val1 second (= 10r+c)] + EXPECT_DOUBLE_EQ(vo[0], 12.0); // r=1, c=1, val2 + EXPECT_DOUBLE_EQ(vo[1], 11.0); // r=1, c=1, val1 +} + +TEST_F(ExpressionFixture, SelectAgentsDuplicateThrows) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector&, size_t k) { return static_cast(k); }); + auto a = BinaryFile::open_file(path_a, 'r'); + + // BinaryMetadata::validate requires unique labels, so duplicates are rejected at construction. + EXPECT_THROW(Expression(a).select_agents({"val1", "val1"}), std::runtime_error); +} + +TEST_F(ExpressionFixture, SelectAgentsMissingThrows) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); + auto a = BinaryFile::open_file(path_a, 'r'); + EXPECT_THROW(Expression(a).select_agents({"nope"}), std::runtime_error); +} + +TEST_F(ExpressionFixture, SelectAgentsAfterBinary) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector&, size_t k) { return 10.0 + static_cast(k); }); + write_qvr(path_b, md, [](const std::vector&, size_t k) { return 20.0 + static_cast(k); }); + auto a = BinaryFile::open_file(path_a, 'r'); + auto b = BinaryFile::open_file(path_b, 'r'); + auto sum = Expression(a) + Expression(b); + sum.select_agents({"val1"}).save(path_out); + + auto vo = read_all_cells(path_out); + ASSERT_EQ(vo.size(), 6u); + // val1 (k=0): 10 + 20 = 30 in every cell. + for (double v : vo) EXPECT_DOUBLE_EQ(v, 30.0); +} + +// ============================================================================= +// ExpressionRenameAgents — label-axis rename +// ============================================================================= + +TEST_F(ExpressionFixture, RenameAgentsPartial) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector& dims, size_t k) { + return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); + }); + auto a = BinaryFile::open_file(path_a, 'r'); + auto out = Expression(a).rename_agents({{"val1", "alpha"}}); + out.save(path_out); + + const auto& m = out.metadata(); + ASSERT_EQ(m.labels.size(), 2u); + EXPECT_EQ(m.labels[0], "alpha"); + EXPECT_EQ(m.labels[1], "val2"); + + auto orig = read_all_cells(path_a); + auto renamed = read_all_cells(path_out); + ASSERT_EQ(orig.size(), renamed.size()); + for (size_t i = 0; i < orig.size(); ++i) EXPECT_DOUBLE_EQ(orig[i], renamed[i]); +} + +TEST_F(ExpressionFixture, RenameAgentsAll) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); + auto a = BinaryFile::open_file(path_a, 'r'); + auto out = Expression(a).rename_agents({{"val1", "alpha"}, {"val2", "beta"}}); + + const auto& m = out.metadata(); + ASSERT_EQ(m.labels.size(), 2u); + EXPECT_EQ(m.labels[0], "alpha"); + EXPECT_EQ(m.labels[1], "beta"); +} + +TEST_F(ExpressionFixture, RenameAgentsMissingThrows) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); + auto a = BinaryFile::open_file(path_a, 'r'); + EXPECT_THROW(Expression(a).rename_agents({{"nope", "x"}}), std::runtime_error); +} + +TEST_F(ExpressionFixture, RenameAgentsCollisionThrows) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); + auto a = BinaryFile::open_file(path_a, 'r'); + // Rename val1 -> val2 leaves output labels = {"val2", "val2"} which validate() rejects. + EXPECT_THROW(Expression(a).rename_agents({{"val1", "val2"}}), std::runtime_error); +} + +TEST_F(ExpressionFixture, RenameAgentsDuplicateKeyThrows) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); + auto a = BinaryFile::open_file(path_a, 'r'); + EXPECT_THROW(Expression(a).rename_agents({{"val1", "alpha"}, {"val1", "beta"}}), std::runtime_error); +} + +TEST_F(ExpressionFixture, ChainSelectThenRename) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector& dims, size_t k) { + return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); + }); + auto a = BinaryFile::open_file(path_a, 'r'); + Expression(a).select_agents({"val2"}).rename_agents({{"val2", "renamed"}}).save(path_out); + + auto reopened = BinaryFile::open_file(path_out, 'r'); + const auto& m = reopened.get_metadata(); + ASSERT_EQ(m.labels.size(), 1u); + EXPECT_EQ(m.labels[0], "renamed"); + + auto vo = read_all_cells(path_out); + // Same values as val2 column from the original (10r + c + 1). + EXPECT_DOUBLE_EQ(vo[0], 12.0); + EXPECT_DOUBLE_EQ(vo[1], 13.0); +} + +TEST_F(ExpressionFixture, RenameAgentsSaveProducesReadableFile) { + auto md = make_simple_metadata(); + write_qvr(path_a, md, [](const std::vector&, size_t) { return 7.0; }); + auto a = BinaryFile::open_file(path_a, 'r'); + Expression(a).rename_agents({{"val1", "alpha"}}).save(path_out); + + auto reopened = BinaryFile::open_file(path_out, 'r'); + const auto& m = reopened.get_metadata(); + ASSERT_EQ(m.labels.size(), 2u); + EXPECT_EQ(m.labels[0], "alpha"); + EXPECT_EQ(m.labels[1], "val2"); +} From e298e354bd7e57d0bd9639c08798a9d1818d84d6 Mon Sep 17 00:00:00 2001 From: raphasampaio Date: Wed, 13 May 2026 23:45:11 -0300 Subject: [PATCH 09/13] Update --- tests/test_expression.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_expression.cpp b/tests/test_expression.cpp index 945de306..a643d727 100644 --- a/tests/test_expression.cpp +++ b/tests/test_expression.cpp @@ -2024,7 +2024,8 @@ TEST_F(ExpressionFixture, SelectAgentsAfterBinary) { auto vo = read_all_cells(path_out); ASSERT_EQ(vo.size(), 6u); // val1 (k=0): 10 + 20 = 30 in every cell. - for (double v : vo) EXPECT_DOUBLE_EQ(v, 30.0); + for (double v : vo) + EXPECT_DOUBLE_EQ(v, 30.0); } // ============================================================================= @@ -2048,7 +2049,8 @@ TEST_F(ExpressionFixture, RenameAgentsPartial) { auto orig = read_all_cells(path_a); auto renamed = read_all_cells(path_out); ASSERT_EQ(orig.size(), renamed.size()); - for (size_t i = 0; i < orig.size(); ++i) EXPECT_DOUBLE_EQ(orig[i], renamed[i]); + for (size_t i = 0; i < orig.size(); ++i) + EXPECT_DOUBLE_EQ(orig[i], renamed[i]); } TEST_F(ExpressionFixture, RenameAgentsAll) { From a950ab44fee1d750b28aa37863358e8ad997ea75 Mon Sep 17 00:00:00 2001 From: raphasampaio Date: Thu, 14 May 2026 07:13:42 -0300 Subject: [PATCH 10/13] refactor: split expression_node.cpp into per-class files Replace the 1013-line expression_node.cpp with one .cpp per node class (ExpressionFile, Scalar, Binary, Unary, Ternary, Aggregate, AggregateAgents, SelectAgents, RenameAgents) plus a shared expression_helpers.h holding the broadcast/validation/aggregation helpers previously in an anonymous namespace. Mechanical refactor: no behavior change, all 7 test suites pass unchanged. Also fix a pre-existing C2375 linkage mismatch by adding QUIVER_API to friend declarations in expression.h. The friends were declared without the DLL attribute while the matching free function decls had QUIVER_API; MSVC /permissive- rejects this. The mismatch was introduced when friend decls replaced the public node() accessor, and was masked until now by stale incremental build state. Co-Authored-By: Claude Opus 4.7 (1M context) --- CLAUDE.md | 11 +- include/quiver/expression/expression.h | 40 +- src/CMakeLists.txt | 10 +- src/expression/expression_aggregate.cpp | 195 ++++ .../expression_aggregate_agents.cpp | 122 ++ src/expression/expression_binary.cpp | 109 ++ src/expression/expression_file.cpp | 30 + src/expression/expression_helpers.h | 353 ++++++ src/expression/expression_node.cpp | 1013 ----------------- src/expression/expression_rename_agents.cpp | 59 + src/expression/expression_scalar.cpp | 22 + src/expression/expression_select_agents.cpp | 57 + src/expression/expression_ternary.cpp | 133 +++ src/expression/expression_unary.cpp | 52 + 14 files changed, 1171 insertions(+), 1035 deletions(-) create mode 100644 src/expression/expression_aggregate.cpp create mode 100644 src/expression/expression_aggregate_agents.cpp create mode 100644 src/expression/expression_binary.cpp create mode 100644 src/expression/expression_file.cpp create mode 100644 src/expression/expression_helpers.h delete mode 100644 src/expression/expression_node.cpp create mode 100644 src/expression/expression_rename_agents.cpp create mode 100644 src/expression/expression_scalar.cpp create mode 100644 src/expression/expression_select_agents.cpp create mode 100644 src/expression/expression_ternary.cpp create mode 100644 src/expression/expression_unary.cpp diff --git a/CLAUDE.md b/CLAUDE.md index 24cc0ad6..bd58c29f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -44,7 +44,16 @@ src/binary/ # Binary C++ implementation time_properties.cpp # TimeFrequency string conversion src/expression/ # Expression C++ implementation expression.cpp # Expression class, operator overloads, save engine - expression_node.cpp # ExpressionFile/Scalar/Binary impls + Unary/Ternary/Aggregation scaffold impls + validation/broadcast helpers + expression_helpers.h # Shared inline helpers (validation, broadcast metadata, aggregation templates, percentile) + expression_file.cpp # ExpressionFile (leaf reading from .qvr) + expression_scalar.cpp # ExpressionScalar (constant broadcast) + expression_binary.cpp # ExpressionBinary (Add/Sub/Mul/Div) + expression_unary.cpp # ExpressionUnary (Negate/Abs/Sqrt/Log/Exp) + expression_ternary.cpp # ExpressionTernary (IfElse) + expression_aggregate.cpp # ExpressionAggregate (dimension-axis Sum/Mean/Min/Max/Percentile) + expression_aggregate_agents.cpp # ExpressionAggregateAgents (label-axis reduction) + expression_select_agents.cpp # ExpressionSelectAgents (label-axis projection) + expression_rename_agents.cpp # ExpressionRenameAgents (label-axis rename) src/c/ # C API implementation internal.h # Shared structs (quiver_database, quiver_element, quiver_binary_file, quiver_binary_metadata), QUIVER_REQUIRE macro database_helpers.h # Marshaling templates, strdup_safe, metadata converters diff --git a/include/quiver/expression/expression.h b/include/quiver/expression/expression.h index bf45ac54..730e3826 100644 --- a/include/quiver/expression/expression.h +++ b/include/quiver/expression/expression.h @@ -35,26 +35,26 @@ class QUIVER_API Expression { Expression rename_agents(const std::vector>& mapping) const; private: - friend Expression operator+(const Expression&, const Expression&); - friend Expression operator+(const Expression&, double); - friend Expression operator+(double, const Expression&); - friend Expression operator-(const Expression&, const Expression&); - friend Expression operator-(const Expression&, double); - friend Expression operator-(double, const Expression&); - friend Expression operator*(const Expression&, const Expression&); - friend Expression operator*(const Expression&, double); - friend Expression operator*(double, const Expression&); - friend Expression operator/(const Expression&, const Expression&); - friend Expression operator/(const Expression&, double); - friend Expression operator/(double, const Expression&); - - friend Expression operator-(const Expression&); - friend Expression abs(const Expression&); - friend Expression sqrt(const Expression&); - friend Expression log(const Expression&); - friend Expression exp(const Expression&); - - friend Expression ifelse(const Expression&, const Expression&, const Expression&); + friend QUIVER_API Expression operator+(const Expression&, const Expression&); + friend QUIVER_API Expression operator+(const Expression&, double); + friend QUIVER_API Expression operator+(double, const Expression&); + friend QUIVER_API Expression operator-(const Expression&, const Expression&); + friend QUIVER_API Expression operator-(const Expression&, double); + friend QUIVER_API Expression operator-(double, const Expression&); + friend QUIVER_API Expression operator*(const Expression&, const Expression&); + friend QUIVER_API Expression operator*(const Expression&, double); + friend QUIVER_API Expression operator*(double, const Expression&); + friend QUIVER_API Expression operator/(const Expression&, const Expression&); + friend QUIVER_API Expression operator/(const Expression&, double); + friend QUIVER_API Expression operator/(double, const Expression&); + + friend QUIVER_API Expression operator-(const Expression&); + friend QUIVER_API Expression abs(const Expression&); + friend QUIVER_API Expression sqrt(const Expression&); + friend QUIVER_API Expression log(const Expression&); + friend QUIVER_API Expression exp(const Expression&); + + friend QUIVER_API Expression ifelse(const Expression&, const Expression&, const Expression&); std::shared_ptr node_; }; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 38689d65..b9f18260 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -27,7 +27,15 @@ set(QUIVER_SOURCES binary/dimension.cpp binary/time_properties.cpp expression/expression.cpp - expression/expression_node.cpp + expression/expression_file.cpp + expression/expression_scalar.cpp + expression/expression_binary.cpp + expression/expression_unary.cpp + expression/expression_ternary.cpp + expression/expression_aggregate.cpp + expression/expression_aggregate_agents.cpp + expression/expression_select_agents.cpp + expression/expression_rename_agents.cpp ) # Build type diff --git a/src/expression/expression_aggregate.cpp b/src/expression/expression_aggregate.cpp new file mode 100644 index 00000000..1a3f95f9 --- /dev/null +++ b/src/expression/expression_aggregate.cpp @@ -0,0 +1,195 @@ +#include "quiver/expression/expression_node.h" + +#include "expression_helpers.h" +#include "quiver/binary/iteration.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace quiver { + +ExpressionAggregate::Operation ExpressionAggregate::parse_operation(const std::string& name) { + return parse_aggregation_operation_name(name, "aggregate"); +} + +ExpressionAggregate::ExpressionAggregate(Operation operation, + std::shared_ptr operand, + std::string dimension_name, + std::optional parameter) + : operation_(operation), operand_(std::move(operand)), dimension_name_(std::move(dimension_name)), + parameter_(parameter) { + validate_aggregation_param(operation_, parameter_, "aggregate"); + + const auto& operand_meta = operand_->metadata(); + + const int reduced_idx = find_dim_index(operand_meta.dimensions, dimension_name_); + if (reduced_idx < 0) { + throw std::runtime_error("Dimension not found: '" + dimension_name_ + "' in operand metadata"); + } + reduced_operand_index_ = reduced_idx; + + const auto& reduced_dim = operand_meta.dimensions[reduced_idx]; + const int64_t grandparent_orig_idx = + reduced_dim.is_time_dimension() ? reduced_dim.time->parent_dimension_index : -1; + + output_meta_ = operand_meta; + output_meta_.dimensions.erase(output_meta_.dimensions.begin() + reduced_idx); + if (reduced_dim.is_time_dimension()) { + --output_meta_.number_of_time_dimensions; + } + + operand_to_out_.assign(operand_meta.dimensions.size(), -1); + for (size_t i = 0; i < operand_meta.dimensions.size(); ++i) { + if (static_cast(i) == reduced_idx) { + continue; + } + operand_to_out_[i] = (static_cast(i) < reduced_idx) ? static_cast(i) : static_cast(i) - 1; + } + + for (size_t out_i = 0; out_i < output_meta_.dimensions.size(); ++out_i) { + auto& out_dim = output_meta_.dimensions[out_i]; + if (!out_dim.is_time_dimension()) { + continue; + } + const int operand_idx = + (static_cast(out_i) < reduced_idx) ? static_cast(out_i) : static_cast(out_i) + 1; + const int64_t orig_parent = operand_meta.dimensions[operand_idx].time->parent_dimension_index; + if (orig_parent < 0) { + out_dim.time->parent_dimension_index = -1; + } else if (orig_parent == reduced_idx) { + if (grandparent_orig_idx < 0) { + out_dim.time->parent_dimension_index = -1; + } else { + out_dim.time->parent_dimension_index = + (grandparent_orig_idx < reduced_idx) ? grandparent_orig_idx : grandparent_orig_idx - 1; + } + } else { + out_dim.time->parent_dimension_index = (orig_parent < reduced_idx) ? orig_parent : orig_parent - 1; + } + } + + output_meta_.validate(); + + operand_dims_buf_.resize(operand_meta.dimensions.size()); + operand_row_buf_.resize(operand_meta.labels.size()); + if (operation_ == Operation::Percentile) { + percentile_scratch_.resize(operand_meta.labels.size()); + } +} + +const BinaryMetadata& ExpressionAggregate::metadata() const { + return output_meta_; +} + +void ExpressionAggregate::compute_row(const std::vector& dims, std::vector& out) const { + const auto label_count = operand_row_buf_.size(); + if (out.size() != label_count) { + out.resize(label_count); + } + + const auto& operand_meta = operand_->metadata(); + + for (size_t i = 0; i < operand_meta.dimensions.size(); ++i) { + if (static_cast(i) == reduced_operand_index_) { + continue; + } + operand_dims_buf_[i] = dims[operand_to_out_[i]]; + } + operand_dims_buf_[reduced_operand_index_] = 1; + + const auto& reduced_dim = operand_meta.dimensions[reduced_operand_index_]; + int64_t start = 1; + int64_t end = reduced_dim.size; + if (reduced_dim.is_time_dimension()) { + const auto& tp = *reduced_dim.time; + const int64_t parent_idx = tp.parent_dimension_index; + const auto sizes = dimension_sizes_at_values(operand_meta, operand_dims_buf_); + end = sizes[reduced_operand_index_]; + if (parent_idx < 0) { + start = tp.initial_value; + } else { + const auto& parent_dim = operand_meta.dimensions[parent_idx]; + const int64_t parent_initial = parent_dim.is_time_dimension() ? parent_dim.time->initial_value : 1; + start = (operand_dims_buf_[parent_idx] == parent_initial) ? tp.initial_value : 1; + } + } + + std::vector sum_buf(label_count, 0.0); + std::vector count_buf(label_count, 0); + std::vector min_buf(label_count, std::numeric_limits::infinity()); + std::vector max_buf(label_count, -std::numeric_limits::infinity()); + + if (operation_ == Operation::Percentile) { + for (auto& scratch : percentile_scratch_) { + scratch.clear(); + } + } + + for (int64_t v = start; v <= end; ++v) { + operand_dims_buf_[reduced_operand_index_] = v; + operand_->compute_row(operand_dims_buf_, operand_row_buf_); + + for (size_t k = 0; k < label_count; ++k) { + const double value = operand_row_buf_[k]; + if (std::isnan(value)) { + continue; + } + switch (operation_) { + case Operation::Sum: + case Operation::Mean: + sum_buf[k] += value; + ++count_buf[k]; + break; + case Operation::Min: + if (value < min_buf[k]) { + min_buf[k] = value; + } + ++count_buf[k]; + break; + case Operation::Max: + if (value > max_buf[k]) { + max_buf[k] = value; + } + ++count_buf[k]; + break; + case Operation::Percentile: + percentile_scratch_[k].push_back(value); + break; + } + } + } + + const double nan_value = std::numeric_limits::quiet_NaN(); + for (size_t k = 0; k < label_count; ++k) { + switch (operation_) { + case Operation::Sum: + out[k] = (count_buf[k] > 0) ? sum_buf[k] : nan_value; + break; + case Operation::Mean: + out[k] = (count_buf[k] > 0) ? sum_buf[k] / static_cast(count_buf[k]) : nan_value; + break; + case Operation::Min: + out[k] = (count_buf[k] > 0) ? min_buf[k] : nan_value; + break; + case Operation::Max: + out[k] = (count_buf[k] > 0) ? max_buf[k] : nan_value; + break; + case Operation::Percentile: + out[k] = compute_percentile(percentile_scratch_[k], *parameter_); + break; + } + } +} + +void ExpressionAggregate::collect_input_files(std::vector& out) const { + operand_->collect_input_files(out); +} + +} // namespace quiver diff --git a/src/expression/expression_aggregate_agents.cpp b/src/expression/expression_aggregate_agents.cpp new file mode 100644 index 00000000..016cef0a --- /dev/null +++ b/src/expression/expression_aggregate_agents.cpp @@ -0,0 +1,122 @@ +#include "quiver/expression/expression_node.h" + +#include "expression_helpers.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace quiver { + +ExpressionAggregateAgents::Operation ExpressionAggregateAgents::parse_operation(const std::string& name) { + return parse_aggregation_operation_name(name, "aggregate_agents"); +} + +ExpressionAggregateAgents::ExpressionAggregateAgents(Operation operation, + std::shared_ptr operand, + std::optional parameter) + : operation_(operation), operand_(std::move(operand)), parameter_(parameter) { + validate_aggregation_param(operation_, parameter_, "aggregate_agents"); + + const auto& operand_meta = operand_->metadata(); + output_meta_ = operand_meta; + output_meta_.labels = {aggregation_operation_label(operation_)}; + output_meta_.validate(); + + operand_row_buf_.resize(operand_meta.labels.size()); +} + +const BinaryMetadata& ExpressionAggregateAgents::metadata() const { + return output_meta_; +} + +void ExpressionAggregateAgents::compute_row(const std::vector& dims, std::vector& out) const { + if (out.size() != 1) { + out.resize(1); + } + + operand_->compute_row(dims, operand_row_buf_); + + const double nan_value = std::numeric_limits::quiet_NaN(); + + switch (operation_) { + case Operation::Sum: { + double sum = 0.0; + int64_t count = 0; + for (double v : operand_row_buf_) { + if (std::isnan(v)) { + continue; + } + sum += v; + ++count; + } + out[0] = (count > 0) ? sum : nan_value; + break; + } + case Operation::Mean: { + double sum = 0.0; + int64_t count = 0; + for (double v : operand_row_buf_) { + if (std::isnan(v)) { + continue; + } + sum += v; + ++count; + } + out[0] = (count > 0) ? sum / static_cast(count) : nan_value; + break; + } + case Operation::Min: { + double m = std::numeric_limits::infinity(); + int64_t count = 0; + for (double v : operand_row_buf_) { + if (std::isnan(v)) { + continue; + } + if (v < m) { + m = v; + } + ++count; + } + out[0] = (count > 0) ? m : nan_value; + break; + } + case Operation::Max: { + double m = -std::numeric_limits::infinity(); + int64_t count = 0; + for (double v : operand_row_buf_) { + if (std::isnan(v)) { + continue; + } + if (v > m) { + m = v; + } + ++count; + } + out[0] = (count > 0) ? m : nan_value; + break; + } + case Operation::Percentile: { + std::vector scratch; + scratch.reserve(operand_row_buf_.size()); + for (double v : operand_row_buf_) { + if (!std::isnan(v)) { + scratch.push_back(v); + } + } + out[0] = compute_percentile(scratch, *parameter_); + break; + } + } +} + +void ExpressionAggregateAgents::collect_input_files(std::vector& out) const { + operand_->collect_input_files(out); +} + +} // namespace quiver diff --git a/src/expression/expression_binary.cpp b/src/expression/expression_binary.cpp new file mode 100644 index 00000000..8a414f47 --- /dev/null +++ b/src/expression/expression_binary.cpp @@ -0,0 +1,109 @@ +#include "quiver/expression/expression_node.h" + +#include "expression_helpers.h" + +#include +#include +#include +#include +#include + +namespace quiver { + +double ExpressionBinary::apply(Operation operation, double lhs, double rhs) { + switch (operation) { + case Operation::Add: + return lhs + rhs; + case Operation::Subtract: + return lhs - rhs; + case Operation::Multiply: + return lhs * rhs; + case Operation::Divide: + return lhs / rhs; + } + throw std::runtime_error("Cannot apply: unhandled ExpressionBinary::Operation variant"); +} + +ExpressionBinary::ExpressionBinary(Operation operation, + std::shared_ptr lhs, + std::shared_ptr rhs) + : operation_(operation), lhs_(std::move(lhs)), rhs_(std::move(rhs)) { + const auto& lhs_meta = lhs_->metadata(); + const auto& rhs_meta = rhs_->metadata(); + + validate_compatibility(lhs_meta, rhs_meta); + auto output_labels = compute_output_labels(lhs_meta.labels, rhs_meta.labels); + broadcast_meta_ = build_broadcast_metadata(lhs_meta, rhs_meta, std::move(output_labels)); + broadcast_meta_.validate(); + + const auto& out_dims = broadcast_meta_.dimensions; + lhs_dim_sizes_.assign(out_dims.size(), 0); + rhs_dim_sizes_.assign(out_dims.size(), 0); + lhs_to_out_.assign(lhs_meta.dimensions.size(), -1); + rhs_to_out_.assign(rhs_meta.dimensions.size(), -1); + + for (size_t out_i = 0; out_i < out_dims.size(); ++out_i) { + const auto li = find_dim_index(lhs_meta.dimensions, out_dims[out_i].name); + const auto ri = find_dim_index(rhs_meta.dimensions, out_dims[out_i].name); + lhs_dim_sizes_[out_i] = (li >= 0) ? lhs_meta.dimensions[li].size : 0; + rhs_dim_sizes_[out_i] = (ri >= 0) ? rhs_meta.dimensions[ri].size : 0; + if (li >= 0) { + lhs_to_out_[li] = static_cast(out_i); + } + if (ri >= 0) { + rhs_to_out_[ri] = static_cast(out_i); + } + } + + lhs_label_count_ = lhs_meta.labels.size(); + rhs_label_count_ = rhs_meta.labels.size(); + + lhs_dims_buf_.resize(lhs_meta.dimensions.size()); + rhs_dims_buf_.resize(rhs_meta.dimensions.size()); + lhs_buf_.resize(lhs_label_count_); + rhs_buf_.resize(rhs_label_count_); +} + +const BinaryMetadata& ExpressionBinary::metadata() const { + return broadcast_meta_; +} + +void ExpressionBinary::compute_row(const std::vector& dims, std::vector& out) const { + const auto out_label_count = broadcast_meta_.labels.size(); + if (out.size() != out_label_count) { + out.resize(out_label_count); + } + + for (size_t li = 0; li < lhs_dims_buf_.size(); ++li) { + const auto out_i = lhs_to_out_[li]; + auto coord = dims[out_i]; + if (lhs_dim_sizes_[out_i] == 1) { + coord = 1; + } + lhs_dims_buf_[li] = coord; + } + for (size_t ri = 0; ri < rhs_dims_buf_.size(); ++ri) { + const auto out_i = rhs_to_out_[ri]; + auto coord = dims[out_i]; + if (rhs_dim_sizes_[out_i] == 1) { + coord = 1; + } + rhs_dims_buf_[ri] = coord; + } + + lhs_->compute_row(lhs_dims_buf_, lhs_buf_); + rhs_->compute_row(rhs_dims_buf_, rhs_buf_); + + for (size_t k = 0; k < out_label_count; ++k) { + const size_t li = (lhs_label_count_ == 1) ? 0 : k; + const size_t ri = (rhs_label_count_ == 1) ? 0 : k; + out[k] = apply(operation_, lhs_buf_[li], rhs_buf_[ri]); + } +} + +void ExpressionBinary::collect_input_files(std::vector& out) const { + lhs_->collect_input_files(out); + rhs_->collect_input_files(out); +} + +} // namespace quiver diff --git a/src/expression/expression_file.cpp b/src/expression/expression_file.cpp new file mode 100644 index 00000000..d9dd2a54 --- /dev/null +++ b/src/expression/expression_file.cpp @@ -0,0 +1,30 @@ +#include "quiver/expression/expression_node.h" + +#include "quiver/binary/binary_file.h" + +#include +#include +#include + +namespace quiver { + +ExpressionFile::ExpressionFile(const std::string& path) : meta_(BinaryMetadata::from_toml_file(path)), file_(path) { + dim_map_.reserve(meta_.dimensions.size()); +} + +const BinaryMetadata& ExpressionFile::metadata() const { + return meta_; +} + +void ExpressionFile::compute_row(const std::vector& dims, std::vector& out) const { + for (size_t i = 0; i < meta_.dimensions.size(); ++i) { + dim_map_[meta_.dimensions[i].name] = dims[i]; + } + out = file_.read(dim_map_, /*allow_nulls=*/true); +} + +void ExpressionFile::collect_input_files(std::vector& out) const { + out.push_back(&file_); +} + +} // namespace quiver diff --git a/src/expression/expression_helpers.h b/src/expression/expression_helpers.h new file mode 100644 index 00000000..a553b822 --- /dev/null +++ b/src/expression/expression_helpers.h @@ -0,0 +1,353 @@ +#ifndef QUIVER_EXPRESSION_HELPERS_H +#define QUIVER_EXPRESSION_HELPERS_H + +#include "quiver/binary/binary_metadata.h" +#include "quiver/expression/expression_node.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace quiver { + +inline int find_dim_index(const std::vector& dims, const std::string& name) { + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i].name == name) { + return static_cast(i); + } + } + return -1; +} + +inline bool any_time_dim(const std::vector& dims) { + for (const auto& d : dims) { + if (d.is_time_dimension()) { + return true; + } + } + return false; +} + +inline std::string parent_name_of(int64_t parent_idx, const BinaryMetadata& m) { + return (parent_idx >= 0) ? m.dimensions[parent_idx].name : std::string{}; +} + +inline void validate_unit_match(const BinaryMetadata& lhs, const BinaryMetadata& rhs) { + if (lhs.unit != rhs.unit) { + throw std::runtime_error("Cannot apply: units differ ('" + lhs.unit + "' vs '" + rhs.unit + "')"); + } +} + +inline void validate_shape_compatibility(const BinaryMetadata& lhs, const BinaryMetadata& rhs) { + for (const auto& l_dim : lhs.dimensions) { + auto r_idx = find_dim_index(rhs.dimensions, l_dim.name); + if (r_idx < 0) { + continue; + } + + auto l_size = l_dim.size; + auto r_size = rhs.dimensions[r_idx].size; + + if (l_size == r_size) { + continue; + } + + if (l_size == 1 || r_size == 1) { + continue; + } + + throw std::runtime_error("Cannot apply: dimension '" + l_dim.name + "' has incompatible sizes " + + std::to_string(l_size) + " vs " + std::to_string(r_size) + + " (broadcasting requires n x n, 1 x n, or n x 1)"); + } + + for (const auto& l_dim : lhs.dimensions) { + auto r_idx = find_dim_index(rhs.dimensions, l_dim.name); + if (r_idx < 0) { + continue; + } + + const auto& r_dim = rhs.dimensions[r_idx]; + const auto l_time = l_dim.is_time_dimension(); + const auto r_time = r_dim.is_time_dimension(); + if (l_time != r_time) { + const std::string time_side = l_time ? "lhs" : "rhs"; + const std::string nontime_side = l_time ? "rhs" : "lhs"; + throw std::runtime_error("Cannot apply: dimension '" + l_dim.name + "' is a time dimension on " + + time_side + " but not on " + nontime_side); + } + + if (!l_time) { + continue; + } + + const auto& lp = *l_dim.time; + const auto& rp = *r_dim.time; + const auto l_parent = parent_name_of(lp.parent_dimension_index, lhs); + const auto r_parent = parent_name_of(rp.parent_dimension_index, rhs); + if (lp.frequency != rp.frequency || lp.initial_value != rp.initial_value || l_parent != r_parent) { + throw std::runtime_error("Cannot apply: time dimension '" + l_dim.name + + "' has incompatible TimeProperties"); + } + } + + const auto lhs_has_time = any_time_dim(lhs.dimensions); + const auto rhs_has_time = any_time_dim(rhs.dimensions); + if (lhs_has_time && rhs_has_time && lhs.initial_datetime != rhs.initial_datetime) { + throw std::runtime_error("Cannot apply: initial_datetime differs"); + } +} + +inline void validate_compatibility(const BinaryMetadata& lhs, const BinaryMetadata& rhs) { + validate_unit_match(lhs, rhs); + validate_shape_compatibility(lhs, rhs); +} + +inline std::vector compute_output_labels(const std::vector& l_labels, + const std::vector& r_labels) { + const auto ll = l_labels.size(); + const auto rl = r_labels.size(); + if (ll == rl) { + if (l_labels != r_labels) { + throw std::runtime_error("Cannot apply: labels have same size " + std::to_string(ll) + + " but different content"); + } + return l_labels; + } + if (ll == 1 && rl > 1) { + return r_labels; + } + if (rl == 1 && ll > 1) { + return l_labels; + } + throw std::runtime_error("Cannot apply: labels have incompatible sizes " + std::to_string(ll) + " vs " + + std::to_string(rl)); +} + +inline BinaryMetadata +build_broadcast_metadata(const BinaryMetadata& lhs, const BinaryMetadata& rhs, std::vector output_labels) { + BinaryMetadata out; + out.version = lhs.version; + out.unit = lhs.unit; + out.labels = std::move(output_labels); + + const auto lhs_has_time = any_time_dim(lhs.dimensions); + const auto rhs_has_time = any_time_dim(rhs.dimensions); + out.initial_datetime = + lhs_has_time ? lhs.initial_datetime : (rhs_has_time ? rhs.initial_datetime : lhs.initial_datetime); + + std::unordered_map output_index_by_name; + for (const auto& l_dim : lhs.dimensions) { + auto r_idx = find_dim_index(rhs.dimensions, l_dim.name); + int64_t out_size = (r_idx >= 0) ? std::max(l_dim.size, rhs.dimensions[r_idx].size) : l_dim.size; + Dimension d{l_dim.name, out_size, l_dim.time}; + out.dimensions.push_back(std::move(d)); + output_index_by_name[l_dim.name] = static_cast(out.dimensions.size()) - 1; + } + for (const auto& r_dim : rhs.dimensions) { + if (output_index_by_name.count(r_dim.name)) + continue; + out.dimensions.push_back(r_dim); + output_index_by_name[r_dim.name] = static_cast(out.dimensions.size()) - 1; + } + for (auto& out_d : out.dimensions) { + if (!out_d.is_time_dimension()) + continue; + auto src_idx = find_dim_index(lhs.dimensions, out_d.name); + const auto* src_meta = &lhs; + if (src_idx < 0) { + src_idx = find_dim_index(rhs.dimensions, out_d.name); + src_meta = &rhs; + } + int64_t src_parent_idx = src_meta->dimensions[src_idx].time->parent_dimension_index; + if (src_parent_idx < 0) { + out_d.time->parent_dimension_index = -1; + continue; + } + const std::string& parent_name = src_meta->dimensions[src_parent_idx].name; + out_d.time->parent_dimension_index = output_index_by_name.find(parent_name)->second; + } + for (const auto& d : out.dimensions) { + if (d.is_time_dimension()) { + ++out.number_of_time_dimensions; + } + } + return out; +} + +inline std::vector compute_ternary_output_labels(const std::vector& c_labels, + const std::vector& t_labels, + const std::vector& e_labels) { + const std::vector*> non_singleton = [&] { + std::vector*> v; + if (c_labels.size() > 1) + v.push_back(&c_labels); + if (t_labels.size() > 1) + v.push_back(&t_labels); + if (e_labels.size() > 1) + v.push_back(&e_labels); + return v; + }(); + + if (non_singleton.empty()) { + return t_labels; + } + + for (size_t i = 1; i < non_singleton.size(); ++i) { + if (*non_singleton[i] != *non_singleton[0]) { + throw std::runtime_error("Cannot apply: labels are incompatible across operands " + "(non-singleton label sets must match)"); + } + } + return *non_singleton[0]; +} + +inline BinaryMetadata build_ternary_broadcast_metadata(const BinaryMetadata& cond, + const BinaryMetadata& then_meta, + const BinaryMetadata& else_meta, + std::vector output_labels) { + BinaryMetadata out; + out.version = then_meta.version; + out.unit = then_meta.unit; + out.labels = std::move(output_labels); + + const auto then_has_time = any_time_dim(then_meta.dimensions); + const auto else_has_time = any_time_dim(else_meta.dimensions); + const auto cond_has_time = any_time_dim(cond.dimensions); + if (then_has_time) { + out.initial_datetime = then_meta.initial_datetime; + } else if (else_has_time) { + out.initial_datetime = else_meta.initial_datetime; + } else if (cond_has_time) { + out.initial_datetime = cond.initial_datetime; + } else { + out.initial_datetime = then_meta.initial_datetime; + } + + const std::vector sources = {&cond, &then_meta, &else_meta}; + std::unordered_map output_index_by_name; + for (const auto* src : sources) { + for (const auto& dim : src->dimensions) { + if (output_index_by_name.count(dim.name)) + continue; + int64_t out_size = dim.size; + for (const auto* other : sources) { + if (other == src) + continue; + auto idx = find_dim_index(other->dimensions, dim.name); + if (idx >= 0) { + out_size = std::max(out_size, other->dimensions[idx].size); + } + } + Dimension d{dim.name, out_size, dim.time}; + out.dimensions.push_back(std::move(d)); + output_index_by_name[dim.name] = static_cast(out.dimensions.size()) - 1; + } + } + + for (auto& out_d : out.dimensions) { + if (!out_d.is_time_dimension()) + continue; + const BinaryMetadata* src_meta = nullptr; + int src_idx = -1; + for (const auto* s : sources) { + src_idx = find_dim_index(s->dimensions, out_d.name); + if (src_idx >= 0) { + src_meta = s; + break; + } + } + int64_t src_parent_idx = src_meta->dimensions[src_idx].time->parent_dimension_index; + if (src_parent_idx < 0) { + out_d.time->parent_dimension_index = -1; + continue; + } + const std::string& parent_name = src_meta->dimensions[src_parent_idx].name; + out_d.time->parent_dimension_index = output_index_by_name.find(parent_name)->second; + } + + for (const auto& d : out.dimensions) { + if (d.is_time_dimension()) { + ++out.number_of_time_dimensions; + } + } + return out; +} + +template +Op parse_aggregation_operation_name(const std::string& name, const std::string& fn_label) { + if (name == "sum") + return Op::Sum; + if (name == "mean") + return Op::Mean; + if (name == "min") + return Op::Min; + if (name == "max") + return Op::Max; + if (name == "percentile") + return Op::Percentile; + throw std::runtime_error("Cannot " + fn_label + ": unknown operation '" + name + + "' (expected one of: sum, mean, min, max, percentile)"); +} + +template +std::string aggregation_operation_label(Op op) { + switch (op) { + case Op::Sum: + return "sum"; + case Op::Mean: + return "mean"; + case Op::Min: + return "min"; + case Op::Max: + return "max"; + case Op::Percentile: + return "percentile"; + } + throw std::runtime_error("Cannot label aggregation: unhandled Operation variant"); +} + +template +void validate_aggregation_param(Op op, std::optional parameter, const std::string& fn_label) { + const bool needs_param = (op == Op::Percentile); + if (needs_param && !parameter.has_value()) { + throw std::runtime_error("Cannot " + fn_label + ": operation 'percentile' requires a parameter"); + } + if (!needs_param && parameter.has_value()) { + throw std::runtime_error("Cannot " + fn_label + ": operation '" + aggregation_operation_label(op) + + "' does not accept a parameter"); + } + if (needs_param && (*parameter < 0.0 || *parameter > 1.0)) { + throw std::runtime_error("Cannot " + fn_label + ": percentile must be in [0, 1], got " + + std::to_string(*parameter)); + } +} + +inline double compute_percentile(std::vector& values, double fraction) { + if (values.empty()) { + return std::numeric_limits::quiet_NaN(); + } + std::sort(values.begin(), values.end()); + const auto n = values.size(); + if (n == 1) { + return values[0]; + } + const auto pos = static_cast(n - 1) * fraction; + const auto lo = static_cast(std::floor(pos)); + const auto hi = static_cast(std::ceil(pos)); + if (lo == hi) { + return values[lo]; + } + const double frac = pos - static_cast(lo); + return values[lo] * (1.0 - frac) + values[hi] * frac; +} + +} // namespace quiver + +#endif // QUIVER_EXPRESSION_HELPERS_H diff --git a/src/expression/expression_node.cpp b/src/expression/expression_node.cpp deleted file mode 100644 index 14703b95..00000000 --- a/src/expression/expression_node.cpp +++ /dev/null @@ -1,1013 +0,0 @@ -#include "quiver/expression/expression_node.h" - -#include "quiver/binary/binary_file.h" -#include "quiver/binary/iteration.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace quiver { - -ExpressionFile::ExpressionFile(const std::string& path) : meta_(BinaryMetadata::from_toml_file(path)), file_(path) { - dim_map_.reserve(meta_.dimensions.size()); -} - -const BinaryMetadata& ExpressionFile::metadata() const { - return meta_; -} - -void ExpressionFile::compute_row(const std::vector& dims, std::vector& out) const { - for (size_t i = 0; i < meta_.dimensions.size(); ++i) { - dim_map_[meta_.dimensions[i].name] = dims[i]; - } - out = file_.read(dim_map_, /*allow_nulls=*/true); -} - -void ExpressionFile::collect_input_files(std::vector& out) const { - out.push_back(&file_); -} - -ExpressionScalar::ExpressionScalar(double value, BinaryMetadata broadcast_meta) - : value_(value), broadcast_meta_(std::move(broadcast_meta)) {} - -const BinaryMetadata& ExpressionScalar::metadata() const { - return broadcast_meta_; -} - -void ExpressionScalar::compute_row(const std::vector& /*dims*/, std::vector& out) const { - out.assign(broadcast_meta_.labels.size(), value_); -} - -void ExpressionScalar::collect_input_files(std::vector& /*out*/) const {} - -namespace { - -int find_dim_index(const std::vector& dims, const std::string& name) { - for (size_t i = 0; i < dims.size(); ++i) { - if (dims[i].name == name) { - return static_cast(i); - } - } - return -1; -} - -bool any_time_dim(const std::vector& dims) { - for (const auto& d : dims) { - if (d.is_time_dimension()) { - return true; - } - } - return false; -} - -std::string parent_name_of(int64_t parent_idx, const BinaryMetadata& m) { - return (parent_idx >= 0) ? m.dimensions[parent_idx].name : std::string{}; -} - -void validate_unit_match(const BinaryMetadata& lhs, const BinaryMetadata& rhs) { - if (lhs.unit != rhs.unit) { - throw std::runtime_error("Cannot apply: units differ ('" + lhs.unit + "' vs '" + rhs.unit + "')"); - } -} - -void validate_shape_compatibility(const BinaryMetadata& lhs, const BinaryMetadata& rhs) { - for (const auto& l_dim : lhs.dimensions) { - auto r_idx = find_dim_index(rhs.dimensions, l_dim.name); - if (r_idx < 0) { - continue; - } - - auto l_size = l_dim.size; - auto r_size = rhs.dimensions[r_idx].size; - - if (l_size == r_size) { - continue; - } - - if (l_size == 1 || r_size == 1) { - continue; - } - - throw std::runtime_error("Cannot apply: dimension '" + l_dim.name + "' has incompatible sizes " + - std::to_string(l_size) + " vs " + std::to_string(r_size) + - " (broadcasting requires n x n, 1 x n, or n x 1)"); - } - - for (const auto& l_dim : lhs.dimensions) { - auto r_idx = find_dim_index(rhs.dimensions, l_dim.name); - if (r_idx < 0) { - continue; - } - - const auto& r_dim = rhs.dimensions[r_idx]; - const auto l_time = l_dim.is_time_dimension(); - const auto r_time = r_dim.is_time_dimension(); - if (l_time != r_time) { - const std::string time_side = l_time ? "lhs" : "rhs"; - const std::string nontime_side = l_time ? "rhs" : "lhs"; - throw std::runtime_error("Cannot apply: dimension '" + l_dim.name + "' is a time dimension on " + - time_side + " but not on " + nontime_side); - } - - if (!l_time) { - continue; - } - - const auto& lp = *l_dim.time; - const auto& rp = *r_dim.time; - const auto l_parent = parent_name_of(lp.parent_dimension_index, lhs); - const auto r_parent = parent_name_of(rp.parent_dimension_index, rhs); - if (lp.frequency != rp.frequency || lp.initial_value != rp.initial_value || l_parent != r_parent) { - throw std::runtime_error("Cannot apply: time dimension '" + l_dim.name + - "' has incompatible TimeProperties"); - } - } - - const auto lhs_has_time = any_time_dim(lhs.dimensions); - const auto rhs_has_time = any_time_dim(rhs.dimensions); - if (lhs_has_time && rhs_has_time && lhs.initial_datetime != rhs.initial_datetime) { - throw std::runtime_error("Cannot apply: initial_datetime differs"); - } -} - -void validate_compatibility(const BinaryMetadata& lhs, const BinaryMetadata& rhs) { - validate_unit_match(lhs, rhs); - validate_shape_compatibility(lhs, rhs); -} - -std::vector compute_output_labels(const std::vector& l_labels, - const std::vector& r_labels) { - const auto ll = l_labels.size(); - const auto rl = r_labels.size(); - if (ll == rl) { - if (l_labels != r_labels) { - throw std::runtime_error("Cannot apply: labels have same size " + std::to_string(ll) + - " but different content"); - } - return l_labels; - } - if (ll == 1 && rl > 1) { - return r_labels; - } - if (rl == 1 && ll > 1) { - return l_labels; - } - throw std::runtime_error("Cannot apply: labels have incompatible sizes " + std::to_string(ll) + " vs " + - std::to_string(rl)); -} - -BinaryMetadata -build_broadcast_metadata(const BinaryMetadata& lhs, const BinaryMetadata& rhs, std::vector output_labels) { - BinaryMetadata out; - out.version = lhs.version; - out.unit = lhs.unit; - out.labels = std::move(output_labels); - - const auto lhs_has_time = any_time_dim(lhs.dimensions); - const auto rhs_has_time = any_time_dim(rhs.dimensions); - out.initial_datetime = - lhs_has_time ? lhs.initial_datetime : (rhs_has_time ? rhs.initial_datetime : lhs.initial_datetime); - - std::unordered_map output_index_by_name; - for (const auto& l_dim : lhs.dimensions) { - auto r_idx = find_dim_index(rhs.dimensions, l_dim.name); - int64_t out_size = (r_idx >= 0) ? std::max(l_dim.size, rhs.dimensions[r_idx].size) : l_dim.size; - Dimension d{l_dim.name, out_size, l_dim.time}; // copy time props from lhs (rhs equal) - out.dimensions.push_back(std::move(d)); - output_index_by_name[l_dim.name] = static_cast(out.dimensions.size()) - 1; - } - for (const auto& r_dim : rhs.dimensions) { - if (output_index_by_name.count(r_dim.name)) - continue; // already placed - out.dimensions.push_back(r_dim); - output_index_by_name[r_dim.name] = static_cast(out.dimensions.size()) - 1; - } - for (auto& out_d : out.dimensions) { - if (!out_d.is_time_dimension()) - continue; - auto src_idx = find_dim_index(lhs.dimensions, out_d.name); - const auto* src_meta = &lhs; - if (src_idx < 0) { - src_idx = find_dim_index(rhs.dimensions, out_d.name); - src_meta = &rhs; - } - int64_t src_parent_idx = src_meta->dimensions[src_idx].time->parent_dimension_index; - if (src_parent_idx < 0) { - out_d.time->parent_dimension_index = -1; - continue; - } - const std::string& parent_name = src_meta->dimensions[src_parent_idx].name; - out_d.time->parent_dimension_index = output_index_by_name.find(parent_name)->second; - } - for (const auto& d : out.dimensions) { - if (d.is_time_dimension()) { - ++out.number_of_time_dimensions; - } - } - return out; -} - -std::vector compute_ternary_output_labels(const std::vector& c_labels, - const std::vector& t_labels, - const std::vector& e_labels) { - const std::vector*> non_singleton = [&] { - std::vector*> v; - if (c_labels.size() > 1) - v.push_back(&c_labels); - if (t_labels.size() > 1) - v.push_back(&t_labels); - if (e_labels.size() > 1) - v.push_back(&e_labels); - return v; - }(); - - if (non_singleton.empty()) { - return t_labels; - } - - for (size_t i = 1; i < non_singleton.size(); ++i) { - if (*non_singleton[i] != *non_singleton[0]) { - throw std::runtime_error("Cannot apply: labels are incompatible across operands " - "(non-singleton label sets must match)"); - } - } - return *non_singleton[0]; -} - -BinaryMetadata build_ternary_broadcast_metadata(const BinaryMetadata& cond, - const BinaryMetadata& then_meta, - const BinaryMetadata& else_meta, - std::vector output_labels) { - BinaryMetadata out; - out.version = then_meta.version; - out.unit = then_meta.unit; // validated == else_meta.unit; cond.unit is ignored - out.labels = std::move(output_labels); - - const auto then_has_time = any_time_dim(then_meta.dimensions); - const auto else_has_time = any_time_dim(else_meta.dimensions); - const auto cond_has_time = any_time_dim(cond.dimensions); - if (then_has_time) { - out.initial_datetime = then_meta.initial_datetime; - } else if (else_has_time) { - out.initial_datetime = else_meta.initial_datetime; - } else if (cond_has_time) { - out.initial_datetime = cond.initial_datetime; - } else { - out.initial_datetime = then_meta.initial_datetime; - } - - const std::vector sources = {&cond, &then_meta, &else_meta}; - std::unordered_map output_index_by_name; - for (const auto* src : sources) { - for (const auto& dim : src->dimensions) { - if (output_index_by_name.count(dim.name)) - continue; - int64_t out_size = dim.size; - for (const auto* other : sources) { - if (other == src) - continue; - auto idx = find_dim_index(other->dimensions, dim.name); - if (idx >= 0) { - out_size = std::max(out_size, other->dimensions[idx].size); - } - } - Dimension d{dim.name, out_size, dim.time}; - out.dimensions.push_back(std::move(d)); - output_index_by_name[dim.name] = static_cast(out.dimensions.size()) - 1; - } - } - - for (auto& out_d : out.dimensions) { - if (!out_d.is_time_dimension()) - continue; - const BinaryMetadata* src_meta = nullptr; - int src_idx = -1; - for (const auto* s : sources) { - src_idx = find_dim_index(s->dimensions, out_d.name); - if (src_idx >= 0) { - src_meta = s; - break; - } - } - int64_t src_parent_idx = src_meta->dimensions[src_idx].time->parent_dimension_index; - if (src_parent_idx < 0) { - out_d.time->parent_dimension_index = -1; - continue; - } - const std::string& parent_name = src_meta->dimensions[src_parent_idx].name; - out_d.time->parent_dimension_index = output_index_by_name.find(parent_name)->second; - } - - for (const auto& d : out.dimensions) { - if (d.is_time_dimension()) { - ++out.number_of_time_dimensions; - } - } - return out; -} - -template -Op parse_aggregation_operation_name(const std::string& name, const std::string& fn_label) { - if (name == "sum") - return Op::Sum; - if (name == "mean") - return Op::Mean; - if (name == "min") - return Op::Min; - if (name == "max") - return Op::Max; - if (name == "percentile") - return Op::Percentile; - throw std::runtime_error("Cannot " + fn_label + ": unknown operation '" + name + - "' (expected one of: sum, mean, min, max, percentile)"); -} - -template -std::string aggregation_operation_label(Op op) { - switch (op) { - case Op::Sum: - return "sum"; - case Op::Mean: - return "mean"; - case Op::Min: - return "min"; - case Op::Max: - return "max"; - case Op::Percentile: - return "percentile"; - } - throw std::runtime_error("Cannot label aggregation: unhandled Operation variant"); -} - -template -void validate_aggregation_param(Op op, std::optional parameter, const std::string& fn_label) { - const bool needs_param = (op == Op::Percentile); - if (needs_param && !parameter.has_value()) { - throw std::runtime_error("Cannot " + fn_label + ": operation 'percentile' requires a parameter"); - } - if (!needs_param && parameter.has_value()) { - throw std::runtime_error("Cannot " + fn_label + ": operation '" + aggregation_operation_label(op) + - "' does not accept a parameter"); - } - if (needs_param && (*parameter < 0.0 || *parameter > 1.0)) { - throw std::runtime_error("Cannot " + fn_label + ": percentile must be in [0, 1], got " + - std::to_string(*parameter)); - } -} - -double compute_percentile(std::vector& values, double fraction) { - if (values.empty()) { - return std::numeric_limits::quiet_NaN(); - } - std::sort(values.begin(), values.end()); - const auto n = values.size(); - if (n == 1) { - return values[0]; - } - const auto pos = static_cast(n - 1) * fraction; - const auto lo = static_cast(std::floor(pos)); - const auto hi = static_cast(std::ceil(pos)); - if (lo == hi) { - return values[lo]; - } - const double frac = pos - static_cast(lo); - return values[lo] * (1.0 - frac) + values[hi] * frac; -} - -} // namespace - -double ExpressionBinary::apply(Operation operation, double lhs, double rhs) { - switch (operation) { - case Operation::Add: - return lhs + rhs; - case Operation::Subtract: - return lhs - rhs; - case Operation::Multiply: - return lhs * rhs; - case Operation::Divide: - return lhs / rhs; - } - throw std::runtime_error("Cannot apply: unhandled ExpressionBinary::Operation variant"); -} - -ExpressionBinary::ExpressionBinary(Operation operation, - std::shared_ptr lhs, - std::shared_ptr rhs) - : operation_(operation), lhs_(std::move(lhs)), rhs_(std::move(rhs)) { - const auto& lhs_meta = lhs_->metadata(); - const auto& rhs_meta = rhs_->metadata(); - - validate_compatibility(lhs_meta, rhs_meta); - auto output_labels = compute_output_labels(lhs_meta.labels, rhs_meta.labels); - broadcast_meta_ = build_broadcast_metadata(lhs_meta, rhs_meta, std::move(output_labels)); - broadcast_meta_.validate(); - - const auto& out_dims = broadcast_meta_.dimensions; - lhs_dim_sizes_.assign(out_dims.size(), 0); - rhs_dim_sizes_.assign(out_dims.size(), 0); - lhs_to_out_.assign(lhs_meta.dimensions.size(), -1); - rhs_to_out_.assign(rhs_meta.dimensions.size(), -1); - - for (size_t out_i = 0; out_i < out_dims.size(); ++out_i) { - const auto li = find_dim_index(lhs_meta.dimensions, out_dims[out_i].name); - const auto ri = find_dim_index(rhs_meta.dimensions, out_dims[out_i].name); - lhs_dim_sizes_[out_i] = (li >= 0) ? lhs_meta.dimensions[li].size : 0; - rhs_dim_sizes_[out_i] = (ri >= 0) ? rhs_meta.dimensions[ri].size : 0; - if (li >= 0) { - lhs_to_out_[li] = static_cast(out_i); - } - if (ri >= 0) { - rhs_to_out_[ri] = static_cast(out_i); - } - } - - lhs_label_count_ = lhs_meta.labels.size(); - rhs_label_count_ = rhs_meta.labels.size(); - - lhs_dims_buf_.resize(lhs_meta.dimensions.size()); - rhs_dims_buf_.resize(rhs_meta.dimensions.size()); - lhs_buf_.resize(lhs_label_count_); - rhs_buf_.resize(rhs_label_count_); -} - -const BinaryMetadata& ExpressionBinary::metadata() const { - return broadcast_meta_; -} - -void ExpressionBinary::compute_row(const std::vector& dims, std::vector& out) const { - const auto out_label_count = broadcast_meta_.labels.size(); - if (out.size() != out_label_count) { - out.resize(out_label_count); - } - - for (size_t li = 0; li < lhs_dims_buf_.size(); ++li) { - const auto out_i = lhs_to_out_[li]; - auto coord = dims[out_i]; - if (lhs_dim_sizes_[out_i] == 1) { - coord = 1; // size-1 broadcast clamp - } - lhs_dims_buf_[li] = coord; - } - for (size_t ri = 0; ri < rhs_dims_buf_.size(); ++ri) { - const auto out_i = rhs_to_out_[ri]; - auto coord = dims[out_i]; - if (rhs_dim_sizes_[out_i] == 1) { - coord = 1; - } - rhs_dims_buf_[ri] = coord; - } - - lhs_->compute_row(lhs_dims_buf_, lhs_buf_); - rhs_->compute_row(rhs_dims_buf_, rhs_buf_); - - for (size_t k = 0; k < out_label_count; ++k) { - const size_t li = (lhs_label_count_ == 1) ? 0 : k; - const size_t ri = (rhs_label_count_ == 1) ? 0 : k; - out[k] = apply(operation_, lhs_buf_[li], rhs_buf_[ri]); - } -} - -void ExpressionBinary::collect_input_files(std::vector& out) const { - lhs_->collect_input_files(out); - rhs_->collect_input_files(out); -} - -double ExpressionUnary::apply(Operation operation, double x) { - switch (operation) { - case Operation::Negate: - return -x; - case Operation::Abs: - return std::abs(x); - case Operation::Sqrt: - return std::sqrt(x); - case Operation::Log: - return std::log(x); - case Operation::Exp: - return std::exp(x); - } - throw std::runtime_error("Cannot apply: unhandled ExpressionUnary::Operation variant"); -} - -ExpressionUnary::ExpressionUnary(Operation operation, std::shared_ptr operand) - : operation_(operation), operand_(std::move(operand)) { - operand_row_buf_.resize(operand_->metadata().labels.size()); -} - -const BinaryMetadata& ExpressionUnary::metadata() const { - return operand_->metadata(); -} - -void ExpressionUnary::compute_row(const std::vector& dims, std::vector& out) const { - const auto n = operand_row_buf_.size(); - if (out.size() != n) { - out.resize(n); - } - operand_->compute_row(dims, operand_row_buf_); - for (size_t k = 0; k < n; ++k) { - out[k] = apply(operation_, operand_row_buf_[k]); - } -} - -void ExpressionUnary::collect_input_files(std::vector& out) const { - operand_->collect_input_files(out); -} - -double ExpressionTernary::apply(Operation operation, double condition, double then_value, double else_value) { - switch (operation) { - case Operation::IfElse: - if (std::isnan(condition)) { - return std::numeric_limits::quiet_NaN(); - } - return (condition != 0.0) ? then_value : else_value; - } - throw std::runtime_error("Cannot apply: unhandled ExpressionTernary::Operation variant"); -} - -ExpressionTernary::ExpressionTernary(Operation operation, - std::shared_ptr condition, - std::shared_ptr then_value, - std::shared_ptr else_value) - : operation_(operation), condition_(std::move(condition)), then_value_(std::move(then_value)), - else_value_(std::move(else_value)) { - const auto& condition_meta = condition_->metadata(); - const auto& then_meta = then_value_->metadata(); - const auto& else_meta = else_value_->metadata(); - - validate_unit_match(then_meta, else_meta); - validate_shape_compatibility(condition_meta, then_meta); - validate_shape_compatibility(then_meta, else_meta); - validate_shape_compatibility(condition_meta, else_meta); - - auto output_labels = compute_ternary_output_labels(condition_meta.labels, then_meta.labels, else_meta.labels); - broadcast_meta_ = build_ternary_broadcast_metadata(condition_meta, then_meta, else_meta, std::move(output_labels)); - broadcast_meta_.validate(); - - const auto& out_dims = broadcast_meta_.dimensions; - condition_dim_sizes_.assign(out_dims.size(), 0); - then_dim_sizes_.assign(out_dims.size(), 0); - else_dim_sizes_.assign(out_dims.size(), 0); - condition_to_out_.assign(condition_meta.dimensions.size(), -1); - then_to_out_.assign(then_meta.dimensions.size(), -1); - else_to_out_.assign(else_meta.dimensions.size(), -1); - - for (size_t out_i = 0; out_i < out_dims.size(); ++out_i) { - const auto ci = find_dim_index(condition_meta.dimensions, out_dims[out_i].name); - const auto ti = find_dim_index(then_meta.dimensions, out_dims[out_i].name); - const auto ei = find_dim_index(else_meta.dimensions, out_dims[out_i].name); - condition_dim_sizes_[out_i] = (ci >= 0) ? condition_meta.dimensions[ci].size : 0; - then_dim_sizes_[out_i] = (ti >= 0) ? then_meta.dimensions[ti].size : 0; - else_dim_sizes_[out_i] = (ei >= 0) ? else_meta.dimensions[ei].size : 0; - if (ci >= 0) - condition_to_out_[ci] = static_cast(out_i); - if (ti >= 0) - then_to_out_[ti] = static_cast(out_i); - if (ei >= 0) - else_to_out_[ei] = static_cast(out_i); - } - - condition_label_count_ = condition_meta.labels.size(); - then_label_count_ = then_meta.labels.size(); - else_label_count_ = else_meta.labels.size(); - - condition_dims_buf_.resize(condition_meta.dimensions.size()); - then_dims_buf_.resize(then_meta.dimensions.size()); - else_dims_buf_.resize(else_meta.dimensions.size()); - condition_buf_.resize(condition_label_count_); - then_buf_.resize(then_label_count_); - else_buf_.resize(else_label_count_); -} - -const BinaryMetadata& ExpressionTernary::metadata() const { - return broadcast_meta_; -} - -void ExpressionTernary::compute_row(const std::vector& dims, std::vector& out) const { - const auto out_label_count = broadcast_meta_.labels.size(); - if (out.size() != out_label_count) { - out.resize(out_label_count); - } - - for (size_t ci = 0; ci < condition_dims_buf_.size(); ++ci) { - const auto out_i = condition_to_out_[ci]; - auto coord = dims[out_i]; - if (condition_dim_sizes_[out_i] == 1) { - coord = 1; - } - condition_dims_buf_[ci] = coord; - } - for (size_t ti = 0; ti < then_dims_buf_.size(); ++ti) { - const auto out_i = then_to_out_[ti]; - auto coord = dims[out_i]; - if (then_dim_sizes_[out_i] == 1) { - coord = 1; - } - then_dims_buf_[ti] = coord; - } - for (size_t ei = 0; ei < else_dims_buf_.size(); ++ei) { - const auto out_i = else_to_out_[ei]; - auto coord = dims[out_i]; - if (else_dim_sizes_[out_i] == 1) { - coord = 1; - } - else_dims_buf_[ei] = coord; - } - - condition_->compute_row(condition_dims_buf_, condition_buf_); - then_value_->compute_row(then_dims_buf_, then_buf_); - else_value_->compute_row(else_dims_buf_, else_buf_); - - for (size_t k = 0; k < out_label_count; ++k) { - const size_t ck = (condition_label_count_ == 1) ? 0 : k; - const size_t tk = (then_label_count_ == 1) ? 0 : k; - const size_t ek = (else_label_count_ == 1) ? 0 : k; - out[k] = apply(operation_, condition_buf_[ck], then_buf_[tk], else_buf_[ek]); - } -} - -void ExpressionTernary::collect_input_files(std::vector& out) const { - condition_->collect_input_files(out); - then_value_->collect_input_files(out); - else_value_->collect_input_files(out); -} - -ExpressionAggregate::Operation ExpressionAggregate::parse_operation(const std::string& name) { - return parse_aggregation_operation_name(name, "aggregate"); -} - -ExpressionAggregate::ExpressionAggregate(Operation operation, - std::shared_ptr operand, - std::string dimension_name, - std::optional parameter) - : operation_(operation), operand_(std::move(operand)), dimension_name_(std::move(dimension_name)), - parameter_(parameter) { - validate_aggregation_param(operation_, parameter_, "aggregate"); - - const auto& operand_meta = operand_->metadata(); - - const int reduced_idx = find_dim_index(operand_meta.dimensions, dimension_name_); - if (reduced_idx < 0) { - throw std::runtime_error("Dimension not found: '" + dimension_name_ + "' in operand metadata"); - } - reduced_operand_index_ = reduced_idx; - - const auto& reduced_dim = operand_meta.dimensions[reduced_idx]; - const int64_t grandparent_orig_idx = - reduced_dim.is_time_dimension() ? reduced_dim.time->parent_dimension_index : -1; - - output_meta_ = operand_meta; - output_meta_.dimensions.erase(output_meta_.dimensions.begin() + reduced_idx); - if (reduced_dim.is_time_dimension()) { - --output_meta_.number_of_time_dimensions; - } - - operand_to_out_.assign(operand_meta.dimensions.size(), -1); - for (size_t i = 0; i < operand_meta.dimensions.size(); ++i) { - if (static_cast(i) == reduced_idx) { - continue; - } - operand_to_out_[i] = (static_cast(i) < reduced_idx) ? static_cast(i) : static_cast(i) - 1; - } - - for (size_t out_i = 0; out_i < output_meta_.dimensions.size(); ++out_i) { - auto& out_dim = output_meta_.dimensions[out_i]; - if (!out_dim.is_time_dimension()) { - continue; - } - const int operand_idx = - (static_cast(out_i) < reduced_idx) ? static_cast(out_i) : static_cast(out_i) + 1; - const int64_t orig_parent = operand_meta.dimensions[operand_idx].time->parent_dimension_index; - if (orig_parent < 0) { - out_dim.time->parent_dimension_index = -1; - } else if (orig_parent == reduced_idx) { - if (grandparent_orig_idx < 0) { - out_dim.time->parent_dimension_index = -1; - } else { - out_dim.time->parent_dimension_index = - (grandparent_orig_idx < reduced_idx) ? grandparent_orig_idx : grandparent_orig_idx - 1; - } - } else { - out_dim.time->parent_dimension_index = (orig_parent < reduced_idx) ? orig_parent : orig_parent - 1; - } - } - - output_meta_.validate(); - - operand_dims_buf_.resize(operand_meta.dimensions.size()); - operand_row_buf_.resize(operand_meta.labels.size()); - if (operation_ == Operation::Percentile) { - percentile_scratch_.resize(operand_meta.labels.size()); - } -} - -const BinaryMetadata& ExpressionAggregate::metadata() const { - return output_meta_; -} - -void ExpressionAggregate::compute_row(const std::vector& dims, std::vector& out) const { - const auto label_count = operand_row_buf_.size(); - if (out.size() != label_count) { - out.resize(label_count); - } - - const auto& operand_meta = operand_->metadata(); - - for (size_t i = 0; i < operand_meta.dimensions.size(); ++i) { - if (static_cast(i) == reduced_operand_index_) { - continue; - } - operand_dims_buf_[i] = dims[operand_to_out_[i]]; - } - operand_dims_buf_[reduced_operand_index_] = 1; - - const auto& reduced_dim = operand_meta.dimensions[reduced_operand_index_]; - int64_t start = 1; - int64_t end = reduced_dim.size; - if (reduced_dim.is_time_dimension()) { - const auto& tp = *reduced_dim.time; - const int64_t parent_idx = tp.parent_dimension_index; - const auto sizes = dimension_sizes_at_values(operand_meta, operand_dims_buf_); - end = sizes[reduced_operand_index_]; - if (parent_idx < 0) { - start = tp.initial_value; - } else { - const auto& parent_dim = operand_meta.dimensions[parent_idx]; - const int64_t parent_initial = parent_dim.is_time_dimension() ? parent_dim.time->initial_value : 1; - start = (operand_dims_buf_[parent_idx] == parent_initial) ? tp.initial_value : 1; - } - } - - std::vector sum_buf(label_count, 0.0); - std::vector count_buf(label_count, 0); - std::vector min_buf(label_count, std::numeric_limits::infinity()); - std::vector max_buf(label_count, -std::numeric_limits::infinity()); - - if (operation_ == Operation::Percentile) { - for (auto& scratch : percentile_scratch_) { - scratch.clear(); - } - } - - for (int64_t v = start; v <= end; ++v) { - operand_dims_buf_[reduced_operand_index_] = v; - operand_->compute_row(operand_dims_buf_, operand_row_buf_); - - for (size_t k = 0; k < label_count; ++k) { - const double value = operand_row_buf_[k]; - if (std::isnan(value)) { - continue; - } - switch (operation_) { - case Operation::Sum: - case Operation::Mean: - sum_buf[k] += value; - ++count_buf[k]; - break; - case Operation::Min: - if (value < min_buf[k]) { - min_buf[k] = value; - } - ++count_buf[k]; - break; - case Operation::Max: - if (value > max_buf[k]) { - max_buf[k] = value; - } - ++count_buf[k]; - break; - case Operation::Percentile: - percentile_scratch_[k].push_back(value); - break; - } - } - } - - const double nan_value = std::numeric_limits::quiet_NaN(); - for (size_t k = 0; k < label_count; ++k) { - switch (operation_) { - case Operation::Sum: - out[k] = (count_buf[k] > 0) ? sum_buf[k] : nan_value; - break; - case Operation::Mean: - out[k] = (count_buf[k] > 0) ? sum_buf[k] / static_cast(count_buf[k]) : nan_value; - break; - case Operation::Min: - out[k] = (count_buf[k] > 0) ? min_buf[k] : nan_value; - break; - case Operation::Max: - out[k] = (count_buf[k] > 0) ? max_buf[k] : nan_value; - break; - case Operation::Percentile: - out[k] = compute_percentile(percentile_scratch_[k], *parameter_); - break; - } - } -} - -void ExpressionAggregate::collect_input_files(std::vector& out) const { - operand_->collect_input_files(out); -} - -ExpressionAggregateAgents::Operation ExpressionAggregateAgents::parse_operation(const std::string& name) { - return parse_aggregation_operation_name(name, "aggregate_agents"); -} - -ExpressionAggregateAgents::ExpressionAggregateAgents(Operation operation, - std::shared_ptr operand, - std::optional parameter) - : operation_(operation), operand_(std::move(operand)), parameter_(parameter) { - validate_aggregation_param(operation_, parameter_, "aggregate_agents"); - - const auto& operand_meta = operand_->metadata(); - output_meta_ = operand_meta; - output_meta_.labels = {aggregation_operation_label(operation_)}; - output_meta_.validate(); - - operand_row_buf_.resize(operand_meta.labels.size()); -} - -const BinaryMetadata& ExpressionAggregateAgents::metadata() const { - return output_meta_; -} - -void ExpressionAggregateAgents::compute_row(const std::vector& dims, std::vector& out) const { - if (out.size() != 1) { - out.resize(1); - } - - operand_->compute_row(dims, operand_row_buf_); - - const double nan_value = std::numeric_limits::quiet_NaN(); - - switch (operation_) { - case Operation::Sum: { - double sum = 0.0; - int64_t count = 0; - for (double v : operand_row_buf_) { - if (std::isnan(v)) { - continue; - } - sum += v; - ++count; - } - out[0] = (count > 0) ? sum : nan_value; - break; - } - case Operation::Mean: { - double sum = 0.0; - int64_t count = 0; - for (double v : operand_row_buf_) { - if (std::isnan(v)) { - continue; - } - sum += v; - ++count; - } - out[0] = (count > 0) ? sum / static_cast(count) : nan_value; - break; - } - case Operation::Min: { - double m = std::numeric_limits::infinity(); - int64_t count = 0; - for (double v : operand_row_buf_) { - if (std::isnan(v)) { - continue; - } - if (v < m) { - m = v; - } - ++count; - } - out[0] = (count > 0) ? m : nan_value; - break; - } - case Operation::Max: { - double m = -std::numeric_limits::infinity(); - int64_t count = 0; - for (double v : operand_row_buf_) { - if (std::isnan(v)) { - continue; - } - if (v > m) { - m = v; - } - ++count; - } - out[0] = (count > 0) ? m : nan_value; - break; - } - case Operation::Percentile: { - std::vector scratch; - scratch.reserve(operand_row_buf_.size()); - for (double v : operand_row_buf_) { - if (!std::isnan(v)) { - scratch.push_back(v); - } - } - out[0] = compute_percentile(scratch, *parameter_); - break; - } - } -} - -void ExpressionAggregateAgents::collect_input_files(std::vector& out) const { - operand_->collect_input_files(out); -} - -ExpressionSelectAgents::ExpressionSelectAgents(std::shared_ptr operand, std::vector labels) - : operand_(std::move(operand)) { - const auto& operand_meta = operand_->metadata(); - - std::unordered_map label_to_index; - label_to_index.reserve(operand_meta.labels.size()); - for (size_t i = 0; i < operand_meta.labels.size(); ++i) { - label_to_index.emplace(operand_meta.labels[i], i); - } - - selected_indices_.reserve(labels.size()); - for (const auto& label : labels) { - auto it = label_to_index.find(label); - if (it == label_to_index.end()) { - throw std::runtime_error("Cannot select_agents: label not found: '" + label + "'"); - } - selected_indices_.push_back(it->second); - } - - output_meta_ = operand_meta; - output_meta_.labels = std::move(labels); - output_meta_.validate(); - - operand_row_buf_.resize(operand_meta.labels.size()); -} - -const BinaryMetadata& ExpressionSelectAgents::metadata() const { - return output_meta_; -} - -void ExpressionSelectAgents::compute_row(const std::vector& dims, std::vector& out) const { - operand_->compute_row(dims, operand_row_buf_); - if (out.size() != selected_indices_.size()) { - out.resize(selected_indices_.size()); - } - for (size_t i = 0; i < selected_indices_.size(); ++i) { - out[i] = operand_row_buf_[selected_indices_[i]]; - } -} - -void ExpressionSelectAgents::collect_input_files(std::vector& out) const { - operand_->collect_input_files(out); -} - -ExpressionRenameAgents::ExpressionRenameAgents(std::shared_ptr operand, - std::vector> mapping) - : operand_(std::move(operand)) { - const auto& operand_meta = operand_->metadata(); - - std::unordered_map rename_map; - std::unordered_map used; - rename_map.reserve(mapping.size()); - used.reserve(mapping.size()); - for (auto& entry : mapping) { - if (!rename_map.emplace(entry.first, std::move(entry.second)).second) { - throw std::runtime_error("Cannot rename_agents: duplicate key '" + entry.first + "'"); - } - used.emplace(entry.first, false); - } - - output_meta_ = operand_meta; - for (auto& label : output_meta_.labels) { - auto it = rename_map.find(label); - if (it != rename_map.end()) { - label = it->second; - used[it->first] = true; - } - } - - for (const auto& entry : used) { - if (!entry.second) { - throw std::runtime_error("Cannot rename_agents: label not found: '" + entry.first + "'"); - } - } - - output_meta_.validate(); -} - -const BinaryMetadata& ExpressionRenameAgents::metadata() const { - return output_meta_; -} - -void ExpressionRenameAgents::compute_row(const std::vector& dims, std::vector& out) const { - operand_->compute_row(dims, out); -} - -void ExpressionRenameAgents::collect_input_files(std::vector& out) const { - operand_->collect_input_files(out); -} - -} // namespace quiver diff --git a/src/expression/expression_rename_agents.cpp b/src/expression/expression_rename_agents.cpp new file mode 100644 index 00000000..aa51e075 --- /dev/null +++ b/src/expression/expression_rename_agents.cpp @@ -0,0 +1,59 @@ +#include "quiver/expression/expression_node.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace quiver { + +ExpressionRenameAgents::ExpressionRenameAgents(std::shared_ptr operand, + std::vector> mapping) + : operand_(std::move(operand)) { + const auto& operand_meta = operand_->metadata(); + + std::unordered_map rename_map; + std::unordered_map used; + rename_map.reserve(mapping.size()); + used.reserve(mapping.size()); + for (auto& entry : mapping) { + if (!rename_map.emplace(entry.first, std::move(entry.second)).second) { + throw std::runtime_error("Cannot rename_agents: duplicate key '" + entry.first + "'"); + } + used.emplace(entry.first, false); + } + + output_meta_ = operand_meta; + for (auto& label : output_meta_.labels) { + auto it = rename_map.find(label); + if (it != rename_map.end()) { + label = it->second; + used[it->first] = true; + } + } + + for (const auto& entry : used) { + if (!entry.second) { + throw std::runtime_error("Cannot rename_agents: label not found: '" + entry.first + "'"); + } + } + + output_meta_.validate(); +} + +const BinaryMetadata& ExpressionRenameAgents::metadata() const { + return output_meta_; +} + +void ExpressionRenameAgents::compute_row(const std::vector& dims, std::vector& out) const { + operand_->compute_row(dims, out); +} + +void ExpressionRenameAgents::collect_input_files(std::vector& out) const { + operand_->collect_input_files(out); +} + +} // namespace quiver diff --git a/src/expression/expression_scalar.cpp b/src/expression/expression_scalar.cpp new file mode 100644 index 00000000..7f7942cb --- /dev/null +++ b/src/expression/expression_scalar.cpp @@ -0,0 +1,22 @@ +#include "quiver/expression/expression_node.h" + +#include +#include +#include + +namespace quiver { + +ExpressionScalar::ExpressionScalar(double value, BinaryMetadata broadcast_meta) + : value_(value), broadcast_meta_(std::move(broadcast_meta)) {} + +const BinaryMetadata& ExpressionScalar::metadata() const { + return broadcast_meta_; +} + +void ExpressionScalar::compute_row(const std::vector& /*dims*/, std::vector& out) const { + out.assign(broadcast_meta_.labels.size(), value_); +} + +void ExpressionScalar::collect_input_files(std::vector& /*out*/) const {} + +} // namespace quiver diff --git a/src/expression/expression_select_agents.cpp b/src/expression/expression_select_agents.cpp new file mode 100644 index 00000000..3938ef4c --- /dev/null +++ b/src/expression/expression_select_agents.cpp @@ -0,0 +1,57 @@ +#include "quiver/expression/expression_node.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace quiver { + +ExpressionSelectAgents::ExpressionSelectAgents(std::shared_ptr operand, std::vector labels) + : operand_(std::move(operand)) { + const auto& operand_meta = operand_->metadata(); + + std::unordered_map label_to_index; + label_to_index.reserve(operand_meta.labels.size()); + for (size_t i = 0; i < operand_meta.labels.size(); ++i) { + label_to_index.emplace(operand_meta.labels[i], i); + } + + selected_indices_.reserve(labels.size()); + for (const auto& label : labels) { + auto it = label_to_index.find(label); + if (it == label_to_index.end()) { + throw std::runtime_error("Cannot select_agents: label not found: '" + label + "'"); + } + selected_indices_.push_back(it->second); + } + + output_meta_ = operand_meta; + output_meta_.labels = std::move(labels); + output_meta_.validate(); + + operand_row_buf_.resize(operand_meta.labels.size()); +} + +const BinaryMetadata& ExpressionSelectAgents::metadata() const { + return output_meta_; +} + +void ExpressionSelectAgents::compute_row(const std::vector& dims, std::vector& out) const { + operand_->compute_row(dims, operand_row_buf_); + if (out.size() != selected_indices_.size()) { + out.resize(selected_indices_.size()); + } + for (size_t i = 0; i < selected_indices_.size(); ++i) { + out[i] = operand_row_buf_[selected_indices_[i]]; + } +} + +void ExpressionSelectAgents::collect_input_files(std::vector& out) const { + operand_->collect_input_files(out); +} + +} // namespace quiver diff --git a/src/expression/expression_ternary.cpp b/src/expression/expression_ternary.cpp new file mode 100644 index 00000000..ce7bf089 --- /dev/null +++ b/src/expression/expression_ternary.cpp @@ -0,0 +1,133 @@ +#include "quiver/expression/expression_node.h" + +#include "expression_helpers.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace quiver { + +double ExpressionTernary::apply(Operation operation, double condition, double then_value, double else_value) { + switch (operation) { + case Operation::IfElse: + if (std::isnan(condition)) { + return std::numeric_limits::quiet_NaN(); + } + return (condition != 0.0) ? then_value : else_value; + } + throw std::runtime_error("Cannot apply: unhandled ExpressionTernary::Operation variant"); +} + +ExpressionTernary::ExpressionTernary(Operation operation, + std::shared_ptr condition, + std::shared_ptr then_value, + std::shared_ptr else_value) + : operation_(operation), condition_(std::move(condition)), then_value_(std::move(then_value)), + else_value_(std::move(else_value)) { + const auto& condition_meta = condition_->metadata(); + const auto& then_meta = then_value_->metadata(); + const auto& else_meta = else_value_->metadata(); + + validate_unit_match(then_meta, else_meta); + validate_shape_compatibility(condition_meta, then_meta); + validate_shape_compatibility(then_meta, else_meta); + validate_shape_compatibility(condition_meta, else_meta); + + auto output_labels = compute_ternary_output_labels(condition_meta.labels, then_meta.labels, else_meta.labels); + broadcast_meta_ = build_ternary_broadcast_metadata(condition_meta, then_meta, else_meta, std::move(output_labels)); + broadcast_meta_.validate(); + + const auto& out_dims = broadcast_meta_.dimensions; + condition_dim_sizes_.assign(out_dims.size(), 0); + then_dim_sizes_.assign(out_dims.size(), 0); + else_dim_sizes_.assign(out_dims.size(), 0); + condition_to_out_.assign(condition_meta.dimensions.size(), -1); + then_to_out_.assign(then_meta.dimensions.size(), -1); + else_to_out_.assign(else_meta.dimensions.size(), -1); + + for (size_t out_i = 0; out_i < out_dims.size(); ++out_i) { + const auto ci = find_dim_index(condition_meta.dimensions, out_dims[out_i].name); + const auto ti = find_dim_index(then_meta.dimensions, out_dims[out_i].name); + const auto ei = find_dim_index(else_meta.dimensions, out_dims[out_i].name); + condition_dim_sizes_[out_i] = (ci >= 0) ? condition_meta.dimensions[ci].size : 0; + then_dim_sizes_[out_i] = (ti >= 0) ? then_meta.dimensions[ti].size : 0; + else_dim_sizes_[out_i] = (ei >= 0) ? else_meta.dimensions[ei].size : 0; + if (ci >= 0) + condition_to_out_[ci] = static_cast(out_i); + if (ti >= 0) + then_to_out_[ti] = static_cast(out_i); + if (ei >= 0) + else_to_out_[ei] = static_cast(out_i); + } + + condition_label_count_ = condition_meta.labels.size(); + then_label_count_ = then_meta.labels.size(); + else_label_count_ = else_meta.labels.size(); + + condition_dims_buf_.resize(condition_meta.dimensions.size()); + then_dims_buf_.resize(then_meta.dimensions.size()); + else_dims_buf_.resize(else_meta.dimensions.size()); + condition_buf_.resize(condition_label_count_); + then_buf_.resize(then_label_count_); + else_buf_.resize(else_label_count_); +} + +const BinaryMetadata& ExpressionTernary::metadata() const { + return broadcast_meta_; +} + +void ExpressionTernary::compute_row(const std::vector& dims, std::vector& out) const { + const auto out_label_count = broadcast_meta_.labels.size(); + if (out.size() != out_label_count) { + out.resize(out_label_count); + } + + for (size_t ci = 0; ci < condition_dims_buf_.size(); ++ci) { + const auto out_i = condition_to_out_[ci]; + auto coord = dims[out_i]; + if (condition_dim_sizes_[out_i] == 1) { + coord = 1; + } + condition_dims_buf_[ci] = coord; + } + for (size_t ti = 0; ti < then_dims_buf_.size(); ++ti) { + const auto out_i = then_to_out_[ti]; + auto coord = dims[out_i]; + if (then_dim_sizes_[out_i] == 1) { + coord = 1; + } + then_dims_buf_[ti] = coord; + } + for (size_t ei = 0; ei < else_dims_buf_.size(); ++ei) { + const auto out_i = else_to_out_[ei]; + auto coord = dims[out_i]; + if (else_dim_sizes_[out_i] == 1) { + coord = 1; + } + else_dims_buf_[ei] = coord; + } + + condition_->compute_row(condition_dims_buf_, condition_buf_); + then_value_->compute_row(then_dims_buf_, then_buf_); + else_value_->compute_row(else_dims_buf_, else_buf_); + + for (size_t k = 0; k < out_label_count; ++k) { + const size_t ck = (condition_label_count_ == 1) ? 0 : k; + const size_t tk = (then_label_count_ == 1) ? 0 : k; + const size_t ek = (else_label_count_ == 1) ? 0 : k; + out[k] = apply(operation_, condition_buf_[ck], then_buf_[tk], else_buf_[ek]); + } +} + +void ExpressionTernary::collect_input_files(std::vector& out) const { + condition_->collect_input_files(out); + then_value_->collect_input_files(out); + else_value_->collect_input_files(out); +} + +} // namespace quiver diff --git a/src/expression/expression_unary.cpp b/src/expression/expression_unary.cpp new file mode 100644 index 00000000..10fffc1e --- /dev/null +++ b/src/expression/expression_unary.cpp @@ -0,0 +1,52 @@ +#include "quiver/expression/expression_node.h" + +#include +#include +#include +#include +#include +#include + +namespace quiver { + +double ExpressionUnary::apply(Operation operation, double x) { + switch (operation) { + case Operation::Negate: + return -x; + case Operation::Abs: + return std::abs(x); + case Operation::Sqrt: + return std::sqrt(x); + case Operation::Log: + return std::log(x); + case Operation::Exp: + return std::exp(x); + } + throw std::runtime_error("Cannot apply: unhandled ExpressionUnary::Operation variant"); +} + +ExpressionUnary::ExpressionUnary(Operation operation, std::shared_ptr operand) + : operation_(operation), operand_(std::move(operand)) { + operand_row_buf_.resize(operand_->metadata().labels.size()); +} + +const BinaryMetadata& ExpressionUnary::metadata() const { + return operand_->metadata(); +} + +void ExpressionUnary::compute_row(const std::vector& dims, std::vector& out) const { + const auto n = operand_row_buf_.size(); + if (out.size() != n) { + out.resize(n); + } + operand_->compute_row(dims, operand_row_buf_); + for (size_t k = 0; k < n; ++k) { + out[k] = apply(operation_, operand_row_buf_[k]); + } +} + +void ExpressionUnary::collect_input_files(std::vector& out) const { + operand_->collect_input_files(out); +} + +} // namespace quiver From c994ef0a0f5891210718e08d700ebdd32f0b3b61 Mon Sep 17 00:00:00 2001 From: raphasampaio Date: Thu, 14 May 2026 07:28:20 -0300 Subject: [PATCH 11/13] refactor: aggregate ops take typed enums instead of strings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace string operation names with typed enums across C++, C API, and Julia binding for Expression::aggregate and aggregate_agents: C++: - Expression::aggregate signature: (string, string, optional) -> (string, ExpressionAggregate::Operation, optional) - Expression::aggregate_agents signature: (string, optional) -> (ExpressionAggregateAgents::Operation, optional) - Delete the static parse_operation methods on both aggregate classes - Delete unused parse_aggregation_operation_name template from expression_helpers.h C API: - Add quiver_expression_aggregate_operation_t and quiver_expression_aggregate_agents_operation_t (5 values each: SUM/MEAN/MIN/MAX/PERCENTILE) - Switch quiver_expression_aggregate and _aggregate_agents to take the new enum types - Add from_c() converters mirroring the existing dispatch() pattern Julia binding: - Tests reference Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_* directly - Regenerate c_api.jl via generator.bat Tests: - Update 20 C++ + 16 C + 22 Julia callsites - Delete three negative tests now compile-time-safe (unknown-op string and null-op pointer) Output labels in ExpressionAggregateAgents (e.g., "sum", "mean") remain strings derived via aggregation_operation_label() — that template is still consumed and kept. Co-Authored-By: Claude Opus 4.7 (1M context) --- CLAUDE.md | 4 +- bindings/julia/src/c_api.jl | 20 ++++- bindings/julia/src/expression.jl | 26 ++++++- bindings/julia/test/test_expression.jl | 54 +++++-------- include/quiver/c/expression/expression.h | 22 +++++- include/quiver/expression/expression.h | 5 +- include/quiver/expression/expression_node.h | 4 - src/c/expression/expression.cpp | 44 +++++++++-- src/expression/expression.cpp | 11 ++- src/expression/expression_aggregate.cpp | 4 - .../expression_aggregate_agents.cpp | 4 - src/expression/expression_helpers.h | 16 ---- tests/test_c_api_expression.cpp | 43 +++++------ tests/test_expression.cpp | 76 ++++++++----------- 14 files changed, 179 insertions(+), 154 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index bd58c29f..9ab54646 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -504,7 +504,7 @@ result.save("output"); // writes output.qvr + output.toml - Constructors: `Expression(const BinaryFile&)` (implicit, enables `bf_a + bf_b`), `Expression(shared_ptr)` - Accessors: `metadata()` - Materialize: `save(path)` — iterates via `first_dimensions`/`next_dimensions`, calls `compute_row()` per cell, writes to a new `.qvr`. Throws if `path` collides (after `weakly_canonical`) with any input file in the DAG. - - Aggregation: `aggregate(dimension, op, [parameter])` collapses a dimension; `aggregate_agents(op, [parameter])` collapses the label axis. `op` is one of `"sum" | "mean" | "min" | "max" | "percentile"` (string tags, validated in C++). `percentile` requires a `parameter` fraction in `[0, 1]`; nullary ops reject `parameter`. + - Aggregation: `aggregate(dimension, op, [parameter])` collapses a dimension; `aggregate_agents(op, [parameter])` collapses the label axis. `op` is the nested enum `ExpressionAggregate::Operation` (for `aggregate`) or `ExpressionAggregateAgents::Operation` (for `aggregate_agents`), each with `Sum / Mean / Min / Max / Percentile`. `Percentile` requires a `parameter` fraction in `[0, 1]`; nullary ops reject `parameter`. - Label-axis projection: `select_agents(labels)` keeps (and may reorder) a chosen subset of operand labels; `rename_agents(mapping)` rewrites labels in place via a partial `{old: new}` map. Both validate eagerly: `select_agents` throws if any requested label is absent; `rename_agents` throws on duplicate keys or unknown keys, and `BinaryMetadata::validate()` rejects renames that produce duplicate output labels. - Operator overloads (12 binary + 1 unary): `+ - * /` × {expr+expr, expr+double, double+expr}, plus unary `-expr`. - Free functions in `quiver::` for unary math: `abs(expr)`, `sqrt(expr)`, `log(expr)`, `exp(expr)`. @@ -521,7 +521,7 @@ result.save("output"); // writes output.qvr + output.toml - `ExpressionSelectAgents`: projects the operand onto a caller-supplied label list. Constructor pre-computes a `selected_indices_` table from operand-label → output-position, copies operand metadata with `labels` replaced, and calls `output_meta_.validate()` (which rejects duplicate output labels). Missing labels throw `"Cannot select_agents: label not found: ''"`. `compute_row` reads the operand row into a reusable buffer and gathers selected columns into `out`. - `ExpressionRenameAgents`: rewrites operand labels via a partial `{old: new}` mapping. Constructor builds a rename map (duplicate keys throw), walks operand labels swapping matched names, verifies every key was used (unmatched keys throw), and calls `output_meta_.validate()` (rejects collisions like `val1→val2` when `val2` already exists). `compute_row` forwards directly to the operand — count and order are unchanged so no per-row reshuffle is needed. - Validation is **eager** at construction for `ExpressionBinary`, `ExpressionTernary`, `ExpressionAggregate`, `ExpressionAggregateAgents`, `ExpressionSelectAgents`, `ExpressionRenameAgents` (units/dim sizes/time-dim properties/label sizes/initial datetimes for binary and ternary; dim existence + op/parameter consistency + output metadata validity for aggregations; label existence + uniqueness for label-axis projections). `ExpressionUnary` has no inputs to cross-validate so its constructor just sizes the row buffer. Computation is **lazy**: no I/O until `save()`. -- The binary-operation enum is nested as `ExpressionBinary::Operation`; the unary-operation enum is nested as `ExpressionUnary::Operation`; the ternary-operation enum is nested as `ExpressionTernary::Operation`. Aggregation operations are parsed via static `parse_operation(string)` methods on each aggregation class. Label-axis projection nodes (`ExpressionSelectAgents`, `ExpressionRenameAgents`) have no operation enum — their behavior is fully specified by the label list / rename map. The C API surface keeps its own stable enums `quiver_expression_operation_t`, `quiver_expression_unary_operation_t`, and `quiver_expression_ternary_operation_t`. +- All operation enums are nested in their owning class: `ExpressionBinary::Operation`, `ExpressionUnary::Operation`, `ExpressionTernary::Operation`, `ExpressionAggregate::Operation`, `ExpressionAggregateAgents::Operation`. The two aggregation enums are parallel types with identical values (`Sum / Mean / Min / Max / Percentile`). Label-axis projection nodes (`ExpressionSelectAgents`, `ExpressionRenameAgents`) have no operation enum — their behavior is fully specified by the label list / rename map. The C API mirrors this with five parallel enums: `quiver_expression_operation_t`, `quiver_expression_unary_operation_t`, `quiver_expression_ternary_operation_t`, `quiver_expression_aggregate_operation_t`, `quiver_expression_aggregate_agents_operation_t`. ### LuaRunner Class Executes Lua scripts with database access: diff --git a/bindings/julia/src/c_api.jl b/bindings/julia/src/c_api.jl index d641beb0..527df240 100644 --- a/bindings/julia/src/c_api.jl +++ b/bindings/julia/src/c_api.jl @@ -658,6 +658,22 @@ end QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE = 0 end +@cenum quiver_expression_aggregate_operation_t::UInt32 begin + QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM = 0 + QUIVER_EXPRESSION_AGGREGATE_OPERATION_MEAN = 1 + QUIVER_EXPRESSION_AGGREGATE_OPERATION_MIN = 2 + QUIVER_EXPRESSION_AGGREGATE_OPERATION_MAX = 3 + QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE = 4 +end + +@cenum quiver_expression_aggregate_agents_operation_t::UInt32 begin + QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_SUM = 0 + QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_MEAN = 1 + QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_MIN = 2 + QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_MAX = 3 + QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_PERCENTILE = 4 +end + function quiver_expression_from_file(file, out) @ccall libquiver_c.quiver_expression_from_file(file::Ptr{quiver_binary_file_t}, out::Ptr{Ptr{quiver_expression_t}})::quiver_error_t end @@ -695,11 +711,11 @@ function quiver_expression_get_metadata(expression, out) end function quiver_expression_aggregate(expression, dimension, operation, parameter, out) - @ccall libquiver_c.quiver_expression_aggregate(expression::Ptr{quiver_expression_t}, dimension::Ptr{Cchar}, operation::Ptr{Cchar}, parameter::Ptr{Cdouble}, out::Ptr{Ptr{quiver_expression_t}})::quiver_error_t + @ccall libquiver_c.quiver_expression_aggregate(expression::Ptr{quiver_expression_t}, dimension::Ptr{Cchar}, operation::quiver_expression_aggregate_operation_t, parameter::Ptr{Cdouble}, out::Ptr{Ptr{quiver_expression_t}})::quiver_error_t end function quiver_expression_aggregate_agents(expression, operation, parameter, out) - @ccall libquiver_c.quiver_expression_aggregate_agents(expression::Ptr{quiver_expression_t}, operation::Ptr{Cchar}, parameter::Ptr{Cdouble}, out::Ptr{Ptr{quiver_expression_t}})::quiver_error_t + @ccall libquiver_c.quiver_expression_aggregate_agents(expression::Ptr{quiver_expression_t}, operation::quiver_expression_aggregate_agents_operation_t, parameter::Ptr{Cdouble}, out::Ptr{Ptr{quiver_expression_t}})::quiver_error_t end function quiver_expression_select_agents(expression, labels, label_count, out) diff --git a/bindings/julia/src/expression.jl b/bindings/julia/src/expression.jl index feb1ddf9..e0599638 100644 --- a/bindings/julia/src/expression.jl +++ b/bindings/julia/src/expression.jl @@ -133,7 +133,12 @@ function get_metadata(e::Expression) return Binary.Metadata(out[]) end -function aggregate(e::Expression, dimension::String, operation::String, parameter::Optional{Real} = nothing) +function aggregate( + e::Expression, + dimension::String, + operation::C.quiver_expression_aggregate_operation_t, + parameter::Optional{Real} = nothing, +) out = Ref{Ptr{C.quiver_expression}}(C_NULL) if parameter === nothing check(C.quiver_expression_aggregate(e.ptr, dimension, operation, C_NULL, out)) @@ -146,7 +151,11 @@ function aggregate(e::Expression, dimension::String, operation::String, paramete return Expression(out[]) end -function aggregate_agents(e::Expression, operation::String, parameter::Optional{Real} = nothing) +function aggregate_agents( + e::Expression, + operation::C.quiver_expression_aggregate_agents_operation_t, + parameter::Optional{Real} = nothing, +) out = Ref{Ptr{C.quiver_expression}}(C_NULL) if parameter === nothing check(C.quiver_expression_aggregate_agents(e.ptr, operation, C_NULL, out)) @@ -159,11 +168,20 @@ function aggregate_agents(e::Expression, operation::String, parameter::Optional{ return Expression(out[]) end -function aggregate(f::Binary.File, dimension::String, operation::String, parameter::Optional{Real} = nothing) +function aggregate( + f::Binary.File, + dimension::String, + operation::C.quiver_expression_aggregate_operation_t, + parameter::Optional{Real} = nothing, +) return aggregate(Expression(f), dimension, operation, parameter) end -function aggregate_agents(f::Binary.File, operation::String, parameter::Optional{Real} = nothing) +function aggregate_agents( + f::Binary.File, + operation::C.quiver_expression_aggregate_agents_operation_t, + parameter::Optional{Real} = nothing, +) return aggregate_agents(Expression(f), operation, parameter) end diff --git a/bindings/julia/test/test_expression.jl b/bindings/julia/test/test_expression.jl index b8b7f2ec..4d6070b6 100644 --- a/bindings/julia/test/test_expression.jl +++ b/bindings/julia/test/test_expression.jl @@ -968,7 +968,7 @@ end try write_fixture(path_a, (r, c, k) -> r * 10 + c + k) with_expr(path_a) do e - out = Quiver.aggregate(e, "row", "sum") + out = Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM) Quiver.save(out, path_out) return Quiver.close!(out) end @@ -987,7 +987,7 @@ end try write_fixture(path_a, (r, c, k) -> r * 10 + c + k) with_expr(path_a) do e - out = Quiver.aggregate(e, "row", "mean") + out = Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_MEAN) Quiver.save(out, path_out) return Quiver.close!(out) end @@ -1004,8 +1004,8 @@ end try write_fixture(path_a, (r, c, k) -> r * 10 + c + k) with_expr(path_a) do e - out_min = Quiver.aggregate(e, "row", "min") - out_max = Quiver.aggregate(e, "row", "max") + out_min = Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_MIN) + out_max = Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_MAX) Quiver.save(out_min, path_out) Quiver.save(out_max, path_out2) Quiver.close!(out_min) @@ -1024,7 +1024,7 @@ end try write_fixture(path_a, (r, c, k) -> r * 10 + c + k) with_expr(path_a) do e - out = Quiver.aggregate(e, "row", "percentile", 0.5) + out = Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE, 0.5) Quiver.save(out, path_out) return Quiver.close!(out) end @@ -1041,7 +1041,7 @@ end try write_fixture(path_a, (r, c, k) -> r == 2 ? NaN : Float64(r * 10 + c + k)) with_expr(path_a) do e - out = Quiver.aggregate(e, "row", "sum") + out = Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM) Quiver.save(out, path_out) return Quiver.close!(out) end @@ -1064,7 +1064,7 @@ end ) write_dense(path_a, md, [:row, :col, :depth], [3, 2, 2], 1, (_, _) -> 2.0) with_expr(path_a) do e - out = Quiver.aggregate(Quiver.aggregate(e, "row", "sum"), "col", "sum") + out = Quiver.aggregate(Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM), "col", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM) Quiver.save(out, path_out) return Quiver.close!(out) end @@ -1083,7 +1083,7 @@ end write_fixture(path_b, (r, c, k) -> r + c + k) with_expr(path_a) do a with_expr(path_b) do b - out = Quiver.aggregate(a + b, "row", "sum") + out = Quiver.aggregate(a + b, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM) Quiver.save(out, path_out) return Quiver.close!(out) end @@ -1101,19 +1101,7 @@ end try write_fixture(path_a, (_, _, _) -> 1.0) with_expr(path_a) do e - @test_throws Quiver.DatabaseException Quiver.aggregate(e, "nonexistent", "sum") - end - finally - cleanup(path_a) - end - end - - @testset "Aggregate unknown operation throws" begin - path_a = make_path("a") - try - write_fixture(path_a, (_, _, _) -> 1.0) - with_expr(path_a) do e - @test_throws Quiver.DatabaseException Quiver.aggregate(e, "row", "average") + @test_throws Quiver.DatabaseException Quiver.aggregate(e, "nonexistent", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM) end finally cleanup(path_a) @@ -1125,7 +1113,7 @@ end try write_fixture(path_a, (_, _, _) -> 1.0) with_expr(path_a) do e - @test_throws Quiver.DatabaseException Quiver.aggregate(e, "row", "percentile") + @test_throws Quiver.DatabaseException Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE) end finally cleanup(path_a) @@ -1137,7 +1125,7 @@ end try write_fixture(path_a, (_, _, _) -> 1.0) with_expr(path_a) do e - @test_throws Quiver.DatabaseException Quiver.aggregate(e, "row", "sum", 0.5) + @test_throws Quiver.DatabaseException Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, 0.5) end finally cleanup(path_a) @@ -1149,8 +1137,8 @@ end try write_fixture(path_a, (_, _, _) -> 1.0) with_expr(path_a) do e - @test_throws Quiver.DatabaseException Quiver.aggregate(e, "row", "percentile", 1.5) - @test_throws Quiver.DatabaseException Quiver.aggregate(e, "row", "percentile", -0.1) + @test_throws Quiver.DatabaseException Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE, 1.5) + @test_throws Quiver.DatabaseException Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE, -0.1) end finally cleanup(path_a) @@ -1166,7 +1154,7 @@ end try write_fixture(path_a, (r, c, k) -> r * 10 + c + k) with_expr(path_a) do e - out = Quiver.aggregate_agents(e, "sum") + out = Quiver.aggregate_agents(e, Quiver.C.QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_SUM) md = Quiver.get_metadata(out) @test Quiver.Binary.get_labels(md) == ["sum"] Quiver.save(out, path_out) @@ -1185,7 +1173,7 @@ end try write_fixture(path_a, (r, c, k) -> r * 10 + c + k) with_expr(path_a) do e - out = Quiver.aggregate_agents(e, "mean") + out = Quiver.aggregate_agents(e, Quiver.C.QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_MEAN) Quiver.save(out, path_out) return Quiver.close!(out) end @@ -1202,7 +1190,7 @@ end try write_fixture(path_a, (r, c, k) -> r * 10 + c + k) with_expr(path_a) do e - out = Quiver.aggregate_agents(e, "percentile", 0.5) + out = Quiver.aggregate_agents(e, Quiver.C.QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_PERCENTILE, 0.5) Quiver.save(out, path_out) return Quiver.close!(out) end @@ -1219,7 +1207,7 @@ end # Mark label k=1 as NaN; sum should fall back to the other label. write_fixture(path_a, (r, c, k) -> k == 1 ? NaN : Float64(r * 10 + c + k)) with_expr(path_a) do e - out = Quiver.aggregate_agents(e, "sum") + out = Quiver.aggregate_agents(e, Quiver.C.QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_SUM) Quiver.save(out, path_out) return Quiver.close!(out) end @@ -1236,7 +1224,7 @@ end try write_fixture(path_a, (_, _, _) -> 1.0) with_expr(path_a) do e - out = Quiver.aggregate_agents(e, "mean") + out = Quiver.aggregate_agents(e, Quiver.C.QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_MEAN) md = Quiver.get_metadata(out) @test Quiver.Binary.get_labels(md) == ["mean"] @test Quiver.Binary.get_unit(md) == "MW" @@ -1256,7 +1244,7 @@ end try write_fixture(path_a, (_, _, _) -> 3.0) with_expr(path_a) do e - out = Quiver.aggregate_agents(Quiver.aggregate(e, "row", "sum"), "mean") + out = Quiver.aggregate_agents(Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM), Quiver.C.QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_MEAN) Quiver.save(out, path_out) return Quiver.close!(out) end @@ -1274,7 +1262,7 @@ end write_fixture(path_a, (r, c, k) -> r * 10 + c + k) file = Quiver.Binary.open_file(path_a; mode = 'r') try - out = Quiver.aggregate(file, "row", "sum") + out = Quiver.aggregate(file, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM) Quiver.save(out, path_out) Quiver.close!(out) finally @@ -1293,7 +1281,7 @@ end write_fixture(path_a, (r, c, k) -> r * 10 + c + k) file = Quiver.Binary.open_file(path_a; mode = 'r') try - out = Quiver.aggregate_agents(file, "mean") + out = Quiver.aggregate_agents(file, Quiver.C.QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_MEAN) Quiver.save(out, path_out) Quiver.close!(out) finally diff --git a/include/quiver/c/expression/expression.h b/include/quiver/c/expression/expression.h index 4c40a9b9..aaaf21d2 100644 --- a/include/quiver/c/expression/expression.h +++ b/include/quiver/c/expression/expression.h @@ -34,6 +34,24 @@ typedef enum { QUIVER_EXPRESSION_TERNARY_OPERATION_IFELSE = 0, } quiver_expression_ternary_operation_t; +// Aggregate operation kind (dimension-axis reduction) +typedef enum { + QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM = 0, + QUIVER_EXPRESSION_AGGREGATE_OPERATION_MEAN = 1, + QUIVER_EXPRESSION_AGGREGATE_OPERATION_MIN = 2, + QUIVER_EXPRESSION_AGGREGATE_OPERATION_MAX = 3, + QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE = 4, +} quiver_expression_aggregate_operation_t; + +// Aggregate agents operation kind (label-axis reduction) +typedef enum { + QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_SUM = 0, + QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_MEAN = 1, + QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_MIN = 2, + QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_MAX = 3, + QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_PERCENTILE = 4, +} quiver_expression_aggregate_agents_operation_t; + // Construction QUIVER_C_API quiver_error_t quiver_expression_from_file(quiver_binary_file_t* file, quiver_expression_t** out); @@ -73,12 +91,12 @@ QUIVER_C_API quiver_error_t quiver_expression_get_metadata(quiver_expression_t* // Non-null pointer supplies the value (required for percentile, in [0, 1]). QUIVER_C_API quiver_error_t quiver_expression_aggregate(quiver_expression_t* expression, const char* dimension, - const char* operation, + quiver_expression_aggregate_operation_t operation, const double* parameter, quiver_expression_t** out); QUIVER_C_API quiver_error_t quiver_expression_aggregate_agents(quiver_expression_t* expression, - const char* operation, + quiver_expression_aggregate_agents_operation_t operation, const double* parameter, quiver_expression_t** out); diff --git a/include/quiver/expression/expression.h b/include/quiver/expression/expression.h index 730e3826..8f56bdd1 100644 --- a/include/quiver/expression/expression.h +++ b/include/quiver/expression/expression.h @@ -25,10 +25,11 @@ class QUIVER_API Expression { void save(const std::string& path) const; Expression aggregate(const std::string& dimension, - const std::string& operation, + ExpressionAggregate::Operation operation, std::optional parameter = std::nullopt) const; - Expression aggregate_agents(const std::string& operation, std::optional parameter = std::nullopt) const; + Expression aggregate_agents(ExpressionAggregateAgents::Operation operation, + std::optional parameter = std::nullopt) const; Expression select_agents(const std::vector& labels) const; diff --git a/include/quiver/expression/expression_node.h b/include/quiver/expression/expression_node.h index 3b71ac69..cc0a3614 100644 --- a/include/quiver/expression/expression_node.h +++ b/include/quiver/expression/expression_node.h @@ -152,8 +152,6 @@ class QUIVER_API ExpressionAggregate final : public ExpressionNode { public: enum class Operation { Sum, Mean, Min, Max, Percentile }; - static Operation parse_operation(const std::string& name); - ExpressionAggregate(Operation operation, std::shared_ptr operand, std::string dimension_name, @@ -181,8 +179,6 @@ class QUIVER_API ExpressionAggregateAgents final : public ExpressionNode { public: enum class Operation { Sum, Mean, Min, Max, Percentile }; - static Operation parse_operation(const std::string& name); - ExpressionAggregateAgents(Operation operation, std::shared_ptr operand, std::optional parameter = std::nullopt); diff --git a/src/c/expression/expression.cpp b/src/c/expression/expression.cpp index 70d52b33..f5b2e3ce 100644 --- a/src/c/expression/expression.cpp +++ b/src/c/expression/expression.cpp @@ -53,6 +53,38 @@ quiver::Expression dispatch_ternary(quiver_expression_ternary_operation_t operat throw std::runtime_error("Cannot apply: unknown expression ternary operation"); } +quiver::ExpressionAggregate::Operation from_c(quiver_expression_aggregate_operation_t op) { + switch (op) { + case QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM: + return quiver::ExpressionAggregate::Operation::Sum; + case QUIVER_EXPRESSION_AGGREGATE_OPERATION_MEAN: + return quiver::ExpressionAggregate::Operation::Mean; + case QUIVER_EXPRESSION_AGGREGATE_OPERATION_MIN: + return quiver::ExpressionAggregate::Operation::Min; + case QUIVER_EXPRESSION_AGGREGATE_OPERATION_MAX: + return quiver::ExpressionAggregate::Operation::Max; + case QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE: + return quiver::ExpressionAggregate::Operation::Percentile; + } + throw std::runtime_error("Cannot aggregate: unknown operation enum value"); +} + +quiver::ExpressionAggregateAgents::Operation from_c(quiver_expression_aggregate_agents_operation_t op) { + switch (op) { + case QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_SUM: + return quiver::ExpressionAggregateAgents::Operation::Sum; + case QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_MEAN: + return quiver::ExpressionAggregateAgents::Operation::Mean; + case QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_MIN: + return quiver::ExpressionAggregateAgents::Operation::Min; + case QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_MAX: + return quiver::ExpressionAggregateAgents::Operation::Max; + case QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_PERCENTILE: + return quiver::ExpressionAggregateAgents::Operation::Percentile; + } + throw std::runtime_error("Cannot aggregate_agents: unknown operation enum value"); +} + } // namespace extern "C" { @@ -207,14 +239,14 @@ QUIVER_C_API quiver_error_t quiver_expression_get_metadata(quiver_expression_t* QUIVER_C_API quiver_error_t quiver_expression_aggregate(quiver_expression_t* expression, const char* dimension, - const char* operation, + quiver_expression_aggregate_operation_t operation, const double* parameter, quiver_expression_t** out) { - QUIVER_REQUIRE(expression, dimension, operation, out); + QUIVER_REQUIRE(expression, dimension, out); try { std::optional p = parameter ? std::optional(*parameter) : std::nullopt; - *out = new quiver_expression(expression->expression.aggregate(dimension, operation, p)); + *out = new quiver_expression(expression->expression.aggregate(dimension, from_c(operation), p)); return QUIVER_OK; } catch (const std::bad_alloc&) { quiver_set_last_error("Memory allocation failed"); @@ -226,14 +258,14 @@ QUIVER_C_API quiver_error_t quiver_expression_aggregate(quiver_expression_t* exp } QUIVER_C_API quiver_error_t quiver_expression_aggregate_agents(quiver_expression_t* expression, - const char* operation, + quiver_expression_aggregate_agents_operation_t operation, const double* parameter, quiver_expression_t** out) { - QUIVER_REQUIRE(expression, operation, out); + QUIVER_REQUIRE(expression, out); try { std::optional p = parameter ? std::optional(*parameter) : std::nullopt; - *out = new quiver_expression(expression->expression.aggregate_agents(operation, p)); + *out = new quiver_expression(expression->expression.aggregate_agents(from_c(operation), p)); return QUIVER_OK; } catch (const std::bad_alloc&) { quiver_set_last_error("Memory allocation failed"); diff --git a/src/expression/expression.cpp b/src/expression/expression.cpp index d3cd3c0d..2ee31218 100644 --- a/src/expression/expression.cpp +++ b/src/expression/expression.cpp @@ -24,15 +24,14 @@ const BinaryMetadata& Expression::metadata() const { } Expression Expression::aggregate(const std::string& dimension, - const std::string& operation, + ExpressionAggregate::Operation operation, std::optional parameter) const { - auto op = ExpressionAggregate::parse_operation(operation); - return Expression(std::make_shared(op, node_, dimension, parameter)); + return Expression(std::make_shared(operation, node_, dimension, parameter)); } -Expression Expression::aggregate_agents(const std::string& operation, std::optional parameter) const { - auto op = ExpressionAggregateAgents::parse_operation(operation); - return Expression(std::make_shared(op, node_, parameter)); +Expression Expression::aggregate_agents(ExpressionAggregateAgents::Operation operation, + std::optional parameter) const { + return Expression(std::make_shared(operation, node_, parameter)); } Expression Expression::select_agents(const std::vector& labels) const { diff --git a/src/expression/expression_aggregate.cpp b/src/expression/expression_aggregate.cpp index 1a3f95f9..032ca1b0 100644 --- a/src/expression/expression_aggregate.cpp +++ b/src/expression/expression_aggregate.cpp @@ -15,10 +15,6 @@ namespace quiver { -ExpressionAggregate::Operation ExpressionAggregate::parse_operation(const std::string& name) { - return parse_aggregation_operation_name(name, "aggregate"); -} - ExpressionAggregate::ExpressionAggregate(Operation operation, std::shared_ptr operand, std::string dimension_name, diff --git a/src/expression/expression_aggregate_agents.cpp b/src/expression/expression_aggregate_agents.cpp index 016cef0a..15da69a6 100644 --- a/src/expression/expression_aggregate_agents.cpp +++ b/src/expression/expression_aggregate_agents.cpp @@ -13,10 +13,6 @@ namespace quiver { -ExpressionAggregateAgents::Operation ExpressionAggregateAgents::parse_operation(const std::string& name) { - return parse_aggregation_operation_name(name, "aggregate_agents"); -} - ExpressionAggregateAgents::ExpressionAggregateAgents(Operation operation, std::shared_ptr operand, std::optional parameter) diff --git a/src/expression/expression_helpers.h b/src/expression/expression_helpers.h index a553b822..ef699458 100644 --- a/src/expression/expression_helpers.h +++ b/src/expression/expression_helpers.h @@ -280,22 +280,6 @@ inline BinaryMetadata build_ternary_broadcast_metadata(const BinaryMetadata& con return out; } -template -Op parse_aggregation_operation_name(const std::string& name, const std::string& fn_label) { - if (name == "sum") - return Op::Sum; - if (name == "mean") - return Op::Mean; - if (name == "min") - return Op::Min; - if (name == "max") - return Op::Max; - if (name == "percentile") - return Op::Percentile; - throw std::runtime_error("Cannot " + fn_label + ": unknown operation '" + name + - "' (expected one of: sum, mean, min, max, percentile)"); -} - template std::string aggregation_operation_label(Op op) { switch (op) { diff --git a/tests/test_c_api_expression.cpp b/tests/test_c_api_expression.cpp index 80faa835..f8864ca4 100644 --- a/tests/test_c_api_expression.cpp +++ b/tests/test_c_api_expression.cpp @@ -1012,7 +1012,7 @@ TEST_F(ExpressionCApiFixture, AggregateSumOverDim) { auto* a = expr_from_file(path_a); quiver_expression_t* agg = nullptr; - ASSERT_EQ(quiver_expression_aggregate(a, "row", "sum", nullptr, &agg), QUIVER_OK); + ASSERT_EQ(quiver_expression_aggregate(a, "row", QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, nullptr, &agg), QUIVER_OK); ASSERT_EQ(quiver_expression_save(agg, path_out.c_str()), QUIVER_OK); quiver_expression_close(a); quiver_expression_close(agg); @@ -1035,7 +1035,7 @@ TEST_F(ExpressionCApiFixture, AggregatePercentileWithParam) { auto* a = expr_from_file(path_a); const double p = 0.5; quiver_expression_t* agg = nullptr; - ASSERT_EQ(quiver_expression_aggregate(a, "row", "percentile", &p, &agg), QUIVER_OK); + ASSERT_EQ(quiver_expression_aggregate(a, "row", QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE, &p, &agg), QUIVER_OK); ASSERT_EQ(quiver_expression_save(agg, path_out.c_str()), QUIVER_OK); quiver_expression_close(a); quiver_expression_close(agg); @@ -1052,7 +1052,7 @@ TEST_F(ExpressionCApiFixture, AggregateSumWithExtraParamReturnsError) { auto* a = expr_from_file(path_a); const double p = 0.5; quiver_expression_t* agg = nullptr; - EXPECT_EQ(quiver_expression_aggregate(a, "row", "sum", &p, &agg), QUIVER_ERROR); + EXPECT_EQ(quiver_expression_aggregate(a, "row", QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, &p, &agg), QUIVER_ERROR); EXPECT_EQ(agg, nullptr); EXPECT_NE(std::string(quiver_get_last_error()).find("does not accept a parameter"), std::string::npos); quiver_expression_close(a); @@ -1062,7 +1062,7 @@ TEST_F(ExpressionCApiFixture, AggregatePercentileMissingParamReturnsError) { write_fixture(path_a, [](int, int, int) { return 1.0; }); auto* a = expr_from_file(path_a); quiver_expression_t* agg = nullptr; - EXPECT_EQ(quiver_expression_aggregate(a, "row", "percentile", nullptr, &agg), QUIVER_ERROR); + EXPECT_EQ(quiver_expression_aggregate(a, "row", QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE, nullptr, &agg), QUIVER_ERROR); EXPECT_EQ(agg, nullptr); EXPECT_NE(std::string(quiver_get_last_error()).find("requires a parameter"), std::string::npos); quiver_expression_close(a); @@ -1072,27 +1072,18 @@ TEST_F(ExpressionCApiFixture, AggregateDimensionNotFoundReturnsError) { write_fixture(path_a, [](int, int, int) { return 1.0; }); auto* a = expr_from_file(path_a); quiver_expression_t* agg = nullptr; - EXPECT_EQ(quiver_expression_aggregate(a, "nonexistent", "sum", nullptr, &agg), QUIVER_ERROR); + EXPECT_EQ(quiver_expression_aggregate(a, "nonexistent", QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, nullptr, &agg), QUIVER_ERROR); EXPECT_EQ(agg, nullptr); EXPECT_NE(std::string(quiver_get_last_error()).find("Dimension not found"), std::string::npos); quiver_expression_close(a); } -TEST_F(ExpressionCApiFixture, AggregateUnknownOperationReturnsError) { - write_fixture(path_a, [](int, int, int) { return 1.0; }); - auto* a = expr_from_file(path_a); - quiver_expression_t* agg = nullptr; - EXPECT_EQ(quiver_expression_aggregate(a, "row", "average", nullptr, &agg), QUIVER_ERROR); - EXPECT_EQ(agg, nullptr); - quiver_expression_close(a); -} - TEST_F(ExpressionCApiFixture, AggregateAgentsSumReducesLabels) { write_fixture(path_a, [](int r, int c, int k) { return static_cast(r * 10 + c + k); }); auto* a = expr_from_file(path_a); quiver_expression_t* agg = nullptr; - ASSERT_EQ(quiver_expression_aggregate_agents(a, "sum", nullptr, &agg), QUIVER_OK); + ASSERT_EQ(quiver_expression_aggregate_agents(a, QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_SUM, nullptr, &agg), QUIVER_OK); // Verify output metadata: single label "sum", dims unchanged. quiver_binary_metadata_t* out_md = nullptr; @@ -1123,7 +1114,7 @@ TEST_F(ExpressionCApiFixture, AggregateAgentsPercentileWithParam) { auto* a = expr_from_file(path_a); const double p = 0.5; quiver_expression_t* agg = nullptr; - ASSERT_EQ(quiver_expression_aggregate_agents(a, "percentile", &p, &agg), QUIVER_OK); + ASSERT_EQ(quiver_expression_aggregate_agents(a, QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_PERCENTILE, &p, &agg), QUIVER_OK); ASSERT_EQ(quiver_expression_save(agg, path_out.c_str()), QUIVER_OK); quiver_expression_close(a); quiver_expression_close(agg); @@ -1138,10 +1129,12 @@ TEST_F(ExpressionCApiFixture, AggregateNullArguments) { auto* a = expr_from_file(path_a); quiver_expression_t* agg = nullptr; - EXPECT_EQ(quiver_expression_aggregate(nullptr, "row", "sum", nullptr, &agg), QUIVER_ERROR); - EXPECT_EQ(quiver_expression_aggregate(a, nullptr, "sum", nullptr, &agg), QUIVER_ERROR); - EXPECT_EQ(quiver_expression_aggregate(a, "row", nullptr, nullptr, &agg), QUIVER_ERROR); - EXPECT_EQ(quiver_expression_aggregate(a, "row", "sum", nullptr, nullptr), QUIVER_ERROR); + EXPECT_EQ(quiver_expression_aggregate(nullptr, "row", QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, nullptr, &agg), + QUIVER_ERROR); + EXPECT_EQ(quiver_expression_aggregate(a, nullptr, QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, nullptr, &agg), + QUIVER_ERROR); + EXPECT_EQ(quiver_expression_aggregate(a, "row", QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, nullptr, nullptr), + QUIVER_ERROR); quiver_expression_close(a); } @@ -1151,9 +1144,11 @@ TEST_F(ExpressionCApiFixture, AggregateAgentsNullArguments) { auto* a = expr_from_file(path_a); quiver_expression_t* agg = nullptr; - EXPECT_EQ(quiver_expression_aggregate_agents(nullptr, "sum", nullptr, &agg), QUIVER_ERROR); - EXPECT_EQ(quiver_expression_aggregate_agents(a, nullptr, nullptr, &agg), QUIVER_ERROR); - EXPECT_EQ(quiver_expression_aggregate_agents(a, "sum", nullptr, nullptr), QUIVER_ERROR); + EXPECT_EQ(quiver_expression_aggregate_agents(nullptr, QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_SUM, nullptr, + &agg), + QUIVER_ERROR); + EXPECT_EQ(quiver_expression_aggregate_agents(a, QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_SUM, nullptr, nullptr), + QUIVER_ERROR); quiver_expression_close(a); } @@ -1167,7 +1162,7 @@ TEST_F(ExpressionCApiFixture, AggregateChainedWithBinary) { quiver_expression_t* sum = nullptr; ASSERT_EQ(quiver_expression_apply(QUIVER_EXPRESSION_OPERATION_ADD, a, b, &sum), QUIVER_OK); quiver_expression_t* agg = nullptr; - ASSERT_EQ(quiver_expression_aggregate(sum, "row", "sum", nullptr, &agg), QUIVER_OK); + ASSERT_EQ(quiver_expression_aggregate(sum, "row", QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, nullptr, &agg), QUIVER_OK); ASSERT_EQ(quiver_expression_save(agg, path_out.c_str()), QUIVER_OK); quiver_expression_close(a); quiver_expression_close(b); diff --git a/tests/test_expression.cpp b/tests/test_expression.cpp index a643d727..95ce3a89 100644 --- a/tests/test_expression.cpp +++ b/tests/test_expression.cpp @@ -992,7 +992,7 @@ TEST_F(ExpressionFixture, AggregateSumOverNonTimeDim) { return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression e = Expression(a).aggregate("row", "sum"); + Expression e = Expression(a).aggregate("row", ExpressionAggregate::Operation::Sum); e.save(path_out); // Output has only "col" dim (size 2) and labels [val1, val2] = 4 cells. @@ -1019,7 +1019,7 @@ TEST_F(ExpressionFixture, AggregateMeanOverNonTimeDim) { return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression(a).aggregate("row", "mean").save(path_out); + Expression(a).aggregate("row", ExpressionAggregate::Operation::Mean).save(path_out); auto vo = read_all_cells(path_out); ASSERT_EQ(vo.size(), 4u); @@ -1037,7 +1037,7 @@ TEST_F(ExpressionFixture, AggregateMinOverNonTimeDim) { return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression(a).aggregate("row", "min").save(path_out); + Expression(a).aggregate("row", ExpressionAggregate::Operation::Min).save(path_out); auto vo = read_all_cells(path_out); ASSERT_EQ(vo.size(), 4u); @@ -1054,7 +1054,7 @@ TEST_F(ExpressionFixture, AggregateMaxOverNonTimeDim) { return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression(a).aggregate("row", "max").save(path_out); + Expression(a).aggregate("row", ExpressionAggregate::Operation::Max).save(path_out); auto vo = read_all_cells(path_out); ASSERT_EQ(vo.size(), 4u); @@ -1071,7 +1071,7 @@ TEST_F(ExpressionFixture, AggregatePercentileOverNonTimeDim) { return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression(a).aggregate("row", "percentile", 0.5).save(path_out); + Expression(a).aggregate("row", ExpressionAggregate::Operation::Percentile, 0.5).save(path_out); auto vo = read_all_cells(path_out); ASSERT_EQ(vo.size(), 4u); @@ -1095,7 +1095,7 @@ TEST_F(ExpressionFixture, AggregateSumOverTimeDimSimple) { .set("labels", {"v1"})); write_qvr(path_a, md, [](const std::vector& dims, size_t) { return static_cast(dims[0]); }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression(a).aggregate("year", "sum").save(path_out); + Expression(a).aggregate("year", ExpressionAggregate::Operation::Sum).save(path_out); auto vo = read_all_cells(path_out); // Output: [scenario(2)] × [v1] = 2 cells. Sum over year=1..3 = 6. @@ -1119,7 +1119,7 @@ TEST_F(ExpressionFixture, AggregateSumOverTimeDimVariable) { // Fill every cell with 1.0 → sum over block at each month equals the number of days. write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression(a).aggregate("block", "sum").save(path_out); + Expression(a).aggregate("block", ExpressionAggregate::Operation::Sum).save(path_out); auto vo = read_all_cells(path_out); ASSERT_EQ(vo.size(), 4u); // 4 months × 1 label @@ -1139,7 +1139,7 @@ TEST_F(ExpressionFixture, AggregateSumSkipsNaNs) { return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression(a).aggregate("row", "sum").save(path_out); + Expression(a).aggregate("row", ExpressionAggregate::Operation::Sum).save(path_out); auto vo = read_all_cells(path_out); ASSERT_EQ(vo.size(), 4u); @@ -1155,7 +1155,7 @@ TEST_F(ExpressionFixture, AggregateAllNaNRangeProducesNaN) { const double kNan = std::numeric_limits::quiet_NaN(); write_qvr(path_a, md, [kNan](const std::vector&, size_t) { return kNan; }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression(a).aggregate("row", "sum").save(path_out); + Expression(a).aggregate("row", ExpressionAggregate::Operation::Sum).save(path_out); auto vo = read_all_cells(path_out); ASSERT_EQ(vo.size(), 4u); @@ -1179,7 +1179,7 @@ TEST_F(ExpressionFixture, AggregateTimeDimRewireParents) { .set("labels", {"v1"})); write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); auto a = BinaryFile::open_file(path_a, 'r'); - auto out = Expression(a).aggregate("scenario", "sum"); + auto out = Expression(a).aggregate("scenario", ExpressionAggregate::Operation::Sum); const auto& m = out.metadata(); ASSERT_EQ(m.dimensions.size(), 2u); @@ -1203,7 +1203,7 @@ TEST_F(ExpressionFixture, AggregateReduceOutermostTimeDimWithChildren) { .set("labels", {"v1"})); write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); auto a = BinaryFile::open_file(path_a, 'r'); - auto out = Expression(a).aggregate("year", "sum"); + auto out = Expression(a).aggregate("year", ExpressionAggregate::Operation::Sum); const auto& m = out.metadata(); ASSERT_EQ(m.dimensions.size(), 1u); @@ -1217,36 +1217,29 @@ TEST_F(ExpressionFixture, AggregateDimensionNotFoundThrows) { auto md = make_simple_metadata(); write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); auto a = BinaryFile::open_file(path_a, 'r'); - EXPECT_THROW(Expression(a).aggregate("nonexistent", "sum"), std::runtime_error); -} - -TEST_F(ExpressionFixture, AggregateUnknownOperationThrows) { - auto md = make_simple_metadata(); - write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); - auto a = BinaryFile::open_file(path_a, 'r'); - EXPECT_THROW(Expression(a).aggregate("row", "average"), std::runtime_error); + EXPECT_THROW(Expression(a).aggregate("nonexistent", ExpressionAggregate::Operation::Sum), std::runtime_error); } TEST_F(ExpressionFixture, AggregatePercentileMissingParamThrows) { auto md = make_simple_metadata(); write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); auto a = BinaryFile::open_file(path_a, 'r'); - EXPECT_THROW(Expression(a).aggregate("row", "percentile"), std::runtime_error); + EXPECT_THROW(Expression(a).aggregate("row", ExpressionAggregate::Operation::Percentile), std::runtime_error); } TEST_F(ExpressionFixture, AggregateSumExtraParamThrows) { auto md = make_simple_metadata(); write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); auto a = BinaryFile::open_file(path_a, 'r'); - EXPECT_THROW(Expression(a).aggregate("row", "sum", 0.5), std::runtime_error); + EXPECT_THROW(Expression(a).aggregate("row", ExpressionAggregate::Operation::Sum, 0.5), std::runtime_error); } TEST_F(ExpressionFixture, AggregatePercentileOutOfRangeThrows) { auto md = make_simple_metadata(); write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); auto a = BinaryFile::open_file(path_a, 'r'); - EXPECT_THROW(Expression(a).aggregate("row", "percentile", 1.5), std::runtime_error); - EXPECT_THROW(Expression(a).aggregate("row", "percentile", -0.1), std::runtime_error); + EXPECT_THROW(Expression(a).aggregate("row", ExpressionAggregate::Operation::Percentile, 1.5), std::runtime_error); + EXPECT_THROW(Expression(a).aggregate("row", ExpressionAggregate::Operation::Percentile, -0.1), std::runtime_error); } TEST_F(ExpressionFixture, AggregateChained) { @@ -1261,7 +1254,7 @@ TEST_F(ExpressionFixture, AggregateChained) { .set("labels", {"v1"})); write_qvr(path_a, md, [](const std::vector&, size_t) { return 2.0; }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression(a).aggregate("row", "sum").aggregate("col", "sum").save(path_out); + Expression(a).aggregate("row", ExpressionAggregate::Operation::Sum).aggregate("col", ExpressionAggregate::Operation::Sum).save(path_out); // Output dims = [depth(2)] × 1 label = 2 cells. Each cell sums 3 rows × 2 cols of 2.0 = 12.0. auto vo = read_all_cells(path_out); @@ -1280,7 +1273,7 @@ TEST_F(ExpressionFixture, AggregateComposedWithBinary) { }); auto a = BinaryFile::open_file(path_a, 'r'); auto b = BinaryFile::open_file(path_b, 'r'); - (Expression(a) + Expression(b)).aggregate("row", "sum").save(path_out); + (Expression(a) + Expression(b)).aggregate("row", ExpressionAggregate::Operation::Sum).save(path_out); auto vo = read_all_cells(path_out); ASSERT_EQ(vo.size(), 4u); @@ -1302,7 +1295,7 @@ TEST_F(ExpressionFixture, AgentSumReducesLabels) { return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); }); auto a = BinaryFile::open_file(path_a, 'r'); - auto out = Expression(a).aggregate_agents("sum"); + auto out = Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Sum); out.save(path_out); const auto& m = out.metadata(); @@ -1328,7 +1321,7 @@ TEST_F(ExpressionFixture, AgentMeanReducesLabels) { return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); }); auto a = BinaryFile::open_file(path_a, 'r'); - auto out = Expression(a).aggregate_agents("mean"); + auto out = Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Mean); out.save(path_out); const auto& m = out.metadata(); @@ -1346,7 +1339,7 @@ TEST_F(ExpressionFixture, AgentMinReducesLabels) { return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); }); auto a = BinaryFile::open_file(path_a, 'r'); - auto out = Expression(a).aggregate_agents("min"); + auto out = Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Min); out.save(path_out); EXPECT_EQ(out.metadata().labels[0], "min"); @@ -1362,7 +1355,7 @@ TEST_F(ExpressionFixture, AgentMaxReducesLabels) { return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); }); auto a = BinaryFile::open_file(path_a, 'r'); - auto out = Expression(a).aggregate_agents("max"); + auto out = Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Max); out.save(path_out); EXPECT_EQ(out.metadata().labels[0], "max"); @@ -1378,7 +1371,7 @@ TEST_F(ExpressionFixture, AgentPercentileReducesLabels) { return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); }); auto a = BinaryFile::open_file(path_a, 'r'); - auto out = Expression(a).aggregate_agents("percentile", 0.5); + auto out = Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Percentile, 0.5); out.save(path_out); EXPECT_EQ(out.metadata().labels[0], "percentile"); @@ -1396,7 +1389,7 @@ TEST_F(ExpressionFixture, AgentSkipsNaNs) { return static_cast(dims[0] * 10 + dims[1] + static_cast(k)); }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression(a).aggregate_agents("sum").save(path_out); + Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Sum).save(path_out); auto vo = read_all_cells(path_out); // Only label k=1 contributes; sum = single value = 10r + c + 1. @@ -1409,7 +1402,7 @@ TEST_F(ExpressionFixture, AgentAllNaNProducesNaN) { const double kNan = std::numeric_limits::quiet_NaN(); write_qvr(path_a, md, [kNan](const std::vector&, size_t) { return kNan; }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression(a).aggregate_agents("sum").save(path_out); + Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Sum).save(path_out); auto vo = read_all_cells(path_out); for (double v : vo) { @@ -1429,7 +1422,7 @@ TEST_F(ExpressionFixture, AgentPreservesDimensions) { .set("labels", {"v1", "v2", "v3"})); write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); auto a = BinaryFile::open_file(path_a, 'r'); - auto out = Expression(a).aggregate_agents("mean"); + auto out = Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Mean); const auto& m = out.metadata(); EXPECT_EQ(m.unit, "MW"); @@ -1444,33 +1437,26 @@ TEST_F(ExpressionFixture, AgentPreservesDimensions) { EXPECT_EQ(m.labels[0], "mean"); } -TEST_F(ExpressionFixture, AgentUnknownOperationThrows) { - auto md = make_simple_metadata(); - write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); - auto a = BinaryFile::open_file(path_a, 'r'); - EXPECT_THROW(Expression(a).aggregate_agents("average"), std::runtime_error); -} - TEST_F(ExpressionFixture, AgentPercentileMissingParamThrows) { auto md = make_simple_metadata(); write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); auto a = BinaryFile::open_file(path_a, 'r'); - EXPECT_THROW(Expression(a).aggregate_agents("percentile"), std::runtime_error); + EXPECT_THROW(Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Percentile), std::runtime_error); } TEST_F(ExpressionFixture, AgentPercentileOutOfRangeThrows) { auto md = make_simple_metadata(); write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); auto a = BinaryFile::open_file(path_a, 'r'); - EXPECT_THROW(Expression(a).aggregate_agents("percentile", 1.5), std::runtime_error); - EXPECT_THROW(Expression(a).aggregate_agents("percentile", -0.1), std::runtime_error); + EXPECT_THROW(Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Percentile, 1.5), std::runtime_error); + EXPECT_THROW(Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Percentile, -0.1), std::runtime_error); } TEST_F(ExpressionFixture, AgentChainedAfterAggregate) { auto md = make_simple_metadata(); write_qvr(path_a, md, [](const std::vector&, size_t) { return 3.0; }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression(a).aggregate("row", "sum").aggregate_agents("mean").save(path_out); + Expression(a).aggregate("row", ExpressionAggregate::Operation::Sum).aggregate_agents(ExpressionAggregateAgents::Operation::Mean).save(path_out); // After reducing row(3) and agents(2): output dims=[col(2)], labels=["mean"] = 2 cells. // First sum over 3 rows of 3.0 → 9.0 in each (col, k). Then mean across 2 labels → 9.0. @@ -1486,7 +1472,7 @@ TEST_F(ExpressionFixture, AgentSaveProducesReadableFile) { return static_cast(dims[0] + dims[1] + static_cast(k)); }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression(a).aggregate_agents("sum").save(path_out); + Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Sum).save(path_out); auto reopened = BinaryFile::open_file(path_out, 'r'); const auto& m = reopened.get_metadata(); From 7c301ff3cc9d6ebfffafb153a6c314562613b158 Mon Sep 17 00:00:00 2001 From: raphasampaio Date: Sun, 17 May 2026 20:02:35 -0300 Subject: [PATCH 12/13] Update --- src/expression/expression_aggregate.cpp | 3 +-- .../expression_aggregate_agents.cpp | 3 +-- src/expression/expression_binary.cpp | 3 +-- src/expression/expression_file.cpp | 3 +-- src/expression/expression_ternary.cpp | 3 +-- tests/test_c_api_expression.cpp | 27 ++++++++++++------- tests/test_expression.cpp | 16 ++++++++--- 7 files changed, 34 insertions(+), 24 deletions(-) diff --git a/src/expression/expression_aggregate.cpp b/src/expression/expression_aggregate.cpp index 032ca1b0..86bce676 100644 --- a/src/expression/expression_aggregate.cpp +++ b/src/expression/expression_aggregate.cpp @@ -1,7 +1,6 @@ -#include "quiver/expression/expression_node.h" - #include "expression_helpers.h" #include "quiver/binary/iteration.h" +#include "quiver/expression/expression_node.h" #include #include diff --git a/src/expression/expression_aggregate_agents.cpp b/src/expression/expression_aggregate_agents.cpp index 15da69a6..4fae40fc 100644 --- a/src/expression/expression_aggregate_agents.cpp +++ b/src/expression/expression_aggregate_agents.cpp @@ -1,6 +1,5 @@ -#include "quiver/expression/expression_node.h" - #include "expression_helpers.h" +#include "quiver/expression/expression_node.h" #include #include diff --git a/src/expression/expression_binary.cpp b/src/expression/expression_binary.cpp index 8a414f47..7038daa4 100644 --- a/src/expression/expression_binary.cpp +++ b/src/expression/expression_binary.cpp @@ -1,6 +1,5 @@ -#include "quiver/expression/expression_node.h" - #include "expression_helpers.h" +#include "quiver/expression/expression_node.h" #include #include diff --git a/src/expression/expression_file.cpp b/src/expression/expression_file.cpp index d9dd2a54..b116ba9d 100644 --- a/src/expression/expression_file.cpp +++ b/src/expression/expression_file.cpp @@ -1,6 +1,5 @@ -#include "quiver/expression/expression_node.h" - #include "quiver/binary/binary_file.h" +#include "quiver/expression/expression_node.h" #include #include diff --git a/src/expression/expression_ternary.cpp b/src/expression/expression_ternary.cpp index ce7bf089..09961ee9 100644 --- a/src/expression/expression_ternary.cpp +++ b/src/expression/expression_ternary.cpp @@ -1,6 +1,5 @@ -#include "quiver/expression/expression_node.h" - #include "expression_helpers.h" +#include "quiver/expression/expression_node.h" #include #include diff --git a/tests/test_c_api_expression.cpp b/tests/test_c_api_expression.cpp index f8864ca4..03dcf8ca 100644 --- a/tests/test_c_api_expression.cpp +++ b/tests/test_c_api_expression.cpp @@ -1012,7 +1012,8 @@ TEST_F(ExpressionCApiFixture, AggregateSumOverDim) { auto* a = expr_from_file(path_a); quiver_expression_t* agg = nullptr; - ASSERT_EQ(quiver_expression_aggregate(a, "row", QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, nullptr, &agg), QUIVER_OK); + ASSERT_EQ(quiver_expression_aggregate(a, "row", QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, nullptr, &agg), + QUIVER_OK); ASSERT_EQ(quiver_expression_save(agg, path_out.c_str()), QUIVER_OK); quiver_expression_close(a); quiver_expression_close(agg); @@ -1035,7 +1036,8 @@ TEST_F(ExpressionCApiFixture, AggregatePercentileWithParam) { auto* a = expr_from_file(path_a); const double p = 0.5; quiver_expression_t* agg = nullptr; - ASSERT_EQ(quiver_expression_aggregate(a, "row", QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE, &p, &agg), QUIVER_OK); + ASSERT_EQ(quiver_expression_aggregate(a, "row", QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE, &p, &agg), + QUIVER_OK); ASSERT_EQ(quiver_expression_save(agg, path_out.c_str()), QUIVER_OK); quiver_expression_close(a); quiver_expression_close(agg); @@ -1062,7 +1064,8 @@ TEST_F(ExpressionCApiFixture, AggregatePercentileMissingParamReturnsError) { write_fixture(path_a, [](int, int, int) { return 1.0; }); auto* a = expr_from_file(path_a); quiver_expression_t* agg = nullptr; - EXPECT_EQ(quiver_expression_aggregate(a, "row", QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE, nullptr, &agg), QUIVER_ERROR); + EXPECT_EQ(quiver_expression_aggregate(a, "row", QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE, nullptr, &agg), + QUIVER_ERROR); EXPECT_EQ(agg, nullptr); EXPECT_NE(std::string(quiver_get_last_error()).find("requires a parameter"), std::string::npos); quiver_expression_close(a); @@ -1072,7 +1075,8 @@ TEST_F(ExpressionCApiFixture, AggregateDimensionNotFoundReturnsError) { write_fixture(path_a, [](int, int, int) { return 1.0; }); auto* a = expr_from_file(path_a); quiver_expression_t* agg = nullptr; - EXPECT_EQ(quiver_expression_aggregate(a, "nonexistent", QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, nullptr, &agg), QUIVER_ERROR); + EXPECT_EQ(quiver_expression_aggregate(a, "nonexistent", QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, nullptr, &agg), + QUIVER_ERROR); EXPECT_EQ(agg, nullptr); EXPECT_NE(std::string(quiver_get_last_error()).find("Dimension not found"), std::string::npos); quiver_expression_close(a); @@ -1083,7 +1087,8 @@ TEST_F(ExpressionCApiFixture, AggregateAgentsSumReducesLabels) { auto* a = expr_from_file(path_a); quiver_expression_t* agg = nullptr; - ASSERT_EQ(quiver_expression_aggregate_agents(a, QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_SUM, nullptr, &agg), QUIVER_OK); + ASSERT_EQ(quiver_expression_aggregate_agents(a, QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_SUM, nullptr, &agg), + QUIVER_OK); // Verify output metadata: single label "sum", dims unchanged. quiver_binary_metadata_t* out_md = nullptr; @@ -1114,7 +1119,8 @@ TEST_F(ExpressionCApiFixture, AggregateAgentsPercentileWithParam) { auto* a = expr_from_file(path_a); const double p = 0.5; quiver_expression_t* agg = nullptr; - ASSERT_EQ(quiver_expression_aggregate_agents(a, QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_PERCENTILE, &p, &agg), QUIVER_OK); + ASSERT_EQ(quiver_expression_aggregate_agents(a, QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_PERCENTILE, &p, &agg), + QUIVER_OK); ASSERT_EQ(quiver_expression_save(agg, path_out.c_str()), QUIVER_OK); quiver_expression_close(a); quiver_expression_close(agg); @@ -1144,9 +1150,9 @@ TEST_F(ExpressionCApiFixture, AggregateAgentsNullArguments) { auto* a = expr_from_file(path_a); quiver_expression_t* agg = nullptr; - EXPECT_EQ(quiver_expression_aggregate_agents(nullptr, QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_SUM, nullptr, - &agg), - QUIVER_ERROR); + EXPECT_EQ( + quiver_expression_aggregate_agents(nullptr, QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_SUM, nullptr, &agg), + QUIVER_ERROR); EXPECT_EQ(quiver_expression_aggregate_agents(a, QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_SUM, nullptr, nullptr), QUIVER_ERROR); @@ -1162,7 +1168,8 @@ TEST_F(ExpressionCApiFixture, AggregateChainedWithBinary) { quiver_expression_t* sum = nullptr; ASSERT_EQ(quiver_expression_apply(QUIVER_EXPRESSION_OPERATION_ADD, a, b, &sum), QUIVER_OK); quiver_expression_t* agg = nullptr; - ASSERT_EQ(quiver_expression_aggregate(sum, "row", QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, nullptr, &agg), QUIVER_OK); + ASSERT_EQ(quiver_expression_aggregate(sum, "row", QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, nullptr, &agg), + QUIVER_OK); ASSERT_EQ(quiver_expression_save(agg, path_out.c_str()), QUIVER_OK); quiver_expression_close(a); quiver_expression_close(b); diff --git a/tests/test_expression.cpp b/tests/test_expression.cpp index 95ce3a89..06b1c466 100644 --- a/tests/test_expression.cpp +++ b/tests/test_expression.cpp @@ -1254,7 +1254,10 @@ TEST_F(ExpressionFixture, AggregateChained) { .set("labels", {"v1"})); write_qvr(path_a, md, [](const std::vector&, size_t) { return 2.0; }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression(a).aggregate("row", ExpressionAggregate::Operation::Sum).aggregate("col", ExpressionAggregate::Operation::Sum).save(path_out); + Expression(a) + .aggregate("row", ExpressionAggregate::Operation::Sum) + .aggregate("col", ExpressionAggregate::Operation::Sum) + .save(path_out); // Output dims = [depth(2)] × 1 label = 2 cells. Each cell sums 3 rows × 2 cols of 2.0 = 12.0. auto vo = read_all_cells(path_out); @@ -1448,15 +1451,20 @@ TEST_F(ExpressionFixture, AgentPercentileOutOfRangeThrows) { auto md = make_simple_metadata(); write_qvr(path_a, md, [](const std::vector&, size_t) { return 1.0; }); auto a = BinaryFile::open_file(path_a, 'r'); - EXPECT_THROW(Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Percentile, 1.5), std::runtime_error); - EXPECT_THROW(Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Percentile, -0.1), std::runtime_error); + EXPECT_THROW(Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Percentile, 1.5), + std::runtime_error); + EXPECT_THROW(Expression(a).aggregate_agents(ExpressionAggregateAgents::Operation::Percentile, -0.1), + std::runtime_error); } TEST_F(ExpressionFixture, AgentChainedAfterAggregate) { auto md = make_simple_metadata(); write_qvr(path_a, md, [](const std::vector&, size_t) { return 3.0; }); auto a = BinaryFile::open_file(path_a, 'r'); - Expression(a).aggregate("row", ExpressionAggregate::Operation::Sum).aggregate_agents(ExpressionAggregateAgents::Operation::Mean).save(path_out); + Expression(a) + .aggregate("row", ExpressionAggregate::Operation::Sum) + .aggregate_agents(ExpressionAggregateAgents::Operation::Mean) + .save(path_out); // After reducing row(3) and agents(2): output dims=[col(2)], labels=["mean"] = 2 cells. // First sum over 3 rows of 3.0 → 9.0 in each (col, k). Then mean across 2 labels → 9.0. From 9de30d811a81812e2d7bec40e4d15306021f89b8 Mon Sep 17 00:00:00 2001 From: raphasampaio Date: Sun, 17 May 2026 20:04:17 -0300 Subject: [PATCH 13/13] Update --- bindings/julia/test/test_expression.jl | 44 ++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/bindings/julia/test/test_expression.jl b/bindings/julia/test/test_expression.jl index 4d6070b6..6e668566 100644 --- a/bindings/julia/test/test_expression.jl +++ b/bindings/julia/test/test_expression.jl @@ -1064,7 +1064,11 @@ end ) write_dense(path_a, md, [:row, :col, :depth], [3, 2, 2], 1, (_, _) -> 2.0) with_expr(path_a) do e - out = Quiver.aggregate(Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM), "col", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM) + out = Quiver.aggregate( + Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM), + "col", + Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, + ) Quiver.save(out, path_out) return Quiver.close!(out) end @@ -1101,7 +1105,11 @@ end try write_fixture(path_a, (_, _, _) -> 1.0) with_expr(path_a) do e - @test_throws Quiver.DatabaseException Quiver.aggregate(e, "nonexistent", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM) + @test_throws Quiver.DatabaseException Quiver.aggregate( + e, + "nonexistent", + Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, + ) end finally cleanup(path_a) @@ -1113,7 +1121,11 @@ end try write_fixture(path_a, (_, _, _) -> 1.0) with_expr(path_a) do e - @test_throws Quiver.DatabaseException Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE) + @test_throws Quiver.DatabaseException Quiver.aggregate( + e, + "row", + Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE, + ) end finally cleanup(path_a) @@ -1125,7 +1137,12 @@ end try write_fixture(path_a, (_, _, _) -> 1.0) with_expr(path_a) do e - @test_throws Quiver.DatabaseException Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, 0.5) + @test_throws Quiver.DatabaseException Quiver.aggregate( + e, + "row", + Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM, + 0.5, + ) end finally cleanup(path_a) @@ -1137,8 +1154,18 @@ end try write_fixture(path_a, (_, _, _) -> 1.0) with_expr(path_a) do e - @test_throws Quiver.DatabaseException Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE, 1.5) - @test_throws Quiver.DatabaseException Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE, -0.1) + @test_throws Quiver.DatabaseException Quiver.aggregate( + e, + "row", + Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE, + 1.5, + ) + @test_throws Quiver.DatabaseException Quiver.aggregate( + e, + "row", + Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_PERCENTILE, + -0.1, + ) end finally cleanup(path_a) @@ -1244,7 +1271,10 @@ end try write_fixture(path_a, (_, _, _) -> 3.0) with_expr(path_a) do e - out = Quiver.aggregate_agents(Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM), Quiver.C.QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_MEAN) + out = Quiver.aggregate_agents( + Quiver.aggregate(e, "row", Quiver.C.QUIVER_EXPRESSION_AGGREGATE_OPERATION_SUM), + Quiver.C.QUIVER_EXPRESSION_AGGREGATE_AGENTS_OPERATION_MEAN, + ) Quiver.save(out, path_out) return Quiver.close!(out) end