diff --git a/docs/language/reference/dataset_methods.md b/docs/language/reference/dataset_methods.md index 9ee0963..ab4926e 100644 --- a/docs/language/reference/dataset_methods.md +++ b/docs/language/reference/dataset_methods.md @@ -19,7 +19,7 @@ The Substrait helper surface behind these methods is split by semantic role: | `with_column` | `def with_column(self, name: str, expr: ColumnExpr) -> Self` | Add or replace one projected column using a scalar expression. | | `group_by` | `def group_by(self, columns: list[ColumnExpr]) -> Self` | Define grouping keys using scalar expressions. | | `agg` | `def agg(self, measures: list[AggregateMeasure]) -> Self` | Apply aggregate measures over the current relation or current grouping. | -| `order_by` | `def order_by(self) -> Self` | Preserve order-planning shape for the package sort boundary. | +| `order_by` | `def order_by(self, columns: list[ColumnExpr]) -> Self` | Sort rows by scalar expressions or ordering helpers such as `asc(...)` and `desc(...)`. | | `limit` | `def limit(self, n: int) -> Self` | Cap row count. | | `explode` | `def explode(self) -> Self` | Expand a nested list column into rows. | diff --git a/docs/language/reference/execution_context.md b/docs/language/reference/execution_context.md index 1ac5099..7ea00d9 100644 --- a/docs/language/reference/execution_context.md +++ b/docs/language/reference/execution_context.md @@ -7,6 +7,8 @@ This page documents the public execution surface in the InQL package. Normative - `Session` is the public execution context for registration, binding, execution, collection, and writes. - `SessionBuilder` configures a `Session` before construction. - `SessionError` is the typed error surface for registration, planning, execution, materialization, and sink failures. +- `BackendSelection` is the portable backend selection envelope stored by a session. +- `BackendOption` carries adapter-specific configuration without adding one field per backend to `Session`. - `backends.DataFusion()` is the current reference backend configuration entry point. ## Construction @@ -15,6 +17,7 @@ This page documents the public execution surface in the InQL package. Normative | ------------------------------------------------------------------ | ------------------------------------------------------------------- | | `Session.default()` | Create a session with the default backend and default configuration | | `Session.builder()` | Create a builder for backend selection and configuration | +| `Session.builder().with_backend(selection).build()` | Build a session from a portable backend-selection envelope | | `Session.builder().with_datafusion(backends.DataFusion()).build()` | Build an explicit DataFusion-backed session | ## Read and registration surface @@ -74,7 +77,7 @@ If no active session exists when a convenience API needs one, the operation fail ## Backend note -DataFusion is the implemented execution backend. The public builder/configuration surface is designed so additional backends can be added without changing the `Session` entry point. +DataFusion is the implemented execution backend. `Session` stores a backend kind plus encoded options, lowers work to Substrait, and dispatches through an internal backend adapter boundary. DataFusion is the first adapter behind that boundary; it is not the shape of the `Session` state. ## Related docs diff --git a/docs/language/reference/functions/index.md b/docs/language/reference/functions/index.md index fd2827b..38d7cb2 100644 --- a/docs/language/reference/functions/index.md +++ b/docs/language/reference/functions/index.md @@ -10,18 +10,25 @@ Today the concrete shipped surfaces are documented here: The canonical scalar literal helper is `lit(...)`. Typed literal helpers construct the same scalar-expression representation. -The current public helper surface is also registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, while the concrete public helper entries are produced by `FUNCTION_REGISTRY.add(...)` decorators in `src/functions.incn`. Each entry exposes a stable function reference such as `inql.functions.col`, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), signature facts, function class, null behavior, alias policy, and Substrait mapping metadata. +The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, ordering, and aggregates. Each runtime entry exposes a stable function reference such as `inql.functions.col`, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), function class, null behavior, alias policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. -The registry is the source for machine-readable function facts. Docstrings remain human-facing explanation, while argument names, type rules, lifecycle facts, and Substrait mappings come from typed registry metadata and public helper signatures. The `registry-metadata` check validates that runtime registry entries produced by decorators still agree with checked API metadata for decorator canonical names, argument names, argument types, and return types. This matters for generated docs, diagnostics, Prism lowering, and backend capability checks as the catalog grows. +The registry is the source for non-derivable machine facts. Public helper declarations are the source for argument names, argument types, and return types. Docstrings remain human-facing explanation, examples, and parameter intent. The `registry-metadata` check validates the checked API metadata projections produced from public facade aliases, registry decorators, and decorated callable signatures. Runtime registry entries are lazy and process-local: they support helper execution and lowering for loaded helpers, while the complete public catalog comes from checked metadata. This matters for generated docs, diagnostics, Prism lowering, and backend capability checks as the catalog grows. -The first registered helpers are: +The registered helper surface currently includes: | Function | Registry class | Mapping | | --- | --- | --- | | `col(...)` | scalar | deterministic field-reference rewrite | | `lit(...)`, `int_expr(...)`, `float_expr(...)`, `str_expr(...)`, `bool_expr(...)`, `int_lit(...)`, `str_lit(...)`, `bool_lit(...)` | scalar | deterministic literal rewrites | -| `add(...)`, `mul(...)`, `eq(...)`, `gt(...)` | scalar | registered Substrait extension functions | | `always_true()`, `always_false()` | scalar | deterministic boolean-literal rewrites | +| `cast(...)`, `try_cast(...)` | scalar | built-in Substrait `Cast` Rex shapes; `try_cast` uses return-null failure behavior | +| `add(...)`, `sub(...)`, `mul(...)`, `div(...)`, `modulo(...)`, `neg(...)` | scalar | registered Substrait scalar mappings; `modulo(...)` registers canonical `mod` | +| `eq(...)`, `ne(...)`, `lt(...)`, `lte(...)`, `gt(...)`, `gte(...)`, `equal_null(...)` | scalar | registered Substrait scalar mappings; `equal_null(...)` lowers as null-safe equality | +| `and_(...)`, `or_(...)`, `not_(...)` | scalar | registered Substrait boolean mappings | +| `is_null(...)`, `is_not_null(...)`, `is_nan(...)`, `is_not_nan(...)` | scalar | registered predicate mappings; `is_not_nan(...)` lowers as `not(is_nan(...))` | +| `coalesce(...)`, `nullif(...)`, `case_when(...)` | scalar | registered Substrait mappings; `case_when(...)` lowers as built-in `IfThen` | +| `in_(...)`, `between(...)` | scalar | built-in membership/range lowering (`SingularOrList` and `between`) | +| `asc(...)`, `desc(...)`, `asc_nulls_first(...)`, `asc_nulls_last(...)`, `desc_nulls_first(...)`, `desc_nulls_last(...)` | ordering | structural sort-field helpers consumed by `order_by(...)` and lowered to Substrait `SortRel.sorts` | | `sum(...)`, `count()` | aggregate | registered Substrait extension functions | Future ANSI-style families should grow under this section instead of bloating `dataset_types` or `dataset_methods`. diff --git a/docs/release_notes/v0_1.md b/docs/release_notes/v0_1.md index 7a5ca2a..3ac9691 100644 --- a/docs/release_notes/v0_1.md +++ b/docs/release_notes/v0_1.md @@ -12,11 +12,12 @@ Entries will be filled in as work lands (link RFCs and PRs when applicable). - **Authoring:** method-chain lowering into a real Substrait boundary today, with `query {}` work still ahead. - **Aggregates:** builder-based `col`, `sum`, and `count` helpers now lower grouped and global aggregates through Prism, Substrait, and Session execution. - **Scalar expressions:** RFC 012 unifies filter predicates, computed projection values, grouping keys, and aggregate inputs around one `ColumnExpr` surface with canonical `lit(...)` and typed literal helpers. -- **Function registry:** RFC 014 adds declaration-site registry decorators for the current public helper surface, including stable function references, signature facts, lifecycle metadata, behavior categories, alias policy, Substrait mapping categories, and checked API metadata drift validation. +- **Core scalar functions:** RFC 015 adds registry-backed scalar function applications and the first core helper slice for casts, comparisons, boolean logic, null/NaN predicates, arithmetic, conditionals, membership/range predicates, and ordering expressions. Implemented helpers lower to Substrait IR through registry metadata, built-in Rex shapes, or structural sort-field lowering; DataFusion remains the first execution adapter rather than the semantic boundary. +- **Function registry:** RFC 014 adds declaration-site registry decorators for the current public helper surface, including stable function references, checked signature projection, lifecycle metadata, behavior categories, alias policy, Substrait mapping categories, and checked API metadata drift validation. - **Projection:** builder-based `with_column`, `add`, `mul`, and literal expression helpers now lower derived columns through Prism, Substrait, and Session execution. - **Substrait internals:** RFC 002 helpers are now split into focused owner modules for relation building, plan assembly, inspection, schema registry, extension bookkeeping, and expression lowering instead of one `substrait.plan` godmodule. - **Prism:** `LazyFrame` lowering applies safe canonical rewrites (`Filter(true)` elimination and adjacent `Limit`/`Project`/`OrderBy` collapse) before RFC 002 plan emission. -- **Execution:** Session-oriented read, execute, and write (reference backend per RFC 004), with `collect(...)` now producing structured `DataFrame` materialization metadata plus preview text instead of treating rendered text as the canonical contract. +- **Execution:** Session-oriented read, execute, and write (reference backend per RFC 004), with `collect(...)` now producing structured `DataFrame` materialization metadata plus preview text instead of treating rendered text as the canonical contract. Session execution dispatch now routes through a backend adapter boundary over Substrait plans; DataFusion remains the first adapter rather than being encoded directly into Session state. - **Documentation:** Current package behavior is documented under `docs/language/`, while RFCs remain design records rather than implementation diaries. Pipe-forward (`|>`) is specified in RFC 005 but **out of scope** for v0.1. diff --git a/docs/rfcs/014_function_registry.md b/docs/rfcs/014_function_registry.md index bc71094..e625801 100644 --- a/docs/rfcs/014_function_registry.md +++ b/docs/rfcs/014_function_registry.md @@ -34,7 +34,7 @@ This RFC defines the InQL function registry: the single source of truth for scal 2. A function belongs to one function class: scalar, aggregate, window, generator, table-valued, partition transform, or extension-only. 3. A function signature defines accepted argument shapes, type coercion, return type rules, null behavior, error behavior, and determinism. 4. A function entry records the required Substrait interchange strategy; backend availability must be declared by adapters and must not redefine the InQL semantic contract. -5. A function entry is registered by attaching one `FUNCTION_REGISTRY.add(...)` decorator to a normal public helper; the decorator call supplies the canonical name and typed machine-readable metadata, while the helper docstring carries human-facing explanation. +5. A function entry is registered by attaching one `function_registry.add(...)` decorator to a normal public helper; the decorator call supplies the canonical name and typed machine-readable metadata, while the helper docstring carries human-facing explanation. 6. A function entry records lifecycle metadata such as introduced, changed, deprecated, removed, and replacement versions. ## Motivation @@ -51,7 +51,7 @@ This is also necessary for diagnostics. If a function is known to InQL but canno - Require registered functions to carry Incan-standard human-facing docstrings without making docstring section policy part of this RFC. - Define version lifecycle metadata for generated docs and compatibility planning. - Define Substrait interchange requirements for portable core functions. -- Require explicit diagnostics for unknown, ambiguous, unsupported, or incorrectly used functions. +- Require explicit diagnostics for unknown, ambiguous, incorrectly used, or backend-rejected functions. - Require mechanical validation that public helper decorators, public helper signatures, and typed registry entries do not drift. - Provide the governance model that later catalog RFCs use when adding functions. @@ -75,7 +75,7 @@ cleaned = orders.with_column("normalized_email", lower(trim(col("email")))) summary = cleaned.group_by([col("customer_id")]).agg([count(), avg(col("amount"))]) ``` -The author does not need to know whether `avg` maps to a core Substrait function, a Substrait extension URI, or a semantics-preserving Substrait rewrite. The author does need clear feedback if a function is known but cannot be represented by the current portable interchange contract. +The author does not need to know whether `avg` maps to a core Substrait function, a Substrait extension URI, or a semantics-preserving Substrait rewrite. The author does need clear feedback if a function is known but used in the wrong query context, or if an execution adapter cannot consume the emitted Substrait representation. ## Reference-level explanation (precise rules) @@ -102,7 +102,7 @@ Each portable core function must declare a Substrait interchange strategy. The s - core Substrait expression or function - registered Substrait extension function - deterministic rewrite to supported Substrait expressions -- explicitly unsupported until a Substrait mapping exists +- structural relation-context lowering, such as sort-field helpers consumed by `SortRel` Prism must only accept portable core function calls that can be represented by the active InQL/Substrait contract. A function with no valid Substrait mapping must remain Draft, extension-only, or rejected for portable core until that mapping exists. @@ -110,21 +110,21 @@ Execution backends must adapt from the Substrait representation rather than rede Each registered function must declare lifecycle metadata. The minimum lifecycle field is the InQL package version where the function was introduced. If a function's signature, semantics, alias set, Substrait mapping, or documentation contract changes in a user-visible way, the registry must record a versioned change entry. Deprecated functions must record the deprecation version, replacement guidance when a replacement exists, and removal status if removal is planned or completed. -Each registered function must have a typed registry entry for machine-readable metadata and an Incan-standard docstring for human-facing explanation. For ordinary public built-in functions, the canonical declaration shape is a normal public helper annotated with `FUNCTION_REGISTRY.add(...)`. The decorator call registers the helper, derives its stable function reference from the canonical name, and supplies the typed machine-readable metadata. The checked InQL package source and typed registry data are the source from which compiler-facing metadata, generated docs, diagnostics metadata, and lowering tables are produced. Generated registry entries may exist for mechanically produced functions, and explicit registry objects may exist for advanced extension cases, but the registry must not depend on arbitrary body inspection, stringly alias metadata, or prose inference. +Each registered function must have a typed registry entry for non-derivable machine metadata and an Incan-standard docstring for human-facing explanation. For ordinary public built-in functions, the canonical declaration shape is a normal public helper annotated with `function_registry.add(...)`. The decorator call registers the helper, derives its stable function reference from the canonical helper name, and supplies typed metadata that cannot be recovered from the public declaration. The checked helper declaration is the source for name, parameters, parameter types, and return type. The checked InQL package source and typed registry data are the source from which compiler-facing metadata, generated docs, diagnostics metadata, and lowering tables are produced. Generated registry entries may exist for mechanically produced functions, and explicit registry objects may exist for advanced extension cases, but the registry must not depend on arbitrary body inspection, stringly alias metadata, or prose inference. -This RFC intentionally defines required metadata shapes rather than exact enum, model, class, or tagged-union implementations. The implementation may represent lifecycle, signatures, behavior categories, and Substrait mappings as enums, models, classes, generated records, or another typed representation, as long as the resulting normalized function catalog exposes the same fields to docs, typechecking, diagnostics, Prism, and backend capability checks. +This RFC intentionally defines required metadata shapes rather than exact enum, model, class, or tagged-union implementations. The implementation may represent lifecycle, declaration-derived signatures, behavior categories, and Substrait mappings as enums, models, classes, generated records, or another typed representation, as long as the resulting normalized function catalog exposes the same fields to docs, typechecking, diagnostics, Prism, and backend capability checks. At minimum, a registered function's machine metadata must include: - lifecycle: introduced version, zero or more versioned changes, optional deprecation metadata, optional removal metadata, and replacement guidance when relevant -- signature: argument names, argument type expressions or type-family constraints, required/optional/variadic/literal-only constraints, default values where supported, and return type rule +- signature: argument names, argument type expressions or type-family constraints, required/optional/variadic constraints, default values where supported, and return type rule, derived from the checked public helper declaration whenever possible - classification: function class such as scalar, aggregate, window, generator, table-valued, partition transform, or extension-only - behavior: normalized determinism, null behavior, and error behavior categories, including strict versus `try_` behavior where relevant -- interchange: Substrait mapping category, Substrait function or extension reference when applicable, rewrite description when applicable, and unsupported reason when no mapping exists +- interchange: Substrait mapping category, Substrait function or extension reference when applicable, rewrite description when applicable, and structural relation context when the helper is consumed outside scalar Rex lowering -The registry implementation must include a validation path that checks the public API surface against the typed registry. The validation must fail if a public helper is decorated with a canonical name that does not produce a registry entry, if a registry entry for an ordinary built-in function has no matching public helper, if the decorator canonical name and registry `function_ref` disagree, or if the helper's checked signature drifts from the registry signature. This validation is part of the RFC scope, not an optional future cleanup. +The registry implementation must include a validation path that checks the public API surface against the typed registry metadata. The validation must fail if a public helper is decorated with a canonical name that is not projected through the public facade, if a decorated ordinary built-in function has no matching public helper, if canonical names are duplicated, or if checked API metadata cannot expose the helper signature needed by the generated catalog. This validation is part of the RFC scope, not an optional future cleanup. -Generated Markdown must preserve the canonical registry facts and must use docstrings as the source for simple explanation and examples. Argument names, argument types, default values, accepted argument shapes, and return types must be derived from typed registry metadata and public helper signatures rather than copied from prose. Hand-written reference pages may add longer conceptual explanation, additional examples, or migration notes, but they must not contradict parsed docstrings and registry metadata. +Generated Markdown must preserve the canonical registry facts and must use docstrings as the source for simple explanation, parameter intent, and examples. Argument names, argument types, default values, accepted argument shapes, and return types must be derived from checked public helper signatures rather than copied from prose. Hand-written reference pages may add longer conceptual explanation, additional examples, or migration notes, but they must not contradict parsed docstrings and registry metadata. ## Design details @@ -142,9 +142,9 @@ The registry defines meaning, not just names. Backend-specific behavior may be u Function documentation is part of the registry contract, but exact required docstring sections are governed by the repository's implementation standards rather than by this RFC. Public registered functions must use Incan-standard docstrings as the canonical human-written format. Docstrings explain behavior, examples, and author intent; they must not be the source for argument shape, return type rules, null behavior, error behavior, determinism, lifecycle status, or Substrait mapping. -For ordinary built-in functions, the declaration-site `FUNCTION_REGISTRY.add(...)` decorator is the canonical registration surface. The public helper should be ordinary code that delegates to the existing expression or aggregate builder, so authors call a normal function while tooling inspects explicit typed metadata from the decorator call. Docs, LSP, typechecking, Prism, and Substrait lowering must all inspect the same resulting function catalog entry produced from the checked source and typed registry metadata. +For ordinary built-in functions, the declaration-site `function_registry.add(...)` decorator is the canonical registration surface. The public helper should be ordinary code that delegates to the existing expression or aggregate builder, so authors call a normal function while tooling inspects explicit typed metadata from the decorator call. Docs, LSP, typechecking, Prism, and Substrait lowering must all inspect the same resulting function catalog entry produced from the checked source and typed registry metadata. -The registration decorator must be the single declaration-side registry event for ordinary built-ins. It links exactly one helper symbol to exactly one stable function reference and receives the typed function spec that records lifecycle, determinism, null behavior, error behavior, alias policy, and Substrait mapping. Backend capability declarations consume those facts; they do not redefine InQL function semantics. +The registration decorator must be the single declaration-side registry event for ordinary built-ins. It links exactly one helper symbol to exactly one stable function reference and receives the typed function spec that records lifecycle, determinism, null behavior, error behavior, alias policy, and Substrait mapping. Helper name and signature come from the checked declaration. Backend capability declarations consume those facts; they do not redefine InQL function semantics. Compatibility aliases must be real callable symbols rather than strings inside a function spec. For example, `mean = avg` should make `mean` an alias of the registered `avg` helper. The function catalog may record that alias after name resolution, but the aggregate spec must not contain `aliases=["mean"]`. Backend spellings and backend aliases remain backend capability concerns. @@ -153,27 +153,26 @@ Generated reference pages must render lifecycle metadata in a consistent form. A The public helper shape should stay compact enough to preserve authoring ergonomics while still making machine facts inspectable. Incan-standard docstrings are the canonical standard for explanation and examples; typed registry entries are canonical for machine facts. The following shape is illustrative only; constructor names, decorator names, enum/model/class boundaries, and helper implementation details may change: ```incan -@FUNCTION_REGISTRY.add( - "avg", - deterministic_spec( - FunctionClass.Aggregate, - FunctionLifecycle( - since=v0_2, - changed=[ - FunctionChange(version=v0_3, note="Added decimal return type rule."), - ], - deprecated=None, - ), - signature([required_arg("expr", "ScalarExpr[number]")], "AggregateMeasure[number]"), - FunctionNullBehavior.NullSkippingAggregate, - extension_mapping("avg", AVG_FUNCTION_ANCHOR), +@function_registry.add(deterministic_spec( + FunctionClass.Aggregate, + FunctionLifecycle( + since=v0_2, + changed=[ + FunctionChange(version=v0_3, note="Added decimal return type rule."), + ], + deprecated=None, ), -) + FunctionNullBehavior.NullSkippingAggregate, + extension_mapping("avg", AVG_FUNCTION_ANCHOR), +)) pub def avg(expr: ScalarExpr[number]) -> AggregateMeasure[number]: """ Return the average non-null numeric value in each group. - Args: + Examples: + average_order_value = avg(col("amount")) + + Parameters: expr: Numeric scalar expression evaluated for each input row. Returns: @@ -223,13 +222,13 @@ Existing helper names such as `sum`, `count`, `add`, `mul`, `eq`, and `gt` may c ### Phase 1: Registry metadata model - Add typed package-owned registry metadata for the current public function surface. -- Represent canonical names, stable function references, function class, lifecycle, signature, behavior categories, alias policy, and Substrait mapping category. +- Represent canonical names, stable function references, function class, lifecycle, declaration-derived signature, behavior categories, alias policy, and Substrait mapping category. - Provide lookup helpers for registry entries by function reference and canonical name. ### Phase 2: Public helper registration - Convert the current `functions` module surface from bare aliases to public helper functions where registration metadata needs to attach to the public call surface. -- Link each registered helper to exactly one stable function reference through the `FUNCTION_REGISTRY.add(...)` decorator. +- Link each registered helper to exactly one stable function reference through the `function_registry.add(...)` decorator. - Preserve existing helper behavior and import names. ### Phase 3: Docs and tests @@ -279,8 +278,8 @@ Existing helper names such as `sum`, `count`, `add`, `mul`, `eq`, and `gt` may c - **Registry ownership:** the checked InQL package source is the source of truth. Compiler-facing metadata, generated docs metadata, diagnostics metadata, and lowering tables are derived from checked package source and typed registry data. The compiler must not maintain an independent InQL function registry. - **Authoring DX:** ordinary function authors should write normal public helpers and attach one registry decorator whose arguments contain the canonical name and typed function spec. The registry derives the stable function reference from that canonical name, avoiding a separate authored constant or central list. -- **Decorator capability:** Incan issue #636 / PR #637 is required for decorator-authored helpers because checked API metadata must preserve source signatures for decorated functions. Incan issue #638 / PR #641 is required for decorator string argument materialization. Incan issue #640 / PR #643 provides generic signature-preserving decorator factories. Incan issue #645 is required for method-call decorators such as `FUNCTION_REGISTRY.add(...)`. The RFC design is one registry method decorator attached to the public helper. +- **Decorator capability:** Incan issue #636 / PR #637 is required for decorator-authored helpers because checked API metadata must preserve source signatures for decorated functions. Incan issue #638 / PR #641 is required for decorator string argument materialization. Incan issue #640 / PR #643 provides generic signature-preserving decorator factories. Incan issue #645 is required for method-call decorators such as `function_registry.add(...)`. The RFC design is one registry method decorator attached to the public helper. - **Lifecycle constants:** typed lifecycle metadata uses immutable version constants such as `v0_1`. These are `const` model values, not mutable registry state or generated strings. - **Alias policy:** core semantic aliases may be available through normal public imports when they are real callable aliases of the canonical function. Dialect, warehouse, Spark, Snowflake, dbt, and backend compatibility aliases require explicit opt-in modules. - **Docstrings:** exact docstring section requirements are not an RFC 014 concern. Public registered functions must follow the repository's Incan-standard docstring policy, but registry metadata and public helper signatures own machine facts. -- **Substrait mapping:** typed registry entries must represent whether a function maps to a core Substrait function, a registered extension function, a deterministic rewrite, or an explicit unsupported state. Backend capability declarations consume that mapping; they do not redefine InQL semantics. +- **Substrait mapping:** typed registry entries must represent whether a function maps to a core Substrait function, a registered extension function, a deterministic rewrite, or structural relation-context lowering. Backend capability declarations consume the emitted Substrait representation; they do not redefine InQL semantics. diff --git a/docs/rfcs/015_core_scalar_functions.md b/docs/rfcs/015_core_scalar_functions.md index e437c58..9b19ed1 100644 --- a/docs/rfcs/015_core_scalar_functions.md +++ b/docs/rfcs/015_core_scalar_functions.md @@ -1,6 +1,6 @@ # InQL RFC 015: Core scalar functions and operators -- **Status:** Draft +- **Status:** Implemented - **Created:** 2026-04-27 - **Author(s):** Danny Meijer (@dannymeijer) - **Related:** @@ -10,8 +10,8 @@ - InQL RFC 003 (`query {}` blocks and relational authoring) - **Issue:** [InQL #32](https://github.com/dannys-code-corner/InQL/issues/32) - **RFC PR:** — -- **Written against:** Incan v0.2 -- **Shipped in:** — +- **Written against:** Incan v0.3-era InQL +- **Shipped in:** v0.1 ## Summary @@ -44,13 +44,13 @@ The core slice should be intentionally small. Functions such as advanced trigono Authors should be able to write everyday filters and computed columns without switching helper families: ```incan -from pub::inql.functions import add, and_, cast, col, coalesce, gt, is_not_null, lit, lower, mul +from pub::inql.functions import add, and_, cast, col, coalesce, gt, is_not_null, lit, mul enriched = ( orders .filter(and_(is_not_null(col("customer_id")), gt(col("amount"), lit(0)))) .with_column("amount_cents", mul(cast(col("amount"), "int"), lit(100))) - .with_column("normalized_status", coalesce(lower(col("status")), lit("unknown"))) + .with_column("status_or_unknown", coalesce([col("status"), lit("unknown")])) ) ``` @@ -68,7 +68,7 @@ Ordinary equality and comparison must follow SQL-style three-valued null behavio Boolean operators must define three-valued logic for nullable boolean operands. If InQL exposes host-language operator sugar later, that sugar must preserve the same truth table. -Arithmetic functions must define numeric type promotion and overflow behavior. If exact overflow behavior remains backend-dependent in Draft, implementations must reject unsupported or ambiguous cases rather than silently changing semantics. +Arithmetic functions must define numeric type promotion and overflow behavior. If exact overflow behavior remains backend-dependent in Draft, implementations must reject ambiguous cases rather than silently changing semantics. Ordering expressions must include ascending, descending, ascending-null-first, ascending-null-last, descending-null-first, and descending-null-last forms. Default null placement must be documented and must not vary silently by backend. @@ -120,11 +120,99 @@ Typed literal helpers such as `int_expr`, `float_expr`, `str_expr`, `bool_expr`, - **Execution / interchange** — Prism and Substrait lowering must preserve casts, null-safe equality, boolean logic, ordering null placement, and `try_` behavior. - **Documentation** — scalar function reference docs should distinguish canonical names from typed helper entrypoints. -## Unresolved questions +## Implementation Plan -- Should the canonical boolean helper names be `and_`, `or_`, and `not_`, or should InQL expose different names because these collide with host-language keywords? -- Should `try_cast` return null on conversion failure, or should InQL eventually support a typed recoverable error result? -- What exact numeric promotion table should InQL use for mixed integer, decimal, and floating arithmetic? -- Should `in_` require literal lists initially, or should it also accept relation-valued subqueries in this RFC? +### Phase 1: Registry-backed scalar application model - +- Keep structural scalar nodes for column references and typed literals. +- Replace bespoke scalar function/operator expression variants with one registry-backed scalar function application node. +- Ensure public function kind and mapping metadata come from registry entries rather than a parallel function-kind switchboard. + +### Phase 2: Public core scalar helpers and registry metadata + +- Add the RFC 015 public helper names in one-helper-per-module `src/functions//.incn` files using declaration-side `@function_registry.add(...)` decorators. +- Keep non-derivable machine metadata in decorator specs; derive names and signatures from checked helper declarations. +- Preserve typed literal helper compatibility while routing through the canonical literal representation. +- Use Substrait mapping metadata only for real IR lowering: extension functions, built-in Rex shapes, deterministic rewrites, or structural relation contexts such as sort fields. + +### Phase 3: Lowering and diagnostics + +- Resolve scalar function application lowering through registry mapping metadata. +- Preserve existing correct lowerings for the scalar functions already represented by current Substrait extension mappings. +- Add registry-driven extension mappings, built-in Rex lowerings, and structural sort-field lowerings for the rest of the core scalar slice. +- Return clear invalid-context diagnostics when structural helpers such as `asc(...)` are used outside their valid query context. + +### Phase 4: Tests and docs + +- Add tests for helper imports, expression shape, registry metadata, scalar lowering, structural ordering lowering, invalid-context diagnostics, and literal helper cleanup. +- Update user-facing function docs and release notes. +- Run the registry metadata check because it protects the RFC 014 declaration-side registry contract. + +## Progress Checklist + +### Spec / lifecycle + +- [x] RFC 014 registry baseline is merged and available as the implementation baseline. +- [x] RFC 015 issue exists and is linked. +- [x] RFC 015 moved to In Progress with implementation plan and checklist. +- [x] Design Decisions record the current implementation-slice answers. + +### Expression model + +- [x] Keep structural scalar nodes for `ColumnRefExpr` and literals. +- [x] Replace bespoke function/operator expression variants with registry-backed scalar function application. +- [x] Make public expression kind/function metadata derive from the registry-backed application. +- [x] Preserve typed literal helper compatibility through the canonical literal representation. + +### Public helpers / registry + +- [x] Register `cast` and `try_cast`. +- [x] Register comparisons: `eq`, `ne`, `lt`, `lte`, `gt`, `gte`, `equal_null`. +- [x] Register boolean logic: `and_`, `or_`, `not_`. +- [x] Register null and NaN predicates: `is_null`, `is_not_null`, `is_nan`, `is_not_nan`. +- [x] Register arithmetic: `add`, `sub`, `mul`, `div`, `mod`, `neg`. +- [x] Register conditionals: `coalesce`, `nullif`, `case_when`. +- [x] Register predicates: `in_`, `between`. +- [x] Register ordering helpers: `asc`, `desc`, `asc_nulls_first`, `asc_nulls_last`, `desc_nulls_first`, `desc_nulls_last`. + +### Lowering / interchange + +- [x] Resolve supported scalar function lowering through registry mapping metadata. +- [x] Keep existing supported Substrait extension lowering for `add`, `mul`, `eq`, and `gt`. +- [x] Add honest current lowerings for casts, comparisons, boolean logic, null/NaN predicates, arithmetic, conditionals, membership predicates, and range predicates. +- [x] Emit invalid-context diagnostics for ordering helpers used as standalone scalar expressions. +- [x] Lower ordering helpers into Substrait `SortRel.sorts` when used through `order_by(...)`. + +### Tests + +- [x] Registry tests cover all RFC 015 helpers and mapping categories. +- [x] Scalar expression tests prove helper calls share one function application node. +- [x] Lowering tests cover scalar helpers, structural ordering helpers, and invalid ordering contexts. +- [x] Regression tests prove adding one scalar function does not require separate kind/lowering switchboards. + +### Docs / release + +- [x] Update function reference docs for the core scalar helper surface. +- [x] Update release notes. +- [x] Confirm no package version bump is required for the unreleased v0.1 package line. + +### Verification + +- [x] `make test-style` +- [x] focused scalar/function registry/Substrait tests +- [x] `make fmt-check` +- [x] `make registry-metadata` +- [x] `make build` +- [x] `make test` +- [x] `make smoke-consumer` + +## Design Decisions + +- **Boolean helper names:** the canonical boolean helper names are `and_`, `or_`, and `not_`; the trailing underscore avoids host-language keyword collisions while keeping SQL-familiar names recognizable. +- **`try_cast` failure result:** `try_cast` uses null-on-conversion-failure semantics for this RFC. Typed recoverable error values remain a future extension point rather than part of the core scalar slice. +- **Numeric promotion boundary:** the current v0.3-era InQL package records numeric helper intent and checked helper signatures but does not introduce a package-local numeric promotion table. Lowering must only use mappings that are currently represented honestly; ambiguous numeric behavior must remain explicit instead of silently choosing backend-dependent behavior. +- **`in_` scope:** `in_` accepts literal or expression lists in this RFC. Relation-valued subquery membership is a future query-surface feature and is out of scope for this core scalar package slice. +- **Registry-backed applications:** structural scalar nodes remain appropriate for `ColumnRefExpr` and typed literals, but function/operator calls such as `add`, `mul`, `eq`, `gt`, `and_`, `or_`, and `cast` are represented as registry-backed scalar function applications rather than one bespoke model per function forever. +- **Current lowering boundary:** the implemented RFC 015 slice lowers through registered Substrait extension mappings, built-in Substrait Rex shapes (`Cast`, `SingularOrList`, and `IfThen`), and structural `SortRel.sorts` lowering for ordering helpers. DataFusion is the first execution adapter that consumes the emitted Substrait plan; it does not define the portable helper semantics. +- **Module layout:** each public helper lives in its own `src/functions//.incn` module with helper-local docs, decorator metadata, and inline tests. Family directories group references, literals, casts, operators, predicates, conditionals, ordering, aggregates, and formatting helpers for readable source ownership and future generated docs. `src/functions/mod.incn` remains the public import facade; all-surface catalog validation uses checked API metadata projections rather than runtime loader hooks. +- **Modulo spelling:** `mod` remains the canonical registry function name, but the Incan public helper is `modulo(...)` because `mod` is reserved by the language. diff --git a/docs/rfcs/README.md b/docs/rfcs/README.md index e07d29c..fac71de 100644 --- a/docs/rfcs/README.md +++ b/docs/rfcs/README.md @@ -21,7 +21,7 @@ InQL uses its **own** RFC series (starting at 000), independent of the [Incan la | [012][rfc-012] | Implemented | Unified scalar expression surface | | | [013][rfc-013] | Planned | Function catalog program | | | [014][rfc-014] | Implemented | Function registry and catalog governance | | -| [015][rfc-015] | Draft | Core scalar functions and operators | | +| [015][rfc-015] | Implemented | Core scalar functions and operators | | | [016][rfc-016] | Draft | Core aggregate functions | | | [017][rfc-017] | Draft | Aggregate modifiers | | | [018][rfc-018] | Draft | Common scalar function catalog | | diff --git a/incan.lock b/incan.lock index c1d45a0..2a6d844 100644 --- a/incan.lock +++ b/incan.lock @@ -3,7 +3,7 @@ [incan] format = 1 -incan-version = "0.3.0-rc6" +incan-version = "0.3.0-rc17" deps-fingerprint = "sha256:424fd53b12f0e810ffe3ba4188a82203bb011b5dd0769ad2b33ab1b854027487" cargo-features = [] cargo-no-default-features = false @@ -370,9 +370,9 @@ dependencies = [ [[package]] name = "autocfg" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" [[package]] name = "base64" @@ -454,9 +454,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.20.2" +version = "3.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +checksum = "72f5acc6cb2ba439de613abc23857ec3d78374d8ed5ac84e9d11336e87da8649" [[package]] name = "byteorder" @@ -1824,14 +1824,14 @@ dependencies = [ [[package]] name = "incan_core" -version = "0.3.0-rc6" +version = "0.3.0-rc17" dependencies = [ "serde", ] [[package]] name = "incan_derive" -version = "0.3.0-rc6" +version = "0.3.0-rc17" dependencies = [ "proc-macro2", "quote", @@ -1840,7 +1840,7 @@ dependencies = [ [[package]] name = "incan_stdlib" -version = "0.3.0-rc6" +version = "0.3.0-rc17" dependencies = [ "incan_core", "incan_derive", @@ -1862,7 +1862,7 @@ dependencies = [ [[package]] name = "inql" -version = "0.3.0-rc6" +version = "0.3.0-rc17" dependencies = [ "byteorder", "datafusion", @@ -1909,9 +1909,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.98" +version = "0.3.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" +checksum = "142bc4740e452c1e57ade0cbc129f139c9093e354346f0872ef985f4f5cf5f11" dependencies = [ "cfg-if", "futures-util", @@ -2043,9 +2043,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.29" +version = "0.4.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +checksum = "616ec5685824bcc94416c6d4a7a446eea774a31efd7062c8480ba6fd06d7a6e5" [[package]] name = "lz4_flex" @@ -3218,9 +3218,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" +checksum = "3ed04576f974d2b2fba0f38c51dbc5518011e38c36bf1143164be765528fd409" dependencies = [ "cfg-if", "once_cell", @@ -3231,9 +3231,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.71" +version = "0.4.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96492d0d3ffba25305a7dc88720d250b1401d7edca02cc3bcd50633b424673b8" +checksum = "9473dbd2991ae90b6291c3c32c30c6187ac49aa32f9905d1cce280ec1e110b0f" dependencies = [ "js-sys", "wasm-bindgen", @@ -3241,9 +3241,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" +checksum = "916151b09da36bd82f6615cbf3a419e2f0ba23a03c6160e8e92eb6bd4aa1dec6" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3251,9 +3251,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" +checksum = "299047362ccbfce148b67ab7e73349f77748e00c8296f9542adfad2ad82c5c5e" dependencies = [ "bumpalo", "proc-macro2", @@ -3264,9 +3264,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" +checksum = "9a929b2c61f11ba3e9bc35b50c1f25cb38e0e892c0c231ae2b8cf78d5dad4437" dependencies = [ "unicode-ident", ] diff --git a/scripts/check_function_registry_metadata.incn b/scripts/check_function_registry_metadata.incn index f79fd1f..1cf9918 100644 --- a/scripts/check_function_registry_metadata.incn +++ b/scripts/check_function_registry_metadata.incn @@ -3,12 +3,9 @@ from std.fs import Path from std.json import JsonValue from std.testing import fail_t -from functions import function_registry_entries -from function_registry import FunctionRegistryEntry const METADATA_PATH: str = "target/function_registry_api_metadata.json" -const FUNCTION_REF_PREFIX: str = "inql.functions." -const REGISTRY_DECORATOR: str = "FUNCTION_REGISTRY.add" +const REGISTRY_DECORATOR: str = "function_registry.add" def metadata_fail[T](message: str) -> T: @@ -49,6 +46,14 @@ def require_string_field(value: JsonValue, key: str, context: str) -> str: return require_string(require_field(value, key, context), f"{context}.{key}") +def optional_field(value: JsonValue, key: str) -> Option[JsonValue]: + """Return one optional JSON object field.""" + fields = require_object(value, key) + if has_key(fields, key): + return Some(require_field(value, key, key)) + return None + + def optional_string_field(value: JsonValue, key: str) -> Option[str]: """Return one optional JSON string field.""" fields = require_object(value, key) @@ -80,22 +85,38 @@ def contains_text(items: list[str], expected: str) -> bool: return false -def same_text_set(left: list[str], right: list[str]) -> bool: - """Return whether two string lists contain the same values, ignoring order.""" +def module_named(package_value: JsonValue, name: str) -> Option[JsonValue]: + """Return one checked API metadata module by single-segment module path.""" + module_items = require_array(require_field(package_value, "modules", "package"), "package.modules") + for module in module_items: + module_path_items = require_array(require_field(module, "module_path", "module"), "module.module_path") + if len(module_path_items) == 1 and require_string(module_path_items[0], "module.module_path[0]") == name: + return Some(module) + return None + + +def paths_equal(left: list[str], right: list[str]) -> bool: + """Return whether two module paths contain the same segments.""" if len(left) != len(right): return false - for item in left: - if not contains_text(right, item): + for index in range(len(left)): + if left[index] != right[index]: return false return true -def module_named(package: JsonValue, name: str) -> Option[JsonValue]: - """Return one checked API metadata module by single-segment module path.""" - modules = require_array(require_field(package, "modules", "package"), "package.modules") - for module in modules: - module_path = require_array(require_field(module, "module_path", "module"), "module.module_path") - if len(module_path) == 1 and require_string(module_path[0], "module.module_path[0]") == name: +def module_metadata_path(module: JsonValue) -> list[str]: + """Return one module metadata path as string segments.""" + module_path = require_field(module, "module_path", "module") + module_path_items = require_array(module_path, "module.module_path") + return [require_string(segment, "module.module_path") for segment in module_path_items] + + +def module_by_path(package_value: JsonValue, path: list[str]) -> Option[JsonValue]: + """Return one checked API metadata module by full module path.""" + module_items = require_array(require_field(package_value, "modules", "package"), "package.modules") + for module in module_items: + if paths_equal(module_metadata_path(module), path): return Some(module) return None @@ -112,25 +133,71 @@ def declarations_by_name(module: JsonValue, kind: str) -> dict[str, JsonValue]: return declarations -def helper_name_from_function_ref(function_ref: str) -> str: - """Return the public helper name encoded by one stable function reference.""" - parts = function_ref.split(FUNCTION_REF_PREFIX) - if len(parts) != 2 or parts[0] != "": - return metadata_fail[str](f"function ref must start with {FUNCTION_REF_PREFIX}, found {function_ref}") +def declaration_named(module: JsonValue, name: str) -> Option[JsonValue]: + """Return one declaration by name from a module.""" + items = require_array(require_field(module, "declarations", "module"), "module.declarations") + for declaration in items: + if require_string_field(declaration, "name", "declaration") == name: + return Some(declaration) + return None + - helper_name = parts[1] - if len(helper_name) == 0 or helper_name.contains("."): - return metadata_fail[str](f"function ref must name one public helper, found {function_ref}") - return helper_name +def resolve_declaration_path( + package_value: JsonValue, + target_path: list[str], + context: str, + depth: int, +) -> Option[JsonValue]: + """Resolve an alias target path to its final declaration.""" + if depth > 8: + return metadata_fail[Option[JsonValue]](f"{context}: alias resolution exceeded maximum depth") + if len(target_path) < 2: + return metadata_fail[Option[JsonValue]]("alias target path must contain a module path and declaration name") + + module_path_value = target_path[:-1] + declaration_name = target_path[len(target_path) - 1] + target_module: JsonValue = match module_by_path(package_value, module_path_value): + Some(module) => module + None => return None + declaration: JsonValue = match declaration_named(target_module, declaration_name): + Some(item) => item + None => return None -def registry_decorators(function: JsonValue) -> list[JsonValue]: - """Return registry decorators attached to one public helper.""" - mut decorators: list[JsonValue] = [] - for decorator in require_array(require_field(function, "decorators", "function"), "function.decorators"): - if decorator_source_matches(decorator): - decorators.append(decorator) - return decorators + declaration_kind = require_string_field(declaration, "kind", context) + if declaration_kind == "function": + return Some(declaration) + if declaration_kind == "alias": + nested_target = require_field(declaration, "target_path", context) + nested_context = f"{context}.target_path" + nested_items = require_array(nested_target, nested_context) + nested_path = [require_string(segment, nested_context) for segment in nested_items] + return resolve_declaration_path(package_value, nested_path, context, depth + 1) + return None + + +def public_helper_declarations(package_value: JsonValue, functions_module: JsonValue) -> dict[str, JsonValue]: + """Return root public helper declarations, resolving facade aliases through family modules.""" + mut functions = declarations_by_name(functions_module, "function") + aliases = declarations_by_name(functions_module, "alias") + for alias_name in aliases.keys(): + alias = aliases[alias_name] + target_path_value = require_field(alias, "target_path", alias_name) + target_context = f"{alias_name}.target_path" + target_path_items = require_array(target_path_value, target_context) + target_path = [require_string(segment, target_context) for segment in target_path_items] + match resolve_declaration_path(package_value, target_path, alias_name, 0): + Some(function) => + functions[alias_name] = function + None => pass + return functions + + +def registry_decorators(value: JsonValue) -> list[JsonValue]: + """Return registry decorators attached to one metadata declaration or projection.""" + decorators = require_field(value, "decorators", "decorated value") + decorator_items = require_array(decorators, "decorated value.decorators") + return [decorator for decorator in decorator_items if decorator_source_matches(decorator)] def decorator_source_matches(decorator: JsonValue) -> bool: @@ -157,117 +224,130 @@ def decorator_canonical_name(decorator: JsonValue, context: str) -> str: return require_string_field(literal, "value", context) -def metadata_type_names(value: JsonValue, context: str) -> list[str]: - """Return normalized leaf type names from API metadata type JSON.""" - fields = require_object(value, context) - - if has_key(fields, "Named"): - named = require_field(value, "Named", context) - return [require_string_field(named, "name", context)] - - if has_key(fields, "Applied"): - applied = require_field(value, "Applied", context) - return applied_type_names(applied, context) - - if has_key(fields, "TypeParam"): - param = require_field(value, "TypeParam", context) - return [require_string_field(param, "name", context)] - - return metadata_fail[list[str]](f"{context}: unsupported type metadata shape") - - -def applied_type_names(applied: JsonValue, context: str) -> list[str]: - """Return normalized leaf type names from an applied type metadata node.""" - applied_name = require_string_field(applied, "name", context) - args = require_array(require_field(applied, "args", context), f"{context}.args") - if applied_name == "Union": - mut names: list[str] = [] - for arg in args: - leaves = metadata_type_names(arg, context) - for leaf in leaves: - names.append(leaf) - return names - return [applied_name] +def projected_function(alias: JsonValue, alias_name: str) -> JsonValue: + """Return the rc17 projected function payload for one public facade alias.""" + match optional_field(alias, "projected_function"): + Some(projection) => return projection + None => return metadata_fail[JsonValue](f"{alias_name}: public function alias is missing projected_function") -def type_matches_rule(value: JsonValue, expected_rule: str, context: str) -> bool: - """Return whether an API metadata type matches one registry type-rule string.""" - actual = metadata_type_names(value, context) - if expected_rule.contains("|"): - return same_text_set(actual, expected_rule.split(" | ")) - return len(actual) == 1 and actual[0] == expected_rule +def projected_callable(projection: JsonValue, alias_name: str) -> JsonValue: + """Return the projected callable signature payload for one facade alias.""" + return require_field(projection, "callable", f"{alias_name}.projected_function") -def check_entry(entry: FunctionRegistryEntry, functions: dict[str, JsonValue]) -> None: - """Validate one runtime registry entry against checked public API metadata.""" - expected_name = helper_name_from_function_ref(entry.function_ref) - if expected_name != entry.canonical_name: - return metadata_fail[None](f"{entry.function_ref}: function ref and canonical name disagree") - if not has_key(functions, entry.canonical_name): - return metadata_fail[None](f"{entry.function_ref}: missing public helper `{entry.canonical_name}`") - - function = functions[entry.canonical_name] - decorators = registry_decorators(function) - if len(decorators) != 1: - return metadata_fail[None]( - f"{entry.canonical_name}: expected exactly one registry decorator, found {len(decorators)}", - ) - decorator_name = decorator_canonical_name(decorators[0], entry.canonical_name) - if decorator_name != entry.canonical_name: - return metadata_fail[None](f"{entry.canonical_name}: decorator registered `{decorator_name}`") - - params = require_array(require_field(function, "params", entry.canonical_name), f"{entry.canonical_name}.params") - if len(params) != len(entry.signature.args): - return metadata_fail[None](f"{entry.canonical_name}: registry arity does not match public helper") - +def assert_callable_has_signature(callable: JsonValue, context: str) -> None: + """Assert one callable metadata node carries checked parameter and return type data.""" + callable_name = require_string_field(callable, "name", context) + assert len(callable_name) > 0, f"{context}: callable name must not be empty" + params = require_array(require_field(callable, "params", context), f"{context}.params") for index in range(len(params)): - param = params[index] - expected_arg = entry.signature.args[index] - param_name = require_string_field(param, "name", entry.canonical_name) - if param_name != expected_arg.name: - return metadata_fail[None](f"{entry.canonical_name}: arg {index} name mismatch") - param_type = require_field(param, "ty", entry.canonical_name) - if not type_matches_rule(param_type, expected_arg.type_rule, f"{entry.canonical_name}.{param_name}"): - return metadata_fail[None]( - f"{entry.canonical_name}.{param_name}: type does not match registry rule `{expected_arg.type_rule}`", - ) - - return_type = require_field(function, "return_type", entry.canonical_name) - if not type_matches_rule(return_type, entry.signature.return_type_rule, f"{entry.canonical_name}.return"): - return metadata_fail[None]( - f"{entry.canonical_name}: return type does not match registry rule `{entry.signature.return_type_rule}`", + param_name = require_string_field(params[index], "name", context) + param_type = require_field(params[index], "ty", context) + param_type_fields = require_object(param_type, f"{context}.{param_name}") + assert len(param_name) > 0, f"{context}: checked parameter name must not be empty" + assert len(param_type_fields.keys()) > 0, f"{context}.{param_name}: checked parameter type is empty" + return_type = require_field(callable, "return_type", context) + return_type_fields = require_object(return_type, f"{context}.return") + assert len(return_type_fields.keys()) > 0, f"{context}: checked return type is empty" + + +def assert_decorated_callable_matches_source(decorator: JsonValue, source_function: JsonValue, context: str) -> None: + """Assert decorator metadata carries the source callable signature it decorates.""" + decorated = require_field(decorator, "decorated_callable", f"{context}.decorator") + decorated_name = require_string_field(decorated, "name", context) + source_name = require_string_field(source_function, "name", context) + if decorated_name != source_name: + return metadata_fail[None](f"{context}: decorated callable `{decorated_name}` does not match `{source_name}`") + assert_callable_has_signature(decorated, f"{context}.decorated_callable") + + +def docstring_param_names(function: JsonValue, context: str) -> list[str]: + """Return parameter names documented in parsed docstring sections.""" + sections = require_field(function, "docstring_sections", context) + params = require_array(require_field(sections, "params", context), f"{context}.docstring_sections.params") + for param in params: + name = require_string_field(param, "name", context) + description = require_string_field(param, "description", context) + assert len(description) > 0, f"{context}.{name}: docstring parameter description must not be empty" + return [require_string_field(param, "name", context) for param in params] + + +def assert_docstring_documents_params(function: JsonValue, callable: JsonValue, context: str) -> None: + """Assert every checked parameter has human-facing docstring intent.""" + documented_names = docstring_param_names(function, context) + for param in require_array(require_field(callable, "params", context), f"{context}.params"): + param_name = require_string_field(param, "name", context) + if not contains_text(documented_names, param_name): + return metadata_fail[None](f"{context}.{param_name}: registered helper parameter lacks docstring text") + + +def check_registered_helper(alias_name: str, alias: JsonValue, source_function: JsonValue) -> list[str]: + """Validate one public helper alias and return the canonical registry names it exposes.""" + projection = projected_function(alias, alias_name) + callable = projected_callable(projection, alias_name) + assert_callable_has_signature(callable, f"{alias_name}.projected_callable") + + decorators = registry_decorators(projection) + if len(decorators) == 0: + return [] + if len(decorators) != 1: + return metadata_fail[list[str]]( + f"{alias_name}: expected exactly one registry decorator, found {len(decorators)}", ) - -def check_unmatched_decorators(functions: dict[str, JsonValue], seen_names: dict[str, bool]) -> None: - """Fail if a function has a registry decorator but no runtime registry entry.""" - for function_name in functions.keys(): - decorators = registry_decorators(functions[function_name]) - if len(decorators) > 0 and not has_key(seen_names, function_name): - return metadata_fail[None](f"{function_name}: registry decorator did not produce a runtime registry entry") - - -def check_metadata(package: JsonValue, entries: list[FunctionRegistryEntry]) -> int: - """Validate runtime registry entries against checked API metadata.""" - functions_module: JsonValue = match module_named(package, "functions"): + canonical_name = decorator_canonical_name(decorators[0], alias_name) + assert_decorated_callable_matches_source(decorators[0], source_function, canonical_name) + + match optional_string_field(source_function, "docstring"): + Some(docstring) => + if not docstring.contains("Examples:"): + return metadata_fail[list[str]](f"{canonical_name}: registered helper docstring must include examples") + None => return metadata_fail[list[str]](f"{canonical_name}: registered helper must have a docstring") + assert_docstring_documents_params(source_function, callable, canonical_name) + + return [canonical_name] + + +def check_unmatched_decorators(package_value: JsonValue, seen_names: dict[str, bool]) -> None: + """Fail if a function has a registry decorator but no public projected facade entry.""" + module_items = require_array(require_field(package_value, "modules", "package"), "package.modules") + for module in module_items: + functions = declarations_by_name(module, "function") + for function_name in functions.keys(): + function = functions[function_name] + decorators = registry_decorators(function) + for decorator in decorators: + canonical_name = decorator_canonical_name(decorator, function_name) + if not has_key(seen_names, canonical_name): + return metadata_fail[None]( + f"{canonical_name}: registry decorator is not exposed through the public functions facade", + ) + + +def check_metadata(package_value: JsonValue) -> int: + """Validate registry decorators through checked API metadata projections.""" + functions_module: JsonValue = match module_named(package_value, "functions"): Some(module) => module None => return metadata_fail[int]("missing checked API metadata module: functions") - functions = declarations_by_name(functions_module, "function") - mut seen_refs: dict[str, bool] = {} + functions = public_helper_declarations(package_value, functions_module) + aliases = declarations_by_name(functions_module, "alias") mut seen_names: dict[str, bool] = {} + mut helper_count = 0 + + for alias_name in aliases.keys(): + if not has_key(functions, alias_name): + continue - for entry in entries: - if has_key(seen_refs, entry.function_ref): - return metadata_fail[int](f"duplicate function ref in registry: {entry.function_ref}") - seen_refs[entry.function_ref] = true - check_entry(entry, functions) - if has_key(seen_names, entry.canonical_name): - return metadata_fail[int](f"duplicate canonical function name in registry: {entry.canonical_name}") - seen_names[entry.canonical_name] = true + for canonical_name in check_registered_helper(alias_name, aliases[alias_name], functions[alias_name]): + if has_key(seen_names, canonical_name): + return metadata_fail[int](f"duplicate canonical function name in registry metadata: {canonical_name}") + seen_names[canonical_name] = true + helper_count += 1 - check_unmatched_decorators(functions, seen_names) - return len(entries) + check_unmatched_decorators(package_value, seen_names) + return helper_count def main() -> None: @@ -276,9 +356,9 @@ def main() -> None: Ok(text) => text Err(err) => return metadata_fail[None](f"could not read {METADATA_PATH}: {err.message()}") - package: JsonValue = match JsonValue.parse(source): + parsed_package: JsonValue = match JsonValue.parse(source): Ok(value) => value Err(err) => return metadata_fail[None](f"could not parse {METADATA_PATH}: {err.message()}") - helper_count = check_metadata(package, function_registry_entries()) + helper_count = check_metadata(parsed_package) println(f"function registry metadata check passed ({helper_count} helpers)") diff --git a/src/backends.incn b/src/backends.incn index a4f0f7a..c0bab02 100644 --- a/src/backends.incn +++ b/src/backends.incn @@ -1,9 +1,9 @@ """ -Backend configuration objects for RFC 004. +Backend configuration objects for session execution. The root `Session` API stays portable, while backend-specific configuration lives under `pub::inql.backends`. -This module intentionally starts small: DataFusion is the first real backend configuration object and the only -implemented backend in the current slice. +Backend selection is an adapter-neutral kind plus options envelope so adding another backend does not require adding +one field per implementation to the core Session state. """ @@ -23,6 +23,18 @@ pub enum BackendKind(str): DataFusionEngine = "datafusion" +@derive(Clone) +pub class BackendOption: + """One stringly-encoded backend option carried by the portable backend-selection envelope.""" + + pub key: str + pub value: str + + def clone(self) -> Self: + """Return one cloned backend option.""" + return BackendOption(key=self.key, value=self.value) + + @derive(Clone) pub enum SourceKind(str): Csv = "csv" @@ -47,16 +59,29 @@ pub class BackendSelection: """Portable backend selection envelope stored by `Session` and its builder.""" pub kind: BackendKind - pub datafusion: DataFusion + pub options: list[BackendOption] def clone(self) -> Self: - """Return one cloned backend selection preserving engine-specific options.""" - return BackendSelection(kind=self.kind, datafusion=self.datafusion) + """Return one cloned backend selection preserving encoded backend options.""" + return BackendSelection(kind=self.kind, options=[option.clone() for option in self.options]) pub def default_backend_selection() -> BackendSelection: """Return the default backend selection used by Session defaults/builders.""" - return BackendSelection(kind=BackendKind.DataFusionEngine, datafusion=DataFusion(enable_optimizer=true)) + return datafusion_backend_selection(DataFusion(enable_optimizer=true)) + + +pub def datafusion_backend_selection(backend: DataFusion) -> BackendSelection: + """Encode DataFusion backend options into one portable backend-selection envelope.""" + return BackendSelection( + kind=BackendKind.DataFusionEngine, + options=[BackendOption(key="enable_optimizer", value=_bool_option_value(backend.enable_optimizer))], + ) + + +pub def datafusion_backend_from_selection(selection: BackendSelection) -> DataFusion: + """Decode DataFusion options from one backend-selection envelope.""" + return DataFusion(enable_optimizer=_backend_bool_option(selection, "enable_optimizer", true)) pub def backend_kind_name(selection: BackendSelection) -> str: @@ -69,6 +94,21 @@ pub def source_kind_name(kind: SourceKind) -> str: return kind.value() +def _bool_option_value(value: bool) -> str: + """Encode one bool backend option as stable text.""" + if value: + return "true" + return "false" + + +def _backend_bool_option(selection: BackendSelection, key: str, default_value: bool) -> bool: + """Read one bool backend option from an encoded backend selection.""" + for option in selection.options: + if option.key == key: + return option.value == "true" + return default_value + + pub def csv_source(uri: str) -> TableSource: """Build one CSV table source descriptor.""" return TableSource(source_kind=SourceKind.Csv, uri=uri) diff --git a/src/dataset/mod.incn b/src/dataset/mod.incn index 6f4ebb5..0d34d4c 100644 --- a/src/dataset/mod.incn +++ b/src/dataset/mod.incn @@ -1,5 +1,5 @@ """ -Dataset carriers for InQL (RFC 001). +Dataset carriers for InQL. This module defines the *author-facing* type hierarchy used to carry schema-parameterized tabular data through relational pipelines: @@ -66,7 +66,7 @@ from dataset.ops import ( group_by_ds_of_columns, join_ds, limit_ds, - order_by_ds, + order_by_ds_of_columns, select_ds_of_columns, with_column_ds, ) @@ -98,7 +98,7 @@ pub trait DataSet[T with Clone]: def with_column(self, name: str, expr: ColumnExpr) -> Self def group_by(self, columns: list[ColumnExpr]) -> Self def agg(self, measures: list[AggregateMeasure]) -> Self - def order_by(self) -> Self + def order_by(self, columns: list[ColumnExpr]) -> Self def limit(self, n: int) -> Self def explode(self) -> Self @@ -207,9 +207,11 @@ pub class DataFrame[T with Clone] with BoundedDataSet: agg_ds_of_columns(self._substrait_rel, self.planned_columns(), measures), ) - def order_by(self) -> Self: + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new DataFrame with an ordering stage and stale materialization cleared.""" - return _data_frame_with_invalidated_materialization(order_by_ds(self._substrait_rel)) + return _data_frame_with_invalidated_materialization( + order_by_ds_of_columns(self._substrait_rel, self.planned_columns(), columns), + ) def limit(self, n: int) -> Self: """Return one new DataFrame with a row-limit stage and stale materialization cleared.""" @@ -286,9 +288,9 @@ pub class LazyFrame[T with Clone] with BoundedDataSet: """Return one new lazy carrier with an appended aggregation stage.""" return LazyFrame(_cursor=prism_cursor_apply_agg(self._cursor, measures)) - def order_by(self) -> Self: + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new lazy carrier with an appended ordering stage.""" - return LazyFrame(_cursor=prism_cursor_apply_order_by(self._cursor)) + return LazyFrame(_cursor=prism_cursor_apply_order_by(self._cursor, columns)) def limit(self, n: int) -> Self: """Return one new lazy carrier with an appended row-limit stage.""" @@ -428,11 +430,15 @@ pub class DataStream[T with Clone] with UnboundedDataSet: ), ) - def order_by(self) -> Self: + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new DataStream with an ordering stage.""" return DataStream( _row_schema_marker=self._row_schema_marker.clone(), - _substrait_rel=order_by_ds(self._substrait_rel), + _substrait_rel=order_by_ds_of_columns( + self._substrait_rel, + relation_output_columns(self._substrait_rel.clone()), + columns, + ), ) def limit(self, n: int) -> Self: diff --git a/src/dataset/ops.incn b/src/dataset/ops.incn index 0367186..5319f4d 100644 --- a/src/dataset/ops.incn +++ b/src/dataset/ops.incn @@ -9,7 +9,7 @@ views stay aligned with the lowered relation tree. from rust::substrait::proto import Rel from aggregate_builders import AggregateMeasure from projection_builders import ColumnExpr, ProjectionAssignment, with_column_assignment -from substrait.extensions import explode_extension_uri +from substrait.function_extensions import explode_extension_uri from substrait.inspect import relation_output_columns from substrait.relations import ( aggregate_rel_of_columns, @@ -18,7 +18,7 @@ from substrait.relations import ( filter_rel_of_columns, join_rel, project_rel_of_columns, - sort_rel, + sort_rel_of_columns, ) @@ -122,17 +122,23 @@ pub def agg_ds_of_columns(rel: Rel, input_columns: list[str], measures: list[Agg return aggregate_rel_of_columns(rel, input_columns, [], measures) -pub def order_by_ds(rel: Rel) -> Rel: +pub def order_by_ds(rel: Rel, columns: list[ColumnExpr]) -> Rel: """ Apply dataset-level ordering intent to one relation. Args: rel: Input relation to sort. + columns: Sort key expressions, optionally wrapped with ordering helpers. Returns: A relation shaped as a sort over the input relation. """ - return sort_rel(rel) + return order_by_ds_of_columns(rel, relation_output_columns(rel.clone()), columns) + + +pub def order_by_ds_of_columns(rel: Rel, input_columns: list[str], columns: list[ColumnExpr]) -> Rel: + """Apply dataset-level ordering intent using explicit input-column names.""" + return sort_rel_of_columns(rel, input_columns, columns) pub def limit_ds(rel: Rel, n: int) -> Rel: diff --git a/src/filter_builders.incn b/src/filter_builders.incn index e73bd47..40edbce 100644 --- a/src/filter_builders.incn +++ b/src/filter_builders.incn @@ -5,7 +5,9 @@ These names provide predicate-friendly helper spellings over the same scalar-exp projections, grouping keys, and aggregate inputs. """ -from projection_builders import ColumnExpr, bool_expr, eq as scalar_eq, gt as scalar_gt, int_expr, str_expr +from functions.operators.eq import eq as scalar_eq +from functions.operators.gt import gt as scalar_gt +from projection_builders import ColumnExpr, bool_expr, int_expr, str_expr pub def int_lit(value: int) -> ColumnExpr: diff --git a/src/function_registry.incn b/src/function_registry.incn index 8507f70..b4bcb51 100644 --- a/src/function_registry.incn +++ b/src/function_registry.incn @@ -1,11 +1,12 @@ """ Package-owned function registry metadata for the current public InQL function surface. -Public helpers register themselves by calling `FUNCTION_REGISTRY.add(...)` as their decorator. The decorator method owns -the typed metadata and records one normalized entry while returning the helper unchanged. +Public helpers register themselves by calling `function_registry.add(...)` as their decorator. The decorator method owns +non-derivable machine metadata and records one runtime entry while returning the helper unchanged. Public helper names +and signatures are checked API metadata facts, not second copies in this runtime shape. """ -from substrait.extensions import function_extension_uri +from substrait.function_extensions import function_extension_uri const FUNCTION_REF_PREFIX: str = "inql.functions." @@ -16,6 +17,7 @@ pub enum FunctionClass(str): Scalar = "scalar" Aggregate = "aggregate" + Ordering = "ordering" Window = "window" Generator = "generator" TableValued = "table_valued" @@ -64,7 +66,7 @@ pub enum SubstraitMappingKind(str): CoreFunction = "core_function" ExtensionFunction = "extension_function" Rewrite = "rewrite" - Unsupported = "unsupported" + StructuralFunction = "structural_function" @derive(Clone) @@ -106,25 +108,6 @@ pub model FunctionLifecycle: pub deprecated: Option[FunctionDeprecation] -@derive(Clone) -pub model FunctionArg: - """One registry-visible function argument shape.""" - - pub name: str - pub type_rule: str - pub required: bool - pub literal_only: bool - - -@derive(Clone) -pub model FunctionSignature: - """Registry-visible callable signature metadata.""" - - pub args: list[FunctionArg] - pub return_type_rule: str - pub variadic: bool - - @derive(Clone) pub model SubstraitMapping: """Portable interchange mapping metadata for one registered function.""" @@ -134,18 +117,17 @@ pub model SubstraitMapping: pub function_name: str pub anchor: u32 pub rewrite: str - pub unsupported_reason: str + pub detail: str @derive(Clone) pub model FunctionSpec: - """Machine-readable contract supplied to the registry decorator.""" + """Machine-readable function facts supplied to the registry decorator.""" pub function_class: FunctionClass pub aliases: list[str] pub alias_policy: FunctionAliasPolicy pub lifecycle: FunctionLifecycle - pub signature: FunctionSignature pub determinism: FunctionDeterminism pub null_behavior: FunctionNullBehavior pub error_behavior: FunctionErrorBehavior @@ -154,7 +136,7 @@ pub model FunctionSpec: @derive(Clone) pub model FunctionRegistryEntry: - """Normalized metadata for one registered InQL function.""" + """Runtime projection for one registered InQL function.""" pub function_ref: str pub canonical_name: str @@ -162,7 +144,6 @@ pub model FunctionRegistryEntry: pub aliases: list[str] pub alias_policy: FunctionAliasPolicy pub lifecycle: FunctionLifecycle - pub signature: FunctionSignature pub determinism: FunctionDeterminism pub null_behavior: FunctionNullBehavior pub error_behavior: FunctionErrorBehavior @@ -202,7 +183,6 @@ pub class FunctionRegistry: aliases=spec.aliases, alias_policy=spec.alias_policy, lifecycle=spec.lifecycle, - signature=spec.signature, determinism=spec.determinism, null_behavior=spec.null_behavior, error_behavior=spec.error_behavior, @@ -247,21 +227,6 @@ pub def function_ref_for(canonical_name: str) -> str: return f"{FUNCTION_REF_PREFIX}{canonical_name}" -pub def required_arg(name: str, type_rule: str) -> FunctionArg: - """Build one required non-literal-only function argument metadata record.""" - return FunctionArg(name=name, type_rule=type_rule, required=true, literal_only=false) - - -pub def literal_arg(name: str, type_rule: str) -> FunctionArg: - """Build one required literal-only function argument metadata record.""" - return FunctionArg(name=name, type_rule=type_rule, required=true, literal_only=true) - - -pub def signature(args: list[FunctionArg], return_type_rule: str) -> FunctionSignature: - """Build one fixed-arity function signature metadata record.""" - return FunctionSignature(args=args, return_type_rule=return_type_rule, variadic=false) - - pub def extension_mapping(function_name: str, anchor: u32) -> SubstraitMapping: """Build one registered Substrait extension-function mapping.""" return SubstraitMapping( @@ -270,7 +235,19 @@ pub def extension_mapping(function_name: str, anchor: u32) -> SubstraitMapping: function_name=function_name, anchor=anchor, rewrite="", - unsupported_reason="", + detail="", + ) + + +pub def core_mapping(function_name: str) -> SubstraitMapping: + """Build one mapping for a built-in Substrait Rex shape rather than an extension function declaration.""" + return SubstraitMapping( + kind=SubstraitMappingKind.CoreFunction, + uri="", + function_name=function_name, + anchor=0, + rewrite="", + detail="", ) @@ -282,14 +259,37 @@ pub def rewrite_mapping(rewrite: str) -> SubstraitMapping: function_name="", anchor=0, rewrite=rewrite, - unsupported_reason="", + detail="", + ) + + +pub def structural_mapping(context: str) -> SubstraitMapping: + """Build one mapping for helpers lowered by a relation-specific Substrait context.""" + return SubstraitMapping( + kind=SubstraitMappingKind.StructuralFunction, + uri="", + function_name=context, + anchor=0, + rewrite="", + detail="", + ) + + +pub def sort_field_mapping(direction: str) -> SubstraitMapping: + """Build one structural sort-field mapping with explicit direction/null-placement detail.""" + return SubstraitMapping( + kind=SubstraitMappingKind.StructuralFunction, + uri="", + function_name="sort_field", + anchor=0, + rewrite="", + detail=direction, ) pub def deterministic_spec( function_class: FunctionClass, lifecycle: FunctionLifecycle, - signature: FunctionSignature, null_behavior: FunctionNullBehavior, substrait: SubstraitMapping, ) -> FunctionSpec: @@ -299,7 +299,6 @@ pub def deterministic_spec( aliases=[], alias_policy=FunctionAliasPolicy.CoreImport, lifecycle=lifecycle, - signature=signature, determinism=FunctionDeterminism.Deterministic, null_behavior=null_behavior, error_behavior=FunctionErrorBehavior.Typechecked, diff --git a/src/functions.incn b/src/functions.incn deleted file mode 100644 index 6b5d372..0000000 --- a/src/functions.incn +++ /dev/null @@ -1,326 +0,0 @@ -""" -Explicit builder helpers for the current InQL-only relational surface. - -These symbols **must** be imported by authors. The current package slice uses explicit builder functions as the -semantic target for future compiler sugar and query-block lowering. -""" - -from aggregate_builders import count as count_builder, sum as sum_builder -from filter_builders import ( - bool_lit as bool_lit_builder, - always_false as always_false_builder, - always_true as always_true_builder, - eq as eq_builder, - gt as gt_builder, - int_lit as int_lit_builder, - str_lit as str_lit_builder, -) -from projection_builders import ( - add as add_builder, - bool_expr as bool_expr_builder, - col as col_builder, - float_expr as float_expr_builder, - int_expr as int_expr_builder, - lit as lit_builder, - mul as mul_builder, - str_expr as str_expr_builder, -) -from aggregate_builders import AggregateMeasure -from dataset import DataFrame, LazyFrame -from function_registry import ( - FunctionClass, - FunctionLifecycle, - FunctionNullBehavior, - FunctionRegistry, - FunctionRegistryEntry, - SubstraitMappingKind, - deterministic_spec, - extension_mapping, - literal_arg, - required_arg, - rewrite_mapping, - signature, - v0_1, -) -from projection_builders import ColumnExpr -from substrait.extensions import ( - ADD_FUNCTION_ANCHOR, - COUNT_FUNCTION_ANCHOR, - EQUAL_FUNCTION_ANCHOR, - GT_FUNCTION_ANCHOR, - MULTIPLY_FUNCTION_ANCHOR, - SUM_FUNCTION_ANCHOR, -) - -pub static FUNCTION_REGISTRY: FunctionRegistry = FunctionRegistry.new() - - -@FUNCTION_REGISTRY.add("col", deterministic_spec( - FunctionClass.Scalar, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([required_arg("name", "str")], "ColumnExpr"), - FunctionNullBehavior.DependsOnInputs, - rewrite_mapping("direct field reference selection"), -)) -pub def col(name: str) -> ColumnExpr: - """Build one named column reference expression.""" - return col_builder(name) - - -@FUNCTION_REGISTRY.add("lit", deterministic_spec( - FunctionClass.Scalar, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([literal_arg("value", "int | float | str | bool")], "ColumnExpr"), - FunctionNullBehavior.NonNullLiteral, - rewrite_mapping("typed Substrait literal expression"), -)) -pub def lit(value: Union[int, float, str, bool]) -> ColumnExpr: - """Build one canonical scalar literal expression.""" - return lit_builder(value) - - -@FUNCTION_REGISTRY.add("sum", deterministic_spec( - FunctionClass.Aggregate, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([required_arg("expr", "ColumnExpr")], "AggregateMeasure"), - FunctionNullBehavior.NullSkippingAggregate, - extension_mapping("sum", SUM_FUNCTION_ANCHOR), -)) -pub def sum(expr: ColumnExpr) -> AggregateMeasure: - """Build one `sum` aggregate measure over a scalar expression.""" - return sum_builder(expr) - - -@FUNCTION_REGISTRY.add("count", deterministic_spec( - FunctionClass.Aggregate, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([], "AggregateMeasure"), - FunctionNullBehavior.NullSkippingAggregate, - extension_mapping("count", COUNT_FUNCTION_ANCHOR), -)) -pub def count() -> AggregateMeasure: - """Build one zero-argument `count` aggregate measure.""" - return count_builder() - - -@FUNCTION_REGISTRY.add("int_expr", deterministic_spec( - FunctionClass.Scalar, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([literal_arg("value", "int")], "ColumnExpr"), - FunctionNullBehavior.NonNullLiteral, - rewrite_mapping("typed integer literal expression"), -)) -pub def int_expr(value: int) -> ColumnExpr: - """Build one integer literal expression.""" - return int_expr_builder(value) - - -@FUNCTION_REGISTRY.add("float_expr", deterministic_spec( - FunctionClass.Scalar, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([literal_arg("value", "float")], "ColumnExpr"), - FunctionNullBehavior.NonNullLiteral, - rewrite_mapping("typed floating-point literal expression"), -)) -pub def float_expr(value: float) -> ColumnExpr: - """Build one float literal expression.""" - return float_expr_builder(value) - - -@FUNCTION_REGISTRY.add("str_expr", deterministic_spec( - FunctionClass.Scalar, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([literal_arg("value", "str")], "ColumnExpr"), - FunctionNullBehavior.NonNullLiteral, - rewrite_mapping("typed string literal expression"), -)) -pub def str_expr(value: str) -> ColumnExpr: - """Build one string literal expression.""" - return str_expr_builder(value) - - -@FUNCTION_REGISTRY.add("bool_expr", deterministic_spec( - FunctionClass.Scalar, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([literal_arg("value", "bool")], "ColumnExpr"), - FunctionNullBehavior.NonNullLiteral, - rewrite_mapping("typed boolean literal expression"), -)) -pub def bool_expr(value: bool) -> ColumnExpr: - """Build one boolean literal expression.""" - return bool_expr_builder(value) - - -@FUNCTION_REGISTRY.add("add", deterministic_spec( - FunctionClass.Scalar, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([required_arg("left", "ColumnExpr"), required_arg("right", "ColumnExpr")], "ColumnExpr"), - FunctionNullBehavior.DependsOnInputs, - extension_mapping("add", ADD_FUNCTION_ANCHOR), -)) -pub def add(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: - """Build one binary addition expression.""" - return add_builder(left, right) - - -@FUNCTION_REGISTRY.add("mul", deterministic_spec( - FunctionClass.Scalar, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([required_arg("left", "ColumnExpr"), required_arg("right", "ColumnExpr")], "ColumnExpr"), - FunctionNullBehavior.DependsOnInputs, - extension_mapping("multiply", MULTIPLY_FUNCTION_ANCHOR), -)) -pub def mul(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: - """Build one binary multiply expression.""" - return mul_builder(left, right) - - -@FUNCTION_REGISTRY.add("int_lit", deterministic_spec( - FunctionClass.Scalar, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([literal_arg("value", "int")], "ColumnExpr"), - FunctionNullBehavior.NonNullLiteral, - rewrite_mapping("filter-helper integer literal expression"), -)) -pub def int_lit(value: int) -> ColumnExpr: - """Build one integer scalar literal through the filter-helper naming style.""" - return int_lit_builder(value) - - -@FUNCTION_REGISTRY.add("str_lit", deterministic_spec( - FunctionClass.Scalar, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([literal_arg("value", "str")], "ColumnExpr"), - FunctionNullBehavior.NonNullLiteral, - rewrite_mapping("filter-helper string literal expression"), -)) -pub def str_lit(value: str) -> ColumnExpr: - """Build one string scalar literal through the filter-helper naming style.""" - return str_lit_builder(value) - - -@FUNCTION_REGISTRY.add("bool_lit", deterministic_spec( - FunctionClass.Scalar, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([literal_arg("value", "bool")], "ColumnExpr"), - FunctionNullBehavior.NonNullLiteral, - rewrite_mapping("filter-helper boolean literal expression"), -)) -pub def bool_lit(value: bool) -> ColumnExpr: - """Build one boolean scalar literal through the filter-helper naming style.""" - return bool_lit_builder(value) - - -@FUNCTION_REGISTRY.add("always_true", deterministic_spec( - FunctionClass.Scalar, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([], "ColumnExpr"), - FunctionNullBehavior.Predicate, - rewrite_mapping("boolean true literal predicate"), -)) -pub def always_true() -> ColumnExpr: - """Build one no-op scalar predicate that canonical rewrite can eliminate.""" - return always_true_builder() - - -@FUNCTION_REGISTRY.add("always_false", deterministic_spec( - FunctionClass.Scalar, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([], "ColumnExpr"), - FunctionNullBehavior.Predicate, - rewrite_mapping("boolean false literal predicate"), -)) -pub def always_false() -> ColumnExpr: - """Build one scalar predicate that rejects every row.""" - return always_false_builder() - - -@FUNCTION_REGISTRY.add("eq", deterministic_spec( - FunctionClass.Scalar, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([required_arg("left", "ColumnExpr"), required_arg("right", "ColumnExpr")], "ColumnExpr"), - FunctionNullBehavior.Predicate, - extension_mapping("equal", EQUAL_FUNCTION_ANCHOR), -)) -pub def eq(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: - """Build one equality predicate scalar expression.""" - return eq_builder(left, right) - - -@FUNCTION_REGISTRY.add("gt", deterministic_spec( - FunctionClass.Scalar, - FunctionLifecycle(since=v0_1, changed=[], deprecated=None), - signature([required_arg("left", "ColumnExpr"), required_arg("right", "ColumnExpr")], "ColumnExpr"), - FunctionNullBehavior.Predicate, - extension_mapping("gt", GT_FUNCTION_ANCHOR), -)) -pub def gt(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: - """Build one greater-than predicate scalar expression.""" - return gt_builder(left, right) - - -pub def function_registry_entries() -> list[FunctionRegistryEntry]: - """Return the normalized checked function registry for the current package surface.""" - return FUNCTION_REGISTRY.entries - - -pub def function_registry_entry(function_ref: str) -> Option[FunctionRegistryEntry]: - """Return the registry entry for one stable function reference when it is known.""" - for entry in FUNCTION_REGISTRY.entries: - if entry.function_ref == function_ref: - return Some(entry) - return None - - -pub def function_registry_entry_by_name(canonical_name: str) -> Option[FunctionRegistryEntry]: - """Return the registry entry for one canonical public function name when it is known.""" - for entry in FUNCTION_REGISTRY.entries: - if entry.canonical_name == canonical_name: - return Some(entry) - return None - - -pub def function_registry_function_refs() -> list[str]: - """Return registered function references in stable registry order.""" - return FUNCTION_REGISTRY.function_refs() - - -pub def function_registry_canonical_names() -> list[str]: - """Return registered canonical function names in stable registry order.""" - return FUNCTION_REGISTRY.canonical_names() - - -pub def function_registry_entry_count() -> int: - """Return the number of registered function entries.""" - return FUNCTION_REGISTRY.entry_count() - - -pub def registered_substrait_mapped_function_refs() -> list[str]: - """Return function references with a concrete Substrait extension mapping.""" - return [entry.function_ref for entry in FUNCTION_REGISTRY.entries if entry.substrait.kind == SubstraitMappingKind.ExtensionFunction] - - -pub def display[T with Clone](data: LazyFrame[T]) -> None: - """Collect one LazyFrame through the active Session and render a preview.""" - match data.clone().collect(): - Ok(df) => _render_data_frame(df) - Err(err) => println(f"InQL display failed: {err.error_message()}") - - -def _join_columns(columns: list[str]) -> str: - """Join column names for one compact display header line.""" - return ", ".join(columns) - - -def _render_data_frame[T with Clone](data: DataFrame[T]) -> None: - """Render one collected DataFrame preview using structured materialization metadata.""" - preview = data.preview_text() - columns = data.columns() - println("InQL display") - if len(columns) > 0: - println(f"columns: [{_join_columns(columns)}]") - println(f"rows: {data.row_count()}") - if len(preview) == 0: - println("(empty)") - return - println(preview) diff --git a/src/functions/aggregates/count.incn b/src/functions/aggregates/count.incn new file mode 100644 index 0000000..b3b5503 --- /dev/null +++ b/src/functions/aggregates/count.incn @@ -0,0 +1,39 @@ +""" +Count aggregate helper. + +`count` is registered as a zero-argument aggregate with a concrete Substrait extension mapping. +""" + +from aggregate_builders import AggregateKind, AggregateMeasure, count as count_builder +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry +from substrait.function_extensions import COUNT_FUNCTION_ANCHOR + + +@function_registry.add("count", deterministic_spec( + FunctionClass.Aggregate, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NullSkippingAggregate, + extension_mapping("count", COUNT_FUNCTION_ANCHOR), +)) +pub def count() -> AggregateMeasure: + """ + Build a row-count aggregate measure. + + Examples: + rows = count() + """ + return count_builder() + + +module tests: + def test_count_builds_count_aggregate_measure() -> None: + measure = count() + assert measure.kind == AggregateKind.Count diff --git a/src/functions/aggregates/mod.incn b/src/functions/aggregates/mod.incn new file mode 100644 index 0000000..7c607df --- /dev/null +++ b/src/functions/aggregates/mod.incn @@ -0,0 +1,4 @@ +"""Aggregate function helpers.""" + +pub from functions.aggregates.sum import sum +pub from functions.aggregates.count import count diff --git a/src/functions/aggregates/sum.incn b/src/functions/aggregates/sum.incn new file mode 100644 index 0000000..daff4e9 --- /dev/null +++ b/src/functions/aggregates/sum.incn @@ -0,0 +1,44 @@ +""" +Sum aggregate helper. + +`sum` records its aggregate signature and current Substrait extension anchor at the helper declaration site. +""" + +from aggregate_builders import AggregateKind, AggregateMeasure, sum as sum_builder +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr +from substrait.function_extensions import SUM_FUNCTION_ANCHOR + + +@function_registry.add("sum", deterministic_spec( + FunctionClass.Aggregate, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NullSkippingAggregate, + extension_mapping("sum", SUM_FUNCTION_ANCHOR), +)) +pub def sum(expr: ColumnExpr) -> AggregateMeasure: + """ + Build a sum aggregate measure. + + Examples: + revenue = sum(col("amount")) + + Parameters: + expr: Numeric expression to aggregate. + """ + return sum_builder(expr) + + +module tests: + from projection_builders import col + def test_sum_builds_sum_aggregate_measure() -> None: + measure = sum(col("amount")) + assert measure.kind == AggregateKind.Sum diff --git a/src/functions/casts/cast.incn b/src/functions/casts/cast.incn new file mode 100644 index 0000000..196c546 --- /dev/null +++ b/src/functions/casts/cast.incn @@ -0,0 +1,56 @@ +""" +Cast scalar helper. + +`cast` records the target type as a scalar application option and lowers to Substrait's built-in Cast Rex shape. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + core_mapping, + deterministic_spec, + v0_1, +) +from functions.registry import function_registry, registered_application_with_options +from projection_builders import ColumnExpr, scalar_function_option + + +@function_registry.add("cast", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + core_mapping("cast"), +)) +pub def cast(expr: ColumnExpr, target_type: str) -> ColumnExpr: + """ + Build a cast expression. + + `cast` represents a required conversion. Backends should report conversion failures according to their normal cast + semantics. + + Examples: + amount_text = cast(col("amount"), "str") + + Parameters: + expr: Expression to convert. + target_type: InQL type spelling to cast to. + """ + return registered_application_with_options("cast", [expr], [scalar_function_option("target_type", target_type)]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + column_expr_option_value, + ) + def test_cast_records_target_type_option() -> None: + expr = cast(col("amount"), "decimal(10,2)") + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "cast" + assert column_expr_argument_count(expr) == 1 + assert column_expr_option_value(expr, "target_type") == "decimal(10,2)" diff --git a/src/functions/casts/mod.incn b/src/functions/casts/mod.incn new file mode 100644 index 0000000..3f7e875 --- /dev/null +++ b/src/functions/casts/mod.incn @@ -0,0 +1,4 @@ +"""Scalar cast helpers.""" + +pub from functions.casts.cast import cast +pub from functions.casts.try_cast import try_cast diff --git a/src/functions/casts/try_cast.incn b/src/functions/casts/try_cast.incn new file mode 100644 index 0000000..b393ea5 --- /dev/null +++ b/src/functions/casts/try_cast.incn @@ -0,0 +1,55 @@ +""" +Try-cast scalar helper. + +`try_cast` mirrors `cast` metadata and lowers to Substrait's built-in Cast Rex shape with return-null failure behavior. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + core_mapping, + deterministic_spec, + v0_1, +) +from functions.registry import function_registry, registered_application_with_options +from projection_builders import ColumnExpr, scalar_function_option + + +@function_registry.add("try_cast", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + core_mapping("try_cast"), +)) +pub def try_cast(expr: ColumnExpr, target_type: str) -> ColumnExpr: + """ + Build a null-on-failure cast expression. + + `try_cast` represents a best-effort conversion where invalid values become null instead of failing the query. + + Examples: + parsed_amount = try_cast(col("amount_text"), "float") + + Parameters: + expr: Expression to convert. + target_type: InQL type spelling to cast to. + """ + return registered_application_with_options("try_cast", [expr], [scalar_function_option("target_type", target_type)]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + column_expr_option_value, + ) + def test_try_cast_records_target_type_option() -> None: + expr = try_cast(col("amount"), "float64") + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "try_cast" + assert column_expr_argument_count(expr) == 1 + assert column_expr_option_value(expr, "target_type") == "float64" diff --git a/src/functions/conditionals/case_when.incn b/src/functions/conditionals/case_when.incn new file mode 100644 index 0000000..b238c8d --- /dev/null +++ b/src/functions/conditionals/case_when.incn @@ -0,0 +1,78 @@ +""" +Searched case helper. + +`case_when` records condition/result pairs in one generic application node. Lowering derives the pair boundary from the +validated flattened argument list. +""" + +from rust::incan_stdlib::errors import raise_value_error +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + core_mapping, + deterministic_spec, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr + + +@function_registry.add("case_when", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + core_mapping("if_then"), +)) +pub def case_when(conditions: list[ColumnExpr], results: list[ColumnExpr], otherwise: ColumnExpr) -> ColumnExpr: + """ + Build a searched conditional expression. + + Conditions and results are paired by index. The `otherwise` expression is used when none of the conditions match. + + Examples: + bucket = case_when( + [gt(col("amount"), lit(1000)), gt(col("amount"), lit(100))], + [lit("large"), lit("medium")], + lit("small"), + ) + + Parameters: + conditions: Predicate expressions evaluated in order. + results: Result expressions paired with `conditions`. + otherwise: Result expression used when no condition matches. + """ + if len(conditions) == 0: + return raise_value_error("case_when requires at least one condition/result pair") + if len(conditions) != len(results): + return raise_value_error("case_when requires one result for each condition") + mut arguments: list[ColumnExpr] = [] + arguments.extend(conditions) + arguments.extend(results) + arguments.append(otherwise) + return registered_application("case_when", arguments) + + +module tests: + from std.testing import assert_raises + from projection_builders import ( + ColumnExprKind, + bool_expr, + str_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_case_when_records_condition_and_result_arguments() -> None: + expr = case_when([bool_expr(true)], [str_expr("large")], str_expr("small")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "case_when" + assert column_expr_argument_count(expr) == 3 + def _call_case_when_without_conditions() -> None: + case_when([], [], str_expr("fallback")) + def _call_case_when_with_mismatched_results() -> None: + case_when([bool_expr(true)], [], str_expr("fallback")) + def test_case_when_rejects_empty_conditions() -> None: + assert_raises[ValueError](_call_case_when_without_conditions) + def test_case_when_rejects_mismatched_results() -> None: + assert_raises[ValueError](_call_case_when_with_mismatched_results) diff --git a/src/functions/conditionals/coalesce.incn b/src/functions/conditionals/coalesce.incn new file mode 100644 index 0000000..9bc8044 --- /dev/null +++ b/src/functions/conditionals/coalesce.incn @@ -0,0 +1,63 @@ +""" +Coalesce conditional helper. + +`coalesce` accepts a non-empty expression list and stores the call as a variadic registry-backed scalar application. +""" + +from rust::incan_stdlib::errors import raise_value_error +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import COALESCE_FUNCTION_ANCHOR + + +@function_registry.add("coalesce", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("coalesce", COALESCE_FUNCTION_ANCHOR), +)) +pub def coalesce(values: list[ColumnExpr]) -> ColumnExpr: + """ + Build a first-non-null expression. + + `coalesce` requires at least one candidate expression. It returns the first candidate that evaluates to a non-null + value according to backend null semantics. + + Examples: + display_name = coalesce([col("preferred_name"), col("legal_name"), lit("unknown")]) + + Parameters: + values: Candidate expressions in priority order. + """ + if len(values) == 0: + return raise_value_error("coalesce requires at least one scalar expression") + return registered_application("coalesce", values) + + +module tests: + from std.testing import assert_raises + from projection_builders import ( + ColumnExprKind, + col, + str_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_coalesce_builds_variadic_registered_application() -> None: + expr = coalesce([col("status"), str_expr("unknown")]) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "coalesce" + assert column_expr_argument_count(expr) == 2 + def _call_empty_coalesce() -> None: + coalesce([]) + def test_coalesce_rejects_empty_values() -> None: + assert_raises[ValueError](_call_empty_coalesce) diff --git a/src/functions/conditionals/mod.incn b/src/functions/conditionals/mod.incn new file mode 100644 index 0000000..f87523e --- /dev/null +++ b/src/functions/conditionals/mod.incn @@ -0,0 +1,5 @@ +"""Conditional scalar helpers.""" + +pub from functions.conditionals.coalesce import coalesce +pub from functions.conditionals.nullif import nullif +pub from functions.conditionals.case_when import case_when diff --git a/src/functions/conditionals/nullif.incn b/src/functions/conditionals/nullif.incn new file mode 100644 index 0000000..2f758a2 --- /dev/null +++ b/src/functions/conditionals/nullif.incn @@ -0,0 +1,53 @@ +""" +Null-if conditional helper. + +`nullif` is represented as a binary registry-backed scalar application. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import NULLIF_FUNCTION_ANCHOR + + +@function_registry.add("nullif", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("nullif", NULLIF_FUNCTION_ANCHOR), +)) +pub def nullif(expr: ColumnExpr, null_expr: ColumnExpr) -> ColumnExpr: + """ + Build an expression that returns null when two expressions are equal. + + Examples: + normalized_discount = nullif(col("discount_code"), lit("")) + + Parameters: + expr: Expression returned when the comparison does not match. + null_expr: Expression compared with `expr` to decide whether to return null. + """ + return registered_application("nullif", [expr, null_expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + str_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_nullif_builds_registered_application() -> None: + expr = nullif(col("status"), str_expr("")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "nullif" + assert column_expr_argument_count(expr) == 2 diff --git a/src/functions/formatting/display.incn b/src/functions/formatting/display.incn new file mode 100644 index 0000000..b3560bf --- /dev/null +++ b/src/functions/formatting/display.incn @@ -0,0 +1,48 @@ +""" +Interactive LazyFrame display helper. + +`display` lives in its own module to keep the public function facade regular and easy to scan. +""" + +from dataset import DataFrame, LazyFrame + + +pub def display[T with Clone](data: LazyFrame[T]) -> None: + """ + Collect a `LazyFrame` and render a compact preview. + + `display` is a public convenience helper, not a registry-backed scalar function. + + Examples: + display(query) + + Parameters: + data: Lazy relational plan to collect and print. + """ + match data.clone().collect(): + Ok(df) => _render_data_frame(df) + Err(err) => println(f"InQL display failed: {err.error_message()}") + + +def _join_columns(columns: list[str]) -> str: + """Join column names for one compact display header line.""" + return ", ".join(columns) + + +def _render_data_frame[T with Clone](data: DataFrame[T]) -> None: + """Render one collected DataFrame preview using structured materialization metadata.""" + preview = data.preview_text() + columns = data.columns() + println("InQL display") + if len(columns) > 0: + println(f"columns: [{_join_columns(columns)}]") + println(f"rows: {data.row_count()}") + if len(preview) == 0: + println("(empty)") + return + println(preview) + + +module tests: + def test_join_columns_renders_compact_header() -> None: + assert _join_columns(["id", "amount"]) == "id, amount" diff --git a/src/functions/formatting/mod.incn b/src/functions/formatting/mod.incn new file mode 100644 index 0000000..4299fa8 --- /dev/null +++ b/src/functions/formatting/mod.incn @@ -0,0 +1,3 @@ +"""Interactive formatting and display helpers.""" + +pub from functions.formatting.display import display diff --git a/src/functions/literals/always_false.incn b/src/functions/literals/always_false.incn new file mode 100644 index 0000000..8c784d0 --- /dev/null +++ b/src/functions/literals/always_false.incn @@ -0,0 +1,38 @@ +""" +Constant false predicate helper. + +`always_false` is registered as a predicate helper but remains a structural boolean literal. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + rewrite_mapping, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr, bool_expr + + +@function_registry.add("always_false", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + rewrite_mapping("boolean false literal predicate"), +)) +pub def always_false() -> ColumnExpr: + """ + Build a predicate that rejects every row. + + Examples: + disabled_filter = always_false() + """ + return bool_expr(false) + + +module tests: + from projection_builders import ColumnExprKind, column_expr_kind + def test_always_false_uses_structural_bool_literal() -> None: + assert column_expr_kind(always_false()) == ColumnExprKind.BoolLiteral diff --git a/src/functions/literals/always_true.incn b/src/functions/literals/always_true.incn new file mode 100644 index 0000000..808730a --- /dev/null +++ b/src/functions/literals/always_true.incn @@ -0,0 +1,39 @@ +""" +Constant true predicate helper. + +`always_true` is registered as a predicate helper but remains a structural boolean literal so filter rewrites can treat +it as a no-op predicate. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + rewrite_mapping, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr, bool_expr + + +@function_registry.add("always_true", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + rewrite_mapping("boolean true literal predicate"), +)) +pub def always_true() -> ColumnExpr: + """ + Build a predicate that accepts every row. + + Examples: + default_filter = always_true() + """ + return bool_expr(true) + + +module tests: + from projection_builders import ColumnExprKind, column_expr_kind + def test_always_true_uses_structural_bool_literal() -> None: + assert column_expr_kind(always_true()) == ColumnExprKind.BoolLiteral diff --git a/src/functions/literals/bool_expr.incn b/src/functions/literals/bool_expr.incn new file mode 100644 index 0000000..e5eb87f --- /dev/null +++ b/src/functions/literals/bool_expr.incn @@ -0,0 +1,41 @@ +""" +Boolean literal expression helper. + +`bool_expr` is the typed literal spelling for boolean scalar values and is also used by predicate constants. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + rewrite_mapping, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr, bool_expr as bool_expr_builder + + +@function_registry.add("bool_expr", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + rewrite_mapping("typed boolean literal expression"), +)) +pub def bool_expr(value: bool) -> ColumnExpr: + """ + Build a boolean literal expression. + + Examples: + literal_true = bool_expr(true) + + Parameters: + value: Boolean value to embed in the expression tree. + """ + return bool_expr_builder(value) + + +module tests: + from projection_builders import ColumnExprKind, column_expr_kind + def test_bool_expr_uses_structural_bool_literal() -> None: + assert column_expr_kind(bool_expr(true)) == ColumnExprKind.BoolLiteral diff --git a/src/functions/literals/bool_lit.incn b/src/functions/literals/bool_lit.incn new file mode 100644 index 0000000..519934e --- /dev/null +++ b/src/functions/literals/bool_lit.incn @@ -0,0 +1,41 @@ +""" +Filter-style boolean literal helper. + +`bool_lit` keeps the predicate-friendly spelling while returning the same structural boolean literal as `bool_expr`. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + rewrite_mapping, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr, bool_expr + + +@function_registry.add("bool_lit", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + rewrite_mapping("filter-helper boolean literal expression"), +)) +pub def bool_lit(value: bool) -> ColumnExpr: + """ + Build a boolean literal expression using the filter-helper naming style. + + Examples: + explicit_true = bool_lit(true) + + Parameters: + value: Boolean value to embed in the expression tree. + """ + return bool_expr(value) + + +module tests: + from projection_builders import ColumnExprKind, column_expr_kind + def test_bool_lit_uses_structural_bool_literal() -> None: + assert column_expr_kind(bool_lit(false)) == ColumnExprKind.BoolLiteral diff --git a/src/functions/literals/float_expr.incn b/src/functions/literals/float_expr.incn new file mode 100644 index 0000000..116869e --- /dev/null +++ b/src/functions/literals/float_expr.incn @@ -0,0 +1,41 @@ +""" +Float literal expression helper. + +`float_expr` keeps floating-point literals as structural scalar nodes so lowering can emit a literal directly. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + rewrite_mapping, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr, float_expr as float_expr_builder + + +@function_registry.add("float_expr", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + rewrite_mapping("typed floating-point literal expression"), +)) +pub def float_expr(value: float) -> ColumnExpr: + """ + Build a floating-point literal expression. + + Examples: + discounted = mul(col("price"), float_expr(0.9)) + + Parameters: + value: Floating-point value to embed in the expression tree. + """ + return float_expr_builder(value) + + +module tests: + from projection_builders import ColumnExprKind, column_expr_kind + def test_float_expr_uses_structural_float_literal() -> None: + assert column_expr_kind(float_expr(1.5)) == ColumnExprKind.FloatLiteral diff --git a/src/functions/literals/int_expr.incn b/src/functions/literals/int_expr.incn new file mode 100644 index 0000000..828e907 --- /dev/null +++ b/src/functions/literals/int_expr.incn @@ -0,0 +1,42 @@ +""" +Integer literal expression helper. + +`int_expr` is a typed spelling over the canonical structural literal model used by projections, filters, and aggregate +inputs. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + rewrite_mapping, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr, int_expr as int_expr_builder + + +@function_registry.add("int_expr", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + rewrite_mapping("typed integer literal expression"), +)) +pub def int_expr(value: int) -> ColumnExpr: + """ + Build an integer literal expression. + + Examples: + high_value = gt(col("amount"), int_expr(1000)) + + Parameters: + value: Integer value to embed in the expression tree. + """ + return int_expr_builder(value) + + +module tests: + from projection_builders import ColumnExprKind, column_expr_kind + def test_int_expr_uses_structural_int_literal() -> None: + assert column_expr_kind(int_expr(42)) == ColumnExprKind.IntLiteral diff --git a/src/functions/literals/int_lit.incn b/src/functions/literals/int_lit.incn new file mode 100644 index 0000000..1ae1633 --- /dev/null +++ b/src/functions/literals/int_lit.incn @@ -0,0 +1,42 @@ +""" +Filter-style integer literal helper. + +`int_lit` preserves the existing predicate-builder spelling while routing through the same canonical structural literal +representation as `int_expr`. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + rewrite_mapping, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr, int_expr + + +@function_registry.add("int_lit", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + rewrite_mapping("filter-helper integer literal expression"), +)) +pub def int_lit(value: int) -> ColumnExpr: + """ + Build an integer literal expression using the filter-helper naming style. + + Examples: + high_value = gt(col("amount"), int_lit(1000)) + + Parameters: + value: Integer value to embed in the expression tree. + """ + return int_expr(value) + + +module tests: + from projection_builders import ColumnExprKind, column_expr_kind + def test_int_lit_uses_structural_int_literal() -> None: + assert column_expr_kind(int_lit(7)) == ColumnExprKind.IntLiteral diff --git a/src/functions/literals/lit.incn b/src/functions/literals/lit.incn new file mode 100644 index 0000000..f5c2528 --- /dev/null +++ b/src/functions/literals/lit.incn @@ -0,0 +1,47 @@ +""" +Canonical literal helper. + +`lit` chooses the structural literal node for the host value type. It is registered for documentation and metadata but +does not create a scalar function application. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + rewrite_mapping, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr, lit as lit_builder + + +@function_registry.add("lit", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + rewrite_mapping("typed Substrait literal expression"), +)) +pub def lit(value: Union[int, float, str, bool]) -> ColumnExpr: + """ + Build a canonical scalar literal expression. + + `lit` chooses the structural literal node for the provided host value type instead of wrapping literals as scalar + function calls. + + Examples: + paid = eq(col("status"), lit("paid")) + expensive = gt(col("amount"), lit(100)) + + Parameters: + value: Literal value to embed in the expression tree. + """ + return lit_builder(value) + + +module tests: + from projection_builders import ColumnExprKind, column_expr_kind + def test_lit_selects_structural_literal_node() -> None: + assert column_expr_kind(lit(7)) == ColumnExprKind.IntLiteral + assert column_expr_kind(lit("paid")) == ColumnExprKind.StringLiteral diff --git a/src/functions/literals/mod.incn b/src/functions/literals/mod.incn new file mode 100644 index 0000000..45909ae --- /dev/null +++ b/src/functions/literals/mod.incn @@ -0,0 +1,12 @@ +"""Literal and constant scalar helpers.""" + +pub from functions.literals.lit import lit +pub from functions.literals.int_expr import int_expr +pub from functions.literals.float_expr import float_expr +pub from functions.literals.str_expr import str_expr +pub from functions.literals.bool_expr import bool_expr +pub from functions.literals.int_lit import int_lit +pub from functions.literals.str_lit import str_lit +pub from functions.literals.bool_lit import bool_lit +pub from functions.literals.always_true import always_true +pub from functions.literals.always_false import always_false diff --git a/src/functions/literals/str_expr.incn b/src/functions/literals/str_expr.incn new file mode 100644 index 0000000..9587e65 --- /dev/null +++ b/src/functions/literals/str_expr.incn @@ -0,0 +1,41 @@ +""" +String literal expression helper. + +`str_expr` is the typed literal spelling for string values and shares the same structural node as `lit("...")`. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + rewrite_mapping, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr, str_expr as str_expr_builder + + +@function_registry.add("str_expr", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + rewrite_mapping("typed string literal expression"), +)) +pub def str_expr(value: str) -> ColumnExpr: + """ + Build a string literal expression. + + Examples: + paid = eq(col("status"), str_expr("paid")) + + Parameters: + value: String value to embed in the expression tree. + """ + return str_expr_builder(value) + + +module tests: + from projection_builders import ColumnExprKind, column_expr_kind + def test_str_expr_uses_structural_string_literal() -> None: + assert column_expr_kind(str_expr("paid")) == ColumnExprKind.StringLiteral diff --git a/src/functions/literals/str_lit.incn b/src/functions/literals/str_lit.incn new file mode 100644 index 0000000..624acc4 --- /dev/null +++ b/src/functions/literals/str_lit.incn @@ -0,0 +1,41 @@ +""" +Filter-style string literal helper. + +`str_lit` preserves existing filter helper imports while sharing the canonical string literal node. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + rewrite_mapping, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr, str_expr + + +@function_registry.add("str_lit", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + rewrite_mapping("filter-helper string literal expression"), +)) +pub def str_lit(value: str) -> ColumnExpr: + """ + Build a string literal expression using the filter-helper naming style. + + Examples: + paid = eq(col("status"), str_lit("paid")) + + Parameters: + value: String value to embed in the expression tree. + """ + return str_expr(value) + + +module tests: + from projection_builders import ColumnExprKind, column_expr_kind + def test_str_lit_uses_structural_string_literal() -> None: + assert column_expr_kind(str_lit("paid")) == ColumnExprKind.StringLiteral diff --git a/src/functions/mod.incn b/src/functions/mod.incn new file mode 100644 index 0000000..b5eb257 --- /dev/null +++ b/src/functions/mod.incn @@ -0,0 +1,100 @@ +""" +Explicit builder helpers for the current InQL relational surface. + +Each helper lives in a logical family submodule with declaration-side registry metadata and local inline tests. This +facade keeps the public import surface stable without owning a second machine-readable registry list. +""" + +from function_registry import FunctionRegistryEntry +from functions.registry import ( + function_registry_canonical_names as raw_function_registry_canonical_names, + function_registry_entries as raw_function_registry_entries, + function_registry_entry as raw_function_registry_entry, + function_registry_entry_by_name as raw_function_registry_entry_by_name, + function_registry_entry_count as raw_function_registry_entry_count, + function_registry_function_refs as raw_function_registry_function_refs, + registered_substrait_mapped_function_refs as raw_registered_substrait_mapped_function_refs, +) +pub from functions.registry import function_registry +pub from functions.references.col import col +pub from functions.literals.always_false import always_false +pub from functions.literals.always_true import always_true +pub from functions.literals.bool_expr import bool_expr +pub from functions.literals.bool_lit import bool_lit +pub from functions.literals.float_expr import float_expr +pub from functions.literals.int_expr import int_expr +pub from functions.literals.int_lit import int_lit +pub from functions.literals.lit import lit +pub from functions.literals.str_expr import str_expr +pub from functions.literals.str_lit import str_lit +pub from functions.aggregates.count import count +pub from functions.aggregates.sum import sum +pub from functions.operators.add import add +pub from functions.operators.and_ import and_ +pub from functions.operators.div import div +pub from functions.operators.eq import eq +pub from functions.operators.equal_null import equal_null +pub from functions.operators.gt import gt +pub from functions.operators.gte import gte +pub from functions.operators.lt import lt +pub from functions.operators.lte import lte +pub from functions.operators.modulo import modulo +pub from functions.operators.mul import mul +pub from functions.operators.ne import ne +pub from functions.operators.neg import neg +pub from functions.operators.not_ import not_ +pub from functions.operators.or_ import or_ +pub from functions.operators.sub import sub +pub from functions.casts.cast import cast +pub from functions.casts.try_cast import try_cast +pub from functions.predicates.between import between +pub from functions.predicates.in_ import in_ +pub from functions.predicates.is_nan import is_nan +pub from functions.predicates.is_not_nan import is_not_nan +pub from functions.predicates.is_not_null import is_not_null +pub from functions.predicates.is_null import is_null +pub from functions.conditionals.case_when import case_when +pub from functions.conditionals.coalesce import coalesce +pub from functions.conditionals.nullif import nullif +pub from functions.ordering.asc import asc +pub from functions.ordering.asc_nulls_first import asc_nulls_first +pub from functions.ordering.asc_nulls_last import asc_nulls_last +pub from functions.ordering.desc import desc +pub from functions.ordering.desc_nulls_first import desc_nulls_first +pub from functions.ordering.desc_nulls_last import desc_nulls_last +pub from functions.formatting.display import display + + +pub def function_registry_entries() -> list[FunctionRegistryEntry]: + """Return runtime registry entries for helpers loaded in the current process.""" + return raw_function_registry_entries() + + +pub def function_registry_entry(function_ref: str) -> Option[FunctionRegistryEntry]: + """Return a loaded registry entry for one stable function reference when it is known.""" + return raw_function_registry_entry(function_ref) + + +pub def function_registry_entry_by_name(canonical_name: str) -> Option[FunctionRegistryEntry]: + """Return a loaded registry entry for one canonical public function name when it is known.""" + return raw_function_registry_entry_by_name(canonical_name) + + +pub def function_registry_function_refs() -> list[str]: + """Return loaded function references in runtime registry order.""" + return raw_function_registry_function_refs() + + +pub def function_registry_canonical_names() -> list[str]: + """Return loaded canonical function names in runtime registry order.""" + return raw_function_registry_canonical_names() + + +pub def function_registry_entry_count() -> int: + """Return the number of loaded runtime registry entries.""" + return raw_function_registry_entry_count() + + +pub def registered_substrait_mapped_function_refs() -> list[str]: + """Return loaded function references with a concrete Substrait extension mapping.""" + return raw_registered_substrait_mapped_function_refs() diff --git a/src/functions/operators/add.incn b/src/functions/operators/add.incn new file mode 100644 index 0000000..c383a93 --- /dev/null +++ b/src/functions/operators/add.incn @@ -0,0 +1,55 @@ +""" +Addition scalar helper. + +`add` is a registry-backed scalar application. Its Substrait extension name and anchor live in the decorator metadata, +not in a lowering switchboard. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import ADD_FUNCTION_ANCHOR + + +@function_registry.add("add", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("add", ADD_FUNCTION_ANCHOR), +)) +pub def add(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Build an addition expression. + + Examples: + total = add(col("subtotal"), col("tax")) + adjusted = add(col("amount"), lit(5)) + + Parameters: + left: Numeric expression on the left side. + right: Numeric expression to add to `left`. + """ + return registered_application("add", [left, right]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + int_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_add_builds_registered_application() -> None: + expr = add(col("amount"), int_expr(1)) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "add" + assert column_expr_argument_count(expr) == 2 diff --git a/src/functions/operators/and_.incn b/src/functions/operators/and_.incn new file mode 100644 index 0000000..5274088 --- /dev/null +++ b/src/functions/operators/and_.incn @@ -0,0 +1,55 @@ +""" +Boolean conjunction helper. + +`and_` uses a trailing underscore to avoid host-language keyword collisions while keeping the canonical helper spelling +stable for InQL users. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import AND_FUNCTION_ANCHOR + + +@function_registry.add("and_", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + extension_mapping("and", AND_FUNCTION_ANCHOR), +)) +pub def and_(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Build a boolean conjunction. + + Use `and_` instead of `and` because `and` is reserved by the host language. + + Examples: + eligible = and_(gt(col("amount"), lit(1000)), eq(col("status"), lit("paid"))) + + Parameters: + left: Predicate expression on the left side. + right: Predicate expression that must also be true. + """ + return registered_application("and_", [left, right]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + bool_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_and_builds_registered_application() -> None: + expr = and_(bool_expr(true), bool_expr(false)) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "and_" + assert column_expr_argument_count(expr) == 2 diff --git a/src/functions/operators/div.incn b/src/functions/operators/div.incn new file mode 100644 index 0000000..575d585 --- /dev/null +++ b/src/functions/operators/div.incn @@ -0,0 +1,53 @@ +""" +Division scalar helper. + +`div` is represented as a generic registry-backed scalar application and lowers through `divide`. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import DIVIDE_FUNCTION_ANCHOR + + +@function_registry.add("div", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("divide", DIVIDE_FUNCTION_ANCHOR), +)) +pub def div(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Build a division expression. + + Examples: + unit_margin = div(col("margin"), col("quantity")) + + Parameters: + left: Numerator expression. + right: Denominator expression. + """ + return registered_application("div", [left, right]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + int_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_div_builds_registered_application() -> None: + expr = div(col("amount"), int_expr(2)) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "div" + assert column_expr_argument_count(expr) == 2 diff --git a/src/functions/operators/eq.incn b/src/functions/operators/eq.incn new file mode 100644 index 0000000..83a4c8f --- /dev/null +++ b/src/functions/operators/eq.incn @@ -0,0 +1,54 @@ +""" +Equality predicate helper. + +`eq` is a registry-backed predicate application using the existing Substrait equality extension anchor. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import EQUAL_FUNCTION_ANCHOR + + +@function_registry.add("eq", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + extension_mapping("equal", EQUAL_FUNCTION_ANCHOR), +)) +pub def eq(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Build an equality predicate. + + Examples: + paid = eq(col("status"), lit("paid")) + high_value_paid = and_(paid, gt(col("amount"), lit(100))) + + Parameters: + left: Expression on the left side of the comparison. + right: Expression to compare against `left`. + """ + return registered_application("eq", [left, right]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + str_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_eq_builds_registered_application() -> None: + expr = eq(col("status"), str_expr("paid")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "eq" + assert column_expr_argument_count(expr) == 2 diff --git a/src/functions/operators/equal_null.incn b/src/functions/operators/equal_null.incn new file mode 100644 index 0000000..e15aaa1 --- /dev/null +++ b/src/functions/operators/equal_null.incn @@ -0,0 +1,55 @@ +""" +Null-safe equality predicate helper. + +`equal_null` exposes null-safe equality intent and lowers through the Substrait `is_not_distinct_from` mapping. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import IS_NOT_DISTINCT_FROM_FUNCTION_ANCHOR + + +@function_registry.add("equal_null", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + extension_mapping("is_not_distinct_from", IS_NOT_DISTINCT_FROM_FUNCTION_ANCHOR), +)) +pub def equal_null(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Build a null-safe equality predicate. + + `equal_null` treats two null values as equal, unlike ordinary equality semantics where null usually propagates. + + Examples: + same_optional_code = equal_null(col("source_code"), col("target_code")) + + Parameters: + left: Nullable expression on the left side. + right: Nullable expression to compare against `left`. + """ + return registered_application("equal_null", [left, right]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + str_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_equal_null_builds_registered_application() -> None: + expr = equal_null(col("status"), str_expr("paid")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "equal_null" + assert column_expr_argument_count(expr) == 2 diff --git a/src/functions/operators/gt.incn b/src/functions/operators/gt.incn new file mode 100644 index 0000000..d60e325 --- /dev/null +++ b/src/functions/operators/gt.incn @@ -0,0 +1,53 @@ +""" +Greater-than predicate helper. + +`gt` is currently lowerable through the registered Substrait extension mapping. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import GT_FUNCTION_ANCHOR + + +@function_registry.add("gt", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + extension_mapping("gt", GT_FUNCTION_ANCHOR), +)) +pub def gt(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Build a greater-than predicate. + + Examples: + high_value = gt(col("amount"), lit(1000)) + + Parameters: + left: Expression expected to be greater than `right`. + right: Lower-bound expression. + """ + return registered_application("gt", [left, right]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + int_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_gt_builds_registered_application() -> None: + expr = gt(col("amount"), int_expr(10)) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "gt" + assert column_expr_argument_count(expr) == 2 diff --git a/src/functions/operators/gte.incn b/src/functions/operators/gte.incn new file mode 100644 index 0000000..9cf9b4a --- /dev/null +++ b/src/functions/operators/gte.incn @@ -0,0 +1,53 @@ +""" +Greater-than-or-equal predicate helper. + +`gte` completes the core comparison family without introducing a dedicated expression variant. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import GTE_FUNCTION_ANCHOR + + +@function_registry.add("gte", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + extension_mapping("gte", GTE_FUNCTION_ANCHOR), +)) +pub def gte(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Build a greater-than-or-equal predicate. + + Examples: + qualifies = gte(col("score"), lit(90)) + + Parameters: + left: Expression expected to be greater than or equal to `right`. + right: Inclusive lower-bound expression. + """ + return registered_application("gte", [left, right]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + int_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_gte_builds_registered_application() -> None: + expr = gte(col("amount"), int_expr(10)) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "gte" + assert column_expr_argument_count(expr) == 2 diff --git a/src/functions/operators/lt.incn b/src/functions/operators/lt.incn new file mode 100644 index 0000000..734329e --- /dev/null +++ b/src/functions/operators/lt.incn @@ -0,0 +1,53 @@ +""" +Less-than predicate helper. + +`lt` uses the registry-backed scalar application node and lowers through the Substrait `lt` extension mapping. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import LT_FUNCTION_ANCHOR + + +@function_registry.add("lt", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + extension_mapping("lt", LT_FUNCTION_ANCHOR), +)) +pub def lt(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Build a less-than predicate. + + Examples: + below_limit = lt(col("amount"), lit(1000)) + + Parameters: + left: Expression expected to be less than `right`. + right: Upper-bound expression. + """ + return registered_application("lt", [left, right]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + int_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_lt_builds_registered_application() -> None: + expr = lt(col("amount"), int_expr(10)) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "lt" + assert column_expr_argument_count(expr) == 2 diff --git a/src/functions/operators/lte.incn b/src/functions/operators/lte.incn new file mode 100644 index 0000000..c9dab03 --- /dev/null +++ b/src/functions/operators/lte.incn @@ -0,0 +1,53 @@ +""" +Less-than-or-equal predicate helper. + +`lte` is registered as a predicate and uses the generic scalar application node. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import LTE_FUNCTION_ANCHOR + + +@function_registry.add("lte", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + extension_mapping("lte", LTE_FUNCTION_ANCHOR), +)) +pub def lte(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Build a less-than-or-equal predicate. + + Examples: + within_limit = lte(col("amount"), lit(1000)) + + Parameters: + left: Expression expected to be less than or equal to `right`. + right: Inclusive upper-bound expression. + """ + return registered_application("lte", [left, right]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + int_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_lte_builds_registered_application() -> None: + expr = lte(col("amount"), int_expr(10)) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "lte" + assert column_expr_argument_count(expr) == 2 diff --git a/src/functions/operators/mod.incn b/src/functions/operators/mod.incn new file mode 100644 index 0000000..31e3781 --- /dev/null +++ b/src/functions/operators/mod.incn @@ -0,0 +1,18 @@ +"""Arithmetic, comparison, and boolean operator helpers.""" + +pub from functions.operators.add import add +pub from functions.operators.sub import sub +pub from functions.operators.mul import mul +pub from functions.operators.div import div +pub from functions.operators.modulo import modulo +pub from functions.operators.neg import neg +pub from functions.operators.eq import eq +pub from functions.operators.ne import ne +pub from functions.operators.lt import lt +pub from functions.operators.lte import lte +pub from functions.operators.gt import gt +pub from functions.operators.gte import gte +pub from functions.operators.equal_null import equal_null +pub from functions.operators.and_ import and_ +pub from functions.operators.or_ import or_ +pub from functions.operators.not_ import not_ diff --git a/src/functions/operators/modulo.incn b/src/functions/operators/modulo.incn new file mode 100644 index 0000000..93b74e8 --- /dev/null +++ b/src/functions/operators/modulo.incn @@ -0,0 +1,56 @@ +""" +Modulo scalar helper. + +`mod` is a reserved Incan word, so the public source helper is `modulo(...)` while the registry canonical name remains +`mod`. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import MODULUS_FUNCTION_ANCHOR + + +@function_registry.add("mod", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("modulus", MODULUS_FUNCTION_ANCHOR), +)) +pub def modulo(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Build a modulo expression. + + The registry name is `mod`; the public helper is `modulo` because `mod` is reserved by Incan modules. + + Examples: + shard = modulo(col("customer_id"), lit(16)) + + Parameters: + left: Dividend expression. + right: Divisor expression. + """ + return registered_application("mod", [left, right]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + int_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_modulo_builds_registered_mod_application() -> None: + expr = modulo(col("amount"), int_expr(2)) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "mod" + assert column_expr_argument_count(expr) == 2 diff --git a/src/functions/operators/mul.incn b/src/functions/operators/mul.incn new file mode 100644 index 0000000..05a7113 --- /dev/null +++ b/src/functions/operators/mul.incn @@ -0,0 +1,53 @@ +""" +Multiplication scalar helper. + +`mul` uses the public InQL spelling while declaring the Substrait extension name `multiply` in its registry metadata. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import MULTIPLY_FUNCTION_ANCHOR + + +@function_registry.add("mul", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("multiply", MULTIPLY_FUNCTION_ANCHOR), +)) +pub def mul(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Build a multiplication expression. + + Examples: + extended_price = mul(col("quantity"), col("unit_price")) + + Parameters: + left: Numeric expression on the left side. + right: Numeric expression to multiply with `left`. + """ + return registered_application("mul", [left, right]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + int_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_mul_builds_registered_application() -> None: + expr = mul(col("amount"), int_expr(2)) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "mul" + assert column_expr_argument_count(expr) == 2 diff --git a/src/functions/operators/ne.incn b/src/functions/operators/ne.incn new file mode 100644 index 0000000..87a54bc --- /dev/null +++ b/src/functions/operators/ne.incn @@ -0,0 +1,54 @@ +""" +Not-equal predicate helper. + +`ne` is modeled as the same registry-backed scalar application shape as other binary predicates and lowers through the +Substrait `not_equal` extension mapping. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import NOT_EQUAL_FUNCTION_ANCHOR + + +@function_registry.add("ne", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + extension_mapping("not_equal", NOT_EQUAL_FUNCTION_ANCHOR), +)) +pub def ne(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Build an inequality predicate. + + Examples: + not_cancelled = ne(col("status"), lit("cancelled")) + + Parameters: + left: Expression on the left side of the comparison. + right: Expression that must not equal `left`. + """ + return registered_application("ne", [left, right]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + str_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_ne_builds_registered_application() -> None: + expr = ne(col("status"), str_expr("void")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "ne" + assert column_expr_argument_count(expr) == 2 diff --git a/src/functions/operators/neg.incn b/src/functions/operators/neg.incn new file mode 100644 index 0000000..9f55591 --- /dev/null +++ b/src/functions/operators/neg.incn @@ -0,0 +1,51 @@ +""" +Unary numeric negation helper. + +`neg` is modeled as a unary scalar function application rather than a dedicated negation node. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import NEGATE_FUNCTION_ANCHOR + + +@function_registry.add("neg", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("negate", NEGATE_FUNCTION_ANCHOR), +)) +pub def neg(expr: ColumnExpr) -> ColumnExpr: + """ + Build a numeric negation expression. + + Examples: + refund_amount = neg(col("charge_amount")) + + Parameters: + expr: Numeric expression to negate. + """ + return registered_application("neg", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_neg_builds_registered_application() -> None: + expr = neg(col("amount")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "neg" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/operators/not_.incn b/src/functions/operators/not_.incn new file mode 100644 index 0000000..cbbd538 --- /dev/null +++ b/src/functions/operators/not_.incn @@ -0,0 +1,53 @@ +""" +Boolean negation helper. + +`not_` uses the generic unary scalar application shape and keeps keyword-safe source spelling. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import NOT_FUNCTION_ANCHOR + + +@function_registry.add("not_", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + extension_mapping("not", NOT_FUNCTION_ANCHOR), +)) +pub def not_(expr: ColumnExpr) -> ColumnExpr: + """ + Build a boolean negation. + + Use `not_` instead of `not` because `not` is reserved by the host language. + + Examples: + active = not_(eq(col("status"), lit("archived"))) + + Parameters: + expr: Predicate expression to negate. + """ + return registered_application("not_", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + bool_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_not_builds_registered_application() -> None: + expr = not_(bool_expr(true)) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "not_" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/operators/or_.incn b/src/functions/operators/or_.incn new file mode 100644 index 0000000..6bda19e --- /dev/null +++ b/src/functions/operators/or_.incn @@ -0,0 +1,54 @@ +""" +Boolean disjunction helper. + +`or_` is a registry-backed boolean predicate helper that lowers through the Substrait `or` extension mapping. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import OR_FUNCTION_ANCHOR + + +@function_registry.add("or_", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + extension_mapping("or", OR_FUNCTION_ANCHOR), +)) +pub def or_(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Build a boolean disjunction. + + Use `or_` instead of `or` because `or` is reserved by the host language. + + Examples: + visible = or_(eq(col("visibility"), lit("public")), eq(col("owner_id"), col("viewer_id"))) + + Parameters: + left: Predicate expression on the left side. + right: Predicate expression that may also make the result true. + """ + return registered_application("or_", [left, right]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + bool_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_or_builds_registered_application() -> None: + expr = or_(bool_expr(true), bool_expr(false)) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "or_" + assert column_expr_argument_count(expr) == 2 diff --git a/src/functions/operators/sub.incn b/src/functions/operators/sub.incn new file mode 100644 index 0000000..c806e34 --- /dev/null +++ b/src/functions/operators/sub.incn @@ -0,0 +1,53 @@ +""" +Subtraction scalar helper. + +`sub` is part of the core arithmetic surface and lowers through the Substrait `subtract` extension mapping. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import SUBTRACT_FUNCTION_ANCHOR + + +@function_registry.add("sub", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + extension_mapping("subtract", SUBTRACT_FUNCTION_ANCHOR), +)) +pub def sub(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """ + Build a subtraction expression. + + Examples: + remaining = sub(col("budget"), col("spent")) + + Parameters: + left: Numeric expression to subtract from. + right: Numeric expression to subtract. + """ + return registered_application("sub", [left, right]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + int_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_sub_builds_registered_application() -> None: + expr = sub(col("amount"), int_expr(1)) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "sub" + assert column_expr_argument_count(expr) == 2 diff --git a/src/functions/ordering/asc.incn b/src/functions/ordering/asc.incn new file mode 100644 index 0000000..f19a7eb --- /dev/null +++ b/src/functions/ordering/asc.incn @@ -0,0 +1,50 @@ +""" +Ascending ordering helper. + +`asc` records ordering intent for `order_by(...)` and lowers to a Substrait sort field in sort context. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + sort_field_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr + + +@function_registry.add("asc", deterministic_spec( + FunctionClass.Ordering, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + sort_field_mapping("asc_nulls_first"), +)) +pub def asc(expr: ColumnExpr) -> ColumnExpr: + """ + Build an ascending ordering expression. + + Examples: + by_amount = orders.order_by([asc(col("amount"))]) + + Parameters: + expr: Expression used as the sort key. + """ + return registered_application("asc", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_asc_builds_registered_application() -> None: + expr = asc(col("amount")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "asc" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/ordering/asc_nulls_first.incn b/src/functions/ordering/asc_nulls_first.incn new file mode 100644 index 0000000..718ef77 --- /dev/null +++ b/src/functions/ordering/asc_nulls_first.incn @@ -0,0 +1,50 @@ +""" +Ascending nulls-first ordering helper. + +`asc_nulls_first` preserves explicit null placement metadata as a registry-backed helper. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + sort_field_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr + + +@function_registry.add("asc_nulls_first", deterministic_spec( + FunctionClass.Ordering, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + sort_field_mapping("asc_nulls_first"), +)) +pub def asc_nulls_first(expr: ColumnExpr) -> ColumnExpr: + """ + Build an ascending ordering expression with nulls first. + + Examples: + by_optional_rank = ranks.order_by([asc_nulls_first(col("rank"))]) + + Parameters: + expr: Expression used as the sort key. + """ + return registered_application("asc_nulls_first", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_asc_nulls_first_builds_registered_application() -> None: + expr = asc_nulls_first(col("amount")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "asc_nulls_first" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/ordering/asc_nulls_last.incn b/src/functions/ordering/asc_nulls_last.incn new file mode 100644 index 0000000..a68d7c2 --- /dev/null +++ b/src/functions/ordering/asc_nulls_last.incn @@ -0,0 +1,50 @@ +""" +Ascending nulls-last ordering helper. + +`asc_nulls_last` is a unary registry-backed helper consumed by `order_by(...)` sort-field lowering. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + sort_field_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr + + +@function_registry.add("asc_nulls_last", deterministic_spec( + FunctionClass.Ordering, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + sort_field_mapping("asc_nulls_last"), +)) +pub def asc_nulls_last(expr: ColumnExpr) -> ColumnExpr: + """ + Build an ascending ordering expression with nulls last. + + Examples: + by_optional_rank = ranks.order_by([asc_nulls_last(col("rank"))]) + + Parameters: + expr: Expression used as the sort key. + """ + return registered_application("asc_nulls_last", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_asc_nulls_last_builds_registered_application() -> None: + expr = asc_nulls_last(col("amount")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "asc_nulls_last" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/ordering/desc.incn b/src/functions/ordering/desc.incn new file mode 100644 index 0000000..aaaa997 --- /dev/null +++ b/src/functions/ordering/desc.incn @@ -0,0 +1,50 @@ +""" +Descending ordering helper. + +`desc` records descending order intent for `order_by(...)` and lowers to a Substrait sort field in sort context. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + sort_field_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr + + +@function_registry.add("desc", deterministic_spec( + FunctionClass.Ordering, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + sort_field_mapping("desc_nulls_last"), +)) +pub def desc(expr: ColumnExpr) -> ColumnExpr: + """ + Build a descending ordering expression. + + Examples: + newest_first = orders.order_by([desc(col("created_at"))]) + + Parameters: + expr: Expression used as the sort key. + """ + return registered_application("desc", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_desc_builds_registered_application() -> None: + expr = desc(col("amount")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "desc" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/ordering/desc_nulls_first.incn b/src/functions/ordering/desc_nulls_first.incn new file mode 100644 index 0000000..ebaf9b9 --- /dev/null +++ b/src/functions/ordering/desc_nulls_first.incn @@ -0,0 +1,50 @@ +""" +Descending nulls-first ordering helper. + +`desc_nulls_first` keeps ordering metadata in the same scalar helper family as `asc` and `desc`. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + sort_field_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr + + +@function_registry.add("desc_nulls_first", deterministic_spec( + FunctionClass.Ordering, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + sort_field_mapping("desc_nulls_first"), +)) +pub def desc_nulls_first(expr: ColumnExpr) -> ColumnExpr: + """ + Build a descending ordering expression with nulls first. + + Examples: + by_optional_score = scores.order_by([desc_nulls_first(col("score"))]) + + Parameters: + expr: Expression used as the sort key. + """ + return registered_application("desc_nulls_first", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_desc_nulls_first_builds_registered_application() -> None: + expr = desc_nulls_first(col("amount")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "desc_nulls_first" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/ordering/desc_nulls_last.incn b/src/functions/ordering/desc_nulls_last.incn new file mode 100644 index 0000000..28d6114 --- /dev/null +++ b/src/functions/ordering/desc_nulls_last.incn @@ -0,0 +1,50 @@ +""" +Descending nulls-last ordering helper. + +`desc_nulls_last` is the null-placement variant for descending order in `order_by(...)` sort-field lowering. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + sort_field_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr + + +@function_registry.add("desc_nulls_last", deterministic_spec( + FunctionClass.Ordering, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + sort_field_mapping("desc_nulls_last"), +)) +pub def desc_nulls_last(expr: ColumnExpr) -> ColumnExpr: + """ + Build a descending ordering expression with nulls last. + + Examples: + by_optional_score = scores.order_by([desc_nulls_last(col("score"))]) + + Parameters: + expr: Expression used as the sort key. + """ + return registered_application("desc_nulls_last", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_desc_nulls_last_builds_registered_application() -> None: + expr = desc_nulls_last(col("amount")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "desc_nulls_last" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/ordering/mod.incn b/src/functions/ordering/mod.incn new file mode 100644 index 0000000..e28e9bc --- /dev/null +++ b/src/functions/ordering/mod.incn @@ -0,0 +1,8 @@ +"""Ordering expression helpers.""" + +pub from functions.ordering.asc import asc +pub from functions.ordering.desc import desc +pub from functions.ordering.asc_nulls_first import asc_nulls_first +pub from functions.ordering.asc_nulls_last import asc_nulls_last +pub from functions.ordering.desc_nulls_first import desc_nulls_first +pub from functions.ordering.desc_nulls_last import desc_nulls_last diff --git a/src/functions/predicates/between.incn b/src/functions/predicates/between.incn new file mode 100644 index 0000000..9bfc4f6 --- /dev/null +++ b/src/functions/predicates/between.incn @@ -0,0 +1,54 @@ +""" +Inclusive range predicate helper. + +`between` is a ternary registry-backed scalar application over expression, lower bound, and upper bound. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import BETWEEN_FUNCTION_ANCHOR + + +@function_registry.add("between", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + extension_mapping("between", BETWEEN_FUNCTION_ANCHOR), +)) +pub def between(expr: ColumnExpr, lower: ColumnExpr, upper: ColumnExpr) -> ColumnExpr: + """ + Build an inclusive range predicate. + + Examples: + target_month = between(col("order_day"), lit(1), lit(31)) + + Parameters: + expr: Expression whose value is tested. + lower: Inclusive lower bound. + upper: Inclusive upper bound. + """ + return registered_application("between", [expr, lower, upper]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + int_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_between_builds_registered_application() -> None: + expr = between(col("amount"), int_expr(1), int_expr(10)) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "between" + assert column_expr_argument_count(expr) == 3 diff --git a/src/functions/predicates/in_.incn b/src/functions/predicates/in_.incn new file mode 100644 index 0000000..1f512a8 --- /dev/null +++ b/src/functions/predicates/in_.incn @@ -0,0 +1,62 @@ +""" +Membership predicate helper. + +`in_` uses keyword-safe source spelling and stores the value-count option alongside the scalar arguments. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + core_mapping, + deterministic_spec, + v0_1, +) +from functions.registry import function_registry, registered_application_with_options +from projection_builders import ColumnExpr, scalar_function_option + + +@function_registry.add("in_", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + core_mapping("singular_or_list"), +)) +pub def in_(expr: ColumnExpr, values: list[ColumnExpr]) -> ColumnExpr: + """ + Build a membership predicate. + + Use `in_` instead of `in` because `in` is reserved by the host language. + + Examples: + active_region = in_(col("region"), [lit("emea"), lit("amer")]) + + Parameters: + expr: Expression whose value is tested. + values: Candidate values to match against. + """ + mut arguments = [expr] + arguments.extend(values) + return registered_application_with_options( + "in_", + arguments, + [scalar_function_option("value_count", f"{len(values)}")], + ) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + str_expr, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + column_expr_option_value, + ) + def test_in_records_value_count_option() -> None: + expr = in_(col("status"), [str_expr("paid"), str_expr("open")]) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "in_" + assert column_expr_argument_count(expr) == 3 + assert column_expr_option_value(expr, "value_count") == "2" diff --git a/src/functions/predicates/is_nan.incn b/src/functions/predicates/is_nan.incn new file mode 100644 index 0000000..b4e51e6 --- /dev/null +++ b/src/functions/predicates/is_nan.incn @@ -0,0 +1,51 @@ +""" +NaN-test predicate helper. + +`is_nan` captures floating-point NaN test intent and lowers through the Substrait `is_nan` mapping. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import IS_NAN_FUNCTION_ANCHOR + + +@function_registry.add("is_nan", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + extension_mapping("is_nan", IS_NAN_FUNCTION_ANCHOR), +)) +pub def is_nan(expr: ColumnExpr) -> ColumnExpr: + """ + Build a predicate that checks whether a floating-point expression is NaN. + + Examples: + invalid_ratio = is_nan(col("conversion_rate")) + + Parameters: + expr: Floating-point expression to test. + """ + return registered_application("is_nan", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_is_nan_builds_registered_application() -> None: + expr = is_nan(col("amount")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "is_nan" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/predicates/is_not_nan.incn b/src/functions/predicates/is_not_nan.incn new file mode 100644 index 0000000..68639c8 --- /dev/null +++ b/src/functions/predicates/is_not_nan.incn @@ -0,0 +1,52 @@ +""" +Non-NaN-test predicate helper. + +`is_not_nan` is the public convenience spelling for `not_(is_nan(...))`. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + rewrite_mapping, + v0_1, +) +from functions.registry import function_registry +from functions.operators.not_ import not_ +from functions.predicates.is_nan import is_nan +from projection_builders import ColumnExpr + + +@function_registry.add("is_not_nan", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + rewrite_mapping("not_(is_nan(expr))"), +)) +pub def is_not_nan(expr: ColumnExpr) -> ColumnExpr: + """ + Build a predicate that checks whether a floating-point expression is not NaN. + + Examples: + usable_ratio = is_not_nan(col("conversion_rate")) + + Parameters: + expr: Floating-point expression to test. + """ + return not_(is_nan(expr)) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_is_not_nan_builds_canonical_rewrite_expression() -> None: + expr = is_not_nan(col("amount")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "not_" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/predicates/is_not_null.incn b/src/functions/predicates/is_not_null.incn new file mode 100644 index 0000000..6537970 --- /dev/null +++ b/src/functions/predicates/is_not_null.incn @@ -0,0 +1,51 @@ +""" +Non-null-test predicate helper. + +`is_not_null` is represented as a unary registry-backed predicate and lowers through `is_not_null`. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import IS_NOT_NULL_FUNCTION_ANCHOR + + +@function_registry.add("is_not_null", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + extension_mapping("is_not_null", IS_NOT_NULL_FUNCTION_ANCHOR), +)) +pub def is_not_null(expr: ColumnExpr) -> ColumnExpr: + """ + Build a predicate that checks whether an expression is not null. + + Examples: + has_email = is_not_null(col("email")) + + Parameters: + expr: Expression to test for a non-null value. + """ + return registered_application("is_not_null", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_is_not_null_builds_registered_application() -> None: + expr = is_not_null(col("status")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "is_not_null" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/predicates/is_null.incn b/src/functions/predicates/is_null.incn new file mode 100644 index 0000000..89d4601 --- /dev/null +++ b/src/functions/predicates/is_null.incn @@ -0,0 +1,51 @@ +""" +Null-test predicate helper. + +`is_null` captures null-test intent as a registry-backed unary predicate and lowers through `is_null`. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry, registered_application +from projection_builders import ColumnExpr +from substrait.function_extensions import IS_NULL_FUNCTION_ANCHOR + + +@function_registry.add("is_null", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.Predicate, + extension_mapping("is_null", IS_NULL_FUNCTION_ANCHOR), +)) +pub def is_null(expr: ColumnExpr) -> ColumnExpr: + """ + Build a predicate that checks whether an expression is null. + + Examples: + missing_email = is_null(col("email")) + + Parameters: + expr: Expression to test for null. + """ + return registered_application("is_null", [expr]) + + +module tests: + from projection_builders import ( + ColumnExprKind, + col, + column_expr_argument_count, + column_expr_function_name, + column_expr_kind, + ) + def test_is_null_builds_registered_application() -> None: + expr = is_null(col("status")) + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction + assert column_expr_function_name(expr) == "is_null" + assert column_expr_argument_count(expr) == 1 diff --git a/src/functions/predicates/mod.incn b/src/functions/predicates/mod.incn new file mode 100644 index 0000000..90c17f7 --- /dev/null +++ b/src/functions/predicates/mod.incn @@ -0,0 +1,8 @@ +"""Null, NaN, membership, and range predicate helpers.""" + +pub from functions.predicates.is_null import is_null +pub from functions.predicates.is_not_null import is_not_null +pub from functions.predicates.is_nan import is_nan +pub from functions.predicates.is_not_nan import is_not_nan +pub from functions.predicates.in_ import in_ +pub from functions.predicates.between import between diff --git a/src/functions/references/col.incn b/src/functions/references/col.incn new file mode 100644 index 0000000..4579306 --- /dev/null +++ b/src/functions/references/col.incn @@ -0,0 +1,45 @@ +""" +Column reference helper. + +`col` stays a structural scalar expression because column references are not scalar function applications. The registry +entry records that this helper lowers as a direct field-reference rewrite. +""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + rewrite_mapping, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr, col as col_builder + + +@function_registry.add("col", deterministic_spec( + FunctionClass.Scalar, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + rewrite_mapping("direct field reference selection"), +)) +pub def col(name: str) -> ColumnExpr: + """ + Build a named column reference. + + Examples: + amount = col("amount") + status = col("status") + + Parameters: + name: Column name as it appears in the input relation. + """ + return col_builder(name) + + +module tests: + from projection_builders import ColumnExprKind, column_expr_kind, column_expr_name + def test_col_builds_structural_column_reference() -> None: + expr = col("customer_id") + assert column_expr_kind(expr) == ColumnExprKind.Column + assert column_expr_name(expr) == "customer_id" diff --git a/src/functions/references/mod.incn b/src/functions/references/mod.incn new file mode 100644 index 0000000..717063d --- /dev/null +++ b/src/functions/references/mod.incn @@ -0,0 +1,3 @@ +"""Column and field reference helpers.""" + +pub from functions.references.col import col diff --git a/src/functions/registry.incn b/src/functions/registry.incn new file mode 100644 index 0000000..f0c5fc5 --- /dev/null +++ b/src/functions/registry.incn @@ -0,0 +1,84 @@ +""" +Shared declaration-side registry state for public InQL helpers. + +Helper modules attach metadata with `@function_registry.add(...)` next to the helper they expose. This module owns the +runtime projection for helpers that have actually been loaded in the current process. Checked API metadata remains the +source for the complete public catalog and helper signatures. +""" + +from rust::incan_stdlib::errors import raise_value_error +from function_registry import FunctionRegistry, FunctionRegistryEntry, SubstraitMapping, SubstraitMappingKind +from projection_builders import ColumnExpr, ScalarFunctionOption, scalar_function_application + +pub static function_registry: FunctionRegistry = FunctionRegistry.new() + + +pub def registered_substrait_mapping(canonical_name: str) -> SubstraitMapping: + """Return the decorator-registered Substrait mapping for one helper name.""" + entry = registered_function_entry(canonical_name) + return entry.substrait + + +pub def registered_function_entry(canonical_name: str) -> FunctionRegistryEntry: + """Return the decorator-registered metadata for one helper name.""" + match function_registry.entry_by_name(canonical_name): + Some(entry) => return entry + None => pass + message = f"missing function registry entry for `{canonical_name}`" + return raise_value_error(message) + + +pub def registered_application(canonical_name: str, arguments: list[ColumnExpr]) -> ColumnExpr: + """Build one scalar application using decorator-owned registry metadata.""" + return registered_application_with_options(canonical_name, arguments, []) + + +pub def registered_application_with_options( + canonical_name: str, + arguments: list[ColumnExpr], + options: list[ScalarFunctionOption], +) -> ColumnExpr: + """Build one scalar application with non-expression options using decorator-owned registry metadata.""" + registered_function_entry(canonical_name) + return scalar_function_application(canonical_name, arguments, options) + + +pub def function_registry_entries() -> list[FunctionRegistryEntry]: + """Return runtime registry entries for helpers loaded in the current process.""" + return function_registry.entries + + +pub def function_registry_entry(function_ref: str) -> Option[FunctionRegistryEntry]: + """Return a loaded registry entry for one stable function reference when it is known.""" + for entry in function_registry.entries: + if entry.function_ref == function_ref: + return Some(entry) + return None + + +pub def function_registry_entry_by_name(canonical_name: str) -> Option[FunctionRegistryEntry]: + """Return a loaded registry entry for one canonical public function name when it is known.""" + for entry in function_registry.entries: + if entry.canonical_name == canonical_name: + return Some(entry) + return None + + +pub def function_registry_function_refs() -> list[str]: + """Return loaded function references in runtime registry order.""" + return [entry.function_ref for entry in function_registry.entries] + + +pub def function_registry_canonical_names() -> list[str]: + """Return loaded canonical function names in runtime registry order.""" + return [entry.canonical_name for entry in function_registry.entries] + + +pub def function_registry_entry_count() -> int: + """Return the number of loaded runtime registry entries.""" + return len(function_registry.entries) + + +pub def registered_substrait_mapped_function_refs() -> list[str]: + """Return function references with a concrete Substrait extension mapping.""" + return [entry.function_ref for entry in function_registry.entries if entry.substrait.kind == SubstraitMappingKind.ExtensionFunction] diff --git a/src/lib.incn b/src/lib.incn index e2fc618..713e064 100644 --- a/src/lib.incn +++ b/src/lib.incn @@ -9,52 +9,82 @@ pub from dataset import BoundedDataSet, DataFrame, DataSet, DataStream, LazyFram pub from dataset.ops import agg_ds, explode_ds, filter_ds, group_by_ds, join_ds, limit_ds, order_by_ds, select_ds pub from aggregate_builders import AggregateKind, AggregateMeasure pub from projection_builders import ( - AddExpr, BoolLiteralExpr, ColumnExpr, ColumnExprKind, ColumnRefExpr, - EqExpr, FloatLiteralExpr, - GtExpr, IntLiteralExpr, - MultiplyExpr, ProjectionAssignment, + ScalarFunctionApplicationExpr, + ScalarFunctionOption, StringLiteralExpr, + column_expr_argument_count, + column_expr_function_name, + column_expr_function_ref, column_expr_kind, column_expr_name, + column_expr_option_value, ) +pub from functions.registry import function_registry pub from functions import ( - FUNCTION_REGISTRY, - add, - always_false, - always_true, - bool_expr, - bool_lit, - col, - count, - display, - eq, - float_expr, function_registry_canonical_names, function_registry_entries, function_registry_entry, function_registry_entry_by_name, function_registry_entry_count, function_registry_function_refs, - gt, - int_expr, - int_lit, - lit, - mul, registered_substrait_mapped_function_refs, - str_expr, - str_lit, - sum, ) +pub from functions.references.col import col +pub from functions.literals.always_false import always_false +pub from functions.literals.always_true import always_true +pub from functions.literals.bool_expr import bool_expr +pub from functions.literals.bool_lit import bool_lit +pub from functions.literals.float_expr import float_expr +pub from functions.literals.int_expr import int_expr +pub from functions.literals.int_lit import int_lit +pub from functions.literals.lit import lit +pub from functions.literals.str_expr import str_expr +pub from functions.literals.str_lit import str_lit +pub from functions.aggregates.count import count +pub from functions.aggregates.sum import sum +pub from functions.operators.add import add +pub from functions.operators.and_ import and_ +pub from functions.operators.div import div +pub from functions.operators.eq import eq +pub from functions.operators.equal_null import equal_null +pub from functions.operators.gt import gt +pub from functions.operators.gte import gte +pub from functions.operators.lt import lt +pub from functions.operators.lte import lte +pub from functions.operators.modulo import modulo +pub from functions.operators.mul import mul +pub from functions.operators.ne import ne +pub from functions.operators.neg import neg +pub from functions.operators.not_ import not_ +pub from functions.operators.or_ import or_ +pub from functions.operators.sub import sub +pub from functions.casts.cast import cast +pub from functions.casts.try_cast import try_cast +pub from functions.predicates.between import between +pub from functions.predicates.in_ import in_ +pub from functions.predicates.is_nan import is_nan +pub from functions.predicates.is_not_nan import is_not_nan +pub from functions.predicates.is_not_null import is_not_null +pub from functions.predicates.is_null import is_null +pub from functions.conditionals.case_when import case_when +pub from functions.conditionals.coalesce import coalesce +pub from functions.conditionals.nullif import nullif +pub from functions.ordering.asc import asc +pub from functions.ordering.asc_nulls_first import asc_nulls_first +pub from functions.ordering.asc_nulls_last import asc_nulls_last +pub from functions.ordering.desc import desc +pub from functions.ordering.desc_nulls_first import desc_nulls_first +pub from functions.ordering.desc_nulls_last import desc_nulls_last +pub from functions.formatting.display import display pub from function_registry import ( FunctionAliasPolicy, - FunctionArg, FunctionChange, FunctionClass, FunctionDeprecation, @@ -65,22 +95,31 @@ pub from function_registry import ( FunctionRegistry, FunctionRegistryEntry, FunctionSpec, - FunctionSignature, FunctionVersion, SubstraitMapping, SubstraitMappingKind, deterministic_spec, extension_mapping, function_ref_for, - literal_arg, - required_arg, rewrite_mapping, - signature, + sort_field_mapping, + structural_mapping, v0_1, v0_2, v0_3, ) -pub from backends import BackendKind, DataFusion, TableSource, csv_source, parquet_source +pub from backends import ( + BackendKind, + BackendOption, + BackendSelection, + DataFusion, + TableSource, + arrow_source, + csv_source, + datafusion_backend_from_selection, + datafusion_backend_selection, + parquet_source, +) pub from metadata import inql_version pub from session.types import Session, SessionBuilder pub from session.errors import SessionError, SessionErrorKind, format_session_diagnostic, report_session_error @@ -114,6 +153,7 @@ pub from substrait.relations import ( set_rel, set_rel_of_kind, sort_rel, + sort_rel_of_columns, ) pub from substrait.plans import ( empty_plan, @@ -135,7 +175,11 @@ pub from substrait.inspect import ( root_rel, set_operation_name, ) -pub from substrait.extensions import explode_extension_uri, function_extension_uri, registered_substrait_extension_uris +pub from substrait.function_extensions import ( + explode_extension_uri, + function_extension_uri, + registered_substrait_extension_uris, +) pub from substrait.conformance_catalog import ( ConformanceCapabilityTags, ConformancePortability, diff --git a/src/prism/lower.incn b/src/prism/lower.incn index e44edba..9ae303c 100644 --- a/src/prism/lower.incn +++ b/src/prism/lower.incn @@ -3,14 +3,14 @@ from rust::substrait::proto import Plan, Rel from rust::incan_stdlib::errors import raise_value_error from prism.output_columns import rewritten_output_columns -from substrait.extensions import explode_extension_uri +from substrait.function_extensions import explode_extension_uri from substrait.plans import plan_from_root_relation from substrait.relations import ( extension_single_rel, fetch_rel, join_rel, read_named_table_rel, - sort_rel, + sort_rel_of_columns, try_aggregate_rel_of_columns, try_filter_rel_of_columns, try_project_rel_of_columns, @@ -118,7 +118,14 @@ def _try_lower_node(view: PrismOptimizedView, node_id: int) -> Result[Rel, Subst [], node.aggregate_measures, ) - PrismNodeKind.OrderBy => return Ok(sort_rel(_try_lower_node(view, node.input_ids[0])?)) + PrismNodeKind.OrderBy => + return Ok( + sort_rel_of_columns( + _try_lower_node(view, node.input_ids[0])?, + rewritten_output_columns(view, node.input_ids[0]), + node.sort_columns, + ), + ) PrismNodeKind.Limit => return Ok(fetch_rel(_try_lower_node(view, node.input_ids[0])?, 0, node.limit_count)) PrismNodeKind.Explode => return Ok(extension_single_rel(_try_lower_node(view, node.input_ids[0])?, explode_extension_uri())) diff --git a/src/prism/mod.incn b/src/prism/mod.incn index 1754b53..229cbaa 100644 --- a/src/prism/mod.incn +++ b/src/prism/mod.incn @@ -67,6 +67,7 @@ pub class PrismCursor[T with Clone]: filter_predicate=predicate, limit_count=0, group_columns=[], + sort_columns=[], aggregate_measures=[], projection_assignments=[], ) @@ -84,6 +85,7 @@ pub class PrismCursor[T with Clone]: filter_predicate=always_true(), limit_count=0, group_columns=[], + sort_columns=[], aggregate_measures=[], projection_assignments=[], ) @@ -98,6 +100,7 @@ pub class PrismCursor[T with Clone]: filter_predicate=always_true(), limit_count=0, group_columns=[], + sort_columns=[], aggregate_measures=[], projection_assignments=[], ) @@ -114,6 +117,7 @@ pub class PrismCursor[T with Clone]: filter_predicate=always_true(), limit_count=0, group_columns=[], + sort_columns=[], aggregate_measures=[], projection_assignments=[], ) @@ -130,6 +134,7 @@ pub class PrismCursor[T with Clone]: filter_predicate=always_true(), limit_count=0, group_columns=[], + sort_columns=[], aggregate_measures=[], projection_assignments=[with_column_assignment(name, expr)], ) @@ -146,6 +151,7 @@ pub class PrismCursor[T with Clone]: filter_predicate=always_true(), limit_count=0, group_columns=columns, + sort_columns=[], aggregate_measures=[], projection_assignments=[], ) @@ -162,12 +168,13 @@ pub class PrismCursor[T with Clone]: filter_predicate=always_true(), limit_count=0, group_columns=[], + sort_columns=[], aggregate_measures=measures, projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) - def order_by(self) -> Self: + def order_by(self, columns: list[ColumnExpr]) -> Self: """Append one ordering node and return the derived tip.""" next_tip_id = append_node( store_id=self.store_id, @@ -178,6 +185,7 @@ pub class PrismCursor[T with Clone]: filter_predicate=always_true(), limit_count=0, group_columns=[], + sort_columns=columns, aggregate_measures=[], projection_assignments=[], ) @@ -194,6 +202,7 @@ pub class PrismCursor[T with Clone]: filter_predicate=always_true(), limit_count=n, group_columns=[], + sort_columns=[], aggregate_measures=[], projection_assignments=[], ) @@ -210,6 +219,7 @@ pub class PrismCursor[T with Clone]: filter_predicate=always_true(), limit_count=0, group_columns=[], + sort_columns=[], aggregate_measures=[], projection_assignments=[], ) @@ -252,6 +262,7 @@ pub def prism_cursor_named_table[T with Clone](table_name: str) -> PrismCursor[T filter_predicate=always_true(), limit_count=0, group_columns=[], + sort_columns=[], aggregate_measures=[], projection_assignments=[], ) @@ -299,9 +310,9 @@ pub def prism_cursor_apply_agg[T with Clone]( return cursor.agg(measures) -pub def prism_cursor_apply_order_by[T with Clone](cursor: PrismCursor[T]) -> PrismCursor[T]: +pub def prism_cursor_apply_order_by[T with Clone](cursor: PrismCursor[T], columns: list[ColumnExpr]) -> PrismCursor[T]: """Apply dataset-level ordering intent through Prism.""" - return cursor.order_by() + return cursor.order_by(columns) pub def prism_cursor_apply_limit[T with Clone](cursor: PrismCursor[T], n: int) -> PrismCursor[T]: diff --git a/src/prism/rewrite.incn b/src/prism/rewrite.incn index d31460b..6247b0b 100644 --- a/src/prism/rewrite.incn +++ b/src/prism/rewrite.incn @@ -112,7 +112,12 @@ def _derive_rewrite_result(store_id: PrismStoreId, tip_id: int) -> PrismRewriteR applied_rule_names.append(str("collapse_adjacent_aggregate")) continue if _can_collapse_adjacent_order_by(authored_node, remapped_inputs, rewritten_nodes): - authored_to_rewritten_id[authored_node_id] = remapped_inputs[0] + collapsed_id = len(rewritten_nodes) + rewritten_nodes.append( + _build_collapsed_order_by_node(authored_node, remapped_inputs, rewritten_nodes, collapsed_id), + ) + rewritten_origin_ids.append(authored_node_id) + authored_to_rewritten_id[authored_node_id] = collapsed_id applied_rule_names.append(str("collapse_adjacent_order_by")) continue @@ -161,6 +166,7 @@ def _build_collapsed_limit_node( filter_predicate=always_true(), limit_count=_min_int(limit_input.limit_count, node.limit_count), group_columns=[], + sort_columns=[], aggregate_measures=[], projection_assignments=[], ) @@ -196,6 +202,7 @@ def _build_collapsed_project_node( filter_predicate=always_true(), limit_count=0, group_columns=[], + sort_columns=[], aggregate_measures=[], projection_assignments=merged_assignments, ) @@ -231,6 +238,7 @@ def _build_collapsed_aggregate_node( filter_predicate=always_true(), limit_count=0, group_columns=[], + sort_columns=[], aggregate_measures=merged_measures, projection_assignments=[], ) @@ -247,6 +255,29 @@ def _can_collapse_adjacent_order_by( return rewritten_nodes[remapped_inputs[0]].kind == PrismNodeKind.OrderBy +def _build_collapsed_order_by_node( + node: PrismNode, + remapped_inputs: list[int], + rewritten_nodes: list[PrismNode], + rewritten_id: int, +) -> PrismNode: + """Build one collapsed order-by node where the later ordering replaces the adjacent parent ordering.""" + order_input = rewritten_nodes[remapped_inputs[0]] + return PrismNode( + node_id=rewritten_id, + kind=PrismNodeKind.OrderBy, + input_ids=[order_input.input_ids[0]], + named_table=str(""), + join_predicate=false, + filter_predicate=always_true(), + limit_count=0, + group_columns=[], + sort_columns=node.sort_columns, + aggregate_measures=[], + projection_assignments=[], + ) + + def _build_rewritten_node(node: PrismNode, remapped_inputs: list[int], rewritten_id: int) -> PrismNode: """Build one rewritten node by copying authored fields with remapped input ids.""" return PrismNode( @@ -258,6 +289,7 @@ def _build_rewritten_node(node: PrismNode, remapped_inputs: list[int], rewritten filter_predicate=node.filter_predicate.clone(), limit_count=node.limit_count, group_columns=node.group_columns, + sort_columns=node.sort_columns, aggregate_measures=node.aggregate_measures, projection_assignments=node.projection_assignments, ) @@ -302,6 +334,7 @@ def _compact_optimized_view(view: PrismOptimizedView) -> PrismOptimizedView: filter_predicate=old_node.filter_predicate.clone(), limit_count=old_node.limit_count, group_columns=old_node.group_columns, + sort_columns=old_node.sort_columns, aggregate_measures=old_node.aggregate_measures, projection_assignments=old_node.projection_assignments, ), diff --git a/src/prism/store.incn b/src/prism/store.incn index 4b50635..d2e38a5 100644 --- a/src/prism/store.incn +++ b/src/prism/store.incn @@ -2,27 +2,29 @@ from aggregate_builders import AggregateMeasure from projection_builders import ( - AddExpr, BoolLiteralExpr, ColumnExpr, ColumnRefExpr, - EqExpr, FloatLiteralExpr, - GtExpr, IntLiteralExpr, - MultiplyExpr, ProjectionAssignment, + ScalarFunctionApplicationExpr, + ScalarFunctionOption, StringLiteralExpr, ) from prism.types import PrismNode, PrismNodeKind, PrismStoreAdoption, PrismStoreId model PrismStoredNode: + """One append-only node entry tagged with its owning store id.""" + store_id_raw: int node: PrismNode model PrismStoreNodeCount: + """The next node id allocated for one Prism store.""" + store_id_raw: int next_node_id: int @@ -50,6 +52,7 @@ pub def append_node( filter_predicate: ColumnExpr, limit_count: int, group_columns: list[ColumnExpr], + sort_columns: list[ColumnExpr], aggregate_measures: list[AggregateMeasure], projection_assignments: list[ProjectionAssignment], ) -> int: @@ -68,6 +71,7 @@ pub def append_node( filter_predicate=filter_predicate, limit_count=limit_count, group_columns=group_columns, + sort_columns=sort_columns, aggregate_measures=aggregate_measures, projection_assignments=projection_assignments, ) @@ -113,9 +117,11 @@ pub def adopt_cursor_subgraph( id_map[source_node_id] = reused_adopted_id else: adopted_group_columns = [column for column in source_node.group_columns] + adopted_sort_columns = [column for column in source_node.sort_columns] adopted_measures = [measure for measure in source_node.aggregate_measures] adopted_assignments = [assignment for assignment in source_node.projection_assignments] target_group_columns = [column for column in source_node.group_columns] + target_sort_columns = [column for column in source_node.sort_columns] target_measures = [measure for measure in source_node.aggregate_measures] target_assignments = [assignment for assignment in source_node.projection_assignments] adopted_id = append_node( @@ -127,6 +133,7 @@ pub def adopt_cursor_subgraph( filter_predicate=source_node.filter_predicate.clone(), limit_count=source_node.limit_count, group_columns=adopted_group_columns, + sort_columns=adopted_sort_columns, aggregate_measures=adopted_measures, projection_assignments=adopted_assignments, ) @@ -140,6 +147,7 @@ pub def adopt_cursor_subgraph( filter_predicate=source_node.filter_predicate.clone(), limit_count=source_node.limit_count, group_columns=target_group_columns, + sort_columns=target_sort_columns, aggregate_measures=target_measures, projection_assignments=target_assignments, ), @@ -220,6 +228,8 @@ def _nodes_structurally_equal(candidate: PrismNode, source_node: PrismNode, rema return false if not _column_expr_lists_structurally_equal(candidate.group_columns, source_node.group_columns): return false + if not _column_expr_lists_structurally_equal(candidate.sort_columns, source_node.sort_columns): + return false if len(candidate.aggregate_measures) != len(source_node.aggregate_measures): return false for idx in range(len(candidate.aggregate_measures)): @@ -291,44 +301,37 @@ def _column_exprs_structurally_equal(left: ColumnExpr, right: ColumnExpr) -> boo match right: BoolLiteralExpr(right_literal) => return left_literal.value == right_literal.value _ => return false - AddExpr(left_add) => - match right: - AddExpr(right_add) => - if len(left_add.arguments) != 2 or len(right_add.arguments) != 2: - return false - if not _column_exprs_structurally_equal(left_add.arguments[0], right_add.arguments[0]): - return false - return _column_exprs_structurally_equal(left_add.arguments[1], right_add.arguments[1]) - _ => return false - MultiplyExpr(left_multiply) => + ScalarFunctionApplicationExpr(left_application) => match right: - MultiplyExpr(right_multiply) => - if len(left_multiply.arguments) != 2 or len(right_multiply.arguments) != 2: - return false - if not _column_exprs_structurally_equal(left_multiply.arguments[0], right_multiply.arguments[0]): + ScalarFunctionApplicationExpr(right_application) => + if left_application.function_ref != right_application.function_ref: return false - return _column_exprs_structurally_equal(left_multiply.arguments[1], right_multiply.arguments[1]) - _ => return false - EqExpr(left_eq) => - match right: - EqExpr(right_eq) => - if len(left_eq.arguments) != 2 or len(right_eq.arguments) != 2: + if left_application.canonical_name != right_application.canonical_name: return false - if not _column_exprs_structurally_equal(left_eq.arguments[0], right_eq.arguments[0]): + if not _scalar_function_options_structurally_equal( + left_application.options, + right_application.options, + ): return false - return _column_exprs_structurally_equal(left_eq.arguments[1], right_eq.arguments[1]) - _ => return false - GtExpr(left_gt) => - match right: - GtExpr(right_gt) => - if len(left_gt.arguments) != 2 or len(right_gt.arguments) != 2: - return false - if not _column_exprs_structurally_equal(left_gt.arguments[0], right_gt.arguments[0]): - return false - return _column_exprs_structurally_equal(left_gt.arguments[1], right_gt.arguments[1]) + return _column_expr_lists_structurally_equal(left_application.arguments, right_application.arguments) _ => return false +def _scalar_function_options_structurally_equal( + left: list[ScalarFunctionOption], + right: list[ScalarFunctionOption], +) -> bool: + """Return whether two scalar function option lists are structurally equivalent.""" + if len(left) != len(right): + return false + for idx in range(len(left)): + if left[idx].name != right[idx].name: + return false + if left[idx].value != right[idx].value: + return false + return true + + def _latest_store_next_node_id(store_id: PrismStoreId) -> int: """ Return next store-local node id for one store. diff --git a/src/prism/types.incn b/src/prism/types.incn index 138b35f..a5573cf 100644 --- a/src/prism/types.incn +++ b/src/prism/types.incn @@ -39,6 +39,7 @@ pub model PrismNode: pub filter_predicate: ColumnExpr pub limit_count: int pub group_columns: list[ColumnExpr] + pub sort_columns: list[ColumnExpr] pub aggregate_measures: list[AggregateMeasure] pub projection_assignments: list[ProjectionAssignment] diff --git a/src/projection_builders.incn b/src/projection_builders.incn index a5b5000..a37da2f 100644 --- a/src/projection_builders.incn +++ b/src/projection_builders.incn @@ -6,6 +6,7 @@ grouping keys, aggregate inputs, and query-block lowering. Today they provide a scalar intent without requiring parser or typechecker changes in the Incan compiler. """ +from function_registry import function_ref_for from rust::incan_stdlib::errors import raise_value_error @@ -18,10 +19,7 @@ pub enum ColumnExprKind(str): FloatLiteral = "float_literal" StringLiteral = "string_literal" BoolLiteral = "bool_literal" - Add = "add" - Multiply = "multiply" - Eq = "eq" - Gt = "gt" + ScalarFunction = "scalar_function" @derive(Clone) @@ -60,34 +58,24 @@ pub model BoolLiteralExpr: @derive(Clone) -pub model AddExpr: - """Binary addition scalar expression.""" +pub model ScalarFunctionOption: + """One non-expression option attached to a scalar function application.""" - pub arguments: list[ColumnExpr] - - -@derive(Clone) -pub model MultiplyExpr: - """Binary multiplication scalar expression.""" - - pub arguments: list[ColumnExpr] - - -@derive(Clone) -pub model EqExpr: - """Binary equality scalar expression.""" - - pub arguments: list[ColumnExpr] + pub name: str + pub value: str @derive(Clone) -pub model GtExpr: - """Binary greater-than scalar expression.""" +pub model ScalarFunctionApplicationExpr: + """Registry-backed scalar function or operator application.""" + pub function_ref: str + pub canonical_name: str pub arguments: list[ColumnExpr] + pub options: list[ScalarFunctionOption] -pub type ColumnExpr = Union[ColumnRefExpr, IntLiteralExpr, FloatLiteralExpr, StringLiteralExpr, BoolLiteralExpr, AddExpr, MultiplyExpr, EqExpr, GtExpr] +pub type ColumnExpr = Union[ColumnRefExpr, IntLiteralExpr, FloatLiteralExpr, StringLiteralExpr, BoolLiteralExpr, ScalarFunctionApplicationExpr] @derive(Clone) @@ -132,24 +120,23 @@ pub def lit(value: Union[int, float, str, bool]) -> ColumnExpr: str(text) => return str_expr(text) -pub def add(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: - """Build one binary addition expression.""" - return AddExpr(arguments=[left, right]) - - -pub def mul(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: - """Build one binary multiply expression.""" - return MultiplyExpr(arguments=[left, right]) - - -pub def eq(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: - """Build one equality predicate scalar expression.""" - return EqExpr(arguments=[left, right]) +pub def scalar_function_option(name: str, value: str) -> ScalarFunctionOption: + """Build one scalar function application option.""" + return ScalarFunctionOption(name=name, value=value) -pub def gt(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: - """Build one greater-than predicate scalar expression.""" - return GtExpr(arguments=[left, right]) +pub def scalar_function_application( + canonical_name: str, + arguments: list[ColumnExpr], + options: list[ScalarFunctionOption], +) -> ColumnExpr: + """Build one registry-backed scalar function application expression.""" + return ScalarFunctionApplicationExpr( + function_ref=function_ref_for(canonical_name), + canonical_name=canonical_name, + arguments=arguments, + options=options, + ) pub def with_column_assignment(name: str, expr: ColumnExpr) -> ProjectionAssignment: @@ -165,7 +152,7 @@ pub def column_expr_name(expr: ColumnExpr) -> str: pub def require_column_expr_name(expr: ColumnExpr, context: str) -> str: - """Return one direct column name or fail with a clear message for unsupported expression shapes.""" + """Return one direct column name or fail with a clear message for invalid expression shapes.""" match expr: ColumnRefExpr(column) => return column.column_name _ => @@ -188,10 +175,39 @@ pub def column_expr_kind(expr: ColumnExpr) -> ColumnExprKind: FloatLiteralExpr(_) => return ColumnExprKind.FloatLiteral StringLiteralExpr(_) => return ColumnExprKind.StringLiteral BoolLiteralExpr(_) => return ColumnExprKind.BoolLiteral - AddExpr(_) => return ColumnExprKind.Add - MultiplyExpr(_) => return ColumnExprKind.Multiply - EqExpr(_) => return ColumnExprKind.Eq - GtExpr(_) => return ColumnExprKind.Gt + ScalarFunctionApplicationExpr(_) => return ColumnExprKind.ScalarFunction + + +pub def column_expr_function_ref(expr: ColumnExpr) -> str: + """Return the registry function reference for one scalar function expression, otherwise empty.""" + match expr: + ScalarFunctionApplicationExpr(application) => return application.function_ref + _ => return "" + + +pub def column_expr_function_name(expr: ColumnExpr) -> str: + """Return the canonical function name for one scalar function expression, otherwise empty.""" + match expr: + ScalarFunctionApplicationExpr(application) => return application.canonical_name + _ => return "" + + +pub def column_expr_argument_count(expr: ColumnExpr) -> int: + """Return the scalar argument count for one scalar function expression, otherwise zero.""" + match expr: + ScalarFunctionApplicationExpr(application) => return len(application.arguments) + _ => return 0 + + +pub def column_expr_option_value(expr: ColumnExpr, option_name: str) -> str: + """Return one scalar function option value, or empty when absent.""" + match expr: + ScalarFunctionApplicationExpr(application) => + for option in application.options: + if option.name == option_name: + return option.value + return "" + _ => return "" pub def is_bool_literal_expr(expr: ColumnExpr, expected: bool) -> bool: diff --git a/src/session/active.incn b/src/session/active.incn index 0f685d8..78d44da 100644 --- a/src/session/active.incn +++ b/src/session/active.incn @@ -22,7 +22,7 @@ pub class ActiveRegistration: def clone(self) -> Self: """Return one cloned active registration entry.""" - return ActiveRegistration(logical_name=self.logical_name, source=self.source) + return ActiveRegistration(logical_name=self.logical_name, source=self.source.clone()) @derive(Clone) @@ -32,7 +32,10 @@ pub class ActiveSessionState: def clone(self) -> Self: """Return one cloned active-session snapshot.""" - return ActiveSessionState(backend=self.backend, registrations=self.registrations) + return ActiveSessionState( + backend=self.backend.clone(), + registrations=[registration.clone() for registration in self.registrations], + ) pub model ActiveSessionError: diff --git a/src/session/backend_dispatch.incn b/src/session/backend_dispatch.incn new file mode 100644 index 0000000..3bafd25 --- /dev/null +++ b/src/session/backend_dispatch.incn @@ -0,0 +1,62 @@ +"""Backend dispatch boundary for Session execution.""" + +import std.async +from rust::substrait::proto import Plan +from rust::incan_stdlib::async::runtime import block_on +from backends import BackendKind, BackendSelection +from dataset.materialization import DataFrameMaterialization +from session.backend_types import BackendError, BackendErrorKind, BackendRegistration, backend_error +from session.domain import SinkKind +from session.datafusion_backend import ( + datafusion_collect_materialization_async, + datafusion_execute_async, + datafusion_write_csv_async, + datafusion_write_parquet_async, +) + + +pub def backend_execute_plan( + selection: BackendSelection, + registrations: list[BackendRegistration], + plan: Plan, +) -> Result[None, BackendError]: + """Execute one Substrait plan through the selected backend adapter.""" + match selection.kind: + BackendKind.DataFusionEngine => + match block_on(datafusion_execute_async(registrations, plan)): + Ok(result) => return result + Err(err) => return Err(backend_error(BackendErrorKind.RuntimeInitError, err.message())) + + +pub def backend_collect_plan( + selection: BackendSelection, + registrations: list[BackendRegistration], + plan: Plan, +) -> Result[DataFrameMaterialization, BackendError]: + """Execute and materialize one Substrait plan through the selected backend adapter.""" + match selection.kind: + BackendKind.DataFusionEngine => + match block_on(datafusion_collect_materialization_async(registrations, plan)): + Ok(result) => return result + Err(err) => return Err(backend_error(BackendErrorKind.RuntimeInitError, err.message())) + + +pub def backend_write_plan( + selection: BackendSelection, + registrations: list[BackendRegistration], + plan: Plan, + uri: str, + sink_kind: SinkKind, +) -> Result[None, BackendError]: + """Execute one Substrait plan and write it through the selected backend adapter.""" + match selection.kind: + BackendKind.DataFusionEngine => + match sink_kind: + SinkKind.Csv => + match block_on(datafusion_write_csv_async(registrations, plan, uri)): + Ok(result) => return result + Err(err) => return Err(backend_error(BackendErrorKind.RuntimeInitError, err.message())) + SinkKind.Parquet => + match block_on(datafusion_write_parquet_async(registrations, plan, uri)): + Ok(result) => return result + Err(err) => return Err(backend_error(BackendErrorKind.RuntimeInitError, err.message())) diff --git a/src/session/backend_types.incn b/src/session/backend_types.incn new file mode 100644 index 0000000..31fd676 --- /dev/null +++ b/src/session/backend_types.incn @@ -0,0 +1,36 @@ +"""Backend adapter boundary types shared by Session and concrete adapters.""" + +from backends import TableSource + + +@derive(Clone) +pub enum BackendErrorKind(str): + """Stable categories for backend adapter failures.""" + + BackendPlanningError = "backend_planning_error" + BackendExecutionError = "backend_execution_error" + BackendSinkError = "backend_sink_error" + BackendRegistrationError = "backend_registration_error" + RuntimeInitError = "runtime_init_error" + + +@derive(Clone) +pub class BackendRegistration: + """One logical source binding passed from Session state into a backend adapter.""" + + pub logical_name: str + pub source: TableSource + + def clone(self) -> Self: + """Return one cloned backend registration entry.""" + return BackendRegistration(logical_name=self.logical_name, source=self.source.clone()) + + +pub model BackendError: + pub kind: BackendErrorKind + pub message: str + + +pub def backend_error(kind: BackendErrorKind, message: str) -> BackendError: + """Build one backend error with stable kind/message fields.""" + return BackendError(kind=kind, message=message) diff --git a/src/session/datafusion_backend.incn b/src/session/datafusion_backend.incn index 73a9a53..eb13a7b 100644 --- a/src/session/datafusion_backend.incn +++ b/src/session/datafusion_backend.incn @@ -10,40 +10,19 @@ from rust::datafusion::prelude import CsvReadOptions, ParquetReadOptions from rust::datafusion::dataframe import DataFrameWriteOptions from rust::datafusion_substrait::substrait::proto import Plan as ConsumerPlan from rust::datafusion_substrait::logical_plan::consumer import from_substrait_plan -from backends import TableSource +from backends import SourceKind, TableSource from dataset.materialization import DataFrameMaterialization -from session.source_formats import DataFusionSourceRegistration, datafusion_registration_for_source +from session.backend_types import BackendError, BackendErrorKind, BackendRegistration, backend_error from substrait.inspect import root_names @derive(Clone) -pub enum BackendErrorKind(str): - """Stable categories for DataFusion backend failures.""" +enum DataFusionSourceRegistration(str): + """DataFusion adapter registration path for one portable source format.""" - BackendPlanningError = "backend_planning_error" - BackendExecutionError = "backend_execution_error" - BackendSinkError = "backend_sink_error" - BackendRegistrationError = "backend_registration_error" - - -@derive(Clone) -pub class BackendRegistration: - pub logical_name: str - pub source: TableSource - - def clone(self) -> Self: - """Return one cloned backend registration entry.""" - return BackendRegistration(logical_name=self.logical_name, source=self.source) - - -pub model BackendError: - pub kind: BackendErrorKind - pub message: str - - -def _backend_error(kind: BackendErrorKind, message: str) -> BackendError: - """Build one backend error with stable kind/message fields.""" - return BackendError(kind=kind, message=message) + Csv = "csv" + Parquet = "parquet" + Arrow = "arrow" pub async def datafusion_execute_async( @@ -60,11 +39,11 @@ pub async def datafusion_execute_async( Ok(logical_plan) => match await ctx.execute_logical_plan(logical_plan): Ok(df) => match await df.collect(): Ok(_) => return Ok(None) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) pub async def datafusion_collect_materialization_async( @@ -93,13 +72,13 @@ pub async def datafusion_collect_materialization_async( preview_text=rendered, ), ) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) pub async def datafusion_write_csv_async( @@ -117,11 +96,11 @@ pub async def datafusion_write_csv_async( Ok(logical_plan) => match await ctx.execute_logical_plan(logical_plan): Ok(df) => match await df.write_csv(uri, DataFrameWriteOptions.new(), None): Ok(_) => return Ok(None) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendSinkError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendSinkError, err.to_string())) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) pub async def datafusion_write_parquet_async( @@ -139,11 +118,11 @@ pub async def datafusion_write_parquet_async( Ok(logical_plan) => match await ctx.execute_logical_plan(logical_plan): Ok(df) => match await df.write_parquet(uri, DataFrameWriteOptions.new(), None): Ok(_) => return Ok(None) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendSinkError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendSinkError, err.to_string())) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) async def _register_sources( @@ -157,23 +136,31 @@ async def _register_sources( async def _register_one(ctx: SessionContext, logical_name: str, source: TableSource) -> Result[None, BackendError]: - """Register one logical source through the central source-format policy.""" + """Register one logical source through the DataFusion adapter source mapping.""" match datafusion_registration_for_source(source): DataFusionSourceRegistration.Csv => csv_opts = CsvReadOptions.new().has_header(true) match await ctx.register_csv(logical_name, source.uri, csv_opts): Ok(_) => return Ok(None) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendRegistrationError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendRegistrationError, err.to_string())) DataFusionSourceRegistration.Parquet => parquet_opts = ParquetReadOptions.default() match await ctx.register_parquet(logical_name, source.uri, parquet_opts): Ok(_) => return Ok(None) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendRegistrationError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendRegistrationError, err.to_string())) DataFusionSourceRegistration.Arrow => arrow_opts = ArrowReadOptions.default() match await ctx.register_arrow(logical_name, source.uri, arrow_opts): Ok(_) => return Ok(None) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendRegistrationError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendRegistrationError, err.to_string())) + + +def datafusion_registration_for_source(source: TableSource) -> DataFusionSourceRegistration: + """Return the DataFusion adapter registration path for one source descriptor.""" + match source.source_kind: + SourceKind.Csv => return DataFusionSourceRegistration.Csv + SourceKind.Parquet => return DataFusionSourceRegistration.Parquet + SourceKind.Arrow => return DataFusionSourceRegistration.Arrow def _consumer_plan_from_current_plan(plan: Plan) -> Result[ConsumerPlan, BackendError]: @@ -181,11 +168,11 @@ def _consumer_plan_from_current_plan(plan: Plan) -> Result[ConsumerPlan, Backend encoded = plan.encode_to_vec() match ConsumerPlan.decode(encoded.as_slice()): Ok(decoded) => return Ok(decoded) - Err(err) => return Err(_backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) + Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) def _rust_usize_to_int(value: RustUsize) -> Result[int, BackendError]: """Convert one Rust `usize` count into Incan `int` through a checked `i64` boundary.""" match RustI64.try_from(value): Ok(converted) => return Ok(converted.into()) - Err(_) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, "row count does not fit Incan int")) + Err(_) => return Err(backend_error(BackendErrorKind.BackendExecutionError, "row count does not fit Incan int")) diff --git a/src/session/errors.incn b/src/session/errors.incn index bb65c54..e80092b 100644 --- a/src/session/errors.incn +++ b/src/session/errors.incn @@ -18,7 +18,6 @@ pub enum SessionErrorKind(str): BackendRegistrationError = "backend_registration_error" InvalidScalarExpression = "invalid_scalar_expression" UnknownScalarColumn = "unknown_scalar_column" - UnsupportedScalarExpression = "unsupported_scalar_expression" pub model SessionError: diff --git a/src/session/mod.incn b/src/session/mod.incn index 4cc4cc2..5e92c85 100644 --- a/src/session/mod.incn +++ b/src/session/mod.incn @@ -1,5 +1,5 @@ """ -Session execution-context package (RFC 004). +Session execution-context package. Module layout: - `session.types`: public Session API and read/execute/collect/write boundaries @@ -8,7 +8,8 @@ Module layout: - `session.csv_schema`: CSV schema inference helpers used by Session read paths - `session.source_formats`: central source-format policy shared by planning and backend registration - `session.active`: active-session registry used by LazyFrame collect/display conveniences -- `session.datafusion_backend`: internal DataFusion backend implementation +- `session.backend_dispatch`: backend adapter dispatch over Substrait plans +- `session.datafusion_backend`: DataFusion adapter implementation """ pub from session.types import Session, SessionBuilder diff --git a/src/session/source_formats.incn b/src/session/source_formats.incn index 0b6c7ff..63ea0ca 100644 --- a/src/session/source_formats.incn +++ b/src/session/source_formats.incn @@ -14,45 +14,23 @@ pub enum PlannedSchemaPolicy(str): Empty = "empty" -@derive(Clone) -pub enum DataFusionSourceRegistration(str): - """DataFusion registration path for one portable source format.""" - - Csv = "csv" - Parquet = "parquet" - Arrow = "arrow" - - @derive(Clone) pub class SourceFormatPolicy: - """Source-format policy shared by Session schema planning and backend registration.""" + """Portable source-format policy used by Session schema planning.""" pub schema_policy: PlannedSchemaPolicy - pub datafusion_registration: DataFusionSourceRegistration def clone(self) -> Self: """Return one cloned source-format policy.""" - return SourceFormatPolicy(schema_policy=self.schema_policy, datafusion_registration=self.datafusion_registration) + return SourceFormatPolicy(schema_policy=self.schema_policy) pub def source_format_policy_for_kind(kind: SourceKind) -> SourceFormatPolicy: """Return the central policy for one closed source kind.""" match kind: - SourceKind.Csv => - return SourceFormatPolicy( - schema_policy=PlannedSchemaPolicy.InferCsvHeader, - datafusion_registration=DataFusionSourceRegistration.Csv, - ) - SourceKind.Parquet => - return SourceFormatPolicy( - schema_policy=PlannedSchemaPolicy.Empty, - datafusion_registration=DataFusionSourceRegistration.Parquet, - ) - SourceKind.Arrow => - return SourceFormatPolicy( - schema_policy=PlannedSchemaPolicy.Empty, - datafusion_registration=DataFusionSourceRegistration.Arrow, - ) + SourceKind.Csv => return SourceFormatPolicy(schema_policy=PlannedSchemaPolicy.InferCsvHeader) + SourceKind.Parquet => return SourceFormatPolicy(schema_policy=PlannedSchemaPolicy.Empty) + SourceKind.Arrow => return SourceFormatPolicy(schema_policy=PlannedSchemaPolicy.Empty) pub def source_format_policy(source: TableSource) -> SourceFormatPolicy: @@ -66,8 +44,3 @@ pub def planned_schema_columns_for_source(source: TableSource) -> Result[list[Ro match policy.schema_policy: PlannedSchemaPolicy.InferCsvHeader => return infer_csv_schema_columns(source.uri) PlannedSchemaPolicy.Empty => return Ok([]) - - -pub def datafusion_registration_for_source(source: TableSource) -> DataFusionSourceRegistration: - """Return the DataFusion registration path for one source descriptor.""" - return source_format_policy(source).datafusion_registration diff --git a/src/session/types.incn b/src/session/types.incn index a5b91a2..2bb529a 100644 --- a/src/session/types.incn +++ b/src/session/types.incn @@ -1,14 +1,13 @@ """Public Session API and execution-context types.""" -import std.async from rust::substrait::proto import Plan -from rust::incan_stdlib::async::runtime import block_on from backends import ( - BackendKind, BackendSelection, DataFusion, TableSource, backend_kind_name, + datafusion_backend_from_selection, + datafusion_backend_selection, source_kind_name, arrow_source, csv_source, @@ -29,15 +28,8 @@ from session.active import ( get_active_session_state, set_active_session_state, ) -from session.datafusion_backend import ( - BackendError, - BackendErrorKind, - BackendRegistration, - datafusion_collect_materialization_async, - datafusion_execute_async, - datafusion_write_csv_async, - datafusion_write_parquet_async, -) +from session.backend_dispatch import backend_collect_plan, backend_execute_plan, backend_write_plan +from session.backend_types import BackendError, BackendErrorKind, BackendRegistration from substrait.errors import SubstraitLoweringError, SubstraitLoweringErrorKind from substrait.schema_registry import register_named_table_schema from substrait.inspect import root_rel, read_named_table_name @@ -77,9 +69,13 @@ pub class Session: """Return one stable backend name for this Session instance.""" return backend_kind_name(self._backend) + def backend_selection(self) -> BackendSelection: + """Return one cloned backend selection envelope for adapter-level inspection.""" + return self._backend.clone() + def datafusion_backend(self) -> DataFusion: """Return the current DataFusion backend configuration.""" - return self._backend.datafusion.clone() + return datafusion_backend_from_selection(self._backend) def registration_count(self) -> int: """Return the number of registered logical sources in this Session.""" @@ -170,37 +166,25 @@ pub class Session: def execute[T with Clone](self, data: LazyFrame[T]) -> Result[LazyFrame[T], SessionError]: """Validate and execute one lazy plan while preserving deferred carrier shape.""" - _ensure_datafusion_backend(self._backend)? plan = _plan_from_lazy_frame(data)? _validate_named_table_binding(self._registrations, plan)? - match block_on(datafusion_execute_async(_to_backend_registrations(self._registrations), plan)): - Ok(exec_result) => - match exec_result: - Ok(_) => return Ok(data) - Err(err) => return Err(_session_error_from_backend_error(err)) - Err(err) => return Err(SessionError(kind=SessionErrorKind.RuntimeInitError, message=err.message())) + match backend_execute_plan(self._backend, _to_backend_registrations(self._registrations), plan): + Ok(_) => return Ok(data) + Err(err) => return Err(_session_error_from_backend_error(err)) def collect[T with Clone](self, data: LazyFrame[T]) -> Result[DataFrame[T], SessionError]: """Validate and execute one lazy plan, returning a structured materialized DataFrame.""" - _ensure_datafusion_backend(self._backend)? plan = _plan_from_lazy_frame(data)? rel = root_rel(plan.clone()) _validate_named_table_binding(self._registrations, plan)? - match block_on(datafusion_collect_materialization_async(_to_backend_registrations(self._registrations), plan)): - Ok(collect_result) => - match collect_result: - Ok(materialization) => - return Ok( - DataFrame( - _type_witness=_empty_type_witness(), - _materialization=materialization, - _substrait_rel=rel, - ), - ) - Err(err) => return Err(_session_error_from_backend_error(err)) - Err(err) => return Err(SessionError(kind=SessionErrorKind.RuntimeInitError, message=err.message())) + match backend_collect_plan(self._backend, _to_backend_registrations(self._registrations), plan): + Ok(materialization) => + return Ok( + DataFrame(_type_witness=_empty_type_witness(), _materialization=materialization, _substrait_rel=rel), + ) + Err(err) => return Err(_session_error_from_backend_error(err)) def write_csv[T with Clone](self, data: LazyFrame[T], uri: str) -> Result[None, SessionError]: """Execute one lazy plan and write result rows to a CSV sink URI.""" @@ -221,19 +205,12 @@ pub class Session: def _write_plan_to_sink(self, plan: Plan, uri: str, sink_kind: SinkKind) -> Result[None, SessionError]: """Run one validated plan through the selected sink writer and normalize runtime/backend errors.""" sink_uri = _sink_uri_from_text(uri)? - _ensure_datafusion_backend(self._backend)? _validate_named_table_binding(self._registrations, plan)? registrations = _to_backend_registrations(self._registrations) - match sink_kind: - SinkKind.Csv => - match block_on(datafusion_write_csv_async(registrations, plan, sink_uri.0)): - Ok(write_result) => return _write_result_from_backend_result(write_result) - Err(err) => return Err(SessionError(kind=SessionErrorKind.RuntimeInitError, message=err.message())) - SinkKind.Parquet => - match block_on(datafusion_write_parquet_async(registrations, plan, sink_uri.0)): - Ok(write_result) => return _write_result_from_backend_result(write_result) - Err(err) => return Err(SessionError(kind=SessionErrorKind.RuntimeInitError, message=err.message())) + return _write_result_from_backend_result( + backend_write_plan(self._backend, registrations, plan, sink_uri.0, sink_kind), + ) pub class SessionBuilder: @@ -241,9 +218,13 @@ pub class SessionBuilder: pub _backend: BackendSelection + def with_backend(self, backend: BackendSelection) -> Self: + """Select one backend through the portable backend-selection envelope.""" + return SessionBuilder(_backend=backend.clone()) + def with_datafusion(self, backend: DataFusion) -> Self: """Select DataFusion backend options for the Session being built.""" - return SessionBuilder(_backend=BackendSelection(kind=BackendKind.DataFusionEngine, datafusion=backend.clone())) + return self.with_backend(datafusion_backend_selection(backend.clone())) def build(self) -> Session: """Build one Session from the current builder backend selection.""" @@ -285,6 +266,8 @@ def _session_error_from_backend_error(err: BackendError) -> SessionError: return SessionError(kind=SessionErrorKind.BackendSinkError, message=err.message) BackendErrorKind.BackendRegistrationError => return SessionError(kind=SessionErrorKind.BackendRegistrationError, message=err.message) + BackendErrorKind.RuntimeInitError => + return SessionError(kind=SessionErrorKind.RuntimeInitError, message=err.message) def _session_error_from_lowering_error(err: SubstraitLoweringError) -> SessionError: @@ -294,8 +277,6 @@ def _session_error_from_lowering_error(err: SubstraitLoweringError) -> SessionEr return SessionError(kind=SessionErrorKind.InvalidScalarExpression, message=err.message) SubstraitLoweringErrorKind.UnknownScalarColumn => return SessionError(kind=SessionErrorKind.UnknownScalarColumn, message=err.message) - SubstraitLoweringErrorKind.UnsupportedScalarExpression => - return SessionError(kind=SessionErrorKind.UnsupportedScalarExpression, message=err.message) def _duplicate_registration_error(logical_name: str) -> SessionError: @@ -356,18 +337,6 @@ def _sink_uri_from_text(uri: str) -> Result[SinkUri, SessionError]: return SinkUri.from_underlying(uri).map_err(_invalid_sink_from_validation_error) -def _ensure_datafusion_backend(selection: BackendSelection) -> Result[None, SessionError]: - """Guard backend dispatch: this slice only supports DataFusion execution.""" - if selection.kind == BackendKind.DataFusionEngine: - return Ok(None) - return Err( - SessionError( - kind=SessionErrorKind.UnsupportedBackend, - message=f"backend '{backend_kind_name(selection)}' is not implemented", - ), - ) - - def _empty_type_witness[T with Clone]() -> list[T]: """Provide an empty generic witness list so runtime carries the type parameter.""" return [] diff --git a/src/substrait/errors.incn b/src/substrait/errors.incn index 9c63070..76d0b29 100644 --- a/src/substrait/errors.incn +++ b/src/substrait/errors.incn @@ -7,7 +7,6 @@ pub enum SubstraitLoweringErrorKind(str): InvalidScalarExpression = "invalid_scalar_expression" UnknownScalarColumn = "unknown_scalar_column" - UnsupportedScalarExpression = "unsupported_scalar_expression" pub model SubstraitLoweringError: @@ -32,8 +31,3 @@ pub def unknown_scalar_column(name: str) -> SubstraitLoweringError: kind=SubstraitLoweringErrorKind.UnknownScalarColumn, message=f"unknown scalar expression column '{name}'", ) - - -pub def unsupported_scalar_expression(message: str) -> SubstraitLoweringError: - """Create one unsupported scalar-expression validation error.""" - return SubstraitLoweringError(kind=SubstraitLoweringErrorKind.UnsupportedScalarExpression, message=message) diff --git a/src/substrait/expr_lowering.incn b/src/substrait/expr_lowering.incn index f6a76a9..d5dcd72 100644 --- a/src/substrait/expr_lowering.incn +++ b/src/substrait/expr_lowering.incn @@ -8,38 +8,43 @@ conversion helpers needed by proto fields along that path. from rust::incan_stdlib::errors import raise_value_error from rust::std::boxed import Box from rust::std::primitive import i32 as RustI32, u32 as RustU32 +from function_registry import FunctionRegistryEntry, SubstraitMappingKind +from functions.registry import function_registry_entry from projection_builders import ( - AddExpr, BoolLiteralExpr, ColumnExpr, ColumnRefExpr, - EqExpr, FloatLiteralExpr, - GtExpr, IntLiteralExpr, - MultiplyExpr, ProjectionAssignment, + ScalarFunctionApplicationExpr, StringLiteralExpr, ) -from rust::substrait::proto import Expression, FunctionArgument, ProjectRel -from rust::substrait::proto::expression import FieldReference, Literal, ReferenceSegment, RexType, ScalarFunction +from rust::substrait::proto import Expression, FunctionArgument, ProjectRel, Type +from rust::substrait::proto::expression import ( + Cast, + FieldReference, + IfThen, + Literal, + ReferenceSegment, + RexType, + ScalarFunction, + SingularOrList, +) +from rust::substrait::proto::expression::cast import FailureBehavior from rust::substrait::proto::expression::field_reference import ( ReferenceType as FieldReferenceType, RootReference, RootType, ) +from rust::substrait::proto::expression::if_then import IfClause from rust::substrait::proto::expression::literal import LiteralType from rust::substrait::proto::expression::reference_segment import ReferenceType as SegmentReferenceType, StructField from rust::substrait::proto::function_argument import ArgType from rust::substrait::proto::rel_common import EmitKind +from rust::substrait::proto::type import Boolean, Fp64, I64, Kind, Nullability, String as SubstraitString from substrait.errors import SubstraitLoweringError, invalid_scalar_expression, unknown_scalar_column -from substrait.extensions import ( - ADD_FUNCTION_ANCHOR, - EQUAL_FUNCTION_ANCHOR, - GT_FUNCTION_ANCHOR, - MULTIPLY_FUNCTION_ANCHOR, - scalar_function_name_from_anchor, -) +from substrait.extensions import scalar_function_name_from_anchor pub def bool_expr(value: bool) -> Expression: @@ -134,6 +139,51 @@ pub def scalar_function_expr(function_reference: u32, arguments: list[Expression ) +def _nullable_type(kind: str) -> Result[Type, SubstraitLoweringError]: + """Lower one public cast target spelling into a nullable Substrait primitive type.""" + n = Nullability.Nullable + normalized = kind.lower() + if normalized == "bool" or normalized == "boolean": + return Ok(Type(kind=Some(Kind.Bool(Boolean(type_variation_reference=0, nullability=n.into()))))) + if normalized == "int" or normalized == "i64" or normalized == "int64": + return Ok(Type(kind=Some(Kind.I64(I64(type_variation_reference=0, nullability=n.into()))))) + if normalized == "float" or normalized == "f64" or normalized == "float64" or normalized == "double": + return Ok(Type(kind=Some(Kind.Fp64(Fp64(type_variation_reference=0, nullability=n.into()))))) + if normalized == "str" or normalized == "string" or normalized == "utf8": + return Ok(Type(kind=Some(Kind.String(SubstraitString(type_variation_reference=0, nullability=n.into()))))) + return Err(invalid_scalar_expression(f"invalid cast target_type `{kind}`")) + + +def _cast_expr(input: Expression, target_type: Type, failure_behavior: FailureBehavior) -> Expression: + """Build one Substrait Cast Rex expression.""" + return Expression( + rex_type=Some( + RexType.Cast( + Box.new( + Cast(type=Some(target_type), input=Some(Box.new(input)), failure_behavior=failure_behavior.into()), + ), + ), + ), + ) + + +def _singular_or_list_expr(value: Expression, options: list[Expression]) -> Expression: + """Build one Substrait SingularOrList Rex expression for membership predicates.""" + return Expression( + rex_type=Some(RexType.SingularOrList(Box.new(SingularOrList(value=Some(Box.new(value)), options=options)))), + ) + + +def _if_then_expr(ifs: list[IfClause], otherwise: Expression) -> Expression: + """Build one Substrait IfThen Rex expression for searched case expressions.""" + return Expression(rex_type=Some(RexType.IfThen(Box.new(IfThen(ifs=ifs, else=Some(Box.new(otherwise))))))) + + +def _if_clause(condition: Expression, result: Expression) -> IfClause: + """Build one Substrait IfThen clause.""" + return IfClause(if=Some(condition), then=Some(result)) + + @derive(Clone) model ResolvedProjectionBinding: """One output-column binding mapped to a lowered Substrait expression relative to the current project input.""" @@ -201,54 +251,129 @@ def _resolved_projection_expr( FloatLiteralExpr(literal) => return Ok(f64_expr(literal.value)) StringLiteralExpr(literal) => return Ok(string_expr(literal.value)) BoolLiteralExpr(literal) => return Ok(bool_expr(literal.value)) - AddExpr(add_expr) => - if len(add_expr.arguments) != 2: - return Err(invalid_scalar_expression("binary scalar expression requires exactly two arguments")) - return _resolved_binary_scalar_function_expr( - bindings, - add_expr.arguments[0], - add_expr.arguments[1], - ADD_FUNCTION_ANCHOR, - ) - MultiplyExpr(mul_expr) => - if len(mul_expr.arguments) != 2: - return Err(invalid_scalar_expression("binary scalar expression requires exactly two arguments")) - return _resolved_binary_scalar_function_expr( - bindings, - mul_expr.arguments[0], - mul_expr.arguments[1], - MULTIPLY_FUNCTION_ANCHOR, - ) - EqExpr(eq_expr) => - if len(eq_expr.arguments) != 2: - return Err(invalid_scalar_expression("binary scalar expression requires exactly two arguments")) - return _resolved_binary_scalar_function_expr( - bindings, - eq_expr.arguments[0], - eq_expr.arguments[1], - EQUAL_FUNCTION_ANCHOR, + ScalarFunctionApplicationExpr(application) => + return _resolved_scalar_function_application_expr(bindings, application) + + +def _registry_entry_for_application( + application: ScalarFunctionApplicationExpr, +) -> Result[FunctionRegistryEntry, SubstraitLoweringError]: + """Resolve one scalar application through the loaded function registry.""" + match function_registry_entry(application.function_ref): + Some(entry) => return Ok(entry) + None => + return Err(invalid_scalar_expression(f"{application.function_ref} is not loaded in the function registry")) + + +def _resolved_scalar_function_application_expr( + bindings: list[ResolvedProjectionBinding], + application: ScalarFunctionApplicationExpr, +) -> Result[Expression, SubstraitLoweringError]: + """Lower one scalar-function expression through its registry-owned Substrait mapping.""" + entry = _registry_entry_for_application(application)? + match entry.substrait.kind: + SubstraitMappingKind.ExtensionFunction => + lowered_arguments = [_resolved_projection_expr(bindings, argument)? for argument in application.arguments] + return Ok(scalar_function_expr(entry.substrait.anchor, lowered_arguments)) + SubstraitMappingKind.StructuralFunction => + return Err( + invalid_scalar_expression( + f"{entry.function_ref} is only valid in {entry.substrait.function_name} context", + ), ) - GtExpr(gt_expr) => - if len(gt_expr.arguments) != 2: - return Err(invalid_scalar_expression("binary scalar expression requires exactly two arguments")) - return _resolved_binary_scalar_function_expr( - bindings, - gt_expr.arguments[0], - gt_expr.arguments[1], - GT_FUNCTION_ANCHOR, + SubstraitMappingKind.Rewrite => + return Err( + invalid_scalar_expression( + f"{entry.function_ref} is registered as a rewrite helper and must return its canonical expression shape: {entry.substrait.rewrite}", + ), ) + SubstraitMappingKind.CoreFunction => + return _resolved_core_function_application_expr(bindings, application, entry) + + +def _application_option_value(application: ScalarFunctionApplicationExpr, option_name: str) -> str: + """Return one scalar application option value, or empty when it is absent.""" + for option in application.options: + if option.name == option_name: + return option.value + return "" -def _resolved_binary_scalar_function_expr( +def _resolved_cast_application_expr( bindings: list[ResolvedProjectionBinding], - left_expr: ColumnExpr, - right_expr: ColumnExpr, - function_anchor: u32, + application: ScalarFunctionApplicationExpr, + failure_behavior: FailureBehavior, ) -> Result[Expression, SubstraitLoweringError]: - """Lower one binary scalar-function expression against resolved bindings.""" - left = _resolved_projection_expr(bindings, left_expr.clone())? - right = _resolved_projection_expr(bindings, right_expr.clone())? - return Ok(scalar_function_expr(function_anchor, [left, right])) + """Lower one registry-backed cast application into Substrait's built-in Cast Rex shape.""" + if len(application.arguments) != 1: + return Err(invalid_scalar_expression(f"{application.function_ref} requires exactly one scalar argument")) + target_type = _application_option_value(application, "target_type") + if target_type == "": + return Err(invalid_scalar_expression(f"{application.function_ref} requires a target_type option")) + input = _resolved_projection_expr(bindings, application.arguments[0])? + return Ok(_cast_expr(input, _nullable_type(target_type)?, failure_behavior)) + + +def _resolved_in_list_application_expr( + bindings: list[ResolvedProjectionBinding], + application: ScalarFunctionApplicationExpr, +) -> Result[Expression, SubstraitLoweringError]: + """Lower one registry-backed membership predicate into Substrait's SingularOrList Rex shape.""" + if len(application.arguments) < 2: + return Err(invalid_scalar_expression(f"{application.function_ref} requires one value and at least one option")) + value = _resolved_projection_expr(bindings, application.arguments[0])? + options = [_resolved_projection_expr(bindings, application.arguments[idx])? for idx in range( + 1, + len(application.arguments), + )] + return Ok(_singular_or_list_expr(value, options)) + + +def _resolved_case_when_application_expr( + bindings: list[ResolvedProjectionBinding], + application: ScalarFunctionApplicationExpr, +) -> Result[Expression, SubstraitLoweringError]: + """Lower one registry-backed searched case expression into Substrait's IfThen Rex shape.""" + argument_count = len(application.arguments) + if argument_count < 3: + return Err( + invalid_scalar_expression( + f"{application.function_ref} requires at least one condition/result pair and an otherwise expression", + ), + ) + pair_argument_count = argument_count - 1 + mut condition_count = 0 + while condition_count * 2 < pair_argument_count: + condition_count += 1 + if condition_count * 2 != pair_argument_count: + return Err(invalid_scalar_expression(f"{application.function_ref} has mismatched condition/result arguments")) + + clauses = [_if_clause( + _resolved_projection_expr(bindings, application.arguments[idx])?, + _resolved_projection_expr(bindings, application.arguments[idx + condition_count])?, + ) for idx in range(condition_count)] + otherwise = _resolved_projection_expr(bindings, application.arguments[argument_count - 1])? + return Ok(_if_then_expr(clauses, otherwise)) + + +def _resolved_core_function_application_expr( + bindings: list[ResolvedProjectionBinding], + application: ScalarFunctionApplicationExpr, + entry: FunctionRegistryEntry, +) -> Result[Expression, SubstraitLoweringError]: + """Lower registry-backed applications that target built-in Substrait Rex shapes.""" + core_name = entry.substrait.function_name + if core_name == "cast": + return _resolved_cast_application_expr(bindings, application, FailureBehavior.ThrowException) + if core_name == "try_cast": + return _resolved_cast_application_expr(bindings, application, FailureBehavior.ReturnNull) + if core_name == "singular_or_list": + return _resolved_in_list_application_expr(bindings, application) + if core_name == "if_then": + return _resolved_case_when_application_expr(bindings, application) + return Err( + invalid_scalar_expression(f"{entry.function_ref} has unknown Substrait core function mapping `{core_name}`"), + ) def _apply_projection_assignments_to_bindings( diff --git a/src/substrait/extensions.incn b/src/substrait/extensions.incn index a7e01e6..98cca28 100644 --- a/src/substrait/extensions.incn +++ b/src/substrait/extensions.incn @@ -1,8 +1,8 @@ """ Substrait extension and function-anchor bookkeeping for InQL. -This module owns stable function anchors, extension URIs, anchor-name mapping, and recursive extension-declaration -collection over relation and expression trees. +This module derives extension declarations from registry metadata and collects required extension URNs over relation and +expression trees. """ from rust::incan_stdlib::errors import raise_value_error @@ -12,7 +12,10 @@ from rust::substrait::proto::extensions::simple_extension_declaration import Ext from rust::substrait::proto::function_argument import ArgType from rust::substrait::proto::expression import RexType from rust::substrait::proto::rel import RelType +from function_registry import FunctionClass, SubstraitMappingKind +from functions.registry import function_registry_entries from substrait.errors import SubstraitLoweringError, invalid_scalar_expression +from substrait.function_extensions import ExtensionFunctionKind, FunctionExtensionSpec, function_extension_uri from substrait.traversal import relation_children @@ -24,48 +27,8 @@ model ExtensionUrnSpec: urn: str -@derive(Clone) -enum ExtensionFunctionKind(str): - """Function extension categories used by the current Substrait extension registry.""" - - Aggregate = "aggregate" - Scalar = "scalar" - - -@derive(Clone) -model FunctionExtensionSpec: - """One function-extension anchor/name/kind fact.""" - - anchor: u32 - name: str - kind: ExtensionFunctionKind - - -pub const SUM_FUNCTION_ANCHOR: u32 = 0 -pub const COUNT_FUNCTION_ANCHOR: u32 = 1 -pub const EQUAL_FUNCTION_ANCHOR: u32 = 2 -pub const GT_FUNCTION_ANCHOR: u32 = 3 -pub const ADD_FUNCTION_ANCHOR: u32 = 4 -pub const MULTIPLY_FUNCTION_ANCHOR: u32 = 5 const FUNCTION_EXTENSION_URN_ANCHOR: u32 = 0 const RELATION_EXTENSION_URN_ANCHOR: u32 = 1 -const FUNCTION_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/functions.yaml" -const EXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode" - - -pub def function_extension_uri() -> str: - """Return the registered extension URI used for shared function anchors.""" - return FUNCTION_EXTENSION_URI - - -pub def explode_extension_uri() -> str: - """Return the registered extension URI used for EXPLODE-style gap encoding.""" - return EXPLODE_EXTENSION_URI - - -pub def registered_substrait_extension_uris() -> list[str]: - """Return the registered extension URIs used by current package-level Substrait lowering.""" - return [FUNCTION_EXTENSION_URI, EXPLODE_EXTENSION_URI] pub def aggregate_function_name_from_anchor(anchor: u32) -> str: @@ -79,20 +42,29 @@ pub def scalar_function_name_from_anchor(anchor: u32) -> str: def _function_extension_specs() -> list[FunctionExtensionSpec]: - """Return all known function-extension specs in stable declaration order.""" - return [FunctionExtensionSpec(anchor=SUM_FUNCTION_ANCHOR, name="sum", kind=ExtensionFunctionKind.Aggregate), FunctionExtensionSpec( - anchor=COUNT_FUNCTION_ANCHOR, - name="count", - kind=ExtensionFunctionKind.Aggregate, - ), FunctionExtensionSpec(anchor=EQUAL_FUNCTION_ANCHOR, name="equal", kind=ExtensionFunctionKind.Scalar), FunctionExtensionSpec( - anchor=GT_FUNCTION_ANCHOR, - name="gt", - kind=ExtensionFunctionKind.Scalar, - ), FunctionExtensionSpec(anchor=ADD_FUNCTION_ANCHOR, name="add", kind=ExtensionFunctionKind.Scalar), FunctionExtensionSpec( - anchor=MULTIPLY_FUNCTION_ANCHOR, - name="multiply", - kind=ExtensionFunctionKind.Scalar, - )] + """Return Substrait extension specs derived from declaration-side registry metadata.""" + mut specs: list[FunctionExtensionSpec] = [] + for entry in function_registry_entries(): + if entry.substrait.kind != SubstraitMappingKind.ExtensionFunction: + continue + + if entry.function_class == FunctionClass.Aggregate: + specs.append( + FunctionExtensionSpec( + anchor=entry.substrait.anchor, + name=entry.substrait.function_name, + kind=ExtensionFunctionKind.Aggregate, + ), + ) + elif entry.function_class == FunctionClass.Scalar: + specs.append( + FunctionExtensionSpec( + anchor=entry.substrait.anchor, + name=entry.substrait.function_name, + kind=ExtensionFunctionKind.Scalar, + ), + ) + return specs def _function_spec_from_anchor(anchor: u32) -> Result[FunctionExtensionSpec, SubstraitLoweringError]: @@ -158,14 +130,53 @@ def _scalar_extension_decl(anchor: u32) -> SimpleExtensionDeclaration: def _expr_uses_scalar_function_anchor(expr: Expression, expected_anchor: u32) -> bool: """Return whether one expression tree uses the requested scalar-function anchor.""" - if let Some(RexType.ScalarFunction(fun)) = expr.rex_type: - if fun.function_reference == expected_anchor: - return true + match expr.rex_type: + Some(RexType.ScalarFunction(fun)) => + if fun.function_reference == expected_anchor: + return true + + for argument in fun.arguments: + if let Some(ArgType.Value(value)) = argument.arg_type: + if _expr_uses_scalar_function_anchor(value, expected_anchor): + return true + Some(RexType.Cast(cast)) => + if let Some(input) = cast.input: + if _expr_uses_scalar_function_anchor(input.as_ref().clone(), expected_anchor): + return true + Some(RexType.SingularOrList(singular_or_list)) => + list_expr = singular_or_list.as_ref().clone() + if let Some(value) = list_expr.value: + if _expr_uses_scalar_function_anchor(value.as_ref().clone(), expected_anchor): + return true + for option in list_expr.options: + if _expr_uses_scalar_function_anchor(option, expected_anchor): + return true + _ => pass + return false + - for argument in fun.arguments: - if let Some(ArgType.Value(value)) = argument.arg_type: - if _expr_uses_scalar_function_anchor(value, expected_anchor): +def _expr_uses_if_then(expr: Expression) -> bool: + """Return whether one expression tree contains a Substrait IfThen Rex shape.""" + match expr.rex_type: + Some(RexType.IfThen(_)) => return true + Some(RexType.ScalarFunction(fun)) => + for argument in fun.arguments: + if let Some(ArgType.Value(value)) = argument.arg_type: + if _expr_uses_if_then(value): + return true + Some(RexType.Cast(cast)) => + if let Some(input) = cast.input: + if _expr_uses_if_then(input.as_ref().clone()): + return true + Some(RexType.SingularOrList(singular_or_list)) => + list_expr = singular_or_list.as_ref().clone() + if let Some(value) = list_expr.value: + if _expr_uses_if_then(value.as_ref().clone()): return true + for option in list_expr.options: + if _expr_uses_if_then(option): + return true + _ => pass return false @@ -202,27 +213,45 @@ def _rel_uses_scalar_function_anchor(rel: Rel, expected_anchor: u32) -> bool: return false +def _rel_uses_if_then(rel: Rel) -> bool: + """Return whether one relation subtree contains a Substrait IfThen Rex expression.""" + match rel.rel_type.clone(): + Some(RelType.Filter(filter_rel)) => + if let Some(condition) = filter_rel.condition: + if _expr_uses_if_then(condition.as_ref().clone()): + return true + Some(RelType.Project(project_rel)) => + for expr in project_rel.expressions: + if _expr_uses_if_then(expr): + return true + _ => pass + + for child in relation_children(rel): + if _rel_uses_if_then(child): + return true + return false + + def _aggregate_extension_anchors_for_rel(rel: Rel) -> list[u32]: """Collect aggregate-function anchors used by one relation subtree in stable declaration order.""" mut anchors: list[u32] = [] - if _rel_uses_aggregate_function_anchor(rel.clone(), SUM_FUNCTION_ANCHOR): - anchors.append(SUM_FUNCTION_ANCHOR) - if _rel_uses_aggregate_function_anchor(rel, COUNT_FUNCTION_ANCHOR): - anchors.append(COUNT_FUNCTION_ANCHOR) + for spec in _function_extension_specs(): + if spec.kind == ExtensionFunctionKind.Aggregate and _rel_uses_aggregate_function_anchor(rel.clone(), spec.anchor): + anchors.append(spec.anchor) return anchors def _scalar_extension_anchors_for_rel(rel: Rel) -> list[u32]: """Collect scalar-function anchors used by one relation subtree in stable declaration order.""" + if _rel_uses_if_then(rel.clone()): + # `IfThen` branch fields are generated from keyword-named proto fields. Declare scalar extensions + # conservatively for now so nested case predicates execute without relying on field-level traversal here. + return [spec.anchor for spec in _function_extension_specs() if spec.kind == ExtensionFunctionKind.Scalar] + mut anchors: list[u32] = [] - if _rel_uses_scalar_function_anchor(rel.clone(), EQUAL_FUNCTION_ANCHOR): - anchors.append(EQUAL_FUNCTION_ANCHOR) - if _rel_uses_scalar_function_anchor(rel.clone(), GT_FUNCTION_ANCHOR): - anchors.append(GT_FUNCTION_ANCHOR) - if _rel_uses_scalar_function_anchor(rel.clone(), ADD_FUNCTION_ANCHOR): - anchors.append(ADD_FUNCTION_ANCHOR) - if _rel_uses_scalar_function_anchor(rel, MULTIPLY_FUNCTION_ANCHOR): - anchors.append(MULTIPLY_FUNCTION_ANCHOR) + for spec in _function_extension_specs(): + if spec.kind == ExtensionFunctionKind.Scalar and _rel_uses_scalar_function_anchor(rel.clone(), spec.anchor): + anchors.append(spec.anchor) return anchors @@ -297,7 +326,7 @@ pub def extension_urns_for_rel(rel: Rel) -> list[SimpleExtensionUrn]: """Collect extension URNs required by one relation subtree.""" mut specs: list[ExtensionUrnSpec] = [] if _function_extension_urn_is_required(rel.clone()): - specs.append(ExtensionUrnSpec(anchor=FUNCTION_EXTENSION_URN_ANCHOR, urn=FUNCTION_EXTENSION_URI)) + specs.append(ExtensionUrnSpec(anchor=FUNCTION_EXTENSION_URN_ANCHOR, urn=function_extension_uri())) for urn in _collect_extension_urn_strings(rel): specs.append(ExtensionUrnSpec(anchor=RELATION_EXTENSION_URN_ANCHOR, urn=urn)) return [SimpleExtensionUrn(extension_urn_anchor=spec.anchor, urn=spec.urn) for spec in specs] diff --git a/src/substrait/function_extensions.incn b/src/substrait/function_extensions.incn new file mode 100644 index 0000000..409cf9d --- /dev/null +++ b/src/substrait/function_extensions.incn @@ -0,0 +1,66 @@ +""" +Stable Substrait extension anchors and URIs for registry-backed InQL functions. + +Function helpers import anchors from this module when declaring their registry metadata. Runtime extension declarations +derive names and kinds from those registry entries instead of keeping a second hand-written function catalog. +""" + + +@derive(Clone) +pub enum ExtensionFunctionKind(str): + """Function extension categories used by the current Substrait extension registry.""" + + Aggregate = "aggregate" + Scalar = "scalar" + + +@derive(Clone) +pub model FunctionExtensionSpec: + """One function-extension anchor/name/kind fact derived from function registry metadata.""" + + pub anchor: u32 + pub name: str + pub kind: ExtensionFunctionKind + + +pub const SUM_FUNCTION_ANCHOR: u32 = 0 +pub const COUNT_FUNCTION_ANCHOR: u32 = 1 +pub const EQUAL_FUNCTION_ANCHOR: u32 = 2 +pub const GT_FUNCTION_ANCHOR: u32 = 3 +pub const ADD_FUNCTION_ANCHOR: u32 = 4 +pub const MULTIPLY_FUNCTION_ANCHOR: u32 = 5 +pub const NOT_EQUAL_FUNCTION_ANCHOR: u32 = 6 +pub const LT_FUNCTION_ANCHOR: u32 = 7 +pub const LTE_FUNCTION_ANCHOR: u32 = 8 +pub const GTE_FUNCTION_ANCHOR: u32 = 9 +pub const IS_NOT_DISTINCT_FROM_FUNCTION_ANCHOR: u32 = 10 +pub const AND_FUNCTION_ANCHOR: u32 = 11 +pub const OR_FUNCTION_ANCHOR: u32 = 12 +pub const NOT_FUNCTION_ANCHOR: u32 = 13 +pub const IS_NULL_FUNCTION_ANCHOR: u32 = 14 +pub const IS_NOT_NULL_FUNCTION_ANCHOR: u32 = 15 +pub const IS_NAN_FUNCTION_ANCHOR: u32 = 16 +pub const SUBTRACT_FUNCTION_ANCHOR: u32 = 17 +pub const DIVIDE_FUNCTION_ANCHOR: u32 = 18 +pub const MODULUS_FUNCTION_ANCHOR: u32 = 19 +pub const NEGATE_FUNCTION_ANCHOR: u32 = 20 +pub const COALESCE_FUNCTION_ANCHOR: u32 = 21 +pub const NULLIF_FUNCTION_ANCHOR: u32 = 22 +pub const BETWEEN_FUNCTION_ANCHOR: u32 = 23 +const FUNCTION_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/functions.yaml" +const EXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode" + + +pub def function_extension_uri() -> str: + """Return the registered extension URI used for shared function anchors.""" + return FUNCTION_EXTENSION_URI + + +pub def explode_extension_uri() -> str: + """Return the registered extension URI used for EXPLODE-style gap encoding.""" + return EXPLODE_EXTENSION_URI + + +pub def registered_substrait_extension_uris() -> list[str]: + """Return the registered extension URIs used by current package-level Substrait lowering.""" + return [FUNCTION_EXTENSION_URI, EXPLODE_EXTENSION_URI] diff --git a/src/substrait/inspect.incn b/src/substrait/inspect.incn index d14b53f..0ffa860 100644 --- a/src/substrait/inspect.incn +++ b/src/substrait/inspect.incn @@ -15,10 +15,11 @@ from rust::substrait::proto::read_rel import NamedTable as ReadNamedTable, ReadT from rust::substrait::proto::rel import RelType from rust::substrait::proto::rel_common import Direct, EmitKind from rust::substrait::proto::set_rel import SetOp +from rust::substrait::proto::sort_field import SortDirection, SortKind from aggregate_builders import AggregateKind, AggregateMeasure from projection_builders import scalar_expr_output_name from substrait.expr_lowering import field_index_from_expression, project_output_columns_with_emit -from substrait.extensions import COUNT_FUNCTION_ANCHOR +from substrait.function_extensions import COUNT_FUNCTION_ANCHOR from substrait.schema_registry import named_table_columns, unknown_named_struct from substrait.traversal import relation_children @@ -293,6 +294,54 @@ pub def set_operation_name(rel: Rel) -> str: _ => return "NotSet" +def _sort_direction_name(value: RustI32) -> str: + """Return a stable Substrait sort-direction name for test and diagnostics.""" + asc_nulls_first: RustI32 = SortDirection.AscNullsFirst.into() + asc_nulls_last: RustI32 = SortDirection.AscNullsLast.into() + desc_nulls_first: RustI32 = SortDirection.DescNullsFirst.into() + desc_nulls_last: RustI32 = SortDirection.DescNullsLast.into() + if value == asc_nulls_first: + return "AscNullsFirst" + if value == asc_nulls_last: + return "AscNullsLast" + if value == desc_nulls_first: + return "DescNullsFirst" + if value == desc_nulls_last: + return "DescNullsLast" + return "Unknown" + + +pub def sort_field_count(rel: Rel) -> int: + """Return the number of sort fields on one SortRel, or zero for non-sort relations.""" + match rel.rel_type: + Some(RelType.Sort(sort_rel)) => return len(sort_rel.sorts) + _ => return 0 + + +pub def sort_field_direction_name(rel: Rel, index: int) -> str: + """Return the direction/null-placement name for one SortRel field.""" + match rel.rel_type: + Some(RelType.Sort(sort_rel)) => + if index < 0 or index >= len(sort_rel.sorts): + return "Unknown" + match sort_rel.sorts[index].sort_kind: + Some(SortKind.Direction(direction)) => return _sort_direction_name(direction) + _ => return "Unknown" + _ => return "Unknown" + + +pub def sort_field_expr_index(rel: Rel, index: int) -> int: + """Return the direct field-reference index used by one SortRel field, or `-1` when not a field reference.""" + match rel.rel_type: + Some(RelType.Sort(sort_rel)) => + if index < 0 or index >= len(sort_rel.sorts): + return -1 + match sort_rel.sorts[index].expr: + Some(expr) => return field_index_from_expression(expr) + None => return -1 + _ => return -1 + + pub def reference_subtree_ordinal(rel: Rel) -> RustI32: """Return the subtree ordinal for one `ReferenceRel`, or `-1` when the input is not a reference.""" match rel.rel_type: diff --git a/src/substrait/mod.incn b/src/substrait/mod.incn index e11ac2f..523a858 100644 --- a/src/substrait/mod.incn +++ b/src/substrait/mod.incn @@ -38,6 +38,7 @@ pub from substrait.relations import ( set_rel, set_rel_of_kind, sort_rel, + sort_rel_of_columns, ) pub from substrait.plans import ( empty_plan, @@ -64,6 +65,13 @@ pub from substrait.inspect import ( root_names, root_rel, set_operation_name, + sort_field_count, + sort_field_direction_name, + sort_field_expr_index, source_named_table_name, ) -pub from substrait.extensions import explode_extension_uri, function_extension_uri, registered_substrait_extension_uris +pub from substrait.function_extensions import ( + explode_extension_uri, + function_extension_uri, + registered_substrait_extension_uris, +) diff --git a/src/substrait/relations.incn b/src/substrait/relations.incn index 4e2d374..948b3b5 100644 --- a/src/substrait/relations.incn +++ b/src/substrait/relations.incn @@ -44,7 +44,9 @@ from rust::substrait::proto::rel_common import Direct, Emit, EmitKind from rust::substrait::proto::set_rel import SetOp from rust::substrait::proto::sort_field import SortDirection, SortKind from aggregate_builders import AggregateKind, AggregateMeasure -from projection_builders import ColumnExpr, ProjectionAssignment +from function_registry import SubstraitMappingKind +from functions.registry import function_registry_entry +from projection_builders import ColumnExpr, ProjectionAssignment, ScalarFunctionApplicationExpr, col from substrait.expr_lowering import ( bool_expr, filter_predicate_expr, @@ -55,7 +57,7 @@ from substrait.expr_lowering import ( string_expr, ) from substrait.errors import SubstraitLoweringError, invalid_scalar_expression -from substrait.extensions import COUNT_FUNCTION_ANCHOR, SUM_FUNCTION_ANCHOR +from substrait.function_extensions import COUNT_FUNCTION_ANCHOR, SUM_FUNCTION_ANCHOR from substrait.inspect import relation_output_columns from substrait.schema_registry import named_table_base_schema, unknown_named_struct @@ -234,6 +236,60 @@ def _lowered_rel_or_raise(result: Result[Rel, SubstraitLoweringError]) -> Rel: return raise_value_error(message) +def _sort_direction_for_mapping_detail(direction: str) -> Result[SortDirection, SubstraitLoweringError]: + """Return the Substrait direction/null-placement encoded by one sort-field registry mapping.""" + if direction == "asc_nulls_first": + return Ok(SortDirection.AscNullsFirst) + if direction == "asc_nulls_last": + return Ok(SortDirection.AscNullsLast) + if direction == "desc_nulls_first": + return Ok(SortDirection.DescNullsFirst) + if direction == "desc_nulls_last": + return Ok(SortDirection.DescNullsLast) + return Err(invalid_scalar_expression(f"unknown sort-field direction `{direction}`")) + + +def _sort_field_direction_detail(application: ScalarFunctionApplicationExpr) -> str: + """Return the sort-field registry mapping detail for one ordering helper application, or empty.""" + match function_registry_entry(application.function_ref): + Some(entry) => + if entry.substrait.kind == SubstraitMappingKind.StructuralFunction and entry.substrait.function_name == "sort_field": + return entry.substrait.detail + None => pass + return "" + + +def _sort_field(input_columns: list[str], key: ColumnExpr) -> Result[SortField, SubstraitLoweringError]: + """Lower one InQL sort-key expression into a Substrait SortField.""" + match key: + ScalarFunctionApplicationExpr(application) => + direction_detail = _sort_field_direction_detail(application) + if direction_detail != "": + if len(application.arguments) != 1: + return Err( + invalid_scalar_expression(f"{application.function_ref} requires exactly one sort-key argument"), + ) + return Ok( + SortField( + expr=Some(scalar_expr(input_columns, application.arguments[0])?), + sort_kind=Some(SortKind.Direction(_sort_direction_for_mapping_detail(direction_detail)?.into())), + ), + ) + return Ok( + SortField( + expr=Some(scalar_expr(input_columns, key)?), + sort_kind=Some(SortKind.Direction(SortDirection.AscNullsFirst.into())), + ), + ) + _ => + return Ok( + SortField( + expr=Some(scalar_expr(input_columns, key)?), + sort_kind=Some(SortKind.Direction(SortDirection.AscNullsFirst.into())), + ), + ) + + pub def read_named_table_rel(table_name: str) -> Rel: """Construct a logical `ReadRel(NamedTable)` root for a registered table name.""" read = ReadRel( @@ -451,13 +507,38 @@ pub def try_aggregate_rel_of_columns( pub def sort_rel(input: Rel) -> Rel: - """Wrap a child relation in `SortRel` with one placeholder ascending sort field.""" - sort_field = SortField( - expr=Some(string_expr("__sort_key__")), - sort_kind=Some(SortKind.Direction(SortDirection.AscNullsFirst.into())), - ) - return _rel_sort( - SortRel(common=Some(_direct_common()), input=Some(Box.new(input)), sorts=[sort_field], advanced_extension=None), + """Wrap a child relation in `SortRel` using the first known output column as the default sort key.""" + input_columns = relation_output_columns(input.clone()) + if len(input_columns) == 0: + return _lowered_rel_or_raise( + Err(invalid_scalar_expression("sort_rel requires at least one input column or explicit sort key")), + ) + return sort_rel_of_columns(input, input_columns, [col(input_columns[0])]) + + +pub def sort_rel_of_columns(input: Rel, input_columns: list[str], sort_keys: list[ColumnExpr]) -> Rel: + """Wrap a child relation in `SortRel` using explicit input columns and sort-key expressions.""" + return _lowered_rel_or_raise(try_sort_rel_of_columns(input, input_columns, sort_keys)) + + +pub def try_sort_rel_of_columns( + input: Rel, + input_columns: list[str], + sort_keys: list[ColumnExpr], +) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in `SortRel` with real sort fields.""" + if len(sort_keys) == 0: + return Err(invalid_scalar_expression("order_by requires at least one sort key")) + sort_fields = [_sort_field(input_columns, key)? for key in sort_keys] + return Ok( + _rel_sort( + SortRel( + common=Some(_direct_common()), + input=Some(Box.new(input)), + sorts=sort_fields, + advanced_extension=None, + ), + ), ) diff --git a/tests/test_core_scalar_functions.incn b/tests/test_core_scalar_functions.incn new file mode 100644 index 0000000..bd44f83 --- /dev/null +++ b/tests/test_core_scalar_functions.incn @@ -0,0 +1,145 @@ +"""Test: RFC 015 core scalar helper surface and shared expression model.""" + +from functions import ( + add, + and_, + asc, + asc_nulls_first, + asc_nulls_last, + between, + bool_lit, + case_when, + cast, + col, + coalesce, + desc, + desc_nulls_first, + desc_nulls_last, + div, + eq, + equal_null, + gt, + gte, + in_, + int_expr, + int_lit, + is_nan, + is_not_nan, + is_not_null, + is_null, + lit, + lt, + lte, + modulo, + mul, + ne, + neg, + not_, + nullif, + or_, + str_lit, + sub, + try_cast, +) +from function_registry import function_ref_for +from projection_builders import ( + ColumnExpr, + ColumnExprKind, + column_expr_argument_count, + column_expr_function_name, + column_expr_function_ref, + column_expr_kind, + column_expr_option_value, +) + + +def _assert_scalar_application(expr: ColumnExpr, expected_name: str) -> None: + """Assert one helper result is represented by the shared scalar application node.""" + assert column_expr_kind(expr) == ColumnExprKind.ScalarFunction, f"{expected_name} should use the scalar function kind" + assert column_expr_function_name(expr) == expected_name, f"{expected_name} should preserve its canonical name" + expected_ref = function_ref_for(expected_name) + assert column_expr_function_ref(expr) == expected_ref, "scalar application should preserve its registry function ref" + + +def test_core_scalar_functions__operator_helpers_share_one_application_node() -> None: + """Assert function and operator helpers do not create one expression model per function.""" + # -- Arrange -- + amount = col("amount") + status = col("status") + + # -- Act / Assert -- + _assert_scalar_application(cast(amount, "string"), "cast") + _assert_scalar_application(try_cast(status, "int64"), "try_cast") + _assert_scalar_application(add(amount, lit(1)), "add") + _assert_scalar_application(sub(amount, int_lit(1)), "sub") + _assert_scalar_application(mul(amount, int_lit(2)), "mul") + _assert_scalar_application(div(amount, int_lit(2)), "div") + _assert_scalar_application(modulo(amount, int_lit(2)), "mod") + _assert_scalar_application(neg(amount), "neg") + _assert_scalar_application(eq(status, str_lit("paid")), "eq") + _assert_scalar_application(ne(status, str_lit("void")), "ne") + _assert_scalar_application(lt(amount, int_lit(10)), "lt") + _assert_scalar_application(lte(amount, int_lit(10)), "lte") + _assert_scalar_application(gt(amount, int_lit(10)), "gt") + _assert_scalar_application(gte(amount, int_lit(10)), "gte") + _assert_scalar_application(equal_null(status, str_lit("paid")), "equal_null") + _assert_scalar_application(and_(gt(amount, int_lit(10)), eq(status, str_lit("paid"))), "and_") + _assert_scalar_application(or_(eq(status, str_lit("paid")), eq(status, str_lit("open"))), "or_") + _assert_scalar_application(not_(eq(status, str_lit("void"))), "not_") + _assert_scalar_application(is_null(status), "is_null") + _assert_scalar_application(is_not_null(status), "is_not_null") + _assert_scalar_application(is_nan(amount), "is_nan") + _assert_scalar_application(coalesce([status, str_lit("unknown")]), "coalesce") + _assert_scalar_application(nullif(status, str_lit("")), "nullif") + _assert_scalar_application(case_when([gt(amount, int_lit(10))], [str_lit("large")], str_lit("small")), "case_when") + _assert_scalar_application(in_(status, [str_lit("paid"), str_lit("open")]), "in_") + _assert_scalar_application(between(amount, int_lit(1), int_lit(10)), "between") + _assert_scalar_application(asc(amount), "asc") + _assert_scalar_application(desc(amount), "desc") + _assert_scalar_application(asc_nulls_first(amount), "asc_nulls_first") + _assert_scalar_application(asc_nulls_last(amount), "asc_nulls_last") + _assert_scalar_application(desc_nulls_first(amount), "desc_nulls_first") + _assert_scalar_application(desc_nulls_last(amount), "desc_nulls_last") + + +def test_core_scalar_functions__is_not_nan_rewrites_to_canonical_expression() -> None: + """Assert is_not_nan returns the expression shape it documents in registry metadata.""" + # -- Arrange / Act -- + expr = is_not_nan(col("amount")) + + # -- Assert -- + _assert_scalar_application(expr, "not_") + assert column_expr_argument_count(expr) == 1, "is_not_nan should wrap one is_nan predicate in not_" + + +def test_core_scalar_functions__cast_records_type_as_registry_option() -> None: + """Assert non-expression scalar helper metadata is attached as application options.""" + # -- Arrange -- + amount = col("amount") + + # -- Act -- + cast_expr = cast(amount, "decimal(10,2)") + try_cast_expr = try_cast(amount, "float64") + + # -- Assert -- + _assert_scalar_application(cast_expr, "cast") + _assert_scalar_application(try_cast_expr, "try_cast") + assert column_expr_argument_count(cast_expr) == 1, "cast should keep only scalar inputs in the argument list" + assert column_expr_argument_count(try_cast_expr) == 1, "try_cast should keep only scalar inputs in the argument list" + assert column_expr_option_value(cast_expr, "target_type") == "decimal(10,2)", "cast should preserve the target type option" + assert column_expr_option_value(try_cast_expr, "target_type") == "float64", "try_cast should preserve the target type option" + + +def test_core_scalar_functions__literal_helpers_stay_structural() -> None: + """Assert literal compatibility helpers route through the canonical literal expression representation.""" + # -- Arrange / Act -- + canonical = lit(7) + typed_expr = int_expr(7) + filter_style = int_lit(7) + boolean_literal = bool_lit(true) + + # -- Assert -- + assert column_expr_kind(canonical) == ColumnExprKind.IntLiteral, "lit should still build structural literals" + assert column_expr_kind(typed_expr) == ColumnExprKind.IntLiteral, "int_expr should use the same literal node" + assert column_expr_kind(filter_style) == ColumnExprKind.IntLiteral, "int_lit should use the same literal node" + assert column_expr_kind(boolean_literal) == ColumnExprKind.BoolLiteral, "bool_lit should use the same literal node" diff --git a/tests/test_dataset.incn b/tests/test_dataset.incn index bbde8b5..25c11ad 100644 --- a/tests/test_dataset.incn +++ b/tests/test_dataset.incn @@ -21,7 +21,7 @@ from functions import ( sum, ) from projection_builders import ColumnExprKind, column_expr_kind, column_expr_name -from substrait.extensions import explode_extension_uri +from substrait.function_extensions import explode_extension_uri from substrait.inspect import plan_contains_relation_kind, plan_has_extension_urn, relation_kind_name, root_rel from substrait.plans import plan_encoded_len, plan_from_named_table, plan_from_root_relation from substrait.relations import read_named_table_rel @@ -406,7 +406,7 @@ def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None # -- Act -- aggregated = grouped.agg([count()]) - ordered: LazyFrame[Order] = lazy_frame_named_table("orders").order_by() + ordered: LazyFrame[Order] = lazy_frame_named_table("orders").order_by([col("id")]) limited: LazyFrame[Order] = lazy_frame_named_table("orders").limit(10) exploded: LazyFrame[Order] = lazy_frame_named_table("orders").explode() @@ -421,12 +421,14 @@ def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None def test_lazy_frame__deeper_independent_roots_still_lower_with_stable_shapes() -> None: # -- Arrange -- + _register_order_schema("orders") + _register_order_schema("orders_archive") left: LazyFrame[Order] = lazy_frame_named_table("orders").filter(always_false()) right_base: LazyFrame[Order] = lazy_frame_named_table("orders_archive") # -- Act -- - right_joined: LazyFrame[Order] = right_base.filter(always_false()).order_by().join( - right_base.filter(always_false()).order_by(), + right_joined: LazyFrame[Order] = right_base.filter(always_false()).order_by([col("id")]).join( + right_base.filter(always_false()).order_by([col("id")]), true, ) joined: LazyFrame[Order] = left.join(right_joined, true) diff --git a/tests/test_function_registry.incn b/tests/test_function_registry.incn index 593deb3..fe8fd07 100644 --- a/tests/test_function_registry.incn +++ b/tests/test_function_registry.incn @@ -6,9 +6,25 @@ from functions import ( add, always_false, always_true, + and_, + asc, + asc_nulls_first, + asc_nulls_last, + between, + bool_expr, + bool_lit, + case_when, + cast, col, + coalesce, count, + desc, + desc_nulls_first, + desc_nulls_last, + div, eq, + equal_null, + float_expr, function_registry_canonical_names, function_registry_entries, function_registry_entry, @@ -16,11 +32,30 @@ from functions import ( function_registry_entry_count, function_registry_function_refs, gt, + gte, + in_, + int_expr, int_lit, + is_nan, + is_not_nan, + is_not_null, + is_null, lit, + lt, + lte, + modulo, + mul, + ne, + neg, + not_, + nullif, + or_, registered_substrait_mapped_function_refs, + str_expr, str_lit, + sub, sum, + try_cast, ) from function_registry import ( FunctionAliasPolicy, @@ -32,11 +67,30 @@ from function_registry import ( v0_1, ) from projection_builders import ColumnExprKind, column_expr_kind -from substrait.extensions import ( +from substrait.function_extensions import ( ADD_FUNCTION_ANCHOR, + AND_FUNCTION_ANCHOR, + BETWEEN_FUNCTION_ANCHOR, + COALESCE_FUNCTION_ANCHOR, COUNT_FUNCTION_ANCHOR, + DIVIDE_FUNCTION_ANCHOR, EQUAL_FUNCTION_ANCHOR, GT_FUNCTION_ANCHOR, + GTE_FUNCTION_ANCHOR, + IS_NAN_FUNCTION_ANCHOR, + IS_NOT_DISTINCT_FROM_FUNCTION_ANCHOR, + IS_NOT_NULL_FUNCTION_ANCHOR, + IS_NULL_FUNCTION_ANCHOR, + LT_FUNCTION_ANCHOR, + LTE_FUNCTION_ANCHOR, + MODULUS_FUNCTION_ANCHOR, + MULTIPLY_FUNCTION_ANCHOR, + NEGATE_FUNCTION_ANCHOR, + NOT_EQUAL_FUNCTION_ANCHOR, + NOT_FUNCTION_ANCHOR, + NULLIF_FUNCTION_ANCHOR, + OR_FUNCTION_ANCHOR, + SUBTRACT_FUNCTION_ANCHOR, SUM_FUNCTION_ANCHOR, function_extension_uri, ) @@ -52,76 +106,157 @@ def _contains_text(items: list[str], expected: str) -> bool: def _entry_or_fail(function_ref: str) -> FunctionRegistryEntry: """Return one registry entry or fail the test with the missing function reference.""" - match function_registry_entry(function_ref): + lookup_ref = f"{function_ref}" + match function_registry_entry(lookup_ref): Some(entry) => return entry - None => return fail_t(f"missing function registry entry: {function_ref}") + None => return fail_t("missing function registry entry") def _entry_by_name_or_fail(canonical_name: str) -> FunctionRegistryEntry: """Return one registry entry by canonical name or fail the test.""" - match function_registry_entry_by_name(canonical_name): + lookup_name = f"{canonical_name}" + match function_registry_entry_by_name(lookup_name): Some(entry) => return entry - None => return fail_t(f"missing function registry entry by name: {canonical_name}") + None => return fail_t("missing function registry entry by name") -def _assert_entry_signature( - function_ref: str, - canonical_name: str, - return_type_rule: str, - arg_names: list[str], - arg_type_rules: list[str], -) -> None: - """Assert one registry entry signature matches the public helper contract.""" - entry = _entry_or_fail(function_ref) - assert entry.canonical_name == canonical_name, f"{function_ref} should expose the expected canonical name" - assert entry.signature.return_type_rule == return_type_rule, f"{function_ref} should expose the expected return type" - assert len(entry.signature.args) == len(arg_names), f"{function_ref} should expose the expected arity" - assert len(arg_names) == len(arg_type_rules), f"{function_ref} test fixture should pair each argument with one type" +def _expected_rfc015_registry_names() -> list[str]: + """Return the expected registered public helper names after RFC 015.""" + return ["col", "lit", "sum", "count", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last"] - for index in range(len(arg_names)): - arg = entry.signature.args[index] - assert arg.name == arg_names[index], f"{function_ref} should preserve argument names" - assert arg.type_rule == arg_type_rules[index], f"{function_ref} should preserve argument type rules" + +def _expected_substrait_mapped_names() -> list[str]: + """Return helpers with concrete Substrait extension-function mappings.""" + return ["sum", "count", "add", "mul", "eq", "gt", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "between"] + + +def _exercise_current_public_helpers() -> None: + """Touch each current registered helper so runtime registry tests observe loaded modules only.""" + amount = col("amount") + status = col("status") + lit(1) + int_expr(1) + float_expr(1.5) + str_expr("paid") + bool_expr(true) + int_lit(1) + str_lit("paid") + bool_lit(true) + always_true() + always_false() + sum(amount) + count() + add(amount, lit(1)) + mul(amount, int_lit(2)) + eq(status, str_lit("paid")) + gt(amount, int_lit(10)) + cast(amount, "string") + try_cast(status, "int64") + ne(status, str_lit("void")) + lt(amount, int_lit(10)) + lte(amount, int_lit(10)) + gte(amount, int_lit(10)) + equal_null(status, str_lit("paid")) + and_(gt(amount, int_lit(10)), eq(status, str_lit("paid"))) + or_(eq(status, str_lit("paid")), eq(status, str_lit("open"))) + not_(eq(status, str_lit("void"))) + is_null(status) + is_not_null(status) + is_nan(amount) + is_not_nan(amount) + sub(amount, int_lit(1)) + div(amount, int_lit(2)) + modulo(amount, int_lit(2)) + neg(amount) + coalesce([status, str_lit("unknown")]) + nullif(status, str_lit("")) + case_when([gt(amount, int_lit(10))], [str_lit("large")], str_lit("small")) + in_(status, [str_lit("paid"), str_lit("open")]) + between(amount, int_lit(1), int_lit(10)) + asc(amount) + desc(amount) + asc_nulls_first(amount) + asc_nulls_last(amount) + desc_nulls_first(amount) + desc_nulls_last(amount) + return + + +def _assert_registered(canonical_name: str) -> None: + """Assert one canonical helper name is registered under the stable function-ref prefix.""" + refs = function_registry_function_refs() + names = function_registry_canonical_names() + ref_name = f"{canonical_name}" + assert _contains_text(refs, function_ref_for(ref_name)), "registry should include the expected function ref" + assert _contains_text(names, canonical_name), "registry should expose the expected canonical name" + + +def _assert_structural_mapping(canonical_name: str, context: str, detail: str) -> None: + """Assert one registered helper declares relation-context structural lowering metadata.""" + lookup_name = f"{canonical_name}" + entry = _entry_by_name_or_fail(lookup_name) + assert entry.function_class == FunctionClass.Ordering, "ordering helper should expose ordering function-class metadata" + assert entry.substrait.kind == SubstraitMappingKind.StructuralFunction, "ordering helper should lower through structural context" + assert entry.substrait.function_name == context, "ordering helper should name its lowering context" + assert entry.substrait.detail == detail, "ordering helper should carry direction/null-placement metadata" + + +def _assert_extension_mapping(canonical_name: str, function_name: str, anchor: u32) -> None: + """Assert one helper declares the expected concrete Substrait extension mapping.""" + entry = _entry_or_fail(function_ref_for(canonical_name)) + mapped_refs = registered_substrait_mapped_function_refs() + assert _contains_text(mapped_refs, function_ref_for(canonical_name)), f"{canonical_name} should be in the Substrait extension mapping set" + assert entry.substrait.kind == SubstraitMappingKind.ExtensionFunction, f"{canonical_name} should use a Substrait extension function" + assert entry.substrait.uri == function_extension_uri(), f"{canonical_name} should use the shared function extension URI" + assert entry.substrait.function_name == function_name, f"{canonical_name} should use the registered extension name" + assert entry.substrait.anchor == anchor, f"{canonical_name} should carry the stable Substrait anchor" + + +def _assert_core_mapping(canonical_name: str, function_name: str) -> None: + """Assert one helper declares the expected built-in Substrait Rex mapping.""" + entry = _entry_or_fail(function_ref_for(canonical_name)) + assert entry.substrait.kind == SubstraitMappingKind.CoreFunction, f"{canonical_name} should use a Substrait core mapping" + assert entry.substrait.function_name == function_name, f"{canonical_name} should use the registered core mapping name" + + +def _assert_rewrite_mapping(canonical_name: str, rewrite: str) -> None: + """Assert one helper declares a deterministic rewrite mapping.""" + entry = _entry_or_fail(function_ref_for(canonical_name)) + assert entry.substrait.kind == SubstraitMappingKind.Rewrite, f"{canonical_name} should use a rewrite mapping" + assert entry.substrait.rewrite == rewrite, f"{canonical_name} should describe its canonical rewrite shape" def _assert_not_deprecated(entry: FunctionRegistryEntry) -> None: """Assert one current registry entry is not marked deprecated.""" match entry.lifecycle.deprecated: - Some(_) => return fail_t(f"{entry.function_ref} should not be deprecated") + Some(_) => return fail_t("current registry entry should not be deprecated") None => pass def test_function_registry__covers_current_public_helpers() -> None: - """Assert that RFC 014 registry metadata covers the current public helper surface.""" + """Assert that RFC 015 registry metadata covers the current public helper surface.""" # -- Arrange -- - refs = function_registry_function_refs() - names = function_registry_canonical_names() + _exercise_current_public_helpers() + expected_names = _expected_rfc015_registry_names() # -- Act -- entry_count = function_registry_entry_count() # -- Assert -- - assert entry_count == 17, "registry should cover the current public helper set" + assert entry_count == len(expected_names), "registry should cover the current public helper set" assert len(function_registry_entries()) == entry_count, "entry count helper should match the runtime registry entries" - assert _contains_text(refs, function_ref_for("col")), "registry should include col" - assert _contains_text(refs, function_ref_for("lit")), "registry should include lit" - assert _contains_text(refs, function_ref_for("sum")), "registry should include sum" - assert _contains_text(refs, function_ref_for("count")), "registry should include count" - assert _contains_text(refs, function_ref_for("add")), "registry should include add" - assert _contains_text(refs, function_ref_for("mul")), "registry should include mul" - assert _contains_text(refs, function_ref_for("eq")), "registry should include eq" - assert _contains_text(refs, function_ref_for("gt")), "registry should include gt" - assert _contains_text(names, "always_true"), "registry should expose canonical helper names" - assert _contains_text(names, "always_false"), "registry should expose canonical helper names" + for expected_name in expected_names: + _assert_registered(expected_name) def test_function_registry__lifecycle_metadata_is_versioned() -> None: """Assert registered helpers carry typed lifecycle facts, not an erased default string.""" # -- Arrange -- + _exercise_current_public_helpers() entries = function_registry_entries() # -- Act / Assert -- - assert len(entries) == 17, "lifecycle fixture should cover the current registry surface" + assert len(entries) == len(_expected_rfc015_registry_names()), "lifecycle fixture should cover the current registry surface" for entry in entries: assert entry.lifecycle.since.major == v0_1.major, f"{entry.function_ref} should expose its introduction major version" assert entry.lifecycle.since.minor == v0_1.minor, f"{entry.function_ref} should expose its introduction minor version" @@ -129,110 +264,110 @@ def test_function_registry__lifecycle_metadata_is_versioned() -> None: _assert_not_deprecated(entry) -def test_function_registry__signatures_match_current_public_helpers() -> None: - """Assert every registered signature mirrors the current public helper contract.""" - # -- Arrange -- - literal_union = "int | float | str | bool" - - # -- Act -- - entry_count = function_registry_entry_count() - - # -- Assert -- - assert entry_count == 17, "signature fixture should cover the whole registry" - _assert_entry_signature(function_ref_for("col"), "col", "ColumnExpr", ["name"], ["str"]) - _assert_entry_signature(function_ref_for("lit"), "lit", "ColumnExpr", ["value"], [literal_union]) - _assert_entry_signature(function_ref_for("sum"), "sum", "AggregateMeasure", ["expr"], ["ColumnExpr"]) - _assert_entry_signature(function_ref_for("count"), "count", "AggregateMeasure", [], []) - _assert_entry_signature(function_ref_for("int_expr"), "int_expr", "ColumnExpr", ["value"], ["int"]) - _assert_entry_signature(function_ref_for("float_expr"), "float_expr", "ColumnExpr", ["value"], ["float"]) - _assert_entry_signature(function_ref_for("str_expr"), "str_expr", "ColumnExpr", ["value"], ["str"]) - _assert_entry_signature(function_ref_for("bool_expr"), "bool_expr", "ColumnExpr", ["value"], ["bool"]) - _assert_entry_signature( - function_ref_for("add"), - "add", - "ColumnExpr", - ["left", "right"], - ["ColumnExpr", "ColumnExpr"], - ) - _assert_entry_signature( - function_ref_for("mul"), - "mul", - "ColumnExpr", - ["left", "right"], - ["ColumnExpr", "ColumnExpr"], - ) - _assert_entry_signature(function_ref_for("int_lit"), "int_lit", "ColumnExpr", ["value"], ["int"]) - _assert_entry_signature(function_ref_for("str_lit"), "str_lit", "ColumnExpr", ["value"], ["str"]) - _assert_entry_signature(function_ref_for("bool_lit"), "bool_lit", "ColumnExpr", ["value"], ["bool"]) - _assert_entry_signature(function_ref_for("always_true"), "always_true", "ColumnExpr", [], []) - _assert_entry_signature(function_ref_for("always_false"), "always_false", "ColumnExpr", [], []) - _assert_entry_signature(function_ref_for("eq"), "eq", "ColumnExpr", ["left", "right"], ["ColumnExpr", "ColumnExpr"]) - _assert_entry_signature(function_ref_for("gt"), "gt", "ColumnExpr", ["left", "right"], ["ColumnExpr", "ColumnExpr"]) - - def test_function_registry__lookup_exposes_canonical_metadata() -> None: """Assert lookup helpers expose canonical registry metadata by ref and by name.""" # -- Arrange -- - col_entry = _entry_or_fail(function_ref_for("col")) - count_entry = _entry_by_name_or_fail("count") + _exercise_current_public_helpers() + col_ref = function_ref_for("col") + count_name = "count" # -- Act -- - col_arg = col_entry.signature.args[0] + col_entry = _entry_or_fail(col_ref) + count_entry = _entry_by_name_or_fail(count_name) # -- Assert -- assert col_entry.canonical_name == "col", "lookup by function ref should return canonical name" assert col_entry.function_class == FunctionClass.Scalar, "col should be classified as scalar" assert col_entry.alias_policy == FunctionAliasPolicy.CoreImport, "core helpers should use core import policy" - assert col_arg.name == "name", "registry signature should preserve public helper argument names" - assert col_arg.type_rule == "str", "registry signature should preserve public helper argument type rules" assert count_entry.function_ref == function_ref_for("count"), "lookup by name should return the stable function ref" assert count_entry.function_class == FunctionClass.Aggregate, "count should be classified as aggregate" - assert len(count_entry.signature.args) == 0, "count should be registered as zero-argument" def test_function_registry__substrait_extension_mappings_are_structured() -> None: """Assert Substrait extension-backed helpers carry stable mapping facts.""" # -- Arrange -- - add_entry = _entry_or_fail(function_ref_for("add")) - sum_entry = _entry_or_fail(function_ref_for("sum")) - count_entry = _entry_or_fail(function_ref_for("count")) - eq_entry = _entry_or_fail(function_ref_for("eq")) - gt_entry = _entry_or_fail(function_ref_for("gt")) + _exercise_current_public_helpers() # -- Act -- mapped_refs = registered_substrait_mapped_function_refs() # -- Assert -- - assert _contains_text(mapped_refs, function_ref_for("add")), "add should be in the Substrait extension mapping set" - assert _contains_text(mapped_refs, function_ref_for("sum")), "sum should be in the Substrait extension mapping set" - assert add_entry.substrait.kind == SubstraitMappingKind.ExtensionFunction, "add should use a Substrait extension function" - assert add_entry.substrait.uri == function_extension_uri(), "add should use the shared function extension URI" - assert add_entry.substrait.function_name == "add", "add should use the registered extension name" - assert add_entry.substrait.anchor == ADD_FUNCTION_ANCHOR, "add should carry the stable Substrait anchor" - assert sum_entry.substrait.function_name == "sum", "sum should use the registered extension name" - assert sum_entry.substrait.anchor == SUM_FUNCTION_ANCHOR, "sum should carry the stable Substrait anchor" - assert count_entry.substrait.anchor == COUNT_FUNCTION_ANCHOR, "count should carry the stable Substrait anchor" - assert eq_entry.substrait.anchor == EQUAL_FUNCTION_ANCHOR, "eq should carry the stable Substrait anchor" - assert gt_entry.substrait.anchor == GT_FUNCTION_ANCHOR, "gt should carry the stable Substrait anchor" + assert len(mapped_refs) == len(_expected_substrait_mapped_names()), "only helpers with honest current Substrait mappings should be extension-backed" + for canonical_name in _expected_substrait_mapped_names(): + assert _contains_text(mapped_refs, function_ref_for(canonical_name)), "mapped helper should be in the Substrait extension mapping set" + _assert_extension_mapping("sum", "sum", SUM_FUNCTION_ANCHOR) + _assert_extension_mapping("count", "count", COUNT_FUNCTION_ANCHOR) + _assert_extension_mapping("add", "add", ADD_FUNCTION_ANCHOR) + _assert_extension_mapping("mul", "multiply", MULTIPLY_FUNCTION_ANCHOR) + _assert_extension_mapping("eq", "equal", EQUAL_FUNCTION_ANCHOR) + _assert_extension_mapping("gt", "gt", GT_FUNCTION_ANCHOR) + _assert_extension_mapping("ne", "not_equal", NOT_EQUAL_FUNCTION_ANCHOR) + _assert_extension_mapping("lt", "lt", LT_FUNCTION_ANCHOR) + _assert_extension_mapping("lte", "lte", LTE_FUNCTION_ANCHOR) + _assert_extension_mapping("gte", "gte", GTE_FUNCTION_ANCHOR) + _assert_extension_mapping("equal_null", "is_not_distinct_from", IS_NOT_DISTINCT_FROM_FUNCTION_ANCHOR) + _assert_extension_mapping("and_", "and", AND_FUNCTION_ANCHOR) + _assert_extension_mapping("or_", "or", OR_FUNCTION_ANCHOR) + _assert_extension_mapping("not_", "not", NOT_FUNCTION_ANCHOR) + _assert_extension_mapping("is_null", "is_null", IS_NULL_FUNCTION_ANCHOR) + _assert_extension_mapping("is_not_null", "is_not_null", IS_NOT_NULL_FUNCTION_ANCHOR) + _assert_extension_mapping("is_nan", "is_nan", IS_NAN_FUNCTION_ANCHOR) + _assert_extension_mapping("sub", "subtract", SUBTRACT_FUNCTION_ANCHOR) + _assert_extension_mapping("div", "divide", DIVIDE_FUNCTION_ANCHOR) + _assert_extension_mapping("mod", "modulus", MODULUS_FUNCTION_ANCHOR) + _assert_extension_mapping("neg", "negate", NEGATE_FUNCTION_ANCHOR) + _assert_extension_mapping("coalesce", "coalesce", COALESCE_FUNCTION_ANCHOR) + _assert_extension_mapping("nullif", "nullif", NULLIF_FUNCTION_ANCHOR) + _assert_extension_mapping("between", "between", BETWEEN_FUNCTION_ANCHOR) + + +def test_function_registry__ordering_helpers_are_contextual_sort_fields() -> None: + """Assert RFC 015 ordering helpers are modeled as sort-field context helpers.""" + # -- Arrange -- + _exercise_current_public_helpers() + + # -- Act / Assert -- + _assert_structural_mapping("asc", "sort_field", "asc_nulls_first") + _assert_structural_mapping("desc", "sort_field", "desc_nulls_last") + _assert_structural_mapping("asc_nulls_first", "sort_field", "asc_nulls_first") + _assert_structural_mapping("asc_nulls_last", "sort_field", "asc_nulls_last") + _assert_structural_mapping("desc_nulls_first", "sort_field", "desc_nulls_first") + _assert_structural_mapping("desc_nulls_last", "sort_field", "desc_nulls_last") + + +def test_function_registry__core_substrait_mappings_are_structured() -> None: + """Assert helpers backed by built-in Substrait Rex shapes declare structured core mappings.""" + # -- Arrange -- + _exercise_current_public_helpers() + + # -- Act / Assert -- + _assert_core_mapping("cast", "cast") + _assert_core_mapping("try_cast", "try_cast") + _assert_core_mapping("in_", "singular_or_list") + _assert_core_mapping("case_when", "if_then") def test_function_registry__rewrite_mappings_identify_non_extension_helpers() -> None: """Assert helper entries that lower as selections or literals are marked as deterministic rewrites.""" # -- Arrange -- - col_entry = _entry_or_fail(function_ref_for("col")) - lit_entry = _entry_or_fail(function_ref_for("lit")) - always_true_entry = _entry_or_fail(function_ref_for("always_true")) - always_false_entry = _entry_or_fail(function_ref_for("always_false")) + _exercise_current_public_helpers() + col_ref = function_ref_for("col") + lit_ref = function_ref_for("lit") + always_true_ref = function_ref_for("always_true") + always_false_ref = function_ref_for("always_false") # -- Act -- - lit_arg = lit_entry.signature.args[0] + col_entry = _entry_or_fail(col_ref) + lit_entry = _entry_or_fail(lit_ref) + always_true_entry = _entry_or_fail(always_true_ref) + always_false_entry = _entry_or_fail(always_false_ref) # -- Assert -- assert col_entry.substrait.kind == SubstraitMappingKind.Rewrite, "col should lower as a direct field-reference rewrite" assert lit_entry.substrait.kind == SubstraitMappingKind.Rewrite, "lit should lower as a literal rewrite" assert always_true_entry.substrait.kind == SubstraitMappingKind.Rewrite, "always_true should lower as a literal rewrite" assert always_false_entry.substrait.kind == SubstraitMappingKind.Rewrite, "always_false should lower as a literal rewrite" - assert lit_arg.literal_only, "literal helper metadata should mark literal-only arguments" + _assert_rewrite_mapping("is_not_nan", "not_(is_nan(expr))") assert always_true_entry.null_behavior == FunctionNullBehavior.Predicate, "predicate helpers should expose predicate null behavior" assert always_false_entry.null_behavior == FunctionNullBehavior.Predicate, "predicate helpers should expose predicate null behavior" @@ -249,14 +384,26 @@ def test_function_registry__public_helpers_preserve_existing_behavior() -> None: add_expr = add(amount, lit(7)) eq_expr = eq(status, str_lit("paid")) gt_expr = gt(amount, int_lit(10)) + core_exprs = [and_(eq_expr, gt_expr), asc(amount), asc_nulls_first(amount), asc_nulls_last(amount), between( + amount, + int_lit(1), + int_lit(10), + ), bool_expr(true), coalesce([status, str_expr("unknown")]), desc(amount), desc_nulls_first(amount), desc_nulls_last( + amount, + ), div(amount, lit(2)), equal_null(status, str_lit("paid")), float_expr(1.5), gte(amount, int_lit(10)), in_( + status, + [str_lit("paid"), str_lit("open")], + ), lt(amount, int_lit(10)), lte(amount, int_lit(10)), modulo(amount, lit(2))] # -- Assert -- assert column_expr_kind(amount) == ColumnExprKind.Column, "col should still build a column reference" assert column_expr_kind(lit(true)) == ColumnExprKind.BoolLiteral, "lit should still build typed literals" assert sum_measure.kind == AggregateKind.Sum, "sum wrapper should preserve aggregate kind" assert count_measure.kind == AggregateKind.Count, "count wrapper should preserve aggregate kind" - assert column_expr_kind(add_expr) == ColumnExprKind.Add, "add wrapper should preserve expression kind" - assert column_expr_kind(eq_expr) == ColumnExprKind.Eq, "eq wrapper should preserve expression kind" - assert column_expr_kind(gt_expr) == ColumnExprKind.Gt, "gt wrapper should preserve expression kind" + assert column_expr_kind(add_expr) == ColumnExprKind.ScalarFunction, "add should use the shared scalar function kind" + assert column_expr_kind(eq_expr) == ColumnExprKind.ScalarFunction, "eq should use the shared scalar function kind" + assert column_expr_kind(gt_expr) == ColumnExprKind.ScalarFunction, "gt should use the shared scalar function kind" + for core_expr in core_exprs: + assert column_expr_kind(core_expr) != ColumnExprKind.Column, "core scalar helpers should build scalar expressions" assert column_expr_kind(always_true()) == ColumnExprKind.BoolLiteral, "always_true should still build a bool literal" assert column_expr_kind(always_false()) == ColumnExprKind.BoolLiteral, "always_false should still build a bool literal" diff --git a/tests/test_prism.incn b/tests/test_prism.incn index 9c3aa49..e668c2a 100644 --- a/tests/test_prism.incn +++ b/tests/test_prism.incn @@ -99,7 +99,7 @@ def test_prism__same_store_join_reuses_shared_history() -> None: def test_prism__same_store_join_with_longer_branches_is_still_one_append() -> None: # -- Arrange -- base: PrismCursor[Order] = prism_cursor_named_table(str("orders")) - left: PrismCursor[Order] = base.filter(always_false()).order_by() + left: PrismCursor[Order] = base.filter(always_false()).order_by([col("id")]) right: PrismCursor[Order] = base.select().limit(5) # -- Act -- @@ -159,10 +159,12 @@ def test_prism__cross_store_join_dedups_equivalent_reachable_rhs_nodes() -> None def test_prism__cross_store_join_dedups_equivalent_rhs_multistep_branches() -> None: # -- Arrange -- + _register_projection_test_schema(str("orders")) + _register_projection_test_schema(str("orders_archive")) left: PrismCursor[Order] = prism_cursor_named_table(str("orders")).filter(always_false()) right_base: PrismCursor[Order] = prism_cursor_named_table(str("orders_archive")) - right_left: PrismCursor[Order] = right_base.filter(always_false()).order_by() - right_right: PrismCursor[Order] = right_base.filter(always_false()).order_by() + right_left: PrismCursor[Order] = right_base.filter(always_false()).order_by([col("id")]) + right_right: PrismCursor[Order] = right_base.filter(always_false()).order_by([col("id")]) # -- Act -- right_joined: PrismCursor[Order] = right_left.join(right_right.clone(), true) @@ -187,7 +189,7 @@ def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: # -- Act -- aggregated: PrismCursor[Order] = grouped.agg([count()]) - ordered: PrismCursor[Order] = prism_cursor_named_table(str("orders")).order_by() + ordered: PrismCursor[Order] = prism_cursor_named_table(str("orders")).order_by([col("id")]) limited: PrismCursor[Order] = prism_cursor_named_table(str("orders")).limit(10) exploded: PrismCursor[Order] = prism_cursor_named_table(str("orders")).explode() @@ -222,7 +224,7 @@ def test_prism__rewrite_collapses_adjacent_limits_projects_and_order_by() -> Non projected: PrismCursor[Order] = prism_cursor_named_table(str("orders")).select().select() # -- Act -- - ordered: PrismCursor[Order] = prism_cursor_named_table(str("orders")).order_by().order_by() + ordered: PrismCursor[Order] = prism_cursor_named_table(str("orders")).order_by([col("id")]).order_by([col("id")]) # -- Assert -- assert prism_cursor_authored_node_count(limited) == 3, "authored history should keep both limit nodes" diff --git a/tests/test_session.incn b/tests/test_session.incn index 176d861..414d38a 100644 --- a/tests/test_session.incn +++ b/tests/test_session.incn @@ -2,7 +2,7 @@ from std.testing import assert_is_err, assert_is_ok, fail, parametrize from rust::std::path import Path -from backends import DataFusion, csv_source, parquet_source +from backends import DataFusion, csv_source, datafusion_backend_selection, parquet_source from dataset import LazyFrame, lazy_frame_named_table from session import Session, SessionError, SessionErrorKind from substrait.inspect import read_kind_name, root_rel @@ -55,6 +55,20 @@ def test_session__builder_selects_datafusion_backend() -> None: assert backend.enable_optimizer is false, "builder should preserve backend-specific DataFusion options" +def test_session__builder_accepts_portable_backend_selection() -> None: + """Session.builder().with_backend should accept an adapter-neutral backend selection envelope.""" + # -- Arrange -- + selection = datafusion_backend_selection(DataFusion(enable_optimizer=false)) + + # -- Act -- + session = Session.builder().with_backend(selection).build() + backend: DataFusion = session.datafusion_backend() + + # -- Assert -- + assert session.backend_name() == "datafusion", "generic backend selection should preserve the selected kind" + assert backend.enable_optimizer is false, "generic backend selection should preserve encoded backend options" + + def test_session__public_types_construct_locally() -> None: """ Session and its public API types should be constructible and callable locally without unexpected errors, diff --git a/tests/test_session_filters.incn b/tests/test_session_filters.incn index 73fc626..7d9b716 100644 --- a/tests/test_session_filters.incn +++ b/tests/test_session_filters.incn @@ -1,9 +1,32 @@ """End-to-end Session filter execution tests over the DataFusion backend.""" -from functions import col, eq, gt, lit +from functions import ( + and_, + between, + cast, + col, + desc, + eq, + equal_null, + gt, + gte, + in_, + is_not_nan, + is_not_null, + is_null, + lit, + lt, + lte, + mul, + ne, + not_, + or_, +) from dataset import LazyFrame +from projection_builders import ColumnExpr from session import Session -from std.testing import assert_is_ok +from std.testing import assert_is_ok, fail_t +from substrait.inspect import root_rel, sort_field_count, sort_field_direction_name @derive(Clone) @@ -21,6 +44,18 @@ model OrderLine: const ORDER_LINES_CSV_FIXTURE: str = "tests/fixtures/order_lines.csv" +def _row_count_for_filter(predicate: ColumnExpr) -> int: + """Collect the order-line fixture with one predicate and return the materialized row count.""" + mut session = Session.default() + lazy: LazyFrame[OrderLine] = assert_is_ok( + session.read_csv("order_lines", ORDER_LINES_CSV_FIXTURE), + "order lines fixture should load", + ) + match session.collect(lazy.filter(predicate)): + Ok(df) => return df.row_count() + Err(err) => return fail_t(err.error_message()) + + def test_session_filters__collect_executes_gt_predicate() -> None: """collect should execute builder-backed gt filters through the DataFusion path.""" # -- Arrange -- @@ -63,3 +98,65 @@ def test_session_filters__collect_executes_eq_predicate() -> None: assert payload.contains("SKU-MAT-07"), "status == open should keep the third open row" assert payload.contains("SKU-LAMP-11") is false, "status == open should exclude closed rows" assert payload.contains("SKU-MONITOR-24") is false, "status == open should exclude cancelled rows" + + +def test_session_filters__collect_executes_comparison_family_predicates() -> None: + """collect should execute registry-backed comparison predicates through DataFusion.""" + # -- Arrange / Act / Assert -- + assert _row_count_for_filter(ne(col("status"), lit("open"))) == 2, "status != open should keep two rows" + assert _row_count_for_filter(lt(col("qty"), lit(3))) == 3, "qty < 3 should keep three rows" + assert _row_count_for_filter(lte(col("qty"), lit(2))) == 3, "qty <= 2 should keep three rows" + assert _row_count_for_filter(gte(col("qty"), lit(3))) == 2, "qty >= 3 should keep two rows" + assert _row_count_for_filter(equal_null(col("status"), lit("open"))) == 3, "null-safe equality should match equality for non-null fixture values" + assert _row_count_for_filter(between(col("qty"), lit(2), lit(3))) == 3, "inclusive between should keep qty 2 and 3" + assert _row_count_for_filter(in_(col("status"), [lit("open"), lit("closed")])) == 4, "in_ should keep status values in the provided set" + + +def test_session_filters__collect_executes_boolean_and_null_predicates() -> None: + """collect should execute registry-backed boolean and null predicates through DataFusion.""" + # -- Arrange -- + open_and_large = and_(eq(col("status"), lit("open")), gte(col("qty"), lit(3))) + closed_or_cancelled = or_(eq(col("status"), lit("closed")), eq(col("status"), lit("cancelled"))) + + # -- Act / Assert -- + assert _row_count_for_filter(open_and_large) == 1, "and_ should require both predicates" + assert _row_count_for_filter(closed_or_cancelled) == 2, "or_ should keep either matching predicate" + assert _row_count_for_filter(not_(eq(col("status"), lit("open")))) == 2, "not_ should invert boolean predicates" + assert _row_count_for_filter(is_not_null(col("status"))) == 5, "all fixture statuses should be non-null" + assert _row_count_for_filter(is_null(col("status"))) == 0, "fixture statuses should contain no nulls" + assert _row_count_for_filter(is_not_nan(cast(col("qty"), "float64"))) == 5, "integer quantities cast to floats should not be NaN" + + +def test_session_filters__mixed_scalar_query_shape_executes_end_to_end() -> None: + """collect should execute a query-like mix of predicate, projection, and ordering scalar helpers.""" + # -- Arrange -- + mut session = Session.default() + + # -- Act -- + lazy: LazyFrame[OrderLine] = assert_is_ok( + session.read_csv("order_lines", ORDER_LINES_CSV_FIXTURE), + "order lines fixture should load", + ) + query_like = lazy.filter(and_(gt(col("qty"), lit(2)), in_(col("status"), [lit("open"), lit("closed")]))).with_column( + "line_total", + mul(cast(col("unit_price"), "float64"), cast(col("qty"), "float64")), + ).order_by([desc(col("line_total"))]) + plan_root = root_rel(query_like.to_substrait_plan()) + df = assert_is_ok(session.collect(query_like), "mixed scalar query shape should collect") + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert sort_field_count(plan_root) == 1, "query-like order_by should emit one sort field" + assert sort_field_direction_name(plan_root, 0) == "DescNullsLast", "desc should lower through registry sort metadata" + assert df.row_count() == 2, "mixed predicate should keep only open/closed rows with quantity above two" + assert len(resolved) == 9, "computed projection should append one materialized output column" + assert resolved[8] == "line_total", "computed projection should preserve the requested output alias" + assert payload.contains("line_total"), "preview should render the computed output column" + assert payload.contains("SKU-MAT-07"), "open row with qty 3 should survive the predicate" + assert payload.contains("SKU-LAMP-11"), "closed row with qty 4 should survive the predicate" + assert payload.contains("SKU-CHAIR-01") is false, "qty 2 row should be filtered out" + assert payload.contains("SKU-DESK-02") is false, "qty 1 row should be filtered out" + assert payload.contains("SKU-MONITOR-24") is false, "cancelled row should be filtered out" + assert payload.contains("102.75"), "line_total should include 34.25 * 3" + assert payload.contains("79.8"), "line_total should include 19.95 * 4" diff --git a/tests/test_session_projection.incn b/tests/test_session_projection.incn index d4ee9ee..c39809f 100644 --- a/tests/test_session_projection.incn +++ b/tests/test_session_projection.incn @@ -1,10 +1,10 @@ """End-to-end Session projection execution tests over the DataFusion backend.""" -from functions import add, col, lit, mul -from dataset import LazyFrame +from functions import add, case_when, cast, coalesce, col, desc, div, gt, lit, modulo, mul, neg, nullif, sub, try_cast +from dataset import DataFrame, LazyFrame from session import Session, SessionErrorKind -from std.testing import assert_is_err, assert_is_ok -from substrait.inspect import root_names +from std.testing import assert_is_err, assert_is_ok, fail_t +from substrait.inspect import relation_kind_name, root_names, root_rel, sort_field_count, sort_field_direction_name @derive(Clone) @@ -16,6 +16,13 @@ pub model AggregateOrder: const AGGREGATE_ORDERS_CSV_FIXTURE: str = "tests/fixtures/aggregate_orders.csv" +def _collect_or_fail(mut session: Session, projected: LazyFrame[AggregateOrder]) -> DataFrame[AggregateOrder]: + """Collect a projected aggregate-order frame or fail with the backend diagnostic.""" + match session.collect(projected): + Ok(df) => return df + Err(err) => return fail_t(err.error_message()) + + def test_session_projection__plan_root_names_match_append_projection() -> None: # -- Arrange -- mut session = Session.default() @@ -83,6 +90,47 @@ def test_session_projection__collect_executes_with_column_append() -> None: assert payload.contains("14"), "materialized projection should include 7 * 2" +def test_session_projection__collect_executes_core_scalar_projection_functions() -> None: + """collect should execute newly mapped RFC 015 scalar projection helpers through DataFusion.""" + # -- Arrange -- + mut session = Session.default() + + # -- Act -- + lazy: LazyFrame[AggregateOrder] = assert_is_ok( + session.read_csv("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE), + "aggregate orders fixture should load", + ) + projected = lazy.with_column("amount_minus_one", sub(col("amount"), lit(1))).with_column( + "amount_div_five", + div(col("amount"), lit(5)), + ).with_column("amount_mod_two", modulo(col("amount"), lit(2))).with_column("negative_amount", neg(col("amount"))).with_column( + "customer_or_unknown", + coalesce([col("customer_id"), lit("unknown")]), + ).with_column("customer_null_if_a", nullif(col("customer_id"), lit("A"))).with_column( + "amount_text", + cast(col("amount"), "string"), + ).with_column("customer_try_int", try_cast(col("customer_id"), "int64")).with_column( + "amount_bucket", + case_when([gt(col("amount"), lit(10))], [lit("large")], lit("small")), + ) + df = _collect_or_fail(session, projected) + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 3, "projected scalar helpers should preserve the input rows" + assert len(resolved) == 11, "projection should expose all appended scalar helper outputs" + assert payload.contains("amount_minus_one"), "sub projection should materialize its alias" + assert payload.contains("amount_div_five"), "div projection should materialize its alias" + assert payload.contains("amount_mod_two"), "mod projection should materialize its alias" + assert payload.contains("negative_amount"), "neg projection should materialize its alias" + assert payload.contains("customer_or_unknown"), "coalesce projection should materialize its alias" + assert payload.contains("customer_null_if_a"), "nullif projection should materialize its alias" + assert payload.contains("amount_text"), "cast projection should materialize its alias" + assert payload.contains("customer_try_int"), "try_cast projection should materialize its alias" + assert payload.contains("amount_bucket"), "case_when projection should materialize its alias" + + def test_session_projection__collect_executes_identity_select() -> None: # -- Arrange -- mut session = Session.default() @@ -102,6 +150,27 @@ def test_session_projection__collect_executes_identity_select() -> None: assert resolved[1] == "amount", "identity select should preserve the second input name" +def test_session_projection__collect_executes_order_by_sort_fields() -> None: + """collect should execute order_by plans that carry real Substrait sort fields.""" + # -- Arrange -- + mut session = Session.default() + + # -- Act -- + lazy: LazyFrame[AggregateOrder] = assert_is_ok( + session.read_csv("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE), + "aggregate orders fixture should load", + ) + ordered = lazy.order_by([desc(col("amount"))]) + plan_root = root_rel(ordered.to_substrait_plan()) + df = _collect_or_fail(session, ordered) + + # -- Assert -- + assert relation_kind_name(plan_root) == "SortRel", "order_by should lower to a Substrait SortRel" + assert sort_field_count(plan_root) == 1, "order_by should emit one Substrait sort field" + assert sort_field_direction_name(plan_root, 0) == "DescNullsLast", "desc should preserve direction/null placement" + assert df.row_count() == 3, "ordered collect should preserve input row count" + + def test_session_projection__collect_unknown_projection_column_returns_planning_error() -> None: # -- Arrange -- mut session = Session.default() diff --git a/tests/test_substrait_plan.incn b/tests/test_substrait_plan.incn index bc7aefa..1cc59cf 100644 --- a/tests/test_substrait_plan.incn +++ b/tests/test_substrait_plan.incn @@ -1,9 +1,51 @@ """Tests for RFC 002 proto-backed Substrait emission and conformance alignment.""" -from aggregate_builders import AggregateMeasure, AggregateKind -from functions import add, always_true, col, lit, mul -from projection_builders import with_column_assignment -from substrait.extensions import explode_extension_uri, function_extension_uri, registered_substrait_extension_uris +from std.testing import fail_t +from functions import ( + add, + always_true, + and_, + asc, + asc_nulls_last, + between, + case_when, + cast, + col, + coalesce, + count, + desc, + div, + eq, + equal_null, + gt, + gte, + in_, + is_nan, + is_not_nan, + is_not_null, + is_null, + lit, + lt, + lte, + modulo, + mul, + ne, + neg, + not_, + nullif, + or_, + sub, + sum, + try_cast, +) +from projection_builders import ColumnExpr, with_column_assignment +from substrait.errors import SubstraitLoweringErrorKind +from substrait.expr_lowering import scalar_expr +from substrait.function_extensions import ( + explode_extension_uri, + function_extension_uri, + registered_substrait_extension_uris, +) from substrait.inspect import ( plan_contains_relation_kind, plan_has_extension_urn, @@ -14,6 +56,9 @@ from substrait.inspect import ( root_names, root_rel, set_operation_name, + sort_field_count, + sort_field_direction_name, + sort_field_expr_index, ) from substrait.plans import ( plan_encoded_len, @@ -43,7 +88,7 @@ from substrait.relations import ( reference_rel, set_rel, set_rel_of_kind, - sort_rel, + sort_rel_of_columns, ) from substrait.schema_registry import named_table_columns, register_named_table_schema from substrait.schema import RowColumnSpec, SubstraitPrimitiveKind @@ -201,7 +246,7 @@ def test_plan__combinators_compose_deterministically() -> None: # -- Arrange -- base = read_named_table_rel("orders") with_filter = filter_rel(base, always_true()) - with_sort = sort_rel(with_filter) + with_sort = sort_rel_of_columns(with_filter, ["id"], [asc(col("id"))]) with_fetch = fetch_rel(with_sort, 0, 10) # -- Act -- @@ -241,18 +286,80 @@ def test_plan__project_rel_supports_projection_expression_payloads() -> None: assert root_output_names[2] == "id_plus_one", "project root should preserve all derived aliases" +def test_plan__ordering_helper_is_invalid_outside_sort_context() -> None: + """Assert ordering helpers are valid only as order_by sort-key wrappers.""" + # -- Arrange -- + expr = asc(col("id")) + + # -- Act -- + result = scalar_expr(["id"], expr) + + # -- Assert -- + match result: + Err(err) => + assert err.kind == SubstraitLoweringErrorKind.InvalidScalarExpression, "context errors should be invalid scalar expressions" + assert err.message.contains("inql.functions.asc"), "diagnostic should name the registered function reference" + assert err.message.contains("sort_field context"), "diagnostic should explain the valid lowering context" + Ok(_) => return fail_t("asc should not lower as a standalone scalar expression") + + +def test_plan__ordering_helpers_lower_to_real_sort_fields() -> None: + """Assert order_by sort-key helpers become Substrait SortField direction/null-placement metadata.""" + # -- Arrange -- + base = read_named_table_rel("orders") + + # -- Act -- + sorted = sort_rel_of_columns(base, ["id", "amount"], [desc(col("amount")), asc_nulls_last(col("id"))]) + + # -- Assert -- + assert relation_kind_name(sorted) == "SortRel", "ordering helpers should lower to SortRel" + assert sort_field_count(sorted) == 2, "SortRel should carry one field per requested sort key" + assert sort_field_expr_index(sorted, 0) == 1, "first sort field should target amount" + assert sort_field_direction_name(sorted, 0) == "DescNullsLast", "desc should use descending nulls-last ordering" + assert sort_field_expr_index(sorted, 1) == 0, "second sort field should target id" + assert sort_field_direction_name(sorted, 1) == "AscNullsLast", "asc_nulls_last should preserve null placement" + + +def _assert_scalar_expr_lowers(expr: ColumnExpr) -> None: + """Assert one scalar expression lowers successfully against the shared scalar fixture columns.""" + match scalar_expr(["id", "amount", "status"], expr): + Ok(_) => pass + Err(err) => return fail_t(err.error_message()) + + +def test_plan__core_scalar_extension_mappings_lower_to_substrait() -> None: + """Assert RFC 015 core scalar helpers emit Substrait scalar functions.""" + # -- Arrange / Act / Assert -- + _assert_scalar_expr_lowers(ne(col("id"), lit(1))) + _assert_scalar_expr_lowers(cast(col("id"), "string")) + _assert_scalar_expr_lowers(try_cast(col("status"), "int64")) + _assert_scalar_expr_lowers(lt(col("id"), lit(10))) + _assert_scalar_expr_lowers(lte(col("id"), lit(10))) + _assert_scalar_expr_lowers(gte(col("id"), lit(1))) + _assert_scalar_expr_lowers(equal_null(col("status"), lit("open"))) + _assert_scalar_expr_lowers(and_(gt(col("id"), lit(1)), eq(col("status"), lit("open")))) + _assert_scalar_expr_lowers(or_(eq(col("status"), lit("open")), eq(col("status"), lit("closed")))) + _assert_scalar_expr_lowers(not_(eq(col("status"), lit("void")))) + _assert_scalar_expr_lowers(is_null(col("status"))) + _assert_scalar_expr_lowers(is_not_null(col("status"))) + _assert_scalar_expr_lowers(is_nan(col("amount"))) + _assert_scalar_expr_lowers(is_not_nan(col("amount"))) + _assert_scalar_expr_lowers(in_(col("status"), [lit("open"), lit("closed")])) + _assert_scalar_expr_lowers(case_when([gt(col("amount"), lit(10))], [lit("large")], lit("small"))) + _assert_scalar_expr_lowers(sub(col("amount"), lit(1))) + _assert_scalar_expr_lowers(div(col("amount"), lit(2))) + _assert_scalar_expr_lowers(modulo(col("amount"), lit(2))) + _assert_scalar_expr_lowers(neg(col("amount"))) + _assert_scalar_expr_lowers(coalesce([col("status"), lit("unknown")])) + _assert_scalar_expr_lowers(nullif(col("status"), lit(""))) + _assert_scalar_expr_lowers(between(col("amount"), lit(1), lit(10))) + + def test_plan__aggregate_rel_surfaces_group_and_measure_output_columns() -> None: # -- Arrange -- _register_orders_schema() base = read_named_table_rel("orders") - aggregated = aggregate_rel( - base, - [col("id")], - [AggregateMeasure(kind=AggregateKind.Sum, expr=col("id")), AggregateMeasure( - kind=AggregateKind.Count, - expr=col(""), - )], - ) + aggregated = aggregate_rel(base, [col("id")], [sum(col("id")), count()]) plan = plan_from_root_relation(aggregated, ["id", "sum_id", "count"]) # -- Act -- @@ -271,11 +378,7 @@ def test_plan__aggregate_rel_accepts_scalar_group_and_measure_expressions() -> N # -- Arrange -- _register_orders_schema() base = read_named_table_rel("orders") - aggregated = aggregate_rel( - base, - [add(col("id"), lit(1))], - [AggregateMeasure(kind=AggregateKind.Sum, expr=add(col("id"), lit(2)))], - ) + aggregated = aggregate_rel(base, [add(col("id"), lit(1))], [sum(add(col("id"), lit(2)))]) # -- Act -- output_columns = relation_output_columns(aggregated) @@ -416,16 +519,16 @@ def test_conformance__core_scenarios_validate_emission_output() -> None: ) CoreScenarioKey.AggregateGroupingSets => plan = plan_from_root_relation( - aggregate_rel( - read_named_table_rel(_fixture_table_main()), - [col(_fixture_col_primary())], - [AggregateMeasure(kind=AggregateKind.Count, expr=col(""))], - ), + aggregate_rel(read_named_table_rel(_fixture_table_main()), [col(_fixture_col_primary())], [count()]), [_fixture_col_primary()], ) CoreScenarioKey.SortRelOrdering => plan = plan_from_root_relation( - sort_rel(read_named_table_rel(_fixture_table_main())), + sort_rel_of_columns( + read_named_table_rel(_fixture_table_main()), + [_fixture_col_primary()], + [asc(col(_fixture_col_primary()))], + ), [_fixture_col_primary()], ) CoreScenarioKey.FetchRelLimitOffset =>