diff --git a/docs/language/reference/dataset_methods.md b/docs/language/reference/dataset_methods.md index ab4926e..26c43f7 100644 --- a/docs/language/reference/dataset_methods.md +++ b/docs/language/reference/dataset_methods.md @@ -19,9 +19,10 @@ The Substrait helper surface behind these methods is split by semantic role: | `with_column` | `def with_column(self, name: str, expr: ColumnExpr) -> Self` | Add or replace one projected column using a scalar expression. | | `group_by` | `def group_by(self, columns: list[ColumnExpr]) -> Self` | Define grouping keys using scalar expressions. | | `agg` | `def agg(self, measures: list[AggregateMeasure]) -> Self` | Apply aggregate measures over the current relation or current grouping. | +| `generate` | `def generate(self, generator: GeneratorApplication) -> Self` | Apply a relation-shaping generator such as `explode(...)` with explicit output aliases. | | `order_by` | `def order_by(self, columns: list[ColumnExpr]) -> Self` | Sort rows by scalar expressions or ordering helpers such as `asc(...)` and `desc(...)`. | | `limit` | `def limit(self, n: int) -> Self` | Cap row count. | -| `explode` | `def explode(self) -> Self` | Expand a nested list column into rows. | +| `explode` | `def explode(self) -> Self` | Emit the lower-level `EXPLODE` extension boundary without expression/schema metadata. | ## `with_column` @@ -67,6 +68,7 @@ def enrich(orders: LazyFrame[Order]) -> LazyFrame[Order]: - `join(...)` is constrained to same-carrier inputs and the boolean join predicate surface shown in the signature. - `select(...)` preserves projection shape; explicit projection lists are represented today through `with_column(...)` and scalar-expression builders. +- `generate(...)` preserves all input columns and appends generated output aliases for `explode`, `explode_outer`, `posexplode`, `posexplode_outer`, `inline`, `inline_outer`, `flatten`, and `stack` generator applications. Alias collisions are rejected during planning/lowering. - `DataFrame[T]` exposes materialized metadata and preview text; row-level accessors belong to the materialized DataFrame API surface. - Query-block and scoped DSL surfaces lower into these builder APIs rather than defining separate method semantics. diff --git a/docs/language/reference/functions/generators.md b/docs/language/reference/functions/generators.md new file mode 100644 index 0000000..e13451e --- /dev/null +++ b/docs/language/reference/functions/generators.md @@ -0,0 +1,43 @@ +# Generator and Table-Valued Functions (Reference) + +Generators are relation-shaping operations. They are registry-backed like scalar and aggregate helpers, but they return +`GeneratorApplication` values and must be applied through a relation method such as `generate(...)`. + +```incan +from pub::inql import LazyFrame +from pub::inql.functions import array, col, explode, inline, lit, named_struct +from models import Order + +def order_lines(orders: LazyFrame[Order]) -> LazyFrame[Order]: + return orders.generate(explode(col("line_items"), "line_item")) + +def fixed_items(orders: LazyFrame[Order]) -> LazyFrame[Order]: + rows = array([ + named_struct(["sku", "quantity"], [lit("A"), lit(1)]), + named_struct(["sku", "quantity"], [lit("B"), lit(2)]), + ]) + return orders.generate(inline(rows, ["sku", "quantity"])) +``` + +The explicit generator surface currently includes: + +| Function | Output aliases | Relation effect | +| --- | --- | --- | +| `explode(expr, as_)` | one value column | Emits one row per array element; null or empty inputs emit zero rows. | +| `explode_outer(expr, as_)` | one value column | Preserves the input row for null or empty inputs and emits a null generated value. | +| `posexplode(expr, position_as, value_as)` | position and value columns | Emits one row per array element with a zero-based position column. | +| `posexplode_outer(expr, position_as, value_as)` | position and value columns | Outer positional explode with the same zero-based position rule. | +| `inline(expr, output_columns)` | one column per struct field | Expands array-of-struct values into generated rows and declared output columns. | +| `inline_outer(expr, output_columns)` | one column per struct field | Outer inline with the same null/empty row preservation rule. | +| `flatten(expr, as_)` | one value column | Portable table-valued flatten for one array expression. | +| `stack(row_count, values, output_columns)` | declared output columns | Emits `row_count` generated rows from row-major scalar values. | + +Generator applications preserve input columns and append generated columns in declaration order. Generated aliases are +required, must be non-empty, and must not collide with existing input columns. + +The zero-argument `DataSet.explode()` method is a lower-level extension-boundary operation. It emits the registered +`EXPLODE` relation extension without carrying a source expression or generated output schema. Generator code should use +`generate(explode(...))` so the relation-shaping function identity, input expression, and output schema are explicit. + +Nested scalar helpers such as `array_flatten(...)` remain scalar expressions. They do not expand rows and are documented +on the [nested data functions](nested.md) page. The relation-shaping `flatten(...)` helper is intentionally separate. diff --git a/docs/language/reference/functions/index.md b/docs/language/reference/functions/index.md index 5bcbe86..0376399 100644 --- a/docs/language/reference/functions/index.md +++ b/docs/language/reference/functions/index.md @@ -7,11 +7,12 @@ Today the concrete shipped surfaces are documented here: - [Filter builders](../builders/filters.md) - [Aggregate builders](../builders/aggregates.md) - [Projection builders](../builders/projections.md) +- [Generator and table-valued functions](generators.md) - [Nested data functions](nested.md) The canonical scalar literal helper is `lit(...)`. Typed literal helpers construct the same scalar-expression representation. -The current registry-backed helper surface covers references, literals, casts, operators, predicates, conditionals, math, ordering, aggregates, and nested data. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), function policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked public helpers provide the signature and, by default, the canonical name; metadata may override the canonical name only for source spelling constraints such as the reserved-word `mod` case. +The current registry-backed helper surface covers references, literals, casts, operators, predicates, conditionals, math, ordering, aggregates, generators, and nested data. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), function policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked public helpers provide the signature and, by default, the canonical name; metadata may override the canonical name only for source spelling constraints such as the reserved-word `mod` case. The registry is the source for non-derivable machine facts. Public helper declarations are the source for argument names, argument types, and return types. Docstrings remain human-facing explanation, examples, and parameter intent. The `registry-metadata` check validates the checked API metadata projections produced from public facade aliases, registry decorators, and decorated callable signatures. Runtime registry entries are lazy and process-local: they support helper execution and lowering for loaded helpers, while the complete public catalog comes from checked metadata. This matters for generated docs, diagnostics, Prism lowering, and backend capability checks as the catalog grows. @@ -32,7 +33,8 @@ The registered helper surface currently includes: | `coalesce(...)`, `nullif(...)`, `case_when(...)` | scalar | registered Substrait mappings; `case_when(...)` lowers as built-in `IfThen` | | `in_(...)`, `between(...)` | scalar | built-in membership/range lowering (`SingularOrList` and `between`) | | `abs(...)`, `ceil(...)`, `floor(...)`, `round(...)` | scalar | registered Substrait math scalar mappings; `round(...)` is currently the single-argument form | -| `array(...)`, `cardinality(...)`, `array_contains(...)`, `arrays_overlap(...)`, `array_position(...)`, `element_at(...)`, `array_sort(...)`, `array_distinct(...)`, `array_except(...)`, `array_intersect(...)`, `array_union(...)`, `array_join(...)`, `array_slice(...)`, `array_reverse(...)`, `array_flatten(...)`, `map_from_arrays(...)`, `map_extract(...)`, `map_contains_key(...)`, `map_keys(...)`, `map_values(...)`, `map_entries(...)`, `named_struct(...)` | scalar | registered nested scalar helpers backed by Substrait extension mappings; `map_contains_key(...)` lowers as a documented predicate rewrite | +| `array(...)`, `cardinality(...)`, `array_contains(...)`, `arrays_overlap(...)`, `array_position(...)`, `array_range(...)`, `element_at(...)`, `array_sort(...)`, `array_distinct(...)`, `array_except(...)`, `array_intersect(...)`, `array_union(...)`, `array_join(...)`, `array_slice(...)`, `array_reverse(...)`, `array_flatten(...)`, `map_from_arrays(...)`, `map_extract(...)`, `map_contains_key(...)`, `map_keys(...)`, `map_values(...)`, `map_entries(...)`, `named_struct(...)` | scalar | registered nested scalar helpers backed by Substrait extension mappings; `array_range(...)` registers canonical `range` for positional generator lowering and `map_contains_key(...)` lowers as a documented predicate rewrite | +| `explode(...)`, `explode_outer(...)`, `posexplode(...)`, `posexplode_outer(...)`, `inline(...)`, `inline_outer(...)`, `flatten(...)`, `stack(...)` | generator | relation-extension mappings consumed by `generate(...)`; positional forms use zero-based positions | | `asc(...)`, `desc(...)`, `asc_nulls_first(...)`, `asc_nulls_last(...)`, `desc_nulls_first(...)`, `desc_nulls_last(...)` | ordering | structural sort-field helpers consumed by `order_by(...)` and lowered to Substrait `SortRel.sorts` | | `sum(...)`, `count(...)`, `count_expr(...)`, `count_distinct(...)`, `count_if(...)`, `avg(...)`, `min(...)`, `max(...)` | aggregate | registered Substrait extension functions for core aggregates plus compatibility rewrites for `count_expr(...)`, `count_distinct(...)`, and `count_if(...)`; core aggregates allow `DISTINCT` and aggregate-local `FILTER` where the aggregate shape is valid | diff --git a/docs/language/reference/functions/nested.md b/docs/language/reference/functions/nested.md index 644e1ad..88fcd22 100644 --- a/docs/language/reference/functions/nested.md +++ b/docs/language/reference/functions/nested.md @@ -20,6 +20,7 @@ Generator or table-valued operations such as row-expanding `explode(...)` are se | `array_intersect(left, right)` | Return elements shared by both arrays. | | `array_union(left, right)` | Return the union of both arrays. | | `array_join(array_expr, delimiter)` | Join a string array into one string. | +| `array_range(start, stop)` | Build a row-level integer array from `start` inclusive to `stop` exclusive. | | `array_slice(array_expr, start, stop)` | Return a one-based array slice using the backend adapter's slice contract. | | `array_reverse(array_expr)` | Reverse one array value. | | `array_flatten(array_expr)` | Flatten an array-of-arrays into one row-level array value. | @@ -54,5 +55,5 @@ projected = ( - Array indexing is one-based for `element_at(...)`, `array_position(...)`, and `array_slice(...)`. - `element_at(...)` currently maps to the portable array-element adapter path. Out-of-range behavior follows the current backend adapter's recoverable result until InQL has a richer static/runtime error-policy split for strict versus try-style element access. -- `array_flatten(...)` is intentionally named to avoid colliding with future table-valued or generator `flatten(...)` forms. +- `array_flatten(...)` is intentionally named to stay distinct from the relation-shaping generator `flatten(...)`. - Grouping or ordering by nested values is not documented as portable until equality and ordering semantics for arrays, maps, and structs are specified. diff --git a/docs/language/reference/substrait/operator_catalog.md b/docs/language/reference/substrait/operator_catalog.md index 4560185..327ad49 100644 --- a/docs/language/reference/substrait/operator_catalog.md +++ b/docs/language/reference/substrait/operator_catalog.md @@ -81,6 +81,9 @@ Core Substrait does not define a portable unnest or explode `Rel` at the logical Current package-level RFC 002 boundary registration: - `https://inql.io/extensions/v0.1/unnest.yaml#explode` +- `https://inql.io/extensions/v0.1/unnest.yaml#explode_outer` +- `https://inql.io/extensions/v0.1/unnest.yaml#posexplode` +- `https://inql.io/extensions/v0.1/unnest.yaml#posexplode_outer` ### Pivot / unpivot diff --git a/docs/release_notes/v0_1.md b/docs/release_notes/v0_1.md index 9dea513..fa2c710 100644 --- a/docs/release_notes/v0_1.md +++ b/docs/release_notes/v0_1.md @@ -15,6 +15,7 @@ Entries will be filled in as work lands (link RFCs and PRs when applicable). - **Core scalar functions:** RFC 015 adds registry-backed scalar function applications and the first core helper slice for casts, comparisons, boolean logic, null/NaN predicates, arithmetic, conditionals, membership/range predicates, and ordering expressions. Implemented helpers lower to Substrait IR through registry metadata, built-in Rex shapes, or structural sort-field lowering; DataFusion remains the first execution adapter rather than the semantic boundary. - **Common scalar functions:** The first RFC 018 slice adds registry-backed math helpers for `abs(...)`, `ceil(...)`, `floor(...)`, and single-argument `round(...)`, with Substrait mappings and DataFusion-backed execution coverage. - **Nested data functions:** RFC 020 adds registry-backed scalar helpers for array construction/access, cardinality, containment, overlap, sorting, set-like operations, joining, slicing, reversing, scalar array flattening, map construction/access, map key/value/entry extraction, map key containment, and named struct construction. These helpers lower through Substrait extension metadata without introducing generator semantics, with representative DataFusion-backed Session coverage for composable array projection paths. +- **Generator functions:** RFC 021 adds registry-backed generator applications for `explode(...)`, `explode_outer(...)`, `posexplode(...)`, `posexplode_outer(...)`, `inline(...)`, `inline_outer(...)`, portable `flatten(...)`, and `stack(...)`. Generators remain relation-shaping operations applied with `generate(...)`; they preserve input columns, require explicit output aliases, lower through the current Substrait extension-relation gap encoding, and execute through the DataFusion Session adapter with concrete output-column materialization. - **Function registry:** RFC 014 adds declaration-site registry decorators for the current public helper surface, including stable function references, checked signature projection, lifecycle metadata, behavior categories, alias policy, Substrait mapping categories, and checked API metadata drift validation. - **Function extension policy:** InQL RFC 024 policy metadata now distinguishes portable core functions, namespaced extension-only functions, opt-in compatibility aliases, engine-specific functions, and rejected compatibility requests without adding an extension plugin system or backend-owned semantics. - **Projection:** builder-based `with_column`, `add`, `mul`, and literal expression helpers now lower derived columns through Prism, Substrait, and Session execution. diff --git a/docs/rfcs/021_generator_table_functions.md b/docs/rfcs/021_generator_table_functions.md index b33febb..65b01a8 100644 --- a/docs/rfcs/021_generator_table_functions.md +++ b/docs/rfcs/021_generator_table_functions.md @@ -1,6 +1,6 @@ # InQL RFC 021: Generator and table-valued functions -- **Status:** Draft +- **Status:** Implemented - **Created:** 2026-04-27 - **Author(s):** Danny Meijer (@dannymeijer) - **Related:** @@ -11,9 +11,9 @@ - InQL RFC 014 (function registry and catalog governance) - InQL RFC 020 (nested data functions) - **Issue:** [InQL #38](https://github.com/dannys-code-corner/InQL/issues/38) -- **RFC PR:** — +- **RFC PR:** [InQL #47](https://github.com/dannys-code-corner/InQL/pull/47) - **Written against:** Incan v0.2 -- **Shipped in:** — +- **Shipped in:** v0.1 ## Summary @@ -42,14 +42,15 @@ InQL already has an unnest/explode design direction through its Substrait work. ## Guide-level explanation (how authors think about it) -Authors should use generators when one input row may become multiple output rows: +Authors should use generators when one input row may become multiple output rows. In the current builder surface, +generators are constructed as explicit applications and then applied to a relation: ```incan -from pub::inql.functions import col +from pub::inql.functions import col, explode items = ( orders - .explode(col("line_items"), as_="line_item") + .generate(explode(col("line_items"), "line_item")) .select(["order_id", "line_item"]) ) ``` @@ -64,13 +65,13 @@ Generator functions must be registry entries with function class `generator` or `explode_outer(array_expr)` must preserve the input row when the input array is null or empty and must produce a null generated value according to its output schema. -`posexplode(array_expr)` and `posexplode_outer(array_expr)` must include a positional output column in addition to the generated element. The position origin must be specified before this RFC reaches Planned status. +`posexplode(array_expr)` and `posexplode_outer(array_expr)` must include a positional output column in addition to the generated element. Positional output is zero-based because `posexplode` follows the Spark-compatible naming convention rather than InQL's one-based scalar collection indexing rule. `inline(array_of_struct_expr)` must expand each struct element into output columns. `inline_outer` must preserve outer rows for null or empty input according to the outer generator rule. `stack` must construct multiple output rows from explicit expressions according to a declared row count and output schema. -`flatten` must be treated as a table-valued/generator operation when supported. Its exact input type, recursive behavior, path behavior, and output columns must be specified before it reaches Planned status. +`flatten` is a table-valued/generator operation in the portable one-array form. Snowflake-style recursive/path flattening is not part of the portable core; scalar `array_flatten(...)` remains part of RFC 020 and does not change row cardinality. Every generator must define output column names, output types, nullability, interaction with existing columns, and aliasing requirements. Name collisions must be diagnosed unless an explicit overwrite or qualification rule applies. @@ -78,15 +79,15 @@ Every generator must define output column names, output types, nullability, inte ### Syntax -Generators may appear as dataframe relation methods, query-block clauses, or table-valued function forms. Regardless of syntax, they must lower to relation-shaping operations. +Generators may appear as dataframe relation methods, query-block clauses, or table-valued function forms. Regardless of syntax, they must lower to relation-shaping operations. The builder API uses `generate(generator)` so generator identity, input expressions, and output schema are explicit. The zero-argument `DataSet.explode()` method remains a lower-level extension-boundary operation rather than the RFC 021 generator surface. ### Semantics -Generator output schema is part of the relation schema after the generator operation. Generators may preserve input columns, replace a nested column with generated columns, or produce a new relation depending on the function and syntax, but the behavior must be explicit. +Generator output schema is part of the relation schema after the generator operation. The initial portable generator applications preserve all input columns and append generated output columns in declaration order. Generated aliases are required, must be non-empty, and must not collide with existing columns. ### Interaction with other InQL surfaces -`query {}` may expose an `EXPLODE` clause or table-valued function syntax. Dataframe APIs may expose `.explode(...)` and related methods. Both must use the same generator semantics. +`query {}` may expose an `EXPLODE` clause or table-valued function syntax when the query surface is available. Dataframe APIs expose the same semantic target through `generate(...)` and registry-backed generator helpers. Both use the same generator semantics. ### Compatibility / migration @@ -112,11 +113,16 @@ Existing unnest/explode behavior should align with this RFC. If current behavior - **Execution / interchange** — Prism and Substrait lowering must represent cardinality changes and output schemas faithfully. - **Documentation** — generator docs should explain cardinality and schema effects before listing helper names. -## Unresolved questions +## Design Decisions -- Should positional generators use zero-based or one-based positions? -- Should `.explode(...)` preserve all input columns by default? -- What aliasing syntax should be required for generated output columns? -- What subset of Snowflake-style `flatten` behavior belongs in portable InQL versus a warehouse compatibility extension? +### Resolved - +- Positional generators use zero-based positions for compatibility with the `posexplode` naming convention. +- Explicit generator applications preserve all input columns by default and append generated output columns. +- Generated aliases are required at builder construction time. +- Snowflake-style recursive/path `flatten` remains outside the portable core until its output schema and compatibility category are specified separately. +- `explode`, `explode_outer`, `posexplode`, `posexplode_outer`, `inline`, `inline_outer`, portable `flatten`, and `stack` are implemented as registry-backed generator applications with Substrait relation-extension metadata and DataFusion-backed Session execution. + +### Remaining + +- No RFC 021 generator semantics remain open. Query-block syntax itself is owned by RFC 003; when that surface lands, its generator clauses must lower to the implemented `GeneratorApplication` model rather than defining a separate generator path. diff --git a/docs/rfcs/README.md b/docs/rfcs/README.md index 32ac4e0..e5d01f8 100644 --- a/docs/rfcs/README.md +++ b/docs/rfcs/README.md @@ -27,7 +27,7 @@ InQL uses its **own** RFC series (starting at 000), independent of the [Incan la | [018][rfc-018] | In Progress | Common scalar function catalog | | | [019][rfc-019] | Draft | Window functions | | | [020][rfc-020] | Implemented | Nested data functions | | -| [021][rfc-021] | Draft | Generator and table-valued functions | | +| [021][rfc-021] | Implemented | Generator and table-valued functions | | | [022][rfc-022] | Draft | Semi-structured and format functions | | | [023][rfc-023] | Draft | Approximate and sketch functions | | | [024][rfc-024] | Implemented | Function extension policy | | diff --git a/src/dataset/mod.incn b/src/dataset/mod.incn index fa850bd..e9b31b1 100644 --- a/src/dataset/mod.incn +++ b/src/dataset/mod.incn @@ -22,6 +22,7 @@ The current method-chain surface in this module is the explicit builder-based AP - `with_column(name: str, expr: ColumnExpr)` - `group_by(columns: list[ColumnExpr])` - `agg(measures: list[AggregateMeasure])` +- `generate(generator: GeneratorApplication)` - plus the structural operators `join`, `select`, `order_by`, `limit`, and `explode` Illustrative current-shape examples: @@ -53,6 +54,7 @@ See also: from rust::substrait::proto import Plan, Rel from aggregate_builders import AggregateMeasure +from generator_builders import GeneratorApplication from projection_builders import ColumnExpr from dataset.materialization import DataFrameMaterialization from substrait.errors import SubstraitLoweringError @@ -63,6 +65,7 @@ from dataset.ops import ( agg_ds_of_columns, explode_ds, filter_ds_of_columns, + generate_ds_of_columns, group_by_ds_of_columns, join_ds, limit_ds, @@ -76,6 +79,7 @@ from prism import ( prism_cursor_apply_agg, prism_cursor_apply_explode, prism_cursor_apply_filter, + prism_cursor_apply_generate, prism_cursor_apply_group_by, prism_cursor_apply_join, prism_cursor_apply_limit, @@ -98,6 +102,7 @@ pub trait DataSet[T with Clone]: def with_column(self, name: str, expr: ColumnExpr) -> Self def group_by(self, columns: list[ColumnExpr]) -> Self def agg(self, measures: list[AggregateMeasure]) -> Self + def generate(self, generator: GeneratorApplication) -> Self def order_by(self, columns: list[ColumnExpr]) -> Self def limit(self, n: int) -> Self def explode(self) -> Self @@ -207,6 +212,12 @@ pub class DataFrame[T with Clone] with BoundedDataSet: agg_ds_of_columns(self._substrait_rel, self.planned_columns(), measures), ) + def generate(self, generator: GeneratorApplication) -> Self: + """Return one new DataFrame with a generator stage and stale materialization cleared.""" + return _data_frame_with_invalidated_materialization( + generate_ds_of_columns(self._substrait_rel, self.planned_columns(), generator), + ) + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new DataFrame with an ordering stage and stale materialization cleared.""" return _data_frame_with_invalidated_materialization( @@ -288,6 +299,10 @@ pub class LazyFrame[T with Clone] with BoundedDataSet: """Return one new lazy carrier with an appended aggregation stage.""" return LazyFrame(_cursor=prism_cursor_apply_agg(self._cursor, measures)) + def generate(self, generator: GeneratorApplication) -> Self: + """Return one new lazy carrier with an appended generator stage.""" + return LazyFrame(_cursor=prism_cursor_apply_generate(self._cursor, generator)) + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new lazy carrier with an appended ordering stage.""" return LazyFrame(_cursor=prism_cursor_apply_order_by(self._cursor, columns)) @@ -430,6 +445,17 @@ pub class DataStream[T with Clone] with UnboundedDataSet: ), ) + def generate(self, generator: GeneratorApplication) -> Self: + """Return one new DataStream with a generator stage.""" + return DataStream( + _row_schema_marker=self._row_schema_marker.clone(), + _substrait_rel=generate_ds_of_columns( + self._substrait_rel, + relation_output_columns(self._substrait_rel.clone()), + generator, + ), + ) + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new DataStream with an ordering stage.""" return DataStream( diff --git a/src/dataset/ops.incn b/src/dataset/ops.incn index 5319f4d..675eee5 100644 --- a/src/dataset/ops.incn +++ b/src/dataset/ops.incn @@ -8,8 +8,9 @@ views stay aligned with the lowered relation tree. from rust::substrait::proto import Rel from aggregate_builders import AggregateMeasure +from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment, with_column_assignment -from substrait.function_extensions import explode_extension_uri +from substrait.function_extensions import EXPLODE_EXTENSION_URI from substrait.inspect import relation_output_columns from substrait.relations import ( aggregate_rel_of_columns, @@ -19,6 +20,7 @@ from substrait.relations import ( join_rel, project_rel_of_columns, sort_rel_of_columns, + generator_rel_of_columns, ) @@ -122,6 +124,16 @@ pub def agg_ds_of_columns(rel: Rel, input_columns: list[str], measures: list[Agg return aggregate_rel_of_columns(rel, input_columns, [], measures) +pub def generate_ds(rel: Rel, generator: GeneratorApplication) -> Rel: + """Apply one relation-shaping generator to a relation.""" + return generate_ds_of_columns(rel, relation_output_columns(rel.clone()), generator) + + +pub def generate_ds_of_columns(rel: Rel, input_columns: list[str], generator: GeneratorApplication) -> Rel: + """Apply one relation-shaping generator using explicit input-column names.""" + return generator_rel_of_columns(rel, input_columns, generator) + + pub def order_by_ds(rel: Rel, columns: list[ColumnExpr]) -> Rel: """ Apply dataset-level ordering intent to one relation. @@ -165,4 +177,4 @@ pub def explode_ds(rel: Rel) -> Rel: Returns: A relation shaped as the registered explode extension over the input relation. """ - return extension_single_rel(rel, explode_extension_uri()) + return extension_single_rel(rel, EXPLODE_EXTENSION_URI) diff --git a/src/function_registry.incn b/src/function_registry.incn index c4b9dd9..f8af3da 100644 --- a/src/function_registry.incn +++ b/src/function_registry.incn @@ -7,7 +7,7 @@ usually come from checked public helper metadata. Decorator metadata may overrid source helper name is constrained by the host language, such as the `modulo(...)` helper for canonical `mod`. """ -from substrait.function_extensions import function_extension_uri +from substrait.function_extensions import FUNCTION_EXTENSION_URI pub const CORE_FUNCTION_NAMESPACE: str = "inql.functions" @@ -76,6 +76,7 @@ pub enum SubstraitMappingKind(str): CoreFunction = "core_function" ExtensionFunction = "extension_function" + RelationExtension = "relation_extension" Rewrite = "rewrite" StructuralFunction = "structural_function" @@ -290,7 +291,7 @@ pub def extension_mapping(function_name: str, anchor: u32) -> SubstraitMapping: """Build one registered Substrait extension-function mapping.""" return SubstraitMapping( kind=SubstraitMappingKind.ExtensionFunction, - uri=function_extension_uri(), + uri=FUNCTION_EXTENSION_URI, function_name=function_name, anchor=anchor, rewrite="", @@ -298,6 +299,18 @@ pub def extension_mapping(function_name: str, anchor: u32) -> SubstraitMapping: ) +pub def relation_extension_mapping(function_name: str, uri: str) -> SubstraitMapping: + """Build one registered Substrait relation-extension mapping.""" + return SubstraitMapping( + kind=SubstraitMappingKind.RelationExtension, + uri=uri, + function_name=function_name, + anchor=0, + rewrite="", + detail="extension_single", + ) + + pub def core_mapping(function_name: str) -> SubstraitMapping: """Build one mapping for a built-in Substrait Rex shape rather than an extension function declaration.""" return SubstraitMapping( diff --git a/src/functions/generators/explode.incn b/src/functions/generators/explode.incn new file mode 100644 index 0000000..2835c04 --- /dev/null +++ b/src/functions/generators/explode.incn @@ -0,0 +1,42 @@ +"""Inner explode generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import register_function +from generator_builders import GeneratorApplication, explode as explode_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import EXPLODE_EXTENSION_URI + + +@register_function(deterministic_spec( + function_class=FunctionClass.Generator, + lifecycle=FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + null_behavior=FunctionNullBehavior.DependsOnInputs, + substrait=relation_extension_mapping("explode", EXPLODE_EXTENSION_URI), +)) +pub def explode(expr: ColumnExpr, as_: str) -> GeneratorApplication: + """ + Build an inner row-expanding generator for array values. + + Examples: + generated = explode(col("line_items"), "line_item") + + Parameters: + expr: Array expression to expand into generated rows. + as_: Output alias for the generated value column. + """ + return explode_builder(expr, as_) + + +module tests: + from projection_builders import col + def test_explode_builds_generator_application() -> None: + generator = explode(col("line_items"), "line_item") + assert generator.canonical_name == "explode" + assert generator.output_columns[0] == "line_item" diff --git a/src/functions/generators/explode_outer.incn b/src/functions/generators/explode_outer.incn new file mode 100644 index 0000000..66348f5 --- /dev/null +++ b/src/functions/generators/explode_outer.incn @@ -0,0 +1,42 @@ +"""Outer explode generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import register_function +from generator_builders import GeneratorApplication, explode_outer as explode_outer_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import EXPLODE_OUTER_EXTENSION_URI + + +@register_function(deterministic_spec( + function_class=FunctionClass.Generator, + lifecycle=FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + null_behavior=FunctionNullBehavior.DependsOnInputs, + substrait=relation_extension_mapping("explode_outer", EXPLODE_OUTER_EXTENSION_URI), +)) +pub def explode_outer(expr: ColumnExpr, as_: str) -> GeneratorApplication: + """ + Build an outer row-expanding generator for array values. + + Examples: + generated = explode_outer(col("line_items"), "line_item") + + Parameters: + expr: Array expression to expand into generated rows. + as_: Output alias for the generated nullable value column. + """ + return explode_outer_builder(expr, as_) + + +module tests: + from projection_builders import col + def test_explode_outer_builds_outer_generator_application() -> None: + generator = explode_outer(col("line_items"), "line_item") + assert generator.canonical_name == "explode_outer" + assert generator.is_outer diff --git a/src/functions/generators/flatten.incn b/src/functions/generators/flatten.incn new file mode 100644 index 0000000..cbb10a1 --- /dev/null +++ b/src/functions/generators/flatten.incn @@ -0,0 +1,42 @@ +"""Portable table-valued flatten generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import register_function +from generator_builders import GeneratorApplication, flatten as flatten_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import FLATTEN_EXTENSION_URI + + +@register_function(deterministic_spec( + function_class=FunctionClass.Generator, + lifecycle=FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + null_behavior=FunctionNullBehavior.DependsOnInputs, + substrait=relation_extension_mapping("flatten", FLATTEN_EXTENSION_URI), +)) +pub def flatten(expr: ColumnExpr, as_: str) -> GeneratorApplication: + """ + Build a portable row-expanding flatten generator for array values. + + Examples: + generated = flatten(col("nested_items"), "item") + + Parameters: + expr: Array expression to expand into generated rows. + as_: Output alias for the generated value column. + """ + return flatten_builder(expr, as_) + + +module tests: + from projection_builders import col + def test_flatten_builds_table_valued_generator_application() -> None: + generator = flatten(col("nested_items"), "item") + assert generator.canonical_name == "flatten" + assert generator.output_columns[0] == "item" diff --git a/src/functions/generators/inline.incn b/src/functions/generators/inline.incn new file mode 100644 index 0000000..0d2cc78 --- /dev/null +++ b/src/functions/generators/inline.incn @@ -0,0 +1,42 @@ +"""Inner struct-expanding generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import register_function +from generator_builders import GeneratorApplication, inline as inline_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import INLINE_EXTENSION_URI + + +@register_function(deterministic_spec( + function_class=FunctionClass.Generator, + lifecycle=FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + null_behavior=FunctionNullBehavior.DependsOnInputs, + substrait=relation_extension_mapping("inline", INLINE_EXTENSION_URI), +)) +pub def inline(expr: ColumnExpr, output_columns: list[str]) -> GeneratorApplication: + """ + Build an inner generator that expands array-of-struct values into generated columns. + + Examples: + generated = inline(col("line_items"), ["sku", "quantity"]) + + Parameters: + expr: Array-of-struct expression to expand into generated rows. + output_columns: Output aliases for the struct fields in order. + """ + return inline_builder(expr, output_columns) + + +module tests: + from projection_builders import col + def test_inline_builds_struct_expanding_generator_application() -> None: + generator = inline(col("line_items"), ["sku", "quantity"]) + assert generator.canonical_name == "inline" + assert generator.output_columns == ["sku", "quantity"] diff --git a/src/functions/generators/inline_outer.incn b/src/functions/generators/inline_outer.incn new file mode 100644 index 0000000..809707f --- /dev/null +++ b/src/functions/generators/inline_outer.incn @@ -0,0 +1,42 @@ +"""Outer struct-expanding generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import register_function +from generator_builders import GeneratorApplication, inline_outer as inline_outer_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import INLINE_OUTER_EXTENSION_URI + + +@register_function(deterministic_spec( + function_class=FunctionClass.Generator, + lifecycle=FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + null_behavior=FunctionNullBehavior.DependsOnInputs, + substrait=relation_extension_mapping("inline_outer", INLINE_OUTER_EXTENSION_URI), +)) +pub def inline_outer(expr: ColumnExpr, output_columns: list[str]) -> GeneratorApplication: + """ + Build an outer generator that expands array-of-struct values into nullable generated columns. + + Examples: + generated = inline_outer(col("line_items"), ["sku", "quantity"]) + + Parameters: + expr: Array-of-struct expression to expand into generated rows. + output_columns: Output aliases for the struct fields in order. + """ + return inline_outer_builder(expr, output_columns) + + +module tests: + from projection_builders import col + def test_inline_outer_builds_outer_struct_expanding_generator_application() -> None: + generator = inline_outer(col("line_items"), ["sku", "quantity"]) + assert generator.canonical_name == "inline_outer" + assert generator.is_outer diff --git a/src/functions/generators/mod.incn b/src/functions/generators/mod.incn new file mode 100644 index 0000000..36753c3 --- /dev/null +++ b/src/functions/generators/mod.incn @@ -0,0 +1,10 @@ +"""Relation-shaping generator helpers.""" + +pub from functions.generators.explode import explode +pub from functions.generators.explode_outer import explode_outer +pub from functions.generators.flatten import flatten +pub from functions.generators.inline import inline +pub from functions.generators.inline_outer import inline_outer +pub from functions.generators.posexplode import posexplode +pub from functions.generators.posexplode_outer import posexplode_outer +pub from functions.generators.stack import stack diff --git a/src/functions/generators/posexplode.incn b/src/functions/generators/posexplode.incn new file mode 100644 index 0000000..682103e --- /dev/null +++ b/src/functions/generators/posexplode.incn @@ -0,0 +1,44 @@ +"""Inner positional explode generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import register_function +from generator_builders import GeneratorApplication, posexplode as posexplode_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import POSEXPLODE_EXTENSION_URI + + +@register_function(deterministic_spec( + function_class=FunctionClass.Generator, + lifecycle=FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + null_behavior=FunctionNullBehavior.DependsOnInputs, + substrait=relation_extension_mapping("posexplode", POSEXPLODE_EXTENSION_URI), +)) +pub def posexplode(expr: ColumnExpr, position_as: str, value_as: str) -> GeneratorApplication: + """ + Build an inner row-expanding generator with a zero-based position column. + + Examples: + generated = posexplode(col("line_items"), "position", "line_item") + + Parameters: + expr: Array expression to expand into generated rows. + position_as: Output alias for the zero-based position column. + value_as: Output alias for the generated value column. + """ + return posexplode_builder(expr, position_as, value_as) + + +module tests: + from projection_builders import col + def test_posexplode_builds_positional_generator_application() -> None: + generator = posexplode(col("line_items"), "position", "line_item") + assert generator.canonical_name == "posexplode" + assert generator.position_origin == 0 + assert generator.output_columns[0] == "position" diff --git a/src/functions/generators/posexplode_outer.incn b/src/functions/generators/posexplode_outer.incn new file mode 100644 index 0000000..1cdda29 --- /dev/null +++ b/src/functions/generators/posexplode_outer.incn @@ -0,0 +1,44 @@ +"""Outer positional explode generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import register_function +from generator_builders import GeneratorApplication, posexplode_outer as posexplode_outer_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import POSEXPLODE_OUTER_EXTENSION_URI + + +@register_function(deterministic_spec( + function_class=FunctionClass.Generator, + lifecycle=FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + null_behavior=FunctionNullBehavior.DependsOnInputs, + substrait=relation_extension_mapping("posexplode_outer", POSEXPLODE_OUTER_EXTENSION_URI), +)) +pub def posexplode_outer(expr: ColumnExpr, position_as: str, value_as: str) -> GeneratorApplication: + """ + Build an outer row-expanding generator with a zero-based position column. + + Examples: + generated = posexplode_outer(col("line_items"), "position", "line_item") + + Parameters: + expr: Array expression to expand into generated rows. + position_as: Output alias for the zero-based position column. + value_as: Output alias for the generated nullable value column. + """ + return posexplode_outer_builder(expr, position_as, value_as) + + +module tests: + from projection_builders import col + def test_posexplode_outer_builds_outer_positional_generator_application() -> None: + generator = posexplode_outer(col("line_items"), "position", "line_item") + assert generator.canonical_name == "posexplode_outer" + assert generator.is_outer + assert generator.output_columns[1] == "line_item" diff --git a/src/functions/generators/stack.incn b/src/functions/generators/stack.incn new file mode 100644 index 0000000..8803167 --- /dev/null +++ b/src/functions/generators/stack.incn @@ -0,0 +1,44 @@ +"""Row-major stack generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import register_function +from generator_builders import GeneratorApplication, stack as stack_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import STACK_EXTENSION_URI + + +@register_function(deterministic_spec( + function_class=FunctionClass.Generator, + lifecycle=FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + null_behavior=FunctionNullBehavior.DependsOnInputs, + substrait=relation_extension_mapping("stack", STACK_EXTENSION_URI), +)) +pub def stack(row_count: int, values: list[ColumnExpr], output_columns: list[str]) -> GeneratorApplication: + """ + Build a generator that creates multiple output rows from row-major scalar values. + + Examples: + generated = stack(2, [col("left_a"), col("right_a"), col("left_b"), col("right_b")], ["left", "right"]) + + Parameters: + row_count: Number of generated rows per input row. + values: Row-major scalar values, with `row_count * len(output_columns)` entries. + output_columns: Output aliases for each generated column. + """ + return stack_builder(row_count, values, output_columns) + + +module tests: + from projection_builders import col + def test_stack_builds_row_major_generator_application() -> None: + generator = stack(2, [col("a"), col("b"), col("c"), col("d")], ["left", "right"]) + assert generator.canonical_name == "stack" + assert generator.row_count == 2 + assert generator.output_columns == ["left", "right"] diff --git a/src/functions/mod.incn b/src/functions/mod.incn index f2471bd..a6f89e2 100644 --- a/src/functions/mod.incn +++ b/src/functions/mod.incn @@ -45,6 +45,7 @@ pub from functions.nested.array_flatten import array_flatten pub from functions.nested.array_intersect import array_intersect pub from functions.nested.array_join import array_join pub from functions.nested.array_position import array_position +pub from functions.nested.array_range import array_range pub from functions.nested.array_reverse import array_reverse pub from functions.nested.array_slice import array_slice pub from functions.nested.array_sort import array_sort @@ -59,6 +60,14 @@ pub from functions.nested.map_from_arrays import map_from_arrays pub from functions.nested.map_keys import map_keys pub from functions.nested.map_values import map_values pub from functions.nested.named_struct import named_struct +pub from functions.generators.explode import explode +pub from functions.generators.explode_outer import explode_outer +pub from functions.generators.flatten import flatten +pub from functions.generators.inline import inline +pub from functions.generators.inline_outer import inline_outer +pub from functions.generators.posexplode import posexplode +pub from functions.generators.posexplode_outer import posexplode_outer +pub from functions.generators.stack import stack pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div diff --git a/src/functions/nested/array_range.incn b/src/functions/nested/array_range.incn new file mode 100644 index 0000000..c06741a --- /dev/null +++ b/src/functions/nested/array_range.incn @@ -0,0 +1,34 @@ +"""Array range construction helper used by positional generator lowering.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import register_function, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import RANGE_FUNCTION_ANCHOR + + +@register_function(deterministic_spec( + canonical_name="range", + function_class=FunctionClass.Scalar, + lifecycle=FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + null_behavior=FunctionNullBehavior.DependsOnInputs, + substrait=extension_mapping("range", RANGE_FUNCTION_ANCHOR), +)) +pub def array_range(start: ColumnExpr, stop: ColumnExpr) -> ColumnExpr: + """ + Build an array of integer positions from `start` inclusive to `stop` exclusive. + + Examples: + positions = array_range(int_lit(0), cardinality(col("items"))) + + Parameters: + start: Inclusive starting position expression. + stop: Exclusive stopping position expression. + """ + return registered_application("range", [start, stop]) diff --git a/src/functions/nested/mod.incn b/src/functions/nested/mod.incn index bdbdff1..b532867 100644 --- a/src/functions/nested/mod.incn +++ b/src/functions/nested/mod.incn @@ -8,6 +8,7 @@ pub from functions.nested.array_flatten import array_flatten pub from functions.nested.array_intersect import array_intersect pub from functions.nested.array_join import array_join pub from functions.nested.array_position import array_position +pub from functions.nested.array_range import array_range pub from functions.nested.array_reverse import array_reverse pub from functions.nested.array_slice import array_slice pub from functions.nested.array_sort import array_sort diff --git a/src/generator_builders.incn b/src/generator_builders.incn new file mode 100644 index 0000000..18a179a --- /dev/null +++ b/src/generator_builders.incn @@ -0,0 +1,259 @@ +""" +Relation-shaping generator builder surface. + +Generators are not scalar expressions: they may change row cardinality and append output columns. This module carries +the authoring intent through Dataset, Prism, and Substrait boundaries without making generators valid in ordinary +row-level expression positions. +""" + +from rust::incan_stdlib::errors import raise_value_error +from function_registry import function_ref_for +from functions.nested.array import array +from functions.nested.array_range import array_range +from functions.nested.cardinality import cardinality +from functions.nested.named_struct import named_struct +from projection_builders import ColumnExpr, int_expr + + +@derive(Clone) +pub enum GeneratorKind(str): + """Supported relation-shaping generator kinds.""" + + Explode = "explode" + ExplodeOuter = "explode_outer" + PosExplode = "posexplode" + PosExplodeOuter = "posexplode_outer" + Inline = "inline" + InlineOuter = "inline_outer" + Flatten = "flatten" + Stack = "stack" + + +@derive(Clone) +pub model GeneratorApplication: + """One registry-backed relation-shaping generator application.""" + + pub kind: GeneratorKind + pub function_ref: str + pub canonical_name: str + pub expr: ColumnExpr + pub arguments: list[ColumnExpr] + pub output_columns: list[str] + pub preserves_input_columns: bool + pub is_outer: bool + pub position_origin: int + pub row_count: int + + +pub def explode(expr: ColumnExpr, as_: str) -> GeneratorApplication: + """Build an inner `explode` generator that appends one value column.""" + return _generator_application("explode", GeneratorKind.Explode, expr, [expr], [as_], true, false, 0, 0) + + +pub def explode_outer(expr: ColumnExpr, as_: str) -> GeneratorApplication: + """Build an outer `explode` generator that appends one nullable value column.""" + return _generator_application("explode_outer", GeneratorKind.ExplodeOuter, expr, [expr], [as_], true, true, 0, 0) + + +pub def posexplode(expr: ColumnExpr, position_as: str, value_as: str) -> GeneratorApplication: + """Build an inner positional explode generator with zero-based positions.""" + return _generator_application( + "posexplode", + GeneratorKind.PosExplode, + expr, + [_position_range_expr(expr), expr], + [position_as, value_as], + true, + false, + 0, + 0, + ) + + +pub def posexplode_outer(expr: ColumnExpr, position_as: str, value_as: str) -> GeneratorApplication: + """Build an outer positional explode generator with zero-based positions.""" + return _generator_application( + "posexplode_outer", + GeneratorKind.PosExplodeOuter, + expr, + [_position_range_expr(expr), expr], + [position_as, value_as], + true, + true, + 0, + 0, + ) + + +pub def inline(expr: ColumnExpr, output_columns: list[str]) -> GeneratorApplication: + """Build an inner generator that expands an array-of-struct expression into columns.""" + return _generator_application("inline", GeneratorKind.Inline, expr, [expr], output_columns, true, false, 0, 0) + + +pub def inline_outer(expr: ColumnExpr, output_columns: list[str]) -> GeneratorApplication: + """Build an outer generator that expands an array-of-struct expression into nullable columns.""" + return _generator_application( + "inline_outer", + GeneratorKind.InlineOuter, + expr, + [expr], + output_columns, + true, + true, + 0, + 0, + ) + + +pub def flatten(expr: ColumnExpr, as_: str) -> GeneratorApplication: + """Build a portable row-expanding flatten generator for one array expression.""" + return _generator_application("flatten", GeneratorKind.Flatten, expr, [expr], [as_], true, false, 0, 0) + + +pub def stack(row_count: int, values: list[ColumnExpr], output_columns: list[str]) -> GeneratorApplication: + """Build a generator that turns row-major scalar values into `row_count` generated rows.""" + _validate_stack_shape(row_count, values, output_columns) + # Encode stack as an array of named structs so it can share the same unnest-and-expand execution path as inline. + stacked_rows = _stack_row_exprs(row_count, values, output_columns) + stacked_array = array(stacked_rows) + return _generator_application( + "stack", + GeneratorKind.Stack, + stacked_array, + [stacked_array], + output_columns, + true, + false, + 0, + row_count, + ) + + +pub def generator_output_columns(input_columns: list[str], generator: GeneratorApplication) -> list[str]: + """Return output columns after applying one generator to the provided input columns.""" + mut output_columns: list[str] = [] + if generator.preserves_input_columns: + output_columns.extend(input_columns) + for output_column in generator.output_columns: + if _contains_text(output_columns, output_column): + message = f"generator output column `{output_column}` conflicts with an existing column" + return raise_value_error(message) + output_columns.append(output_column) + return output_columns + + +pub def generator_primary_output_column(generator: GeneratorApplication) -> str: + """Return the primary generated value column for inspection and tests.""" + if len(generator.output_columns) == 0: + return "" + return generator.output_columns[len(generator.output_columns) - 1] + + +def _generator_application( + canonical_name: str, + kind: GeneratorKind, + expr: ColumnExpr, + arguments: list[ColumnExpr], + output_columns: list[str], + preserves_input_columns: bool, + is_outer: bool, + position_origin: int, + row_count: int, +) -> GeneratorApplication: + """Build one generator application after validating declared output aliases.""" + _validate_output_columns(canonical_name, output_columns) + return GeneratorApplication( + kind=kind, + function_ref=function_ref_for(canonical_name), + canonical_name=canonical_name, + expr=expr, + arguments=arguments, + output_columns=output_columns, + preserves_input_columns=preserves_input_columns, + is_outer=is_outer, + position_origin=position_origin, + row_count=row_count, + ) + + +def _validate_output_columns(canonical_name: str, output_columns: list[str]) -> None: + """Validate mandatory generator output aliases.""" + if len(output_columns) == 0: + message = f"{canonical_name} requires at least one output alias" + return raise_value_error(message) + mut seen: list[str] = [] + for output_column in output_columns: + if len(output_column) == 0: + message = f"{canonical_name} output aliases must be non-empty" + return raise_value_error(message) + if _contains_text(seen, output_column): + message = f"{canonical_name} output alias `{output_column}` is duplicated" + return raise_value_error(message) + seen.append(output_column) + return + + +def _validate_stack_shape(row_count: int, values: list[ColumnExpr], output_columns: list[str]) -> None: + """Validate the row-major `stack(...)` value matrix.""" + if row_count <= 0: + return raise_value_error("stack row_count must be greater than zero") + if len(output_columns) == 0: + return raise_value_error("stack requires at least one output alias") + expected_values = row_count * len(output_columns) + if len(values) != expected_values: + message = f"stack requires exactly {expected_values} values for {row_count} rows and {len(output_columns)} columns" + return raise_value_error(message) + return + + +def _stack_row_exprs(row_count: int, values: list[ColumnExpr], output_columns: list[str]) -> list[ColumnExpr]: + """Return one named struct expression per declared `stack(...)` output row.""" + column_count = len(output_columns) + mut rows: list[ColumnExpr] = [] + for row_idx in range(row_count): + # Values are row-major: for two output columns, [a, b, c, d] becomes rows (a, b) and (c, d). + start_idx = row_idx * column_count + row_values = [values[start_idx + col_idx] for col_idx in range(column_count)] + rows.append(named_struct(output_columns, row_values)) + return rows + + +def _position_range_expr(expr: ColumnExpr) -> ColumnExpr: + """Build the zero-based position array paired with positional generators.""" + # Positional generators lower as two equally sized arrays: generated positions and original values. + return array_range(int_expr(0), cardinality(expr)) + + +def _contains_text(values: list[str], expected: str) -> bool: + """Return whether a string list contains a value.""" + for value in values: + if value == expected: + return true + return false + + +module tests: + from projection_builders import col, column_expr_name + def test_explode_application_records_function_identity_and_output_column() -> None: + generator = explode(col("line_items"), "line_item") + assert generator.kind == GeneratorKind.Explode + assert generator.canonical_name == "explode" + assert generator.function_ref == "inql.functions.explode" + assert column_expr_name(generator.expr) == "line_items" + assert len(generator.arguments) == 1 + assert generator.output_columns[0] == "line_item" + assert generator.preserves_input_columns + assert not generator.is_outer + def test_posexplode_uses_zero_based_position_origin() -> None: + generator = posexplode(col("line_items"), "pos", "line_item") + assert generator.kind == GeneratorKind.PosExplode + assert generator.position_origin == 0 + assert len(generator.arguments) == 2 + assert generator.output_columns[0] == "pos" + assert generator.output_columns[1] == "line_item" + def test_stack_records_row_count_and_single_stacked_argument() -> None: + generator = stack(2, [col("a"), col("b"), col("c"), col("d")], ["left", "right"]) + assert generator.kind == GeneratorKind.Stack + assert generator.row_count == 2 + assert len(generator.arguments) == 1 + assert generator.output_columns == ["left", "right"] diff --git a/src/lib.incn b/src/lib.incn index 51136d6..b508418 100644 --- a/src/lib.incn +++ b/src/lib.incn @@ -6,8 +6,24 @@ Consumers depend on this package via `[dependencies]` and import with `from pub: """ pub from dataset import BoundedDataSet, DataFrame, DataSet, DataStream, LazyFrame, UnboundedDataSet -pub from dataset.ops import agg_ds, explode_ds, filter_ds, group_by_ds, join_ds, limit_ds, order_by_ds, select_ds +pub from dataset.ops import ( + agg_ds, + explode_ds, + filter_ds, + generate_ds, + group_by_ds, + join_ds, + limit_ds, + order_by_ds, + select_ds, +) pub from aggregate_builders import AggregateKind, AggregateMeasure +pub from generator_builders import ( + GeneratorApplication, + GeneratorKind, + generator_output_columns, + generator_primary_output_column, +) pub from projection_builders import ( BoolLiteralExpr, ColumnExpr, @@ -66,6 +82,7 @@ pub from functions.nested.array_flatten import array_flatten pub from functions.nested.array_intersect import array_intersect pub from functions.nested.array_join import array_join pub from functions.nested.array_position import array_position +pub from functions.nested.array_range import array_range pub from functions.nested.array_reverse import array_reverse pub from functions.nested.array_slice import array_slice pub from functions.nested.array_sort import array_sort @@ -80,6 +97,14 @@ pub from functions.nested.map_from_arrays import map_from_arrays pub from functions.nested.map_keys import map_keys pub from functions.nested.map_values import map_values pub from functions.nested.named_struct import named_struct +pub from functions.generators.explode import explode +pub from functions.generators.explode_outer import explode_outer +pub from functions.generators.flatten import flatten +pub from functions.generators.inline import inline +pub from functions.generators.inline_outer import inline_outer +pub from functions.generators.posexplode import posexplode +pub from functions.generators.posexplode_outer import posexplode_outer +pub from functions.generators.stack import stack pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div @@ -143,6 +168,7 @@ pub from function_registry import ( function_ref_for, namespaced_function_ref, rejected_function_policy, + relation_extension_mapping, rewrite_mapping, sort_field_mapping, structural_mapping, @@ -185,6 +211,8 @@ pub from substrait.relations import ( extension_single_rel, fetch_rel, filter_rel, + generator_rel, + generator_rel_of_columns, join_rel, join_rel_of_kind, project_rel, @@ -212,6 +240,8 @@ pub from substrait.inspect import ( aggregate_measure_function_names, aggregate_measure_invocation_names, aggregate_measure_sort_counts, + plan_extension_urn_anchor_at, + plan_extension_urn_count, plan_contains_relation_kind, plan_has_extension_urn, read_kind_name, @@ -222,9 +252,16 @@ pub from substrait.inspect import ( set_operation_name, ) pub from substrait.function_extensions import ( - explode_extension_uri, - function_extension_uri, + EXPLODE_EXTENSION_URI, + EXPLODE_OUTER_EXTENSION_URI, + FLATTEN_EXTENSION_URI, + FUNCTION_EXTENSION_URI, + INLINE_EXTENSION_URI, + INLINE_OUTER_EXTENSION_URI, + POSEXPLODE_EXTENSION_URI, + POSEXPLODE_OUTER_EXTENSION_URI, registered_substrait_extension_uris, + STACK_EXTENSION_URI, ) pub from substrait.conformance_catalog import ( ConformanceCapabilityTags, diff --git a/src/prism/lower.incn b/src/prism/lower.incn index 9ae303c..490dd7e 100644 --- a/src/prism/lower.incn +++ b/src/prism/lower.incn @@ -3,13 +3,14 @@ from rust::substrait::proto import Plan, Rel from rust::incan_stdlib::errors import raise_value_error from prism.output_columns import rewritten_output_columns -from substrait.function_extensions import explode_extension_uri +from substrait.function_extensions import EXPLODE_EXTENSION_URI from substrait.plans import plan_from_root_relation from substrait.relations import ( extension_single_rel, fetch_rel, join_rel, read_named_table_rel, + try_generator_rel_of_columns, sort_rel_of_columns, try_aggregate_rel_of_columns, try_filter_rel_of_columns, @@ -118,6 +119,12 @@ def _try_lower_node(view: PrismOptimizedView, node_id: int) -> Result[Rel, Subst [], node.aggregate_measures, ) + PrismNodeKind.Generate => + return try_generator_rel_of_columns( + _try_lower_node(view, node.input_ids[0])?, + rewritten_output_columns(view, node.input_ids[0]), + node.generator_applications[0], + ) PrismNodeKind.OrderBy => return Ok( sort_rel_of_columns( @@ -128,4 +135,4 @@ def _try_lower_node(view: PrismOptimizedView, node_id: int) -> Result[Rel, Subst ) PrismNodeKind.Limit => return Ok(fetch_rel(_try_lower_node(view, node.input_ids[0])?, 0, node.limit_count)) PrismNodeKind.Explode => - return Ok(extension_single_rel(_try_lower_node(view, node.input_ids[0])?, explode_extension_uri())) + return Ok(extension_single_rel(_try_lower_node(view, node.input_ids[0])?, EXPLODE_EXTENSION_URI)) diff --git a/src/prism/mod.incn b/src/prism/mod.incn index 229cbaa..3564d35 100644 --- a/src/prism/mod.incn +++ b/src/prism/mod.incn @@ -13,6 +13,7 @@ This façade keeps one stable internal import surface while the implementation i from rust::substrait::proto import Plan, Rel from aggregate_builders import AggregateMeasure from filter_builders import always_true +from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment, with_column_assignment from prism.lower import ( lower_prism_tip as lower_prism_tip_impl, @@ -69,6 +70,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -87,6 +89,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -102,6 +105,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -119,6 +123,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -136,6 +141,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[with_column_assignment(name, expr)], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -153,6 +159,7 @@ pub class PrismCursor[T with Clone]: group_columns=columns, sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -170,6 +177,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=measures, + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -187,6 +195,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=columns, aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -204,6 +213,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -221,6 +231,25 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], + projection_assignments=[], + ) + return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) + + def generate(self, generator: GeneratorApplication) -> Self: + """Append one explicit generator node and return the derived tip.""" + next_tip_id = append_node( + store_id=self.store_id, + kind=PrismNodeKind.Generate, + input_ids=[self.tip_id], + named_table="", + join_predicate=false, + filter_predicate=always_true(), + limit_count=0, + group_columns=[], + sort_columns=[], + aggregate_measures=[], + generator_applications=[generator], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -264,6 +293,7 @@ pub def prism_cursor_named_table[T with Clone](table_name: str) -> PrismCursor[T group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=store_id, tip_id=tip_id, _type_marker=[]) @@ -325,6 +355,14 @@ pub def prism_cursor_apply_explode[T with Clone](cursor: PrismCursor[T]) -> Pris return cursor.explode() +pub def prism_cursor_apply_generate[T with Clone]( + cursor: PrismCursor[T], + generator: GeneratorApplication, +) -> PrismCursor[T]: + """Apply one explicit generator through Prism.""" + return cursor.generate(generator) + + pub def prism_cursor_output_columns[T with Clone](cursor: PrismCursor[T]) -> list[str]: """Return plan-time output columns for one cursor tip.""" return cursor.planned_columns() diff --git a/src/prism/output_columns.incn b/src/prism/output_columns.incn index f1de58c..d1cfa06 100644 --- a/src/prism/output_columns.incn +++ b/src/prism/output_columns.incn @@ -3,6 +3,7 @@ from prism.store import node_at from prism.rewrite import rewritten_node_at from prism.types import PrismNodeKind, PrismOptimizedView, PrismStoreId +from generator_builders import generator_output_columns from projection_builders import ColumnExpr, project_output_columns, scalar_expr_output_name from substrait.inspect import aggregate_measure_output_names from substrait.schema_registry import named_table_columns @@ -27,6 +28,11 @@ pub def authored_output_columns(store_id: PrismStoreId, tip_id: int) -> list[str return authored_output_columns(store_id, node.input_ids[0]) if node.kind == PrismNodeKind.Project: return project_output_columns(authored_output_columns(store_id, node.input_ids[0]), node.projection_assignments) + if node.kind == PrismNodeKind.Generate: + return generator_output_columns( + authored_output_columns(store_id, node.input_ids[0]), + node.generator_applications[0], + ) if node.kind == PrismNodeKind.Join: # Join output columns preserve the conventional left-then-right relation order. # We keep both sides verbatim here; duplicate names are part of the current output shape and are resolved later @@ -59,6 +65,11 @@ pub def rewritten_output_columns(view: PrismOptimizedView, node_id: int) -> list return rewritten_output_columns(view, node.input_ids[0]) if node.kind == PrismNodeKind.Project: return project_output_columns(rewritten_output_columns(view, node.input_ids[0]), node.projection_assignments) + if node.kind == PrismNodeKind.Generate: + return generator_output_columns( + rewritten_output_columns(view, node.input_ids[0]), + node.generator_applications[0], + ) if node.kind == PrismNodeKind.Join: # Rewritten views keep the same left-then-right join column order as authored views # so output-column inference stays stable across Prism rewrite passes. diff --git a/src/prism/rewrite.incn b/src/prism/rewrite.incn index 6247b0b..419f968 100644 --- a/src/prism/rewrite.incn +++ b/src/prism/rewrite.incn @@ -168,6 +168,7 @@ def _build_collapsed_limit_node( group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) @@ -204,6 +205,7 @@ def _build_collapsed_project_node( group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=merged_assignments, ) @@ -240,6 +242,7 @@ def _build_collapsed_aggregate_node( group_columns=[], sort_columns=[], aggregate_measures=merged_measures, + generator_applications=[], projection_assignments=[], ) @@ -274,6 +277,7 @@ def _build_collapsed_order_by_node( group_columns=[], sort_columns=node.sort_columns, aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) @@ -291,6 +295,7 @@ def _build_rewritten_node(node: PrismNode, remapped_inputs: list[int], rewritten group_columns=node.group_columns, sort_columns=node.sort_columns, aggregate_measures=node.aggregate_measures, + generator_applications=node.generator_applications, projection_assignments=node.projection_assignments, ) @@ -336,6 +341,7 @@ def _compact_optimized_view(view: PrismOptimizedView) -> PrismOptimizedView: group_columns=old_node.group_columns, sort_columns=old_node.sort_columns, aggregate_measures=old_node.aggregate_measures, + generator_applications=old_node.generator_applications, projection_assignments=old_node.projection_assignments, ), ) diff --git a/src/prism/store.incn b/src/prism/store.incn index d620574..e451ade 100644 --- a/src/prism/store.incn +++ b/src/prism/store.incn @@ -1,6 +1,7 @@ """Append-only Prism store allocation, storage, reachability, and cross-store adoption.""" from aggregate_builders import AggregateMeasure +from generator_builders import GeneratorApplication from projection_builders import ( BoolLiteralExpr, ColumnExpr, @@ -54,6 +55,7 @@ pub def append_node( group_columns: list[ColumnExpr], sort_columns: list[ColumnExpr], aggregate_measures: list[AggregateMeasure], + generator_applications: list[GeneratorApplication], projection_assignments: list[ProjectionAssignment], ) -> int: """ @@ -73,6 +75,7 @@ pub def append_node( group_columns=group_columns, sort_columns=sort_columns, aggregate_measures=aggregate_measures, + generator_applications=generator_applications, projection_assignments=projection_assignments, ) prism_stored_nodes.append(PrismStoredNode(store_id_raw=store_id.0, node=appended)) @@ -119,10 +122,12 @@ pub def adopt_cursor_subgraph( adopted_group_columns = [column for column in source_node.group_columns] adopted_sort_columns = [column for column in source_node.sort_columns] adopted_measures = [measure for measure in source_node.aggregate_measures] + adopted_generators = [generator for generator in source_node.generator_applications] adopted_assignments = [assignment for assignment in source_node.projection_assignments] target_group_columns = [column for column in source_node.group_columns] target_sort_columns = [column for column in source_node.sort_columns] target_measures = [measure for measure in source_node.aggregate_measures] + target_generators = [generator for generator in source_node.generator_applications] target_assignments = [assignment for assignment in source_node.projection_assignments] adopted_id = append_node( store_id=target_store_id, @@ -135,6 +140,7 @@ pub def adopt_cursor_subgraph( group_columns=adopted_group_columns, sort_columns=adopted_sort_columns, aggregate_measures=adopted_measures, + generator_applications=adopted_generators, projection_assignments=adopted_assignments, ) target_store_nodes.append( @@ -149,6 +155,7 @@ pub def adopt_cursor_subgraph( group_columns=target_group_columns, sort_columns=target_sort_columns, aggregate_measures=target_measures, + generator_applications=target_generators, projection_assignments=target_assignments, ), ) @@ -232,6 +239,11 @@ def _nodes_structurally_equal(candidate: PrismNode, source_node: PrismNode, rema return false if not _aggregate_measure_lists_structurally_equal(candidate.aggregate_measures, source_node.aggregate_measures): return false + if not _generator_application_lists_structurally_equal( + candidate.generator_applications, + source_node.generator_applications, + ): + return false if not _projection_assignments_structurally_equal( candidate.projection_assignments, source_node.projection_assignments, @@ -271,6 +283,48 @@ def _aggregate_measures_structurally_equal(left: AggregateMeasure, right: Aggreg return _column_expr_lists_structurally_equal(left.ordering, right.ordering) +def _generator_application_lists_structurally_equal( + left: list[GeneratorApplication], + right: list[GeneratorApplication], +) -> bool: + """Return whether two generator-application lists carry identical relation-shaping semantics.""" + if len(left) != len(right): + return false + for idx in range(len(left)): + if not _generator_applications_structurally_equal(left[idx], right[idx]): + return false + return true + + +def _generator_applications_structurally_equal(left: GeneratorApplication, right: GeneratorApplication) -> bool: + """Return whether two generator applications carry identical registry identity and schema effects.""" + if left.kind != right.kind: + return false + if left.function_ref != right.function_ref: + return false + if left.canonical_name != right.canonical_name: + return false + if left.preserves_input_columns != right.preserves_input_columns: + return false + if left.is_outer != right.is_outer: + return false + if left.position_origin != right.position_origin: + return false + if not _text_lists_structurally_equal(left.output_columns, right.output_columns): + return false + return _column_exprs_structurally_equal(left.expr, right.expr) + + +def _text_lists_structurally_equal(left: list[str], right: list[str]) -> bool: + """Return whether two string lists are structurally equivalent.""" + if len(left) != len(right): + return false + for idx in range(len(left)): + if left[idx] != right[idx]: + return false + return true + + def _filter_predicates_structurally_equal(left: ColumnExpr, right: ColumnExpr) -> bool: """Return whether two filter scalar expressions are structurally equivalent.""" return _column_exprs_structurally_equal(left, right) diff --git a/src/prism/types.incn b/src/prism/types.incn index a5573cf..59472c1 100644 --- a/src/prism/types.incn +++ b/src/prism/types.incn @@ -1,6 +1,7 @@ """Shared Prism types that define the internal planning substrate contract.""" from aggregate_builders import AggregateMeasure +from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment @@ -17,6 +18,7 @@ pub enum PrismNodeKind(str): Project = "Project" GroupBy = "GroupBy" Aggregate = "Aggregate" + Generate = "Generate" OrderBy = "OrderBy" Limit = "Limit" Explode = "Explode" @@ -41,6 +43,7 @@ pub model PrismNode: pub group_columns: list[ColumnExpr] pub sort_columns: list[ColumnExpr] pub aggregate_measures: list[AggregateMeasure] + pub generator_applications: list[GeneratorApplication] pub projection_assignments: list[ProjectionAssignment] diff --git a/src/session/datafusion_backend.incn b/src/session/datafusion_backend.incn index eb13a7b..1746d8f 100644 --- a/src/session/datafusion_backend.incn +++ b/src/session/datafusion_backend.incn @@ -2,10 +2,31 @@ import std.async from rust::prost import Message +from rust::std::boxed import Box from rust::std::primitive import i64 as RustI64, usize as RustUsize -from rust::substrait::proto import Plan +from rust::std::sync import Arc +from rust::substrait::proto import ( + AggregateRel, + CrossRel, + ExtensionSingleRel, + FetchRel, + FilterRel, + JoinRel, + Plan, + ProjectRel, + Rel, + SetRel, + SortRel, +) +from rust::substrait::proto::rel import RelType +from rust::substrait::proto::join_rel import JoinType +from rust::datafusion::common import Column, UnnestOptions +from rust::datafusion::arrow::record_batch import RecordBatch +from rust::datafusion::datasource import MemTable +from rust::datafusion::dataframe import DataFrame as RustDataFrame from rust::datafusion::execution::context import SessionContext from rust::datafusion::execution::options import ArrowReadOptions +from rust::datafusion::logical_expr import LogicalPlanBuilder from rust::datafusion::prelude import CsvReadOptions, ParquetReadOptions from rust::datafusion::dataframe import DataFrameWriteOptions from rust::datafusion_substrait::substrait::proto import Plan as ConsumerPlan @@ -13,7 +34,22 @@ from rust::datafusion_substrait::logical_plan::consumer import from_substrait_pl from backends import SourceKind, TableSource from dataset.materialization import DataFrameMaterialization from session.backend_types import BackendError, BackendErrorKind, BackendRegistration, backend_error -from substrait.inspect import root_names +from substrait.function_extensions import ( + EXPLODE_EXTENSION_URI, + EXPLODE_OUTER_EXTENSION_URI, + FLATTEN_EXTENSION_URI, + INLINE_EXTENSION_URI, + INLINE_OUTER_EXTENSION_URI, + POSEXPLODE_EXTENSION_URI, + POSEXPLODE_OUTER_EXTENSION_URI, + STACK_EXTENSION_URI, +) +from substrait.generator_payload import GeneratorExtensionPayload, decode_generator_extension_payload +from substrait.inspect import relation_output_columns, root_names, root_rel +from substrait.plans import plan_from_root_relation +from substrait.relations import project_rel_with_expressions, read_named_table_rel +from substrait.schema import RowColumnSpec, SubstraitPrimitiveKind +from substrait.schema_registry import register_named_table_schema @derive(Clone) @@ -25,25 +61,29 @@ enum DataFusionSourceRegistration(str): Arrow = "arrow" +@derive(Clone) +model MaterializedGeneratorRelation: + """One generator relation that must be bridged through a DataFusion temp view for execution.""" + + path: str + rel: Rel + extension: ExtensionSingleRel + + pub async def datafusion_execute_async( registrations: list[BackendRegistration], plan: Plan, ) -> Result[None, BackendError]: """Execute one Substrait plan via DataFusion and discard collected rows.""" ctx = SessionContext.new() - consumer_plan = _consumer_plan_from_current_plan(plan)? await _register_sources(ctx, registrations)? - state = ctx.state() - match await from_substrait_plan(state, consumer_plan): - Ok(logical_plan) => match await ctx.execute_logical_plan(logical_plan): - Ok(df) => match await df.collect(): - Ok(_) => return Ok(None) - Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - + match await _dataframe_from_plan(ctx, plan): + Ok(df) => match await df.collect(): + Ok(_) => return Ok(None) Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) + Err(err) => return Err(err) pub async def datafusion_collect_materialization_async( @@ -53,32 +93,27 @@ pub async def datafusion_collect_materialization_async( """Execute one Substrait plan via DataFusion and return structured DataFrame materialization.""" ctx = SessionContext.new() resolved_columns = root_names(plan.clone()) - consumer_plan = _consumer_plan_from_current_plan(plan)? await _register_sources(ctx, registrations)? - state = ctx.state() - match await from_substrait_plan(state, consumer_plan): - Ok(logical_plan) => match await ctx.execute_logical_plan(logical_plan): - Ok(df) => match await df.clone().collect(): - Ok(batches) => match await df.to_string(): - Ok(rendered) => - mut row_count = 0 - for batch in batches: - row_count += _rust_usize_to_int(batch.num_rows())? - return Ok( - DataFrameMaterialization( - resolved_columns=resolved_columns, - row_count=row_count, - preview_text=rendered, - ), - ) - Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - + match await _dataframe_from_plan(ctx, plan): + Ok(df) => match await df.clone().collect(): + Ok(batches) => match await df.to_string(): + Ok(rendered) => + mut row_count = 0 + for batch in batches: + row_count += _rust_usize_to_int(batch.num_rows())? + return Ok( + DataFrameMaterialization( + resolved_columns=resolved_columns, + row_count=row_count, + preview_text=rendered, + ), + ) Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) + Err(err) => return Err(err) pub async def datafusion_write_csv_async( @@ -88,19 +123,14 @@ pub async def datafusion_write_csv_async( ) -> Result[None, BackendError]: """Execute one plan and write result rows to a CSV sink URI via DataFusion.""" ctx = SessionContext.new() - consumer_plan = _consumer_plan_from_current_plan(plan)? await _register_sources(ctx, registrations)? - state = ctx.state() - match await from_substrait_plan(state, consumer_plan): - Ok(logical_plan) => match await ctx.execute_logical_plan(logical_plan): - Ok(df) => match await df.write_csv(uri, DataFrameWriteOptions.new(), None): - Ok(_) => return Ok(None) - Err(err) => return Err(backend_error(BackendErrorKind.BackendSinkError, err.to_string())) + match await _dataframe_from_plan(ctx, plan): + Ok(df) => match await df.write_csv(uri, DataFrameWriteOptions.new(), None): + Ok(_) => return Ok(None) + Err(err) => return Err(backend_error(BackendErrorKind.BackendSinkError, err.to_string())) - Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - - Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) + Err(err) => return Err(err) pub async def datafusion_write_parquet_async( @@ -110,21 +140,643 @@ pub async def datafusion_write_parquet_async( ) -> Result[None, BackendError]: """Execute one plan and write result rows to a Parquet sink URI via DataFusion.""" ctx = SessionContext.new() - consumer_plan = _consumer_plan_from_current_plan(plan)? await _register_sources(ctx, registrations)? + match await _dataframe_from_plan(ctx, plan): + Ok(df) => match await df.write_parquet(uri, DataFrameWriteOptions.new(), None): + Ok(_) => return Ok(None) + Err(err) => return Err(backend_error(BackendErrorKind.BackendSinkError, err.to_string())) + + Err(err) => return Err(err) + + +async def _dataframe_from_plan(ctx: SessionContext, plan: Plan) -> Result[RustDataFrame, BackendError]: + """Build a DataFusion DataFrame, including InQL-owned generator extension roots.""" + root = root_rel(plan.clone()) + # DataFusion's stock Substrait consumer cannot plan InQL generator ExtensionSingleRel nodes. Root generators can + # execute directly; nested generators are first materialized and then replaced by execution-only temp reads. + match _generator_extension(root): + Some(extension) => return await _dataframe_from_generator_extension(ctx, extension) + None => + materializations = _collect_materialized_generator_relations(root_rel(plan.clone()), "root") + if len(materializations) == 0: + return await _dataframe_from_standard_plan(ctx, plan) + for materialization in materializations: + # The path key must match the later rewrite traversal so each temp read targets the exact generated + # subtree that was collected here. + table_name = _materialized_generator_table_name(materialization.path) + df = await _dataframe_from_generator_extension(ctx.clone(), materialization.extension)? + await _register_materialized_generator_dataframe(ctx.clone(), table_name, df)? + rewritten_root = _replace_materialized_generator_relations(root_rel(plan.clone()), "root") + rewritten_plan = plan_from_root_relation(rewritten_root, root_names(plan)) + return await _dataframe_from_standard_plan(ctx, rewritten_plan) + + +async def _dataframe_from_standard_plan(ctx: SessionContext, plan: Plan) -> Result[RustDataFrame, BackendError]: + """Build a DataFusion DataFrame through the stock Substrait consumer.""" + consumer_plan = _consumer_plan_from_current_plan(plan)? state = ctx.state() match await from_substrait_plan(state, consumer_plan): - Ok(logical_plan) => match await ctx.execute_logical_plan(logical_plan): - Ok(df) => match await df.write_parquet(uri, DataFrameWriteOptions.new(), None): - Ok(_) => return Ok(None) - Err(err) => return Err(backend_error(BackendErrorKind.BackendSinkError, err.to_string())) + Ok(logical_plan) => + match await ctx.execute_logical_plan(logical_plan): + Ok(df) => return Ok(df) + Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) - Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) +async def _dataframe_from_generator_extension( + ctx: SessionContext, + extension: ExtensionSingleRel, +) -> Result[RustDataFrame, BackendError]: + """Build a DataFusion DataFrame for one InQL generator relation-extension root.""" + payload = _decode_generator_payload(extension.clone())? + child = _generator_input_rel(extension.clone())? + child_columns = relation_output_columns(child.clone()) + # Generator arguments are ordinary InQL scalar expressions. Project them into stable temp columns first so the + # DataFusion unnest API can operate on columns rather than on arbitrary Substrait expressions. + temp_columns = _generator_temp_columns(len(payload.arguments)) + projected_names = _generator_projected_root_names( + child_columns, + temp_columns.clone(), + extension.clone(), + payload.clone(), + )? + projected_rel = project_rel_with_expressions(child, payload.arguments) + projected_plan = plan_from_root_relation(projected_rel, projected_names) + projected_df = await _dataframe_from_standard_plan(ctx, projected_plan)? + return _apply_generator_dataframe_shape(projected_df, extension, payload, temp_columns) + + +def _collect_materialized_generator_relations(rel: Rel, path: str) -> list[MaterializedGeneratorRelation]: + """Return generator relations nested below a non-generator root.""" + match _generator_extension(rel.clone()): + Some(extension) => return [MaterializedGeneratorRelation(path=path, rel=rel, extension=extension)] + None => return _collect_materialized_generator_children(rel, path) + + +def _collect_materialized_generator_children(rel: Rel, path: str) -> list[MaterializedGeneratorRelation]: + """Return generator relations below direct children of one relation.""" + # This traversal must stay in lock-step with _replace_materialized_generator_children; the path segments are the + # shared key between the discovery pass and the execution-plan rewrite pass. + mut materializations: list[MaterializedGeneratorRelation] = [] + match rel.rel_type.clone(): + Some(RelType.Filter(filter_rel)) => + filter = filter_rel.as_ref().clone() + match filter.input: + Some(child) => _extend_materializations( + materializations, + _collect_materialized_generator_relations(child.as_ref().clone(), f"{path}_filter"), + ) + None => pass + Some(RelType.Project(project_rel)) => + project = project_rel.as_ref().clone() + match project.input: + Some(child) => _extend_materializations( + materializations, + _collect_materialized_generator_relations(child.as_ref().clone(), f"{path}_project"), + ) + None => pass + Some(RelType.Join(join_rel)) => + join = join_rel.as_ref().clone() + match join.left: + Some(child) => _extend_materializations( + materializations, + _collect_materialized_generator_relations(child.as_ref().clone(), f"{path}_join_left"), + ) + None => pass + match join.right: + Some(child) => _extend_materializations( + materializations, + _collect_materialized_generator_relations(child.as_ref().clone(), f"{path}_join_right"), + ) + None => pass + Some(RelType.Cross(cross_rel)) => + cross = cross_rel.as_ref().clone() + match cross.left: + Some(child) => _extend_materializations( + materializations, + _collect_materialized_generator_relations(child.as_ref().clone(), f"{path}_cross_left"), + ) + None => pass + match cross.right: + Some(child) => _extend_materializations( + materializations, + _collect_materialized_generator_relations(child.as_ref().clone(), f"{path}_cross_right"), + ) + None => pass + Some(RelType.Aggregate(aggregate_rel)) => + aggregate = aggregate_rel.as_ref().clone() + match aggregate.input: + Some(child) => _extend_materializations( + materializations, + _collect_materialized_generator_relations(child.as_ref().clone(), f"{path}_aggregate"), + ) + None => pass + Some(RelType.Sort(sort_rel)) => + sort = sort_rel.as_ref().clone() + match sort.input: + Some(child) => _extend_materializations( + materializations, + _collect_materialized_generator_relations(child.as_ref().clone(), f"{path}_sort"), + ) + None => pass + Some(RelType.Fetch(fetch_rel)) => + fetch = fetch_rel.as_ref().clone() + match fetch.input: + Some(child) => _extend_materializations( + materializations, + _collect_materialized_generator_relations(child.as_ref().clone(), f"{path}_fetch"), + ) + None => pass + Some(RelType.Set(set_rel)) => + for idx, input in enumerate(set_rel.inputs): + _extend_materializations( + materializations, + _collect_materialized_generator_relations(input, f"{path}_set_{idx}"), + ) + Some(RelType.ExtensionSingle(extension_rel)) => + extension = extension_rel.as_ref().clone() + match extension.input: + Some(child) => _extend_materializations( + materializations, + _collect_materialized_generator_relations(child.as_ref().clone(), f"{path}_extension"), + ) + None => pass + _ => pass + return materializations + + +def _extend_materializations( + mut target: list[MaterializedGeneratorRelation], + values: list[MaterializedGeneratorRelation], +) -> None: + """Append collected materialization records while satisfying list clone requirements.""" + for value in values: + target.append(value.clone()) + return + + +def _replace_materialized_generator_relations(rel: Rel, path: str) -> Rel: + """Replace previously materialized generator relation nodes with temp-view reads.""" + # This is an adapter-only rewrite. The InQL semantic plan remains a generator extension relation; only DataFusion's + # execution plan sees the temporary table scan. + match _generator_extension(rel.clone()): + Some(_) => return read_named_table_rel(_materialized_generator_table_name(path)) + None => return _replace_materialized_generator_children(rel, path) + + +def _replace_materialized_generator_children(rel: Rel, path: str) -> Rel: + """Recursively rewrite generator-bearing children while preserving the surrounding relation node.""" + # Keep this dispatcher shallow. The relation-specific helpers own proto reconstruction details; this function owns + # only the traversal choice. + match rel.rel_type.clone(): + Some(RelType.Filter(filter_rel)) => + return _replace_filter_generator_child(rel, filter_rel.as_ref().clone(), path) + Some(RelType.Project(project_rel)) => + return _replace_project_generator_child(rel, project_rel.as_ref().clone(), path) + Some(RelType.Join(join_rel)) => return _replace_join_generator_children(join_rel.as_ref().clone(), path) + Some(RelType.Cross(cross_rel)) => return _replace_cross_generator_children(cross_rel.as_ref().clone(), path) + Some(RelType.Aggregate(aggregate_rel)) => + return _replace_aggregate_generator_child(rel, aggregate_rel.as_ref().clone(), path) + Some(RelType.Sort(sort_rel)) => return _replace_sort_generator_child(rel, sort_rel.as_ref().clone(), path) + Some(RelType.Fetch(fetch_rel)) => return _replace_fetch_generator_child(rel, fetch_rel.as_ref().clone(), path) + Some(RelType.Set(set_rel)) => return _replace_set_generator_children(set_rel, path) + Some(RelType.ExtensionSingle(extension_rel)) => + return _replace_extension_generator_child(rel, extension_rel.as_ref().clone(), path) + _ => return rel + + +def _rewritten_generator_child(child: Box[Rel], path: str) -> Box[Rel]: + """Return one boxed child after replacing any materialized generator subtree beneath it.""" + return Box.new(_replace_materialized_generator_relations(child.as_ref().clone(), path)) + + +def _replace_filter_generator_child(original: Rel, filter: FilterRel, path: str) -> Rel: + """Rebuild a FilterRel with its input child rewritten when present.""" + match filter.input: + Some(child) => + return Rel( + rel_type=Some( + RelType.Filter( + Box.new( + FilterRel( + common=filter.common, + input=Some(_rewritten_generator_child(child, f"{path}_filter")), + condition=filter.condition, + advanced_extension=filter.advanced_extension, + ), + ), + ), + ), + ) + None => return original + + +def _replace_project_generator_child(original: Rel, project: ProjectRel, path: str) -> Rel: + """Rebuild a ProjectRel with its input child rewritten when present.""" + match project.input: + Some(child) => + return Rel( + rel_type=Some( + RelType.Project( + Box.new( + ProjectRel( + common=project.common, + input=Some(_rewritten_generator_child(child, f"{path}_project")), + expressions=project.expressions, + advanced_extension=project.advanced_extension, + ), + ), + ), + ), + ) + None => return original + + +def _replace_join_generator_children(join: JoinRel, path: str) -> Rel: + """Rebuild a JoinRel with rewritten left and right inputs.""" + mut left = join.left + mut right = join.right + match left: + Some(child) => + left = Some(_rewritten_generator_child(child, f"{path}_join_left")) + None => pass + match right: + Some(child) => + right = Some(_rewritten_generator_child(child, f"{path}_join_right")) + None => pass + # Public InQL joins currently lower only inner joins; preserve that contract while rewriting child inputs. + return Rel( + rel_type=Some( + RelType.Join( + Box.new( + JoinRel( + common=join.common, + left=left, + right=right, + expression=join.expression, + post_join_filter=join.post_join_filter, + type=JoinType.Inner.into(), + advanced_extension=join.advanced_extension, + ), + ), + ), + ), + ) + + +def _replace_cross_generator_children(cross: CrossRel, path: str) -> Rel: + """Rebuild a CrossRel with rewritten left and right inputs.""" + mut left = cross.left + mut right = cross.right + match left: + Some(child) => + left = Some(_rewritten_generator_child(child, f"{path}_cross_left")) + None => pass + match right: + Some(child) => + right = Some(_rewritten_generator_child(child, f"{path}_cross_right")) + None => pass + return Rel( + rel_type=Some( + RelType.Cross( + Box.new( + CrossRel(common=cross.common, left=left, right=right, advanced_extension=cross.advanced_extension), + ), + ), + ), + ) + + +def _replace_aggregate_generator_child(original: Rel, aggregate: AggregateRel, path: str) -> Rel: + """Rebuild an AggregateRel with its input child rewritten when present.""" + match aggregate.input: + Some(child) => + return Rel( + rel_type=Some( + RelType.Aggregate( + Box.new( + AggregateRel( + common=aggregate.common, + input=Some(_rewritten_generator_child(child, f"{path}_aggregate")), + groupings=aggregate.groupings, + measures=aggregate.measures, + grouping_expressions=aggregate.grouping_expressions, + advanced_extension=aggregate.advanced_extension, + ), + ), + ), + ), + ) + None => return original + + +def _replace_sort_generator_child(original: Rel, sort: SortRel, path: str) -> Rel: + """Rebuild a SortRel with its input child rewritten when present.""" + match sort.input: + Some(child) => + return Rel( + rel_type=Some( + RelType.Sort( + Box.new( + SortRel( + common=sort.common, + input=Some(_rewritten_generator_child(child, f"{path}_sort")), + sorts=sort.sorts, + advanced_extension=sort.advanced_extension, + ), + ), + ), + ), + ) + None => return original + + +def _replace_fetch_generator_child(original: Rel, fetch: FetchRel, path: str) -> Rel: + """Rebuild a FetchRel with its input child rewritten when present.""" + match fetch.input: + Some(child) => + return Rel( + rel_type=Some( + RelType.Fetch( + Box.new( + FetchRel( + common=fetch.common, + input=Some(_rewritten_generator_child(child, f"{path}_fetch")), + advanced_extension=fetch.advanced_extension, + offset_mode=fetch.offset_mode, + count_mode=fetch.count_mode, + ), + ), + ), + ), + ) + None => return original + + +def _replace_set_generator_children(set_rel: SetRel, path: str) -> Rel: + """Rebuild a SetRel with every input child rewritten.""" + mut inputs: list[Rel] = [] + for idx, input in enumerate(set_rel.inputs): + rewritten = _replace_materialized_generator_relations(input, f"{path}_set_{idx}") + inputs.append(rewritten.clone()) + return Rel( + rel_type=Some( + RelType.Set( + SetRel( + common=set_rel.common, + inputs=inputs, + op=set_rel.op, + advanced_extension=set_rel.advanced_extension, + ), + ), + ), + ) + + +def _replace_extension_generator_child(original: Rel, extension: ExtensionSingleRel, path: str) -> Rel: + """Rebuild a non-generator ExtensionSingleRel with its input child rewritten when present.""" + match extension.input: + Some(child) => + return Rel( + rel_type=Some( + RelType.ExtensionSingle( + Box.new( + ExtensionSingleRel( + common=extension.common, + input=Some(_rewritten_generator_child(child, f"{path}_extension")), + detail=extension.detail, + ), + ), + ), + ), + ) + None => return original + + +def _materialized_generator_table_name(path: str) -> str: + """Return the internal DataFusion temp-view name for one materialized generator relation path.""" + return f"__inql_generator_materialized_{path}" + + +async def _register_materialized_generator_dataframe( + ctx: SessionContext, + table_name: str, + df: RustDataFrame, +) -> Result[None, BackendError]: + """Collect one generated DataFrame into a MemTable so Substrait reads see concrete table scans.""" + match await df.clone().collect(): + Ok(batches) => + if len(batches) == 0: + # Empty generator output still needs a registered relation for the rewritten temp ReadRel to resolve. + match ctx.register_table(f"{table_name}", df.into_view()): + Ok(_) => return Ok(None) + Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) + _register_materialized_generator_schema_from_batch(f"{table_name}", batches[0]) + schema = batches[0].schema() + # Non-empty outputs are frozen into a MemTable so the stock consumer reads a concrete table instead of an + # InQL-specific generator extension relation. + match MemTable.try_new(schema, [batches]): + Ok(table) => + match ctx.register_table(f"{table_name}", Arc.new(table)): + Ok(_) => return Ok(None) + Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) + + +def _register_materialized_generator_schema_from_batch(table_name: str, batch: RecordBatch) -> None: + """Mirror one DataFusion materialized generator schema into the Substrait named-table schema registry.""" + # The rewritten ReadRel uses the same schema inference path as user-registered tables, so mirror DataFusion's batch + # schema into that registry before the temporary plan is lowered. + schema = batch.schema() + mut columns: list[RowColumnSpec] = [] + for field in schema.fields(): + columns.append( + RowColumnSpec( + name=field.name().to_string(), + kind=_substrait_kind_for_arrow_type_name(field.data_type().to_string()), + nullable=field.is_nullable(), + ), + ) + register_named_table_schema(table_name, columns) + return + + +def _substrait_kind_for_arrow_type_name(type_name: str) -> SubstraitPrimitiveKind: + """Map the current DataFusion primitive output type names into InQL's minimal Substrait schema kinds.""" + if type_name == "Int64" or type_name == "Int32" or type_name == "UInt64" or type_name == "UInt32": + return SubstraitPrimitiveKind.I64 + if type_name == "Float64" or type_name == "Float32": + return SubstraitPrimitiveKind.F64 + if type_name == "Boolean": + return SubstraitPrimitiveKind.Bool + if type_name.contains("Timestamp"): + return SubstraitPrimitiveKind.Timestamp + return SubstraitPrimitiveKind.String + + +def _generator_projected_root_names( + child_columns: list[str], + temp_columns: list[str], + extension: ExtensionSingleRel, + payload: GeneratorExtensionPayload, +) -> Result[list[str], BackendError]: + """Return temporary root names, including nested struct field names when Substrait requires them.""" + mut projected_names = _extended_columns(child_columns, temp_columns) + match extension.detail: + Some(detail) => + if _generator_expands_struct(detail.type_url): + # Struct-expanding generators need field aliases present before DataFusion unnests the temporary struct + # column; otherwise the consumer cannot name the generated fields deterministically. + for output_column in payload.output_columns: + projected_names.append(f"{output_column}") + return Ok(projected_names) + None => + return Err(backend_error(BackendErrorKind.BackendPlanningError, "generator extension is missing detail")) + + +def _generator_extension(rel: Rel) -> Option[ExtensionSingleRel]: + """Return the root generator extension relation when the plan root is an InQL generator.""" + match rel.rel_type: + Some(RelType.ExtensionSingle(extension)) => + if let Some(detail) = extension.detail.clone(): + if _is_generator_extension_uri(detail.type_url): + return Some(extension.as_ref().clone()) + return None + _ => return None + + +def _generator_input_rel(extension: ExtensionSingleRel) -> Result[Rel, BackendError]: + """Return the child relation for one generator extension.""" + match extension.input: + Some(child) => return Ok(child.as_ref().clone()) + None => return Err(backend_error(BackendErrorKind.BackendPlanningError, "generator extension is missing input")) + + +def _decode_generator_payload(extension: ExtensionSingleRel) -> Result[GeneratorExtensionPayload, BackendError]: + """Decode one InQL generator extension payload.""" + match extension.detail: + Some(detail) => + match decode_generator_extension_payload(detail.value): + Ok(payload) => return Ok(payload) + Err(message) => return Err(backend_error(BackendErrorKind.BackendPlanningError, message)) + None => + return Err(backend_error(BackendErrorKind.BackendPlanningError, "generator extension is missing detail")) + + +def _apply_generator_dataframe_shape( + df: RustDataFrame, + extension: ExtensionSingleRel, + payload: GeneratorExtensionPayload, + temp_columns: list[str], +) -> Result[RustDataFrame, BackendError]: + """Apply DataFusion unnest and aliasing for one lowered generator payload.""" + match extension.detail: + Some(detail) => + # The first unnest converts arrays into rows. Inline/stack then need a second unnest to split the generated + # struct values into separate output columns. + mut current = _unnest_columns(df, temp_columns.clone(), _preserve_generator_nulls(detail.type_url))? + if _generator_expands_struct(detail.type_url): + current = _unnest_columns(current, [temp_columns[0]], _preserve_generator_nulls(detail.type_url))? + return _rename_struct_generator_outputs(current, temp_columns[0], payload.output_columns) + return _rename_flat_generator_outputs(current, temp_columns, payload.output_columns) + None => + return Err(backend_error(BackendErrorKind.BackendPlanningError, "generator extension is missing detail")) + + +def _unnest_columns(df: RustDataFrame, columns: list[str], preserve_nulls: bool) -> Result[RustDataFrame, BackendError]: + """Call DataFusion unnest with the InQL inner/outer null-preservation policy.""" + options = UnnestOptions.new().with_preserve_nulls(preserve_nulls) + mut datafusion_columns: list[Column] = [] + for column in columns: + datafusion_column = Column.from_name(f"{column}") + datafusion_columns.append(datafusion_column.clone()) + # DataFusion exposes unnest through LogicalPlanBuilder rather than as a direct DataFrame method, so split the + # DataFrame and rebuild it around the transformed logical plan. + parts = df.into_parts() + state = parts.0 + logical_plan = parts.1 + match LogicalPlanBuilder.from(logical_plan).unnest_columns_with_options(datafusion_columns, options): + Ok(builder) => + match builder.build(): + Ok(next_plan) => return Ok(RustDataFrame.new(state, next_plan)) + Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) +def _rename_flat_generator_outputs( + df: RustDataFrame, + temp_columns: list[str], + output_columns: list[str], +) -> Result[RustDataFrame, BackendError]: + """Rename generated scalar/list outputs from temporary names to InQL aliases.""" + if len(temp_columns) != len(output_columns): + return Err(backend_error(BackendErrorKind.BackendPlanningError, "generator payload/output arity mismatch")) + mut current = df + for idx, output_column in enumerate(output_columns): + match current.with_column_renamed(temp_columns[idx], output_column): + Ok(next_df) => + current = next_df + Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) + return Ok(current) + + +def _rename_struct_generator_outputs( + df: RustDataFrame, + temp_column: str, + output_columns: list[str], +) -> Result[RustDataFrame, BackendError]: + """Rename generated struct fields from temporary qualified names to InQL aliases.""" + mut current = df + # DataFusion names unnested struct fields as `temp.field`; the public InQL result should expose only the declared + # generator output aliases. + for output_column in output_columns: + field_name = f"{temp_column}.{output_column}" + match current.with_column_renamed(field_name, output_column): + Ok(next_df) => + current = next_df + Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) + return Ok(current) + + +def _extended_columns(input_columns: list[str], appended_columns: list[str]) -> list[str]: + """Return input columns followed by appended temporary generator columns.""" + mut columns: list[str] = [] + columns.extend(input_columns) + columns.extend(appended_columns) + return columns + + +def _generator_temp_columns(argument_count: int) -> list[str]: + """Return stable temporary column aliases for lowered generator argument expressions.""" + return [f"__inql_generator_arg_{idx}" for idx in range(argument_count)] + + +def _preserve_generator_nulls(extension_uri: str) -> bool: + """Return whether DataFusion unnest should preserve null/empty inputs for this generator.""" + return extension_uri == EXPLODE_OUTER_EXTENSION_URI or extension_uri == POSEXPLODE_OUTER_EXTENSION_URI or extension_uri == INLINE_OUTER_EXTENSION_URI + + +def _generator_expands_struct(extension_uri: str) -> bool: + """Return whether the generator result should expand a struct value into multiple columns.""" + return extension_uri == INLINE_EXTENSION_URI or extension_uri == INLINE_OUTER_EXTENSION_URI or extension_uri == STACK_EXTENSION_URI + + +def _is_generator_extension_uri(extension_uri: str) -> bool: + """Return whether an ExtensionSingle URI belongs to an InQL generator.""" + return ( + extension_uri == EXPLODE_EXTENSION_URI + or extension_uri == EXPLODE_OUTER_EXTENSION_URI + or extension_uri == POSEXPLODE_EXTENSION_URI + or extension_uri == POSEXPLODE_OUTER_EXTENSION_URI + or extension_uri == INLINE_EXTENSION_URI + or extension_uri == INLINE_OUTER_EXTENSION_URI + or extension_uri == FLATTEN_EXTENSION_URI + or extension_uri == STACK_EXTENSION_URI + ) + + async def _register_sources( ctx: SessionContext, registrations: list[BackendRegistration], diff --git a/src/substrait/expr_lowering.incn b/src/substrait/expr_lowering.incn index d5dcd72..21384eb 100644 --- a/src/substrait/expr_lowering.incn +++ b/src/substrait/expr_lowering.incn @@ -281,6 +281,12 @@ def _resolved_scalar_function_application_expr( f"{entry.function_ref} is only valid in {entry.substrait.function_name} context", ), ) + SubstraitMappingKind.RelationExtension => + return Err( + invalid_scalar_expression( + f"{entry.function_ref} is a relation-shaping generator and must be applied through generate(...)", + ), + ) SubstraitMappingKind.Rewrite => return Err( invalid_scalar_expression( diff --git a/src/substrait/extensions.incn b/src/substrait/extensions.incn index f4efeb4..575c285 100644 --- a/src/substrait/extensions.incn +++ b/src/substrait/extensions.incn @@ -6,6 +6,7 @@ expression trees. """ from rust::incan_stdlib::errors import raise_value_error +from rust::std::primitive import u32 as RustU32 from rust::substrait::proto import AggregateFunction, Expression, FunctionArgument, Rel, SortField from rust::substrait::proto::extensions import SimpleExtensionDeclaration, SimpleExtensionUrn from rust::substrait::proto::extensions::simple_extension_declaration import ExtensionFunction, MappingType @@ -15,7 +16,7 @@ from rust::substrait::proto::rel import RelType from function_registry import FunctionClass, SubstraitMappingKind from functions.registry import function_registry_entries from substrait.errors import SubstraitLoweringError, invalid_scalar_expression -from substrait.function_extensions import ExtensionFunctionKind, FunctionExtensionSpec, function_extension_uri +from substrait.function_extensions import ExtensionFunctionKind, FUNCTION_EXTENSION_URI, FunctionExtensionSpec from substrait.traversal import relation_children @@ -28,7 +29,23 @@ model ExtensionUrnSpec: const FUNCTION_EXTENSION_URN_ANCHOR: u32 = 0 -const RELATION_EXTENSION_URN_ANCHOR: u32 = 1 + + +def _to_extension_urn_anchor(value: int) -> RustU32: + """Convert a small extension-URN anchor into the protobuf field type.""" + match RustU32.try_from(value): + Ok(converted) => return converted + Err(_) => + message = f"extension URN anchor {value} does not fit Rust u32" + return raise_value_error(message) + + +def _has_extension_urn_spec(specs: list[ExtensionUrnSpec], urn: str) -> bool: + """Return whether a plan-level extension URN list already contains one URI.""" + for spec in specs: + if spec.urn == urn: + return true + return false pub def aggregate_function_name_from_anchor(anchor: u32) -> str: @@ -406,7 +423,11 @@ pub def extension_urns_for_rel(rel: Rel) -> list[SimpleExtensionUrn]: """Collect extension URNs required by one relation subtree.""" mut specs: list[ExtensionUrnSpec] = [] if _function_extension_urn_is_required(rel.clone()): - specs.append(ExtensionUrnSpec(anchor=FUNCTION_EXTENSION_URN_ANCHOR, urn=function_extension_uri())) + specs.append(ExtensionUrnSpec(anchor=FUNCTION_EXTENSION_URN_ANCHOR, urn=FUNCTION_EXTENSION_URI)) + mut relation_anchor_count = 0 for urn in _collect_extension_urn_strings(rel): - specs.append(ExtensionUrnSpec(anchor=RELATION_EXTENSION_URN_ANCHOR, urn=urn)) + if _has_extension_urn_spec(specs, urn): + continue + relation_anchor_count += 1 + specs.append(ExtensionUrnSpec(anchor=_to_extension_urn_anchor(relation_anchor_count), urn=urn)) return [SimpleExtensionUrn(extension_urn_anchor=spec.anchor, urn=spec.urn) for spec in specs] diff --git a/src/substrait/function_extensions.incn b/src/substrait/function_extensions.incn index 490f93c..40bf4da 100644 --- a/src/substrait/function_extensions.incn +++ b/src/substrait/function_extensions.incn @@ -75,20 +75,18 @@ pub const MAP_EXTRACT_FUNCTION_ANCHOR: u32 = 48 pub const NAMED_STRUCT_FUNCTION_ANCHOR: u32 = 49 pub const ARRAY_HAS_ANY_FUNCTION_ANCHOR: u32 = 50 pub const ARRAY_FLATTEN_FUNCTION_ANCHOR: u32 = 51 -const FUNCTION_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/functions.yaml" -const EXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode" - - -pub def function_extension_uri() -> str: - """Return the registered extension URI used for shared function anchors.""" - return FUNCTION_EXTENSION_URI - - -pub def explode_extension_uri() -> str: - """Return the registered extension URI used for EXPLODE-style gap encoding.""" - return EXPLODE_EXTENSION_URI +pub const RANGE_FUNCTION_ANCHOR: u32 = 52 +pub const FUNCTION_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/functions.yaml" +pub const EXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode" +pub const EXPLODE_OUTER_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode_outer" +pub const POSEXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#posexplode" +pub const POSEXPLODE_OUTER_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#posexplode_outer" +pub const INLINE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#inline" +pub const INLINE_OUTER_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#inline_outer" +pub const FLATTEN_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#flatten" +pub const STACK_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/table_functions.yaml#stack" pub def registered_substrait_extension_uris() -> list[str]: """Return the registered extension URIs used by current package-level Substrait lowering.""" - return [FUNCTION_EXTENSION_URI, EXPLODE_EXTENSION_URI] + return [FUNCTION_EXTENSION_URI, EXPLODE_EXTENSION_URI, EXPLODE_OUTER_EXTENSION_URI, POSEXPLODE_EXTENSION_URI, POSEXPLODE_OUTER_EXTENSION_URI, INLINE_EXTENSION_URI, INLINE_OUTER_EXTENSION_URI, FLATTEN_EXTENSION_URI, STACK_EXTENSION_URI] diff --git a/src/substrait/generator_payload.incn b/src/substrait/generator_payload.incn new file mode 100644 index 0000000..c25b234 --- /dev/null +++ b/src/substrait/generator_payload.incn @@ -0,0 +1,160 @@ +"""Binary payload for InQL generator relation-extension nodes.""" + +from rust::incan_stdlib::errors import raise_value_error +from rust::prost import Message +from rust::std::string import String as RustString +from rust::substrait::proto import Expression +from std.io import BytesIO, Endian, IoError, _BytesIO + +# ASCII "INQLGEN1"; this lets decoders reject unrelated relation-extension payloads before reading length fields. +const _MAGIC: list[int] = [73, 78, 81, 76, 71, 69, 78, 49] + + +@derive(Clone) +pub model GeneratorExtensionPayload: + """Decoded generator relation-extension payload.""" + + pub row_count: int + pub output_columns: list[str] + pub arguments: list[Expression] + + +pub def encode_generator_extension_payload( + row_count: int, + output_columns: list[str], + arguments: list[Expression], +) -> bytes: + """Encode generator metadata and lowered scalar arguments into stable bytes.""" + # Layout: magic, row_count, output names, then each lowered Substrait argument as a length-prefixed proto blob. + # Keeping this adapter-neutral lets Substrait remain the boundary while DataFusion decodes only execution details. + out = BytesIO() + _write_magic(out) + _write_u64_le(out, row_count) + _write_u64_le(out, len(output_columns)) + for output_column in output_columns: + encoded_name = RustString.from(f"{output_column}").into_bytes() + _write_u64_le(out, len(encoded_name)) + _write_payload_bytes(out, encoded_name) + _write_u64_le(out, len(arguments)) + for argument in arguments: + encoded = argument.encode_to_vec() + _write_u64_le(out, len(encoded)) + _write_payload_bytes(out, encoded) + return out.getvalue() + + +pub def decode_generator_extension_payload(data: bytes) -> Result[GeneratorExtensionPayload, str]: + """Decode generator relation-extension payload bytes.""" + if len(data) < len(_MAGIC) + 24: + return Err("generator extension payload is too short") + reader = BytesIO(data) + _validate_magic(_read_payload_bytes(reader, len(_MAGIC), "generator extension payload has an invalid prefix")?)? + # Decode in the same strict order as encode and reject trailing bytes so payload evolution is explicit. + row_count = _read_u64_le(reader, "generator extension payload ended before row count")? + output_column_count = _read_u64_le(reader, "generator extension payload ended before output column count")? + mut output_columns: list[str] = [] + for _ in range(output_column_count): + column_len = _read_u64_le(reader, "generator extension payload ended before output column length")? + column_bytes = _read_payload_bytes( + reader, + column_len, + "generator extension payload ended inside an output column", + )? + match RustString.from_utf8(column_bytes): + Ok(column_name) => output_columns.append(column_name.to_string()) + Err(err) => return Err(err.to_string()) + argument_count = _read_u64_le(reader, "generator extension payload ended before argument count")? + mut arguments: list[Expression] = [] + for _ in range(argument_count): + argument_len = _read_u64_le(reader, "generator extension payload ended before argument length")? + argument_bytes = _read_payload_bytes( + reader, + argument_len, + "generator extension payload ended inside an argument", + )? + match Expression.decode(argument_bytes.as_slice()): + Ok(argument) => arguments.append(argument.clone()) + Err(err) => return Err(err.to_string()) + if reader.remaining() != 0: + return Err("generator extension payload contains trailing bytes") + return Ok(GeneratorExtensionPayload(row_count=row_count, output_columns=output_columns, arguments=arguments)) + + +def _validate_magic(magic: bytes) -> Result[None, str]: + """Validate the generator payload prefix.""" + for idx in range(len(_MAGIC)): + if int(magic[idx]) != _MAGIC[idx]: + return Err("generator extension payload has an invalid prefix") + return Ok(None) + + +def _read_payload_bytes(reader: _BytesIO, size: int, message: str) -> Result[bytes, str]: + """Read an exact payload segment and map stdlib I/O errors to payload errors.""" + match reader.read_exact(size): + Ok(data) => return Ok(data) + Err(_) => return Err(message) + + +def _read_u64_le(reader: _BytesIO, message: str) -> Result[int, str]: + """Read one little-endian u64 length/count field through std.io.""" + match _read_u64_io(reader): + Ok(value) => return Ok(int(value)) + Err(_) => return Err(message) + + +def _read_u64_io(reader: _BytesIO) -> Result[u64, IoError]: + """Read one little-endian u64 with explicit result type context for the overloaded std.io method.""" + value: u64 = reader.read(Endian.Little)? + return Ok(value) + + +def _write_payload_bytes(out: _BytesIO, data: bytes) -> None: + """Write payload bytes through std.io.""" + match out.write_bytes(data): + Ok(_) => return + Err(err) => return _raise_payload_write_error(err) + + +def _write_magic(out: _BytesIO) -> None: + """Write the generator payload magic through std.io.""" + _write_payload_bytes(out, b"INQLGEN1") + return + + +def _write_u64_le(out: _BytesIO, value: int) -> None: + """Write one non-negative integer as a little-endian u64 through std.io.""" + match out.write(_payload_u64(value), Endian.Little): + Ok(_) => return + Err(err) => return _raise_payload_write_error(err) + + +def _payload_u64(value: int) -> u64: + """Convert one payload length/count to u64 before std.io writes it.""" + maybe_value: Option[u64] = value.try_resize() + match maybe_value: + Some(converted) => return converted + None => return raise_value_error("generator extension payload length/count does not fit u64") + + +def _raise_payload_write_error(err: IoError) -> None: + """Raise an unexpected in-memory payload write failure.""" + message = "generator extension payload write failed: " + err.message() + return raise_value_error(message) + + +module tests: + from std.testing import assert_is_err, assert_is_ok + def test_generator_payload_round_trips_empty_payload() -> None: + encoded = encode_generator_extension_payload(3, ["left", "right"], []) + decoded = assert_is_ok(decode_generator_extension_payload(encoded), "payload should decode") + assert decoded.row_count == 3 + assert decoded.output_columns == ["left", "right"] + assert len(decoded.arguments) == 0 + def test_generator_payload_rejects_invalid_prefix() -> None: + writer = BytesIO() + _write_payload_bytes(writer, b"BADGEN!!") + _write_u64_le(writer, 0) + _write_u64_le(writer, 0) + _write_u64_le(writer, 0) + err = assert_is_err(decode_generator_extension_payload(writer.getvalue()), "invalid prefix should fail") + assert err == "generator extension payload has an invalid prefix" diff --git a/src/substrait/inspect.incn b/src/substrait/inspect.incn index c814b85..6c542b7 100644 --- a/src/substrait/inspect.incn +++ b/src/substrait/inspect.incn @@ -19,8 +19,19 @@ from rust::substrait::proto::set_rel import SetOp from rust::substrait::proto::sort_field import SortDirection, SortKind from aggregate_builders import AggregateMeasure from projection_builders import scalar_expr_output_name -from substrait.expr_lowering import field_index_from_expression, project_output_columns_with_emit +from substrait.expr_lowering import field_index_from_expression, project_output_columns_with_emit, rust_u32_to_int from substrait.extensions import aggregate_function_name_from_anchor +from substrait.function_extensions import ( + EXPLODE_EXTENSION_URI, + EXPLODE_OUTER_EXTENSION_URI, + FLATTEN_EXTENSION_URI, + INLINE_EXTENSION_URI, + INLINE_OUTER_EXTENSION_URI, + POSEXPLODE_EXTENSION_URI, + POSEXPLODE_OUTER_EXTENSION_URI, + STACK_EXTENSION_URI, +) +from substrait.generator_payload import decode_generator_extension_payload from substrait.schema_registry import named_table_columns, unknown_named_struct from substrait.traversal import relation_children @@ -176,7 +187,12 @@ def _relation_output_columns(rel: Rel) -> list[str]: None => return [] Some(RelType.ExtensionSingle(extension_rel)) => match extension_rel.input: - Some(child) => return _relation_output_columns(child.as_ref().clone()) + Some(child) => + input_columns = _relation_output_columns(child.as_ref().clone()) + match extension_rel.detail: + Some(detail) => + return _extension_single_output_columns(input_columns, detail.type_url, detail.value) + None => return input_columns None => return [] Some(RelType.Join(join_rel)) => mut names: list[str] = [] @@ -209,6 +225,29 @@ pub def relation_output_columns(rel: Rel) -> list[str]: return _relation_output_columns(rel) +def _extension_single_output_columns(input_columns: list[str], extension_uri: str, payload: bytes) -> list[str]: + """Return best-effort output columns for known extension-single relation encodings.""" + mut columns: list[str] = [] + columns.extend(input_columns) + # Prefer the RFC 021 payload when present. Some low-level extension-boundary tests still create URI-only nodes, so + # retain the URI fallback below for those structurally valid but less informative relations. + match decode_generator_extension_payload(payload): + Ok(decoded) => + columns.extend(decoded.output_columns) + return columns + Err(_) => pass + if extension_uri == EXPLODE_EXTENSION_URI or extension_uri == EXPLODE_OUTER_EXTENSION_URI: + columns.append("value") + elif extension_uri == POSEXPLODE_EXTENSION_URI or extension_uri == POSEXPLODE_OUTER_EXTENSION_URI: + columns.append("position") + columns.append("value") + elif extension_uri == FLATTEN_EXTENSION_URI: + columns.append("value") + elif extension_uri == INLINE_EXTENSION_URI or extension_uri == INLINE_OUTER_EXTENSION_URI or extension_uri == STACK_EXTENSION_URI: + columns.append("value") + return columns + + pub def aggregate_measure_function_names(rel: Rel) -> list[str]: """Return aggregate function names used by a top-level AggregateRel, otherwise empty.""" match rel.rel_type: @@ -444,3 +483,15 @@ pub def plan_has_extension_urn(plan: Plan, extension_uri: str) -> bool: if urn.urn == extension_uri: return true return false + + +pub def plan_extension_urn_count(plan: Plan) -> int: + """Return the number of extension URN declarations carried by one plan.""" + return len(plan.extension_urns) + + +pub def plan_extension_urn_anchor_at(plan: Plan, index: int) -> int: + """Return one extension URN anchor as an Incan integer for tests and diagnostics.""" + if index < 0 or index >= len(plan.extension_urns): + return -1 + return rust_u32_to_int(plan.extension_urns[index].extension_urn_anchor) diff --git a/src/substrait/mod.incn b/src/substrait/mod.incn index 16e0f38..c310700 100644 --- a/src/substrait/mod.incn +++ b/src/substrait/mod.incn @@ -26,10 +26,13 @@ pub from substrait.relations import ( fetch_rel, filter_rel, filter_rel_of_columns, + generator_rel, + generator_rel_of_columns, join_rel, join_rel_of_kind, project_rel, project_rel_of_columns, + project_rel_with_expressions, read_local_files_rel, read_local_parquet_rel, read_named_table_rel, @@ -58,6 +61,8 @@ pub from substrait.inspect import ( aggregate_measure_invocation_names, aggregate_measure_output_names, aggregate_measure_sort_counts, + plan_extension_urn_anchor_at, + plan_extension_urn_count, plan_contains_relation_kind, plan_has_extension_urn, read_kind_name, @@ -75,7 +80,14 @@ pub from substrait.inspect import ( source_named_table_name, ) pub from substrait.function_extensions import ( - explode_extension_uri, - function_extension_uri, + EXPLODE_EXTENSION_URI, + EXPLODE_OUTER_EXTENSION_URI, + FLATTEN_EXTENSION_URI, + FUNCTION_EXTENSION_URI, + INLINE_EXTENSION_URI, + INLINE_OUTER_EXTENSION_URI, + POSEXPLODE_EXTENSION_URI, + POSEXPLODE_OUTER_EXTENSION_URI, registered_substrait_extension_uris, + STACK_EXTENSION_URI, ) diff --git a/src/substrait/relations.incn b/src/substrait/relations.incn index 849beba..2af68b6 100644 --- a/src/substrait/relations.incn +++ b/src/substrait/relations.incn @@ -46,6 +46,7 @@ from rust::substrait::proto::sort_field import SortDirection, SortKind from aggregate_builders import AggregateKind, AggregateMeasure from function_registry import FunctionClass, FunctionRegistryEntry, SubstraitMappingKind from functions.registry import function_registry_entry +from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment, ScalarFunctionApplicationExpr, col from substrait.expr_lowering import ( bool_expr, @@ -57,6 +58,7 @@ from substrait.expr_lowering import ( string_expr, ) from substrait.errors import SubstraitLoweringError, invalid_scalar_expression +from substrait.generator_payload import encode_generator_extension_payload from substrait.inspect import relation_output_columns from substrait.schema_registry import named_table_base_schema, unknown_named_struct @@ -81,6 +83,15 @@ model ResolvedRelationExpression: expr: Expression +@derive(Clone) +model ResolvedGeneratorApplication: + """One generator application resolved against input columns and registry metadata.""" + + generator: GeneratorApplication + entry: FunctionRegistryEntry + arguments: list[Expression] + + pub enum SubstraitJoinKind: Inner Left @@ -158,7 +169,14 @@ def _rel_reference(reference: ReferenceRel) -> Rel: def _rel_extension_single(input: Rel, extension_uri: str) -> Rel: """Wrap one input relation in an ExtensionSingleRel with the provided URI.""" - detail = Any(type_url=extension_uri, value=[]) + detail = Any(type_url=extension_uri, value=b"") + rel = ExtensionSingleRel(common=Some(_direct_common()), input=Some(Box.new(input)), detail=Some(detail)) + return Rel(rel_type=Some(RelType.ExtensionSingle(Box.new(rel)))) + + +def _rel_extension_single_with_payload(input: Rel, extension_uri: str, payload: bytes) -> Rel: + """Wrap one input relation in an ExtensionSingleRel with relation-extension payload bytes.""" + detail = Any(type_url=extension_uri, value=payload) rel = ExtensionSingleRel(common=Some(_direct_common()), input=Some(Box.new(input)), detail=Some(detail)) return Rel(rel_type=Some(RelType.ExtensionSingle(Box.new(rel)))) @@ -259,6 +277,61 @@ def _validate_aggregate_modifiers(measure: ResolvedAggregateMeasure) -> Result[N return Ok(None) +def _generator_registry_entry(generator: GeneratorApplication) -> Result[FunctionRegistryEntry, SubstraitLoweringError]: + """Resolve one generator registry entry and validate its semantic class.""" + match function_registry_entry(generator.function_ref): + Some(entry) => + if entry.function_class != FunctionClass.Generator: + return Err(invalid_scalar_expression(f"{entry.function_ref} is not registered as a generator function")) + if entry.substrait.kind != SubstraitMappingKind.RelationExtension: + return Err( + invalid_scalar_expression(f"{entry.function_ref} does not declare a relation-extension mapping"), + ) + return Ok(entry) + None => + return Err(invalid_scalar_expression(f"missing generator registry entry for `{generator.canonical_name}`")) + + +def _resolved_generator( + generator: GeneratorApplication, + input_columns: list[str], +) -> Result[ResolvedGeneratorApplication, SubstraitLoweringError]: + """Resolve one generator application against input-column names.""" + _validate_generator_output_columns(input_columns, generator.clone())? + return Ok( + ResolvedGeneratorApplication( + generator=generator.clone(), + entry=_generator_registry_entry(generator.clone())?, + arguments=[scalar_expr(input_columns, argument)? for argument in generator.arguments], + ), + ) + + +def _validate_generator_output_columns( + input_columns: list[str], + generator: GeneratorApplication, +) -> Result[None, SubstraitLoweringError]: + """Validate generator output columns against the current input relation shape.""" + mut output_columns: list[str] = [] + if generator.preserves_input_columns: + output_columns.extend(input_columns) + for output_column in generator.output_columns: + if _contains_text(output_columns, output_column): + return Err( + invalid_scalar_expression(f"generator output column `{output_column}` conflicts with an existing column"), + ) + output_columns.append(output_column) + return Ok(None) + + +def _contains_text(values: list[str], expected: str) -> bool: + """Return whether a string list contains a value.""" + for value in values: + if value == expected: + return true + return false + + def _aggregate_function_reference(measure: ResolvedAggregateMeasure) -> Result[u32, SubstraitLoweringError]: """Resolve one aggregate measure through declaration-side registry metadata.""" match _aggregate_registry_entry(measure): @@ -515,6 +588,19 @@ pub def try_project_rel_of_columns( ) +pub def project_rel_with_expressions(input: Rel, expressions: list[Expression]) -> Rel: + """Append already-lowered Substrait expressions to a relation.""" + # The DataFusion adapter uses this to evaluate generator arguments into temporary columns before unnesting. + return _rel_project( + ProjectRel( + common=Some(_direct_common()), + input=Some(Box.new(input)), + expressions=expressions, + advanced_extension=None, + ), + ) + + pub def join_rel(left: Rel, right: Rel, on_predicate: bool) -> Rel: """Wrap two child relations in an inner `JoinRel`.""" return join_rel_of_kind(left, right, on_predicate, SubstraitJoinKind.Inner) @@ -603,6 +689,38 @@ pub def try_aggregate_rel_of_columns( ) +pub def generator_rel(input: Rel, generator: GeneratorApplication) -> Rel: + """Wrap a child relation in a generator relation-extension node.""" + return _lowered_rel_or_raise(try_generator_rel(input, generator)) + + +pub def try_generator_rel(input: Rel, generator: GeneratorApplication) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in a generator relation-extension node.""" + return try_generator_rel_of_columns(input.clone(), relation_output_columns(input), generator) + + +pub def generator_rel_of_columns(input: Rel, input_columns: list[str], generator: GeneratorApplication) -> Rel: + """Wrap a child relation in a generator relation-extension node using explicit input-column names.""" + return _lowered_rel_or_raise(try_generator_rel_of_columns(input, input_columns, generator)) + + +pub def try_generator_rel_of_columns( + input: Rel, + input_columns: list[str], + generator: GeneratorApplication, +) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in a generator relation-extension node using explicit input-column names.""" + resolved = _resolved_generator(generator, input_columns)? + # Keep generator semantics in Substrait IR. Non-standard generator details travel in the extension payload instead + # of being copied into backend-only side channels. + payload = encode_generator_extension_payload( + resolved.generator.row_count, + resolved.generator.output_columns, + resolved.arguments, + ) + return Ok(_rel_extension_single_with_payload(input, resolved.entry.substrait.uri, payload)) + + pub def sort_rel(input: Rel) -> Rel: """Wrap a child relation in `SortRel` using the first known output column as the default sort key.""" input_columns = relation_output_columns(input.clone()) diff --git a/tests/test_dataset.incn b/tests/test_dataset.incn index a8d82e7..4966967 100644 --- a/tests/test_dataset.incn +++ b/tests/test_dataset.incn @@ -16,19 +16,36 @@ from functions import ( count_distinct, count_if, eq, + explode, + explode_outer, + flatten, float_expr, int_expr, int_lit, + inline, + inline_outer, lit, max, min, mul, + posexplode, + posexplode_outer, str_expr, str_lit, sum, + stack, ) from projection_builders import ColumnExprKind, column_expr_kind, column_expr_name -from substrait.function_extensions import explode_extension_uri +from substrait.function_extensions import ( + EXPLODE_EXTENSION_URI, + EXPLODE_OUTER_EXTENSION_URI, + FLATTEN_EXTENSION_URI, + INLINE_EXTENSION_URI, + INLINE_OUTER_EXTENSION_URI, + POSEXPLODE_EXTENSION_URI, + POSEXPLODE_OUTER_EXTENSION_URI, + STACK_EXTENSION_URI, +) from substrait.inspect import plan_contains_relation_kind, plan_has_extension_urn, relation_kind_name, root_rel from substrait.plans import plan_encoded_len, plan_from_named_table, plan_from_root_relation from substrait.relations import read_named_table_rel @@ -387,7 +404,7 @@ def test_dataset_ops__api_lowered_boundary_facts_stay_stable() -> None: # -- Assert -- assert relation_kind_name(root_rel(joined_plan)) == "JoinRel", "canonical join function should still lower to a JoinRel root" - assert plan_has_extension_urn(exploded_plan, explode_extension_uri()), "explode method should emit the registered extension URI" + assert plan_has_extension_urn(exploded_plan, EXPLODE_EXTENSION_URI), "explode method should emit the registered extension URI" def test_lazy_frame__immutable_branching_and_origin_mapping_hold() -> None: @@ -421,6 +438,14 @@ def test_lazy_frame__independent_roots_can_join_and_lower() -> None: def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None: # -- Arrange -- _register_order_schema("orders") + register_named_table_schema( + "orders_generator_dataset", + [RowColumnSpec(name="id", kind=SubstraitPrimitiveKind.I64, nullable=false), RowColumnSpec( + name="line_items", + kind=SubstraitPrimitiveKind.String, + nullable=true, + )], + ) projected: LazyFrame[Order] = lazy_frame_named_table("orders").select() grouped: LazyFrame[Order] = lazy_frame_named_table("orders").group_by([col("id")]) @@ -429,6 +454,30 @@ def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None ordered: LazyFrame[Order] = lazy_frame_named_table("orders").order_by([col("id")]) limited: LazyFrame[Order] = lazy_frame_named_table("orders").limit(10) exploded: LazyFrame[Order] = lazy_frame_named_table("orders").explode() + generated: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + explode(col("line_items"), "line_item"), + ) + generated_outer: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + explode_outer(col("line_items"), "line_item"), + ) + generated_positional: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + posexplode(col("line_items"), "position", "line_item"), + ) + generated_positional_outer: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + posexplode_outer(col("line_items"), "position", "line_item"), + ) + generated_inline: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + inline(col("line_items"), ["sku", "quantity"]), + ) + generated_inline_outer: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + inline_outer(col("line_items"), ["sku", "quantity"]), + ) + generated_flatten: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + flatten(col("line_items"), "line_item"), + ) + generated_stack: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + stack(2, [str_lit("sku_a"), int_lit(1), str_lit("sku_b"), int_lit(2)], ["sku", "quantity"]), + ) # -- Assert -- assert relation_kind_name(root_rel(projected.to_substrait_plan())) == "ProjectRel", "select should lower through the project boundary shape" @@ -436,7 +485,17 @@ def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None assert relation_kind_name(root_rel(aggregated.to_substrait_plan())) == "AggregateRel", "agg should lower to AggregateRel" assert relation_kind_name(root_rel(ordered.to_substrait_plan())) == "SortRel", "order_by should lower to SortRel" assert relation_kind_name(root_rel(limited.to_substrait_plan())) == "FetchRel", "limit should lower to FetchRel" - assert plan_has_extension_urn(exploded.to_substrait_plan(), explode_extension_uri()), "explode should keep emitting the registered extension boundary" + assert plan_has_extension_urn(exploded.to_substrait_plan(), EXPLODE_EXTENSION_URI), "explode should keep emitting the registered extension boundary" + assert relation_kind_name(root_rel(generated.to_substrait_plan())) == "ExtensionSingleRel", "generate should lower through the relation extension boundary" + assert generated.planned_columns() == ["id", "line_items", "line_item"], "generate should append declared output aliases" + assert plan_has_extension_urn(generated_outer.to_substrait_plan(), EXPLODE_OUTER_EXTENSION_URI), "outer explode should use its relation extension URI" + assert plan_has_extension_urn(generated_positional.to_substrait_plan(), POSEXPLODE_EXTENSION_URI), "posexplode should use its relation extension URI" + assert plan_has_extension_urn(generated_positional_outer.to_substrait_plan(), POSEXPLODE_OUTER_EXTENSION_URI), "posexplode_outer should use its relation extension URI" + assert generated_inline.planned_columns() == ["id", "line_items", "sku", "quantity"], "inline should append all declared struct output aliases" + assert plan_has_extension_urn(generated_inline.to_substrait_plan(), INLINE_EXTENSION_URI), "inline should use its relation extension URI" + assert plan_has_extension_urn(generated_inline_outer.to_substrait_plan(), INLINE_OUTER_EXTENSION_URI), "inline_outer should use its relation extension URI" + assert plan_has_extension_urn(generated_flatten.to_substrait_plan(), FLATTEN_EXTENSION_URI), "flatten should use its relation extension URI" + assert plan_has_extension_urn(generated_stack.to_substrait_plan(), STACK_EXTENSION_URI), "stack should use its relation extension URI" def test_lazy_frame__deeper_independent_roots_still_lower_with_stable_shapes() -> None: @@ -482,4 +541,4 @@ def test_lazy_frame__canonical_rewrite_keeps_extension_boundary_for_explode() -> plan = exploded.to_substrait_plan() # -- Assert -- - assert plan_has_extension_urn(plan, explode_extension_uri()), "canonical rewrites must keep explode extension URI emission intact" + assert plan_has_extension_urn(plan, EXPLODE_EXTENSION_URI), "canonical rewrites must keep explode extension URI emission intact" diff --git a/tests/test_function_registry.incn b/tests/test_function_registry.incn index 129158b..c448202 100644 --- a/tests/test_function_registry.incn +++ b/tests/test_function_registry.incn @@ -19,6 +19,7 @@ from functions import ( array_intersect, array_join, array_position, + array_range, array_reverse, array_slice, array_sort, @@ -45,6 +46,9 @@ from functions import ( eq, equal_null, element_at, + explode, + explode_outer, + flatten, floor, float_expr, function_registry_canonical_names, @@ -62,6 +66,8 @@ from functions import ( is_not_nan, is_not_null, is_null, + inline, + inline_outer, lit, lt, lte, @@ -81,12 +87,15 @@ from functions import ( not_, nullif, or_, + posexplode, + posexplode_outer, registered_substrait_mapped_function_refs, round, str_expr, str_lit, sub, sum, + stack, try_cast, ) from function_registry import ( @@ -162,9 +171,18 @@ from substrait.function_extensions import ( NULLIF_FUNCTION_ANCHOR, OR_FUNCTION_ANCHOR, ROUND_FUNCTION_ANCHOR, + STACK_EXTENSION_URI, SUBTRACT_FUNCTION_ANCHOR, SUM_FUNCTION_ANCHOR, - function_extension_uri, + EXPLODE_EXTENSION_URI, + EXPLODE_OUTER_EXTENSION_URI, + FLATTEN_EXTENSION_URI, + FUNCTION_EXTENSION_URI, + INLINE_EXTENSION_URI, + INLINE_OUTER_EXTENSION_URI, + POSEXPLODE_EXTENSION_URI, + POSEXPLODE_OUTER_EXTENSION_URI, + RANGE_FUNCTION_ANCHOR, ) @@ -223,12 +241,12 @@ def _local_entry_by_namespace_and_name_or_fail( def _expected_registry_names() -> list[str]: """Return the expected registered public helper names.""" - return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_contains_key", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct"] + return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "range", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_contains_key", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "explode", "explode_outer", "posexplode", "posexplode_outer", "inline", "inline_outer", "flatten", "stack"] def _expected_substrait_mapped_names() -> list[str]: """Return helpers with concrete Substrait extension-function mappings.""" - return ["sum", "count", "avg", "min", "max", "add", "mul", "eq", "gt", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "between", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct"] + return ["sum", "count", "avg", "min", "max", "add", "mul", "eq", "gt", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "between", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "range", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct"] def _exercise_current_public_helpers() -> None: @@ -301,6 +319,7 @@ def _exercise_current_public_helpers() -> None: array_intersect(tags, backup_tags) array_join(tags, str_lit("|")) array_position(tags, str_lit("paid")) + array_range(int_lit(0), cardinality(tags)) array_reverse(tags) array_slice(tags, int_lit(1), int_lit(2)) array_sort(tags) @@ -314,6 +333,14 @@ def _exercise_current_public_helpers() -> None: map_keys(attr_map) map_values(attr_map) named_struct(["status", "amount"], [status, amount]) + explode(tags, "tag") + explode_outer(tags, "tag") + posexplode(tags, "position", "tag") + posexplode_outer(tags, "position", "tag") + inline(array([named_struct(["status", "amount"], [status, amount])]), ["status", "amount"]) + inline_outer(array([named_struct(["status", "amount"], [status, amount])]), ["status", "amount"]) + flatten(tags, "tag") + stack(2, [status, amount, str_lit("fallback"), int_lit(0)], ["status", "amount"]) return @@ -342,11 +369,20 @@ def _assert_extension_mapping(canonical_name: str, function_name: str, anchor: u mapped_refs = registered_substrait_mapped_function_refs() assert _contains_text(mapped_refs, function_ref_for(canonical_name)), f"{canonical_name} should be in the Substrait extension mapping set" assert entry.substrait.kind == SubstraitMappingKind.ExtensionFunction, f"{canonical_name} should use a Substrait extension function" - assert entry.substrait.uri == function_extension_uri(), f"{canonical_name} should use the shared function extension URI" + assert entry.substrait.uri == FUNCTION_EXTENSION_URI, f"{canonical_name} should use the shared function extension URI" assert entry.substrait.function_name == function_name, f"{canonical_name} should use the registered extension name" assert entry.substrait.anchor == anchor, f"{canonical_name} should carry the stable Substrait anchor" +def _assert_relation_extension_mapping(canonical_name: str, function_name: str, extension_uri: str) -> None: + """Assert one generator helper declares a relation-extension mapping.""" + entry = _entry_or_fail(function_ref_for(canonical_name)) + assert entry.function_class == FunctionClass.Generator, f"{canonical_name} should be classified as a generator" + assert entry.substrait.kind == SubstraitMappingKind.RelationExtension, f"{canonical_name} should use a relation extension" + assert entry.substrait.uri == extension_uri, f"{canonical_name} should carry the registered relation extension URI" + assert entry.substrait.function_name == function_name, f"{canonical_name} should use the registered extension name" + + def _assert_core_mapping(canonical_name: str, function_name: str) -> None: """Assert one helper declares the expected built-in Substrait Rex mapping.""" entry = _entry_or_fail(function_ref_for(canonical_name)) @@ -601,6 +637,7 @@ def test_function_registry__substrait_extension_mappings_are_structured() -> Non _assert_extension_mapping("array_intersect", "array_intersect", ARRAY_INTERSECT_FUNCTION_ANCHOR) _assert_extension_mapping("array_join", "array_to_string", ARRAY_TO_STRING_FUNCTION_ANCHOR) _assert_extension_mapping("array_position", "array_position", ARRAY_POSITION_FUNCTION_ANCHOR) + _assert_extension_mapping("range", "range", RANGE_FUNCTION_ANCHOR) _assert_extension_mapping("array_reverse", "array_reverse", ARRAY_REVERSE_FUNCTION_ANCHOR) _assert_extension_mapping("array_slice", "array_slice", ARRAY_SLICE_FUNCTION_ANCHOR) _assert_extension_mapping("array_sort", "array_sort", ARRAY_SORT_FUNCTION_ANCHOR) @@ -616,6 +653,22 @@ def test_function_registry__substrait_extension_mappings_are_structured() -> Non _assert_extension_mapping("named_struct", "named_struct", NAMED_STRUCT_FUNCTION_ANCHOR) +def test_function_registry__generator_helpers_are_relation_extensions() -> None: + """Assert generator helpers are registry entries without scalar or aggregate extension anchors.""" + # -- Arrange -- + _exercise_current_public_helpers() + + # -- Act / Assert -- + _assert_relation_extension_mapping("explode", "explode", EXPLODE_EXTENSION_URI) + _assert_relation_extension_mapping("explode_outer", "explode_outer", EXPLODE_OUTER_EXTENSION_URI) + _assert_relation_extension_mapping("posexplode", "posexplode", POSEXPLODE_EXTENSION_URI) + _assert_relation_extension_mapping("posexplode_outer", "posexplode_outer", POSEXPLODE_OUTER_EXTENSION_URI) + _assert_relation_extension_mapping("inline", "inline", INLINE_EXTENSION_URI) + _assert_relation_extension_mapping("inline_outer", "inline_outer", INLINE_OUTER_EXTENSION_URI) + _assert_relation_extension_mapping("flatten", "flatten", FLATTEN_EXTENSION_URI) + _assert_relation_extension_mapping("stack", "stack", STACK_EXTENSION_URI) + + def test_function_registry__ordering_helpers_are_contextual_sort_fields() -> None: """Assert RFC 015 ordering helpers are modeled as sort-field context helpers.""" # -- Arrange -- diff --git a/tests/test_generator_functions.incn b/tests/test_generator_functions.incn new file mode 100644 index 0000000..e91e768 --- /dev/null +++ b/tests/test_generator_functions.incn @@ -0,0 +1,88 @@ +"""Tests for registry-backed generator and table-valued function builders.""" + +from std.testing import assert_raises +from generator_builders import GeneratorKind, generator_output_columns, generator_primary_output_column +from functions import col, explode, explode_outer, flatten, inline, inline_outer, posexplode, posexplode_outer, stack + + +def test_generator_functions__explode_family_builds_relation_applications() -> None: + # -- Arrange -- + items = col("line_items") + + # -- Act -- + inner = explode(items, "line_item") + outer = explode_outer(items, "line_item") + positional = posexplode(items, "position", "line_item") + positional_outer = posexplode_outer(items, "position", "line_item") + flattened = flatten(items, "line_item") + inlined = inline(items, ["sku", "quantity"]) + inlined_outer = inline_outer(items, ["sku", "quantity"]) + stacked = stack(2, [col("left_a"), col("right_a"), col("left_b"), col("right_b")], ["left", "right"]) + + # -- Assert -- + assert inner.kind == GeneratorKind.Explode + assert outer.kind == GeneratorKind.ExplodeOuter + assert positional.kind == GeneratorKind.PosExplode + assert positional_outer.kind == GeneratorKind.PosExplodeOuter + assert flattened.kind == GeneratorKind.Flatten + assert inlined.kind == GeneratorKind.Inline + assert inlined_outer.kind == GeneratorKind.InlineOuter + assert stacked.kind == GeneratorKind.Stack + assert not inner.is_outer + assert outer.is_outer + assert not inlined.is_outer + assert inlined_outer.is_outer + assert positional.position_origin == 0 + assert positional_outer.position_origin == 0 + assert stacked.row_count == 2 + assert generator_primary_output_column(inner) == "line_item" + assert generator_primary_output_column(positional) == "line_item" + assert generator_primary_output_column(inlined) == "quantity" + + +def test_generator_functions__output_columns_preserve_input_then_append_aliases() -> None: + # -- Arrange -- + input_columns = ["id", "line_items"] + + # -- Act -- + exploded_columns = generator_output_columns(input_columns, explode(col("line_items"), "line_item")) + positional_columns = generator_output_columns(input_columns, posexplode(col("line_items"), "position", "line_item")) + inline_columns = generator_output_columns(input_columns, inline(col("line_items"), ["sku", "quantity"])) + stack_columns = generator_output_columns( + input_columns, + stack(2, [col("left_a"), col("right_a"), col("left_b"), col("right_b")], ["left", "right"]), + ) + + # -- Assert -- + assert exploded_columns == ["id", "line_items", "line_item"] + assert positional_columns == ["id", "line_items", "position", "line_item"] + assert inline_columns == ["id", "line_items", "sku", "quantity"] + assert stack_columns == ["id", "line_items", "left", "right"] + + +def _call_generator_with_input_collision() -> None: + """Call generator output inference with a generated name that collides with input.""" + generator_output_columns(["id", "line_items"], explode(col("line_items"), "id")) + return + + +def test_generator_functions__output_alias_collisions_are_rejected() -> None: + # -- Arrange -- + call = _call_generator_with_input_collision + + # -- Act / Assert -- + assert_raises[ValueError](call) + + +def _call_stack_with_invalid_shape() -> None: + """Call stack with a value count that does not fill the declared rows and columns.""" + stack(2, [col("a"), col("b"), col("c")], ["left", "right"]) + return + + +def test_generator_functions__stack_shape_is_validated() -> None: + # -- Arrange -- + call = _call_stack_with_invalid_shape + + # -- Act / Assert -- + assert_raises[ValueError](call) diff --git a/tests/test_prism.incn b/tests/test_prism.incn index 83334e5..785f870 100644 --- a/tests/test_prism.incn +++ b/tests/test_prism.incn @@ -1,9 +1,10 @@ """Internal Prism engine tests against the shared-store cursor substrate.""" -from functions import always_false, always_true, col, count, lit, mul, sum +from functions import always_false, always_true, col, count, count_expr, explode, lit, mul, sum from prism import ( PrismCursor, prism_cursor_apply_filter, + prism_cursor_apply_generate, prism_cursor_apply_limit, prism_cursor_apply_select, prism_cursor_authored_node_count, @@ -35,6 +36,17 @@ def _register_projection_test_schema(table_name: str) -> None: register_named_table_schema(table_name, [RowColumnSpec(name="id", kind=SubstraitPrimitiveKind.I64, nullable=false)]) +def _register_generator_test_schema(table_name: str) -> None: + register_named_table_schema( + table_name, + [RowColumnSpec(name="id", kind=SubstraitPrimitiveKind.I64, nullable=false), RowColumnSpec( + name="line_items", + kind=SubstraitPrimitiveKind.String, + nullable=true, + )], + ) + + def test_prism__branching_keeps_base_reachable_history_small() -> None: # -- Arrange -- base: PrismCursor[Order] = prism_cursor_named_table(str("orders")) @@ -206,6 +218,7 @@ def test_prism__cross_store_adoption_keeps_distinct_aggregate_modifier_state() - def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: # -- Arrange -- _register_projection_test_schema(str("orders")) + _register_generator_test_schema(str("orders_generator_prism")) projected: PrismCursor[Order] = prism_cursor_named_table(str("orders")).select() grouped: PrismCursor[Order] = prism_cursor_named_table(str("orders")).group_by([col("id")]) @@ -214,6 +227,9 @@ def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: ordered: PrismCursor[Order] = prism_cursor_named_table(str("orders")).order_by([col("id")]) limited: PrismCursor[Order] = prism_cursor_named_table(str("orders")).limit(10) exploded: PrismCursor[Order] = prism_cursor_named_table(str("orders")).explode() + generated: PrismCursor[Order] = prism_cursor_named_table(str("orders_generator_prism")).generate( + explode(col("line_items"), "line_item"), + ) # -- Assert -- assert prism_cursor_tip_kind_name(projected) == str("Project"), "select should append a native project node" @@ -222,6 +238,8 @@ def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: assert prism_cursor_tip_kind_name(ordered) == str("OrderBy"), "order_by should append a native sort node" assert prism_cursor_tip_kind_name(limited) == str("Limit"), "limit should append a native limit node" assert prism_cursor_tip_kind_name(exploded) == str("Explode"), "explode should append a native explode node" + assert prism_cursor_tip_kind_name(generated) == str("Generate"), "generate should append a native generator node" + assert prism_cursor_output_columns(generated) == ["id", "line_items", "line_item"], "generate should append declared output aliases" def test_prism__rewrite_eliminates_filter_true_by_default() -> None: @@ -332,3 +350,21 @@ def test_prism__cursor_methods_match_apply_helpers() -> None: assert relation_kind_name(root_rel(via_methods.to_substrait_plan())) == relation_kind_name( root_rel(via_helpers.to_substrait_plan()), ), "method and helper paths should lower to equivalent root relation kinds" + + +def test_prism__generate_method_matches_apply_helper() -> None: + # -- Arrange -- + _register_generator_test_schema("orders_generator_apply") + base: PrismCursor[Order] = prism_cursor_named_table(str("orders_generator_apply")) + generator = explode(col("line_items"), "line_item") + + # -- Act -- + via_method = base.generate(generator) + via_helper = prism_cursor_apply_generate(base, generator) + + # -- Assert -- + assert prism_cursor_tip_kind_name(via_method) == prism_cursor_tip_kind_name(via_helper), "method and helper paths should produce the same generator node kind" + assert prism_cursor_output_columns(via_method) == ["id", "line_items", "line_item"], "generator helper should preserve planned output columns" + assert relation_kind_name(root_rel(via_method.to_substrait_plan())) == relation_kind_name( + root_rel(via_helper.to_substrait_plan()), + ), "generator method and helper paths should lower to equivalent root relation kinds" diff --git a/tests/test_session_generators.incn b/tests/test_session_generators.incn new file mode 100644 index 0000000..e6bcfc7 --- /dev/null +++ b/tests/test_session_generators.incn @@ -0,0 +1,251 @@ +"""End-to-end Session generator execution tests over the DataFusion backend.""" + +from dataset import DataFrame, LazyFrame +from functions import ( + array, + col, + eq, + explode, + explode_outer, + flatten, + inline, + inline_outer, + lit, + named_struct, + posexplode, + posexplode_outer, + stack, +) +from session import Session +from std.testing import assert_is_ok, fail_t +from projection_builders import ColumnExpr + + +@derive(Clone) +pub model AggregateOrder: + pub customer_id: str + pub amount: int + + +const AGGREGATE_ORDERS_CSV_FIXTURE: str = "tests/fixtures/aggregate_orders.csv" + + +def _collect_or_fail(mut session: Session, generated: LazyFrame[AggregateOrder]) -> DataFrame[AggregateOrder]: + """Collect a generated aggregate-order frame or fail with the backend diagnostic.""" + match session.collect(generated): + Ok(df) => return df + Err(err) => return fail_t(err.error_message()) + + +def _preview_line_contains_all(line: str, expected_cells: list[str]) -> bool: + """Return whether one rendered preview row contains every expected cell value.""" + for cell in expected_cells: + if not line.contains(cell): + return false + return true + + +def _assert_preview_row_contains(payload: str, expected_cells: list[str], context: str) -> None: + """Assert one rendered preview row carries the expected materialized cells together.""" + for line in payload.split("\n"): + if _preview_line_contains_all(line, expected_cells): + return + return fail_t(context) + + +def _aggregate_orders(mut session: Session) -> LazyFrame[AggregateOrder]: + """Register and return the aggregate-order fixture.""" + return assert_is_ok( + session.read_csv("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE), + "aggregate orders fixture should load", + ) + + +def _tags() -> ColumnExpr: + """Return a simple two-element string array expression used by generator tests.""" + return array([lit("paid"), col("customer_id")]) + + +def test_session_generators__collect_executes_explode_family() -> None: + # -- Arrange -- + mut session = Session.default() + lazy = _aggregate_orders(session) + + # -- Act -- + df = _collect_or_fail(session, lazy.generate(explode(_tags(), "tag"))) + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 6, "explode should emit two generated rows per fixture input row" + assert resolved == ["customer_id", "amount", "tag"], "explode should append one generated value column" + _assert_preview_row_contains( + payload, + ["A", "10", "paid"], + "explode should materialize the literal tag for customer A", + ) + _assert_preview_row_contains(payload, ["A", "10", "A"], "explode should materialize the customer tag for customer A") + _assert_preview_row_contains(payload, ["B", "7", "B"], "explode should materialize the customer tag for customer B") + + +def test_session_generators__collect_executes_explode_outer_family() -> None: + # -- Arrange -- + mut session = Session.default() + lazy = _aggregate_orders(session) + + # -- Act -- + df = _collect_or_fail(session, lazy.generate(explode_outer(_tags(), "tag"))) + payload = df.preview_text() + + # -- Assert -- + assert df.row_count() == 6, "explode_outer should match explode for non-empty fixture arrays" + _assert_preview_row_contains( + payload, + ["A", "10", "paid"], + "explode_outer should materialize the literal tag for customer A", + ) + _assert_preview_row_contains( + payload, + ["B", "7", "B"], + "explode_outer should materialize the customer tag for customer B", + ) + + +def test_session_generators__collect_executes_positional_explode_family() -> None: + # -- Arrange -- + mut session = Session.default() + lazy = _aggregate_orders(session) + + # -- Act -- + df = _collect_or_fail(session, lazy.generate(posexplode(_tags(), "position", "tag"))) + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 6, "posexplode should emit two generated rows per fixture input row" + assert resolved == ["customer_id", "amount", "position", "tag"], "posexplode should append position then value columns" + _assert_preview_row_contains( + payload, + ["A", "10", "0", "paid"], + "posexplode should materialize zero-based position 0", + ) + _assert_preview_row_contains(payload, ["A", "10", "1", "A"], "posexplode should materialize zero-based position 1") + + +def test_session_generators__collect_executes_positional_outer_explode_family() -> None: + # -- Arrange -- + mut session = Session.default() + lazy = _aggregate_orders(session) + + # -- Act -- + df = _collect_or_fail(session, lazy.generate(posexplode_outer(_tags(), "position", "tag"))) + payload = df.preview_text() + + # -- Assert -- + assert df.row_count() == 6, "posexplode_outer should match posexplode for non-empty fixture arrays" + _assert_preview_row_contains(payload, ["B", "7", "1", "B"], "posexplode_outer should materialize positional rows") + + +def test_session_generators__collect_executes_portable_flatten() -> None: + # -- Arrange -- + mut session = Session.default() + lazy = _aggregate_orders(session) + + # -- Act -- + df = _collect_or_fail(session, lazy.generate(flatten(_tags(), "tag"))) + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 6, "flatten should emit one generated row per array element" + assert resolved == ["customer_id", "amount", "tag"], "flatten should append one generated value column" + _assert_preview_row_contains(payload, ["A", "10", "paid"], "flatten should materialize the literal tag") + + +def test_session_generators__collect_executes_inline_family() -> None: + # -- Arrange -- + mut session = Session.default() + lazy = _aggregate_orders(session) + rows = array( + [named_struct(["sku", "quantity"], [lit("A"), lit(1)]), named_struct(["sku", "quantity"], [lit("B"), lit(2)])], + ) + + # -- Act -- + df = _collect_or_fail(session, lazy.generate(inline(rows, ["sku", "quantity"]))) + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 6, "inline should emit one row per struct element" + assert resolved == ["customer_id", "amount", "sku", "quantity"], "inline should append declared struct output columns" + _assert_preview_row_contains(payload, ["A", "10", "A", "1"], "inline should materialize the first struct row") + _assert_preview_row_contains(payload, ["A", "10", "B", "2"], "inline should materialize the second struct row") + + +def test_session_generators__collect_executes_inline_outer_family() -> None: + # -- Arrange -- + mut session = Session.default() + lazy = _aggregate_orders(session) + rows = array( + [named_struct(["sku", "quantity"], [lit("A"), lit(1)]), named_struct(["sku", "quantity"], [lit("B"), lit(2)])], + ) + + # -- Act -- + df = _collect_or_fail(session, lazy.generate(inline_outer(rows, ["sku", "quantity"]))) + payload = df.preview_text() + + # -- Assert -- + assert df.row_count() == 6, "inline_outer should match inline for non-empty fixture arrays" + _assert_preview_row_contains(payload, ["B", "7", "B", "2"], "inline_outer should materialize struct rows") + + +def test_session_generators__collect_executes_stack() -> None: + # -- Arrange -- + mut session = Session.default() + lazy = _aggregate_orders(session) + + # -- Act -- + df = _collect_or_fail( + session, + lazy.generate(stack(2, [col("customer_id"), col("amount"), lit("fixed"), lit(99)], ["label", "value"])), + ) + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 6, "stack should emit the declared number of rows per fixture input row" + assert resolved == ["customer_id", "amount", "label", "value"], "stack should append all declared output columns" + _assert_preview_row_contains(payload, ["A", "10", "A", "10"], "stack should materialize row-major source values") + _assert_preview_row_contains( + payload, + ["A", "10", "fixed", "99"], + "stack should materialize row-major literal values", + ) + + +def test_session_generators__generated_relations_compose_with_limit() -> None: + # -- Arrange -- + mut session = Session.default() + lazy = _aggregate_orders(session) + + # -- Act -- + df = _collect_or_fail(session, lazy.generate(explode(_tags(), "tag")).limit(4)) + payload = df.preview_text() + + # -- Assert -- + assert df.row_count() == 4, "limit after generate should execute over generated rows" + _assert_preview_row_contains(payload, ["A", "10", "paid"], "composed generator plan should keep first generated row") + + +def test_session_generators__generated_columns_can_feed_filter() -> None: + # -- Arrange -- + mut session = Session.default() + lazy = _aggregate_orders(session) + + # -- Act -- + df = _collect_or_fail(session, lazy.generate(explode(_tags(), "tag")).filter(eq(col("tag"), lit("paid")))) + payload = df.preview_text() + + # -- Assert -- + assert df.row_count() == 3, "filter after generate should evaluate against generated columns" + _assert_preview_row_contains(payload, ["A", "10", "paid"], "filter should retain generated rows with the paid tag") diff --git a/tests/test_substrait_plan.incn b/tests/test_substrait_plan.incn index 3bd3b3b..da38c46 100644 --- a/tests/test_substrait_plan.incn +++ b/tests/test_substrait_plan.incn @@ -73,15 +73,24 @@ from projection_builders import ColumnExpr, with_column_assignment from substrait.errors import SubstraitLoweringErrorKind from substrait.expr_lowering import scalar_expr from substrait.function_extensions import ( - explode_extension_uri, - function_extension_uri, + EXPLODE_EXTENSION_URI, + EXPLODE_OUTER_EXTENSION_URI, + FLATTEN_EXTENSION_URI, + FUNCTION_EXTENSION_URI, + INLINE_EXTENSION_URI, + INLINE_OUTER_EXTENSION_URI, + POSEXPLODE_EXTENSION_URI, + POSEXPLODE_OUTER_EXTENSION_URI, registered_substrait_extension_uris, + STACK_EXTENSION_URI, ) from substrait.inspect import ( aggregate_measure_filter_flags, aggregate_measure_function_names, aggregate_measure_invocation_names, aggregate_measure_sort_counts, + plan_extension_urn_anchor_at, + plan_extension_urn_count, plan_contains_relation_kind, plan_has_extension_urn, read_kind_name, @@ -588,15 +597,24 @@ def test_plan__reference_rel_preserves_subtree_ordinal() -> None: def test_plan__extension_urns_are_surfaced() -> None: # -- Arrange -- - extension_uri = explode_extension_uri() + extension_uri = EXPLODE_EXTENSION_URI rel = extension_single_rel(read_named_table_rel("orders"), extension_uri) + nested = extension_single_rel( + extension_single_rel(read_named_table_rel("orders"), EXPLODE_EXTENSION_URI), + POSEXPLODE_EXTENSION_URI, + ) # -- Act -- plan = plan_from_root_relation(rel, ["id"]) + nested_plan = plan_from_root_relation(nested, ["id", "position", "value"]) # -- Assert -- assert plan_has_extension_urn(plan, extension_uri), "extension relation should populate extension URNs" assert plan_contains_relation_kind(plan, "ExtensionSingleRel"), "extension root should remain inspectable" + assert plan_has_extension_urn(nested_plan, EXPLODE_EXTENSION_URI), "nested extension plans should include child extension URNs" + assert plan_has_extension_urn(nested_plan, POSEXPLODE_EXTENSION_URI), "nested extension plans should include root extension URNs" + assert plan_extension_urn_count(nested_plan) == 2, "different relation extension URIs should be declared once each" + assert plan_extension_urn_anchor_at(nested_plan, 0) != plan_extension_urn_anchor_at(nested_plan, 1), "relation extension URNs should use distinct anchors" def test_plan__revision_pin_and_extension_registry_are_exported() -> None: @@ -610,9 +628,16 @@ def test_plan__revision_pin_and_extension_registry_are_exported() -> None: # -- Assert -- assert tag == "v0.63.0", "revision helpers should expose the currently targeted Substrait release tag" assert producer == "inql-rfc002", "revision helpers should expose the package producer label" - assert len(registered) == 2, "current package boundary should register both extension URIs" - assert registered[0] == function_extension_uri(), "registry should include the shared function extension URI first" - assert registered[1] == explode_extension_uri(), "registry should include the emitted explode extension URI" + assert len(registered) == 9, "current package boundary should register function and generator extension URIs" + assert registered[0] == FUNCTION_EXTENSION_URI, "registry should include the shared function extension URI first" + assert registered[1] == EXPLODE_EXTENSION_URI, "registry should include the emitted explode extension URI" + assert registered[2] == EXPLODE_OUTER_EXTENSION_URI, "registry should include the outer explode extension URI" + assert registered[3] == POSEXPLODE_EXTENSION_URI, "registry should include the positional explode extension URI" + assert registered[4] == POSEXPLODE_OUTER_EXTENSION_URI, "registry should include the outer positional explode extension URI" + assert registered[5] == INLINE_EXTENSION_URI, "registry should include the inline extension URI" + assert registered[6] == INLINE_OUTER_EXTENSION_URI, "registry should include the outer inline extension URI" + assert registered[7] == FLATTEN_EXTENSION_URI, "registry should include the flatten extension URI" + assert registered[8] == STACK_EXTENSION_URI, "registry should include the stack extension URI" def test_conformance__core_scenarios_validate_emission_output() -> None: