diff --git a/docs/dialect/distance-operators.rst b/docs/dialect/distance-operators.rst index d3ebe89..10455e1 100644 --- a/docs/dialect/distance-operators.rst +++ b/docs/dialect/distance-operators.rst @@ -309,6 +309,11 @@ Find nearby same-strand features within distance constraints: WHERE nearest.distance BETWEEN -10000 AND 10000 ORDER BY peaks.name, ABS(nearest.distance) +Target support +~~~~~~~~~~~~~~ + +A correlated ``NEAREST`` (its reference is an outer-row column) runs on lateral-capable engines — DuckDB and the generic target — via a correlated ``LATERAL`` subquery, and on Apache DataFusion, which has no correlated-``LATERAL`` physical plan, via a decorrelated window-function rewrite. Both forms return identical results (a deterministic tiebreaker orders rows tied at the k-th distance the same way on every engine). A standalone ``NEAREST`` with a literal reference is uncorrelated and uses the same ordered, limited subquery on every target. + Notes ~~~~~ diff --git a/src/giql/expander.py b/src/giql/expander.py index f93132a..3b5a900 100644 --- a/src/giql/expander.py +++ b/src/giql/expander.py @@ -39,10 +39,11 @@ i.e. a ``(target, op)`` or ``(generic, op)`` expander is registered. Otherwise it falls through to the legacy ``*_sql`` emitter on -:class:`giql.generators.base.BaseGIQLGenerator`. As of this issue **no operator -sets ``GIQL_EXPAND`` and the registry is empty, so the pass is a strict no-op**: -no node is touched and the emitted SQL is byte-identical. Each later migration PR -(epic #137 steps 4-9) registers a generic expander, flips one operator's +:class:`giql.generators.base.BaseGIQLGenerator`. The built-in expanders register +at import time via :mod:`giql.expanders`; the pass rewrites a node only when +``GIQL_EXPAND=True`` **and** an expander resolves for ``(active target, operator +type)``, and is a no-op for any operator that is unflagged or has no registered +expander. A migration PR registers an expander, flips one operator's ``GIQL_EXPAND`` flag, and deletes that operator's ``*_sql`` method. """ @@ -149,6 +150,14 @@ class OperatorExpander(Protocol): a registered object satisfies it. A plain function is *not* an ``OperatorExpander`` (it has no ``expand`` method); register one by wrapping it (see :func:`register`, which accepts either form). + + An expander is **node-local**: ``expand(node, ctx) -> exp.Expression`` sees + one operator node and returns the expression that replaces it in place. It + cannot express a whole-query rewrite such as the INTERSECTS IEJoin fold, + which restructures the surrounding query (joins, CTEs) rather than a single + node. That fold is therefore deferred — it would need a separate + query-level mechanism — and is handled by the pre-pass join transformers, not + by an expander. """ def expand(self, node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: ... @@ -215,6 +224,16 @@ def register( expander : OperatorExpander | ExpanderFn The expander object or function. A later registration for the same key replaces an earlier one (last-write-wins override). + + Notes + ----- + Registering a *non-generic* ``(target, operator)`` expander where the + operator has a built-in whole-query join rewrite (notably + :class:`~giql.expressions.Intersects`, whose binned equi-join / DuckDB + IEJoin transformers run before expansion) signals that this expander + assumes responsibility for that rewrite: the built-in join transformers + are bypassed for that target so the operator node flows untouched into + :class:`ExpandOperators`. See :meth:`has_override`. """ self._expanders[(target, operator)] = _as_callable(expander) @@ -223,6 +242,12 @@ def resolve(self, target: Target, operator: type) -> ExpanderFn | None: Tries the exact ``(target, op)`` entry, then the ``(GenericTarget(), op)`` fallback, then ``None`` (legacy emitter). + + A non-generic exact ``(target, op)`` entry is also a *join-rewrite + override* for operators with a built-in whole-query join rewrite (notably + :class:`~giql.expressions.Intersects`): registering one bypasses the + built-in binned / IEJoin transformers for that target (see + :meth:`register` and :meth:`has_override`). """ fn = self._expanders.get((target, operator)) if fn is not None: @@ -233,6 +258,19 @@ def resolve(self, target: Target, operator: type) -> ExpanderFn | None: return fn return None + def has_override(self, target: Target, operator: type) -> bool: + """Whether an exact non-generic override supersedes built-in handling. + + Returns ``True`` only when *target* is not :class:`~giql.targets.GenericTarget` + and an exact ``(target, operator)`` entry is registered. Such an entry is a + target-specific override that supersedes built-in handling for that target + (e.g. it takes responsibility for the whole-query join rewrite that the + built-in transformers would otherwise perform); the portable + ``(GenericTarget(), operator)`` fallback is *not* an override and does not + count here. + """ + return target != GenericTarget() and (target, operator) in self._expanders + def unregister(self, target: Target, operator: type) -> None: """Drop the ``(target, operator)`` entry if present. @@ -253,6 +291,33 @@ def clear(self) -> None: """ self._expanders.clear() + def snapshot(self) -> dict[tuple[Target, type], ExpanderFn]: + """Return a shallow copy of the current registrations. + + The save half of a save/restore seam used primarily for test baseline + isolation (and which may serve a plugin that mutates the process-wide + :data:`REGISTRY` around a body): capture the baseline with this and hand + it back to :meth:`restore` afterward, so the built-in expanders + registered at import survive an isolating fixture that would otherwise + :meth:`clear` them permanently. It is not a committed plugin API. + + The returned dict is a fresh mapping (mutating it does not affect the + registry), keyed by the same ``(target, operator)`` tuples. + """ + return dict(self._expanders) + + def restore(self, snapshot: dict[tuple[Target, type], ExpanderFn]) -> None: + """Replace all registrations with those captured by :meth:`snapshot`. + + The restore half of the save/restore seam (test baseline isolation; may + also serve a plugin, but is not a committed plugin API). Drops every + current entry and re-installs exactly the *snapshot* contents, so a + fixture can return the registry to a previously captured baseline + regardless of what its body registered or cleared. + """ + self._expanders.clear() + self._expanders.update(snapshot) + def __contains__(self, key: tuple[Target, type]) -> bool: """Whether an *exact* ``(target, operator)`` entry is registered. @@ -280,8 +345,9 @@ def __bool__(self) -> bool: #: The process-wide registry the :func:`register` decorator writes to and the -#: :class:`ExpandOperators` pass reads from. Empty as of this issue, so the pass -#: is a strict no-op. +#: :class:`ExpandOperators` pass reads from. The built-in expanders register into +#: it at import time via :mod:`giql.expanders`; the pass rewrites a node only when +#: an expander resolves here (and the operator is flagged ``GIQL_EXPAND``). REGISTRY = ExpanderRegistry() @@ -380,9 +446,10 @@ class sets ``GIQL_EXPAND = True`` *and* the registry resolves an expander for ``(target, operator type)`` through its fallback chain; otherwise the node is left untouched and the legacy ``*_sql`` emitter handles it. - The pass mutates and returns *expression* in place. **With no operator - flagged and an empty registry it is a strict no-op** and the emitted SQL is - byte-identical, so the existing suite is the migration oracle. + The pass mutates and returns *expression* in place. It touches only nodes + whose operator is flagged ``GIQL_EXPAND`` and resolves an expander; for every + other operator it is a no-op, leaving the emitted SQL byte-identical, so the + existing suite is the migration oracle. Parameters ---------- @@ -399,9 +466,9 @@ class sets ``GIQL_EXPAND = True`` *and* the registry resolves an expander for Returns ------- exp.Expression - The same *expression*, with opted-in operator nodes replaced by their - target-specific expansions (none, while every flag is off / the registry - is empty). + The same *expression*, with each opted-in operator node that resolves an + expander replaced by its target-specific expansion; nodes that are + unflagged or resolve no expander are left untouched. """ reg = registry if registry is not None else REGISTRY operators = _giql_operators() diff --git a/src/giql/expanders/__init__.py b/src/giql/expanders/__init__.py new file mode 100644 index 0000000..042a07d --- /dev/null +++ b/src/giql/expanders/__init__.py @@ -0,0 +1,27 @@ +"""Built-in operator expanders for epic #137. + +Importing this package registers every built-in expander as a side effect: +each submodule decorates its expander(s) with ``@register(...)`` at import +time, and this package imports all of them. The import is wired once (in +:mod:`giql.transpile`) so the process-wide ``REGISTRY`` is populated before the +first transpile. + +New operator modules are picked up automatically: drop a ``.py`` into +this package and it is imported here without editing this file. + +Modules whose name starts with ``_`` are skipped (private helpers, not +expanders). Submodules import in :func:`pkgutil.iter_modules` order, which sets +last-write-wins resolution-order precedence for overlapping registrations; an +import error here aborts the whole package import by design (a broken built-in +expander must not be silently skipped). +""" + +from __future__ import annotations + +import importlib +import pkgutil + +for _module_info in pkgutil.iter_modules(__path__): + if _module_info.name.startswith("_"): + continue + importlib.import_module(f"{__name__}.{_module_info.name}") diff --git a/src/giql/expanders/nearest.py b/src/giql/expanders/nearest.py new file mode 100644 index 0000000..9bfafbc --- /dev/null +++ b/src/giql/expanders/nearest.py @@ -0,0 +1,415 @@ +"""The NEAREST operator expander (epic #137, issue #142). + +NEAREST is the first operator whose expansion is genuinely capability-driven. +The portable form is a correlated ``LATERAL`` subquery: each outer row drives a +``SELECT ... FROM WHERE ORDER BY ABS(distance) +LIMIT k`` whose reference endpoints are outer-table columns. DuckDB and the +generic target plan that directly (``supports_lateral == True``). + +Apache DataFusion has no correlated-``LATERAL`` physical plan +(``supports_lateral == False``). For it the same k-nearest / ``max_distance`` / +``stranded`` / ``signed`` semantics are reproduced with a **decorrelated +window-function fallback**: the target is cross-joined against the outer +relation, each candidate is ranked with +``ROW_NUMBER() OVER (PARTITION BY ORDER BY ABS(distance))``, and +the surrounding ``CROSS JOIN LATERAL`` is rewritten into a plain join that +re-associates the top-``k`` ranked candidates back to every outer row sharing +that reference key. Ranking depends only on the reference value, so ranking once +per distinct reference value and re-joining is row-for-row identical to the +per-row LATERAL form (verified by the cross-target result oracle). + +A literal-reference (standalone) NEAREST is already an uncorrelated subquery, so +every target — DataFusion included — uses the LATERAL/standalone form unchanged; +only the *correlated* shape needs the fallback. + +The expander reuses :class:`giql.generators.base.BaseGIQLGenerator`'s +``_generate_distance_case`` (shared with DISTANCE, #140) and ``_nearest_*`` +passthrough/diagnostic helpers — all static, so they are called on the class with +no generator instance — then parses the assembled SQL fragments into AST so the +emitted SQL is reserialized by the active target's serializer. +""" + +from __future__ import annotations + +from sqlglot import exp +from sqlglot import parse_one + +from giql.dialect import GIQLDialect +from giql.expander import ExpansionContext +from giql.expander import register +from giql.expressions import GIQLNearest +from giql.generators.base import BaseGIQLGenerator +from giql.resolver import ResolvedInterval +from giql.resolver import ResolvedRef +from giql.targets import GenericTarget + +#: Reserved column names the window-function fallback synthesizes inside its +#: ranked subquery. They share the expander's ``__giql_x_`` prefix so they stay +#: clear of user identifiers, mirroring the other reserved internal prefixes. +_RANK_COL = "__giql_x_rn" +_REF_KEY_PREFIX = "__giql_x_rk_" + + +def _nearest_params(expression: GIQLNearest): + """Unpack the (k, max_distance, stranded, signed) parameters of a NEAREST.""" + k = expression.args.get("k") + k_value = int(str(k)) if k else 1 + + max_distance = expression.args.get("max_distance") + max_dist_value = int(str(max_distance)) if max_distance else None + + is_stranded = BaseGIQLGenerator._extract_bool_param(expression.args.get("stranded")) + is_signed = BaseGIQLGenerator._extract_bool_param(expression.args.get("signed")) + return k_value, max_dist_value, is_stranded, is_signed + + +def _distance_and_filters( + expression, table_name, target_ref, ref, ref_fragments=None +): + """Build the shared distance SQL, the qualified target columns, and WHERE. + + Returns ``(distance_expr, abs_distance_expr, where_clauses, passthrough)`` — + the fragments common to the LATERAL/standalone form and the decorrelated + fallback. Distance math, the chromosome pre-filter, the optional strand match, + and the optional ``max_distance`` filter all reproduce the legacy + ``giqlnearest_sql`` emitter exactly. Each form derives its deterministic + ORDER BY tiebreaker from the target columns itself. + + ``ref_fragments`` optionally overrides the reference ``(chrom, start, end, + strand)`` SQL fragments. The LATERAL form consumes the resolution's + outer-qualified fragments verbatim; the fallback passes fragments pointing at + its renamed, pre-projected reference relation so the cross-joined columns + carry names distinct from the target's (DataFusion's planner cannot resolve a + window ordering over a join with duplicate column names). + """ + target_chrom, target_start, target_end = target_ref.cols + _k_value, max_dist_value, is_stranded, is_signed = _nearest_params(expression) + + output_table = BaseGIQLGenerator._nearest_output_encoding(expression, target_ref) + passthrough = BaseGIQLGenerator._nearest_passthrough( + table_name, target_start, target_end, output_table + ) + + if ref_fragments is not None: + ref_chrom, ref_start, ref_end, ref_strand_frag = ref_fragments + else: + ref_chrom, ref_start, ref_end, ref_strand_frag = ( + ref.chrom, + ref.start, + ref.end, + ref.strand, + ) + + ref_strand = None + target_strand = None + if is_stranded: + ref_strand = ref_strand_frag + if output_table and output_table.strand_col: + target_strand = f'{table_name}."{output_table.strand_col}"' + + target_chrom_expr = f'{table_name}."{target_chrom}"' + target_start_expr = f'{table_name}."{target_start}"' + target_end_expr = f'{table_name}."{target_end}"' + + distance_expr = BaseGIQLGenerator._generate_distance_case( + ref_chrom, + ref_start, + ref_end, + ref_strand, + target_chrom_expr, + target_start_expr, + target_end_expr, + target_strand, + stranded=is_stranded, + signed=is_signed, + ) + abs_distance_expr = f"ABS({distance_expr})" + + where_clauses = [f"{ref_chrom} = {target_chrom_expr}"] + if is_stranded and ref_strand and target_strand: + where_clauses.append(f"{ref_strand} = {target_strand}") + if max_dist_value is not None: + where_clauses.append(f"({abs_distance_expr}) <= {max_dist_value}") + + return distance_expr, abs_distance_expr, where_clauses, passthrough + + +def _lateral_form(expression, ctx, table_name, target_ref, ref): + """The portable LATERAL/standalone subquery. + + Builds a two-level subquery: an inner ``SELECT , AS + distance FROM WHERE ...`` that materializes the distance, wrapped by + an outer ``SELECT * FROM () AS x ORDER BY ABS(x.distance), x., + x. LIMIT k`` that orders on the *precomputed* ``distance`` column. For a + correlated placement the parent ``LATERAL`` correlates it to the outer row; + for a standalone (literal-reference) placement it stands alone. + + Splitting the distance computation (inner) from the ordering (outer) is + load-bearing for cross-engine support: + + * DuckDB's correlated-``LATERAL`` binder will not resolve a SELECT-list alias + named ``distance`` from inside an ``ORDER BY`` that also projects + ``.*``, so the order key must reference a *materialized* column + (``x.distance``) from the wrapping level rather than an alias in the same + SELECT. + * DataFusion's planner, given the distance ``CASE`` re-emitted inline in the + ``ORDER BY`` over the chromosome-equality prefiltered scan, rewrites the + filtered ``chrom`` to a self-comparison in one copy of the CASE but not the + other and trips ``SanityCheckPlan``; ordering on the materialized column + avoids re-deriving the key. + + A deterministic ``(start, end)`` tiebreaker follows ``ABS(distance)`` so rows + tied at the k-th distance order identically across engines and against the + decorrelated fallback's ranking (#142 A5). + """ + k_value, *_ = _nearest_params(expression) + ( + distance_expr, + _abs_distance_expr, + where_clauses, + passthrough, + ) = _distance_and_filters(expression, table_name, target_ref, ref) + where_sql = " AND ".join(where_clauses) + # The wrapping level reads the inner row's *bare* column names (the passthrough + # projected ``.*``), so the tiebreaker qualifies them by the wrapper + # alias, not the original ``table_name."col"``. + _chrom, target_start_col, target_end_col = target_ref.cols + wrapper = ctx.alias() + inner = ( + f"SELECT {passthrough}, {distance_expr} AS distance " + f"FROM {table_name} WHERE {where_sql}" + ) + sql = ( + f"(SELECT * FROM ({inner}) AS {wrapper} " + f'ORDER BY ABS({wrapper}."distance"), ' + f'{wrapper}."{target_start_col}", {wrapper}."{target_end_col}" ' + f"LIMIT {k_value})" + ) + return parse_one(sql, dialect=GIQLDialect) + + +def _outer_relation(ref: ResolvedInterval) -> tuple[str, str]: + """Return ``(physical_relation, alias)`` for the correlated reference table. + + The reference endpoints are alias-qualified fragments (``a."chrom"``). The + alias is the outer table's correlation name in the query; the physical + relation comes from the reference's backing :class:`~giql.table.Table`. Both + are needed to re-introduce the outer relation inside the decorrelated + subquery the fallback builds. + """ + parsed = parse_one(ref.chrom, dialect=GIQLDialect) + alias = parsed.table if isinstance(parsed, exp.Column) else "" + relation = ref.table.name if ref.table is not None else alias + return relation, alias + + +def _fallback_form(expression, ctx, table_name, target_ref, ref): + """The decorrelated window-function fallback for non-LATERAL targets. + + Rewrites the surrounding `` AS a CROSS JOIN LATERAL (nearest) AS b`` + into `` AS a JOIN () AS b ON AND + b. <= k``. The ranked subquery cross-joins the target against the outer + relation and ranks candidates per distinct reference key with + ``ROW_NUMBER()``; the join re-associates the top-k back to every outer row + sharing that key, reproducing the per-row LATERAL semantics. Swaps the parent + ``LATERAL`` for the decorrelated subquery in place and returns the (now + detached) NEAREST node, so the pass's own ``node.replace`` is a no-op. + + The no-op return relies on NEAREST having no nestable inner GIQL operator: a + detached node carrying a still-pending descendant would strand that + descendant's later ``node.replace``. NEAREST's only operands are a registered + target table and an interval reference, neither of which is an expandable + operator, so nothing pending hangs off the node this detaches. + """ + lateral = expression.parent + # Invariants the surrounding-AST rewrite depends on. The fallback only runs + # for a correlated NEAREST, whose pass-1 placement is always a CROSS JOIN + # LATERAL carrying an alias; a violation is an internal pipeline bug, not user + # error, so fail loudly with a clear message rather than dereferencing None. + assert isinstance(lateral, exp.Lateral), ( + "correlated NEAREST fallback expected its parent to be a LATERAL, got " + f"{type(lateral).__name__}" + ) + join = lateral.parent + assert isinstance(join, exp.Join), ( + "correlated NEAREST fallback expected the LATERAL to sit under a JOIN, got " + f"{type(join).__name__}" + ) + assert lateral.args.get("alias") is not None and lateral.args["alias"].name, ( + "correlated NEAREST fallback expected the LATERAL to carry a table alias" + ) + alias = lateral.args["alias"].name + + relation, outer_alias = _outer_relation(ref) + k_value, _max, is_stranded, _signed = _nearest_params(expression) + # Bare target column names: the candidate subquery exposes the target row via + # ``target.*``, so its tiebreaker columns are referenced by name, not by the + # ``table_name."col"`` qualifier the distance math uses. + _target_chrom, target_start_col, target_end_col = target_ref.cols + + # Pre-project the outer relation's reference columns under fresh, non-target + # names into a renamed derived relation. Cross-joining *this* (rather than the + # raw outer table) keeps every reference column distinct from the target's + # columns: DataFusion's planner cannot resolve a window ordering over a join + # whose two sides share column names (e.g. both expose ``start`` / ``end``). + # + # The reference key identifies one distinct reference interval, which the + # ranking partitions by and the join re-associates on. Position + # (chrom/start/end) alone identifies it in the unstranded case; in stranded + # mode strand joins the key too, because two outer rows at the same position + # but opposite strands must each get their own strand-filtered nearest. The + # ref relation is de-duplicated on the key with DISTINCT so ranking happens + # once per distinct reference and the join fans the top-k back out to every + # outer row sharing it — exactly the per-row LATERAL semantics, even when the + # outer table holds duplicate reference rows. + # Mint the synthetic relation aliases from the run's collision-safe sequence + # (rather than hardcoding ``__giql_x_ref`` / ``__giql_x_cand``) so two NEAREST + # operators in one query never reuse the same derived-relation name. The + # reserved *column* names below stay derived from EXPAND_ALIAS_PREFIX. + ref_relation_alias = ctx.alias() + candidate = ctx.alias() + strand_name = f"{_REF_KEY_PREFIX}strand" + stranded_key = is_stranded and ref.strand is not None + + key_names = [f"{_REF_KEY_PREFIX}chrom", f"{_REF_KEY_PREFIX}start", + f"{_REF_KEY_PREFIX}end"] + source_frags = [ref.chrom, ref.start, ref.end] + if stranded_key: + key_names.append(strand_name) + source_frags.append(ref.strand) + + ref_projection = ", ".join( + f'{frag} AS "{name}"' for name, frag in zip(key_names, source_frags) + ) + ref_relation = ( + f"(SELECT DISTINCT {ref_projection} FROM {relation} AS {outer_alias})" + f" AS {ref_relation_alias}" + ) + + # Reference fragments now point at the renamed relation's safe columns. + renamed = [f'{ref_relation_alias}."{name}"' for name in key_names] + renamed_strand = ( + f'{ref_relation_alias}."{strand_name}"' if stranded_key else None + ) + ref_fragments = (renamed[0], renamed[1], renamed[2], renamed_strand) + + ( + distance_expr, + _abs_distance_expr, + where_clauses, + passthrough, + ) = _distance_and_filters( + expression, table_name, target_ref, ref, ref_fragments=ref_fragments + ) + + # Surface the reference-key columns so the rewritten join can match each + # ranked candidate back to its outer row(s). Ranking depends only on these + # values, so partitioning by them and re-joining is identical to the per-row + # LATERAL form even when outer rows share a reference value. + key_cols = list(zip(key_names, renamed)) + key_projection = ", ".join(f'{frag} AS "{name}"' for name, frag in key_cols) + where_sql = " AND ".join(where_clauses) + + # Compute the candidate set (cross join + distance + reference keys) in an + # inner subquery, then add ROW_NUMBER() in the enclosing one. Keeping the + # join and the window in *separate* query levels is load-bearing on + # DataFusion: fused into one level its optimizer mis-derives the window's sort + # order from the chromosome-equality prefilter and trips ``SanityCheckPlan``. + inner = ( + f"SELECT {passthrough}, {distance_expr} AS distance, {key_projection} " + f"FROM {table_name} CROSS JOIN {ref_relation} " + f"WHERE {where_sql}" + ) + partition = ", ".join(f'{candidate}."{name}"' for name, _ in key_cols) + # A deterministic ``(start, end)`` tiebreaker follows ``ABS(distance)`` so rows + # tied at the k-th distance rank identically here and in the LATERAL form, + # making the two provably set-equivalent (no engine-dependent tie ordering). + ranked = ( + f"(SELECT {candidate}.*, " + f"ROW_NUMBER() OVER (PARTITION BY {partition} " + f"ORDER BY ABS({candidate}.distance), " + f'{candidate}."{target_start_col}", {candidate}."{target_end_col}") ' + f'AS "{_RANK_COL}" ' + f"FROM ({inner}) AS {candidate})" + ) + ranked_subquery = parse_one(ranked, dialect=GIQLDialect) + + # Match each ranked candidate back to the *outer* relation by its reference + # value (the original outer-qualified fragments, e.g. ``a."chrom"``), not the + # renamed inner columns which exist only inside the subquery. + on_parts = [ + f'{alias}."{name}" = {src}' for name, src in zip(key_names, source_frags) + ] + on_parts.append(f'{alias}."{_RANK_COL}" <= {k_value}') + on_sql = " AND ".join(on_parts) + on_expr = parse_one(on_sql, dialect=GIQLDialect) + + subquery = exp.Subquery( + this=ranked_subquery.this, alias=lateral.args["alias"].copy() + ) + + # Convert `` CROSS JOIN LATERAL (nearest) AS b`` into + # `` JOIN (ranked) AS b ON AND b.rn <= k``. Swap the + # whole LATERAL out for the decorrelated subquery and drop the CROSS kind so + # the ON clause attaches as a plain (inner) join. + lateral.replace(subquery) + join.set("kind", None) + join.set("side", None) + join.set("on", on_expr) + + # The LATERAL (and the NEAREST node within it) is now detached; returning the + # node unchanged makes the pass's ``node.replace`` a no-op. + return expression + + +def expand_nearest(node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: + """Expand a NEAREST node to LATERAL or the decorrelated window-function form. + + Selects on ``ctx.capabilities.supports_lateral`` and whether the node is + correlated (its parent is a ``LATERAL``). Lateral-capable targets and every + standalone (literal-reference) placement get the portable LATERAL/standalone + subquery; a correlated NEAREST on a target without LATERAL support gets the + decorrelated window-function fallback. + """ + assert isinstance(node, GIQLNearest) + resolution = ctx.resolution + + target_ref = resolution.slot("this") if resolution is not None else None + if not isinstance(target_ref, ResolvedRef): + # An unresolved target means it is not a registered table; raise the + # historical diagnostic (verbatim from the removed giqlnearest_sql). + target = node.this + if isinstance(target, exp.Table): + target_name = target.name + elif isinstance(target, exp.Column): + target_name = target.table if target.table else str(target.this) + else: + target_name = str(target) + raise ValueError( + f"Target table '{target_name}' not found in tables. " + "Register the table before transpiling." + ) + table_name = target_ref.name + + ref = resolution.slot("reference") + if not isinstance(ref, ResolvedInterval): + mode = BaseGIQLGenerator._detect_nearest_mode(node) + BaseGIQLGenerator._raise_nearest_reference_error(node, mode, resolution) + + # A literal-range reference is uncorrelated even under CROSS JOIN LATERAL: its + # endpoints are constants, not outer-row columns, so the subquery stands alone + # and every target — DataFusion included — takes the LATERAL/standalone form. + # Only a genuinely correlated reference (a column / implicit-outer endpoint) + # needs the decorrelated window-function fallback on a lateral-incapable + # target. Gating on parentage alone would mis-route a literal range into + # ``_fallback_form``, which dereferences a non-existent outer relation. + correlated = isinstance(node.parent, exp.Lateral) and ref.kind != "literal_range" + if correlated and not ctx.capabilities.supports_lateral: + return _fallback_form(node, ctx, table_name, target_ref, ref) + return _lateral_form(node, ctx, table_name, target_ref, ref) + + +# The generic registration covers every target through the registry's fallback +# chain; the expander branches on ctx.capabilities.supports_lateral internally, +# so no per-target override is needed. +register(GenericTarget, GIQLNearest)(expand_nearest) diff --git a/src/giql/expressions.py b/src/giql/expressions.py index ce939dc..fe8d397 100644 --- a/src/giql/expressions.py +++ b/src/giql/expressions.py @@ -369,7 +369,11 @@ class GIQLNearest(exp.Func): #: half-open) operands are left untouched and the emitted SQL stays #: byte-identical. GIQL_CANONICALIZE = _CANONICALIZE - GIQL_EXPAND = _EXPAND + #: Migrated to the ExpandOperators pass (epic #137, issue #142): NEAREST is + #: expanded by ``giql.expanders.nearest`` — the portable correlated LATERAL + #: subquery where ``supports_lateral`` holds, a decorrelated window-function + #: form otherwise. The legacy ``giqlnearest_sql`` emitter has been removed. + GIQL_EXPAND = True GIQL_SLOTS = ( SlotSpec("this", frozenset({"registered_table"}), required=True), diff --git a/src/giql/generators/base.py b/src/giql/generators/base.py index 9038369..0138137 100644 --- a/src/giql/generators/base.py +++ b/src/giql/generators/base.py @@ -3,6 +3,7 @@ from giql.canonical import decanonical_end from giql.canonical import decanonical_start +from giql.dialect import GIQLDialect from giql.expressions import Contains from giql.expressions import GIQLDisjoin from giql.expressions import GIQLDistance @@ -15,7 +16,6 @@ from giql.resolver import META_KEY from giql.resolver import OperatorResolution from giql.resolver import ResolvedColumn -from giql.resolver import ResolvedInterval from giql.resolver import ResolvedRef from giql.table import Table from giql.table import Tables @@ -34,10 +34,6 @@ class BaseGIQLGenerator(Generator): compatibility with virtually all SQL databases. """ - # Most databases support LATERAL joins (PostgreSQL 9.3+, DuckDB 0.7.0+) - # SQLite does not support LATERAL, so it overrides this to False - SUPPORTS_LATERAL = True - def __init__(self, tables: Tables | None = None, **kwargs): super().__init__(**kwargs) self.tables = tables or Tables() @@ -82,172 +78,9 @@ def spatialsetpredicate_sql(self, expression: SpatialSetPredicate) -> str: """ return self._generate_spatial_set(expression) - def giqlnearest_sql(self, expression: GIQLNearest) -> str: - """Generate SQL for NEAREST function. - - Detects mode (standalone vs correlated) and generates appropriate SQL: - - Standalone: Direct query with ORDER BY + LIMIT - - Correlated (LATERAL): Subquery for k-nearest neighbors - - :param expression: - GIQLNearest expression node - :return: - SQL string for NEAREST operation - """ - # Detect mode - mode = self._detect_nearest_mode(expression) - - # Unpack the resolution metadata attached by ResolveOperatorRefs (pass 1). - resolution = self._nearest_resolution(expression) - - # Target (already a registered-table ResolvedRef from the pass). An - # unresolved target means it is not a registered table; raise the - # historical diagnostic. - target_ref = resolution.slot("this") if resolution is not None else None - if not isinstance(target_ref, ResolvedRef): - target = expression.this - if isinstance(target, exp.Table): - target_name = target.name - elif isinstance(target, exp.Column): - target_name = target.table if target.table else str(target.this) - else: - target_name = str(target) - raise ValueError( - f"Target table '{target_name}' not found in tables. " - "Register the table before transpiling." - ) - table_name = target_ref.name - target_chrom, target_start, target_end = target_ref.cols - - # The target's *declared* encoding, which the passed-through target row - # (SELECT {table_name}.*) must round-trip back into. CanonicalizeCoordinates - # (pass 2) preserves it on the resolution when it wraps a non-canonical - # target in a __giql_canon_* CTE (the slot's own Table is then None); a - # canonical target is left unwrapped and its slot Table carries the - # (identity) encoding. The synthesized `distance` column is encoding- - # invariant (a count of bases) and needs no round-trip. - output_table = self._nearest_output_encoding(expression, target_ref) - passthrough = self._nearest_passthrough( - table_name, target_start, target_end, output_table - ) - - # Reference interval (a ResolvedInterval from the pass). An unresolved - # reference re-raises the generator's historical diagnostic. Input - # canonicalization is owned by CanonicalizeCoordinates (pass 2, issue - # #123): a literal range is already canonical, and a column / implicit- - # outer reference's endpoints are canonicalized in place by the pass, so - # the emitter consumes the fragments verbatim with no canonicalization. - ref = resolution.slot("reference") - if not isinstance(ref, ResolvedInterval): - self._raise_nearest_reference_error(expression, mode, resolution) - ref_chrom, ref_start, ref_end = ref.chrom, ref.start, ref.end - - # Extract parameters - k = expression.args.get("k") - k_value = int(str(k)) if k else 1 # Default k=1 - - max_distance = expression.args.get("max_distance") - max_dist_value = int(str(max_distance)) if max_distance else None - - is_stranded = self._extract_bool_param(expression.args.get("stranded")) - is_signed = self._extract_bool_param(expression.args.get("signed")) - - # Resolve strand columns if stranded mode. The reference strand is - # carried on the resolved interval (a literal's strand, an explicit - # column's strand, or the outer table's strand for an implicit - # reference — already gated to preserve the historical divergence). - ref_strand = None - target_strand = None - if is_stranded: - ref_strand = ref.strand - # When pass 2 wraps a non-canonical target its slot Table is blanked, - # so the strand column name comes from the *declared* encoding the - # pass preserved (output_table). The canon CTE's SELECT * REPLACE - # passes the strand column through unchanged under its physical name, - # so the qualifier stays the relation NEAREST selects from. - if output_table and output_table.strand_col: - target_strand = f'{table_name}."{output_table.strand_col}"' - - # Distance math below assumes 0-based half-open. Input canonicalization - # is owned by CanonicalizeCoordinates (pass 2, issue #123): a - # non-canonical target is rewritten to a canonical __giql_canon_* CTE - # before generation (table_name then names the CTE), so the target - # endpoints are consumed verbatim with no in-emitter canonicalization. The - # output round-trip of the passed-through target row stays here (see the - # SELECT projection below). - target_start_expr = f'{table_name}."{target_start}"' - target_end_expr = f'{table_name}."{target_end}"' - - # Build distance calculation using CASE expression - # For NEAREST: ORDER BY absolute distance, but RETURN signed distance - distance_expr = self._generate_distance_case( - ref_chrom, - ref_start, - ref_end, - ref_strand, - f'{table_name}."{target_chrom}"', - target_start_expr, - target_end_expr, - target_strand, - stranded=is_stranded, - signed=is_signed, - ) - - # Use absolute distance for ordering and filtering - abs_distance_expr = f"ABS({distance_expr})" - - # Build WHERE clauses - where_clauses = [ - f'{ref_chrom} = {table_name}."{target_chrom}"' # Chromosome pre-filter - ] - - # Add strand matching for stranded mode - if is_stranded and ref_strand and target_strand: - where_clauses.append(f"{ref_strand} = {target_strand}") - - if max_dist_value is not None: - where_clauses.append(f"({abs_distance_expr}) <= {max_dist_value}") - - where_sql = " AND ".join(where_clauses) - - # Generate SQL based on mode - if mode == "standalone": - # Standalone mode: direct ORDER BY + LIMIT - # Return signed distance, but order by absolute distance - sql = f"""( - SELECT {passthrough}, {distance_expr} AS distance - FROM {table_name} - WHERE {where_sql} - ORDER BY {abs_distance_expr} - LIMIT {k_value} - )""" - else: - # Correlated mode: requires LATERAL join support - if not self.SUPPORTS_LATERAL: - raise ValueError( - "NEAREST in correlated mode (CROSS JOIN LATERAL) is not supported " - "in SQLite. SQLite does not support LATERAL joins. " - "\n\nAlternatives:" - "\n1. Use standalone mode: SELECT * FROM NEAREST(table, " - "reference='chr1:100-200', k=3)" - "\n2. Use DuckDB for queries requiring LATERAL joins" - "\n3. Manually write equivalent window function query" - ) - - # LATERAL mode: subquery for k-nearest neighbors - # Return signed distance, but order by absolute distance - sql = f"""( - SELECT {passthrough}, {distance_expr} AS distance - FROM {table_name} - WHERE {where_sql} - ORDER BY {abs_distance_expr} - LIMIT {k_value} - )""" - - return sql.strip() - + @staticmethod def _nearest_output_encoding( - self, expression: GIQLNearest, target_ref: ResolvedRef + expression: GIQLNearest, target_ref: ResolvedRef ) -> Table | None: """Return the target's declared encoding for NEAREST's row passthrough. @@ -272,8 +105,8 @@ def _nearest_output_encoding( return preserved return target_ref.table + @staticmethod def _nearest_passthrough( - self, table_name: str, target_start: str, target_end: str, @@ -572,8 +405,8 @@ def _distance_operand( raise ValueError(f"Literal range as {position} argument not yet supported") + @staticmethod def _generate_distance_case( - self, chrom_a: str, start_a: str, end_a: str, @@ -881,8 +714,9 @@ def _generate_spatial_set(self, expression: SpatialSetPredicate) -> str: combinator = " OR " if quantifier.upper() == "ANY" else " AND " return "(" + combinator.join(conditions) + ")" + @staticmethod def _detect_nearest_mode( - self, expression: GIQLNearest, parent_expression: exp.Expression | None = None + expression: GIQLNearest, parent_expression: exp.Expression | None = None ) -> str: """Detect whether NEAREST is in standalone or correlated mode. @@ -906,25 +740,8 @@ def _detect_nearest_mode( # (validation will catch missing reference errors later) return "correlated" - def _nearest_resolution(self, expression: GIQLNearest) -> OperatorResolution | None: - """Return the NEAREST resolution attached by ResolveOperatorRefs (pass 1). - - The transpile pipeline attaches an - :class:`~giql.resolver.OperatorResolution` before generation, and it - survives the generator's defensive tree copy. The emitter reads only the - attached metadata; resolution lives entirely in the pass. - - :param expression: - GIQLNearest expression node - :return: - The attached :class:`~giql.resolver.OperatorResolution`, or ``None`` - if resolution did not produce one. - """ - resolution = expression.meta.get(META_KEY) - return resolution if isinstance(resolution, OperatorResolution) else None - + @staticmethod def _raise_nearest_reference_error( - self, expression: GIQLNearest, mode: str, resolution: OperatorResolution | None, @@ -972,7 +789,7 @@ def _raise_nearest_reference_error( # An explicit reference that deferred is a literal range that failed to # parse (column references always resolve). Re-parse to surface the # original parse error in the historical message. - reference_sql = self.sql(reference) + reference_sql = reference.sql(dialect=GIQLDialect) range_str = reference_sql.strip("'\"") try: RangeParser.parse(range_str).to_zero_based_half_open() diff --git a/src/giql/targets.py b/src/giql/targets.py index 6fd29d3..825a6d2 100644 --- a/src/giql/targets.py +++ b/src/giql/targets.py @@ -29,11 +29,13 @@ class Capabilities: Parameters ---------- supports_lateral : bool - Whether the engine supports ``LATERAL`` / correlated joins. Will - drive the NEAREST LATERAL-vs-window-function strategy (#142). Until - then, :attr:`giql.generators.base.BaseGIQLGenerator.SUPPORTS_LATERAL` - remains the live source of truth at generation time; #142 reconciles - the two. + Whether the engine supports ``LATERAL`` / correlated joins. Drives the + NEAREST LATERAL-vs-window-function strategy (#142): a correlated NEAREST + expands to a portable correlated ``LATERAL`` subquery where this holds + and to a decorrelated window-function form where it does not. This + capability is the single source of truth — the former + ``BaseGIQLGenerator.SUPPORTS_LATERAL`` generator attribute has been + removed. supports_star_replace : bool Whether the engine supports ``SELECT * REPLACE (...)``. Drives the coordinate-canonicalization output: ``* REPLACE`` where supported, diff --git a/src/giql/transpile.py b/src/giql/transpile.py index 9ef2100..4817a2d 100644 --- a/src/giql/transpile.py +++ b/src/giql/transpile.py @@ -11,6 +11,7 @@ from sqlglot import parse_one +import giql.expanders # noqa: F401 (side-effect: registers built-in expanders) from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect from giql.expander import ExpandOperators @@ -197,16 +198,17 @@ def transpile( # Pass 2 of the normalization pipeline (epic #114): synthesize canonical # __giql_canon_* wrapper CTEs for non-canonical interval operands of - # opted-in operators (GIQL_CANONICALIZE). No operator opts in yet, so this - # is a strict no-op until the per-operator port issues (#122, #123) land. + # operators that opt into GIQL_CANONICALIZE; those operators' non-canonical + # operands are rewritten here, and identity (0-based half-open) operands are + # left untouched. with _reraise_as_value_error("Canonicalization error"): ast = canonicalize_coordinates(ast) # Pass 3 of the normalization pipeline (epic #137): replace each opted-in # GIQL operator node with the AST its registered expander produces for the - # active target. No operator sets GIQL_EXPAND and the registry is empty, so - # this is a strict no-op until the per-operator migration issues land; the - # legacy *_sql emitters on the generator remain the fallback. + # active target. An operator is rewritten here only when it both sets + # GIQL_EXPAND and resolves to a registered expander; otherwise its node is + # left untouched and the legacy *_sql emitter on the generator handles it. expand_operators = ExpandOperators(target, tables_container) with _reraise_as_value_error("Expansion error"): ast = expand_operators.transform(ast) diff --git a/tests/generators/test_base.py b/tests/generators/test_base.py index f95e91b..d57c0b2 100644 --- a/tests/generators/test_base.py +++ b/tests/generators/test_base.py @@ -8,34 +8,37 @@ from hypothesis import given from hypothesis import settings from hypothesis import strategies as st -from sqlglot import exp from sqlglot import parse_one +import giql.expanders # noqa: F401 (side-effect: registers the NEAREST expander) from giql import Table from giql import transpile from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect -from giql.expressions import GIQLNearest +from giql.expander import ExpandOperators from giql.generators import BaseGIQLGenerator from giql.resolver import resolve_operator_refs from giql.table import Tables +from giql.targets import GenericTarget def _generate_through_passes(sql: str, tables: Tables) -> str: - """Parse, run normalization passes 1 and 2, then generate SQL. - - Coordinate canonicalization for operator operands moved out of the emitter and - into the CanonicalizeCoordinates pass (issue #123). Emitter-level tests that - pin canonicalized output must therefore run both passes before generating, - exactly as :func:`giql.transpile.transpile` does, rather than calling - ``generate`` on a bare parsed AST (which would skip canonicalization). This - helper is used where the full ``transpile`` pipeline would otherwise rewrite - the node away (a column-to-column ``INTERSECTS`` is turned into a binned - equi-join before the predicate emitter runs). + """Parse, run normalization passes 1-3, then generate SQL. + + Coordinate canonicalization (issue #123) and operator expansion (epic #137) + moved out of the emitter into the CanonicalizeCoordinates / ExpandOperators + passes. Emitter-level tests that pin canonicalized output — or an operator now + produced by its registered expander, such as NEAREST (#142) — must run those + passes before generating, exactly as :func:`giql.transpile.transpile` does, + rather than calling ``generate`` on a bare parsed AST. This helper is used + where the full ``transpile`` pipeline would otherwise rewrite the node away (a + column-to-column ``INTERSECTS`` is turned into a binned equi-join before the + predicate emitter runs). """ ast = parse_one(sql, dialect=GIQLDialect) ast = resolve_operator_refs(ast, tables) ast = canonicalize_coordinates(ast) + ast = ExpandOperators(GenericTarget(), tables).transform(ast) return BaseGIQLGenerator(tables=tables).generate(ast) @@ -103,13 +106,12 @@ def test_instantiation_defaults(self): """ GIVEN no tables provided WHEN Generator is instantiated with defaults - THEN Generator has empty Tables and SUPPORTS_LATERAL is True. + THEN Generator has empty Tables. """ generator = BaseGIQLGenerator() assert generator.tables is not None assert "variants" not in generator.tables - assert generator.SUPPORTS_LATERAL is True def test_instantiation_with_tables(self, tables_info): """ @@ -395,41 +397,43 @@ def test_spatialsetpredicate_sql_all(self): ) assert output == expected - def test_giqlnearest_sql_standalone(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_emit_ordered_limit_subquery_when_standalone( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest in standalone mode with literal reference - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN Subquery with ORDER BY distance LIMIT k is generated. """ sql = "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 3)" output = _generate_through_passes(sql, tables_with_peaks_and_genes) + # NEAREST now expands via its registered expander (#142): a two-level + # subquery — an inner SELECT that materializes ``distance`` and an outer + # wrapper that orders on the materialized ``__giql_x_0."distance"`` plus a + # deterministic ``(start, end)`` tiebreaker (#142 A5). Splitting the + # distance computation from the ordering keeps DuckDB's correlated-LATERAL + # binder and DataFusion's planner both happy while staying result- + # equivalent to the legacy single-level emitter. expected = ( - "SELECT * FROM (\n" - " SELECT genes.*, " - "CASE WHEN 'chr1' != genes.\"chrom\" THEN NULL " + "SELECT * FROM (SELECT * FROM (SELECT genes.*, " + "CASE WHEN 'chr1' <> genes.\"chrom\" THEN NULL " 'WHEN 1000 < genes."end" AND 2000 > genes."start" THEN 0 ' - 'WHEN 2000 <= genes."start" ' - 'THEN (genes."start" - 2000 + 1) ' - 'ELSE (1000 - genes."end" + 1) END AS distance\n' - " FROM genes\n" - " WHERE 'chr1' = genes.\"chrom\"\n" - " ORDER BY ABS(" - "CASE WHEN 'chr1' != genes.\"chrom\" THEN NULL " - 'WHEN 1000 < genes."end" AND 2000 > genes."start" THEN 0 ' - 'WHEN 2000 <= genes."start" ' - 'THEN (genes."start" - 2000 + 1) ' - 'ELSE (1000 - genes."end" + 1) END)\n' - " LIMIT 3\n" - " )" + 'WHEN 2000 <= genes."start" THEN (genes."start" - 2000 + 1) ' + 'ELSE (1000 - genes."end" + 1) END AS distance ' + "FROM genes WHERE 'chr1' = genes.\"chrom\") AS __giql_x_0 " + 'ORDER BY ABS(__giql_x_0."distance"), ' + '__giql_x_0."start", __giql_x_0."end" LIMIT 3)' ) assert output == expected - def test_giqlnearest_sql_correlated(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_emit_lateral_subquery_when_correlated( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest in correlated mode (LATERAL join context) - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs on a lateral-capable target THEN LATERAL-compatible subquery is generated. """ sql = ( @@ -439,33 +443,30 @@ def test_giqlnearest_sql_correlated(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) + # Reserialized by the #142 expander as a two-level subquery: the inner + # SELECT materializes ``distance`` (CASE and WHERE semantically unchanged) + # and the outer wrapper orders on the materialized + # ``__giql_x_0."distance"`` plus a deterministic ``(start, end)`` + # tiebreaker (#142 A5). expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (\n" - " SELECT genes.*, " - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' - 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END AS distance\n' - " FROM genes\n" - ' WHERE peaks."chrom" = genes."chrom"\n' - " ORDER BY ABS(" - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT * FROM (SELECT genes.*, " + 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END)\n' - " LIMIT 3\n" - " )" + 'ELSE (peaks."start" - genes."end" + 1) END AS distance ' + 'FROM genes WHERE peaks."chrom" = genes."chrom") AS __giql_x_0 ' + 'ORDER BY ABS(__giql_x_0."distance"), ' + '__giql_x_0."start", __giql_x_0."end" LIMIT 3)' ) assert output == expected - def test_giqlnearest_sql_with_max_distance(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_filter_on_max_distance( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest with max_distance parameter - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN WHERE clause includes distance filter. """ sql = ( @@ -476,40 +477,35 @@ def test_giqlnearest_sql_with_max_distance(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) + # Reserialized by the #142 expander as a two-level subquery; the + # max_distance filter on ABS of the distance CASE stays in the inner + # SELECT's WHERE and the outer wrapper orders on the materialized + # ``__giql_x_0."distance"`` plus the deterministic ``(start, end)`` + # tiebreaker (#142 A5). expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (\n" - " SELECT genes.*, " - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT * FROM (SELECT genes.*, " + 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END AS distance\n' - " FROM genes\n" - ' WHERE peaks."chrom" = genes."chrom" ' - "AND (ABS(" - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + 'ELSE (peaks."start" - genes."end" + 1) END AS distance ' + 'FROM genes WHERE peaks."chrom" = genes."chrom" ' + 'AND (ABS(CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END)) <= 100000\n' - " ORDER BY ABS(" - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' - 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END)\n' - " LIMIT 5\n" - " )" + 'ELSE (peaks."start" - genes."end" + 1) END)) <= 100000) ' + 'AS __giql_x_0 ORDER BY ABS(__giql_x_0."distance"), ' + '__giql_x_0."start", __giql_x_0."end" LIMIT 5)' ) assert output == expected - def test_giqlnearest_sql_stranded(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_match_strand_when_stranded( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest with stranded := true - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN Strand matching is included in WHERE clause. """ sql = ( @@ -520,45 +516,33 @@ def test_giqlnearest_sql_stranded(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) + # Reserialized by the #142 expander as a two-level subquery; the stranded + # distance CASE and the ``peaks.strand = genes.strand`` match in the inner + # WHERE are semantically unchanged, with the outer wrapper ordering on the + # materialized ``__giql_x_0."distance"`` plus the ``(start, end)`` + # tiebreaker (#142 A5). expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (\n" - " SELECT genes.*, " - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."strand" IS NULL OR genes."strand" IS NULL THEN NULL ' - "WHEN peaks.\"strand\" = '.' OR peaks.\"strand\" = '?' THEN NULL " - "WHEN genes.\"strand\" = '.' OR genes.\"strand\" = '?' THEN NULL " - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' - "THEN CASE WHEN peaks.\"strand\" = '-' " - 'THEN -(genes."start" - peaks."end" + 1) ' - 'ELSE (genes."start" - peaks."end" + 1) END ' - "ELSE CASE WHEN peaks.\"strand\" = '-' " - 'THEN -(peaks."start" - genes."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END END AS distance\n' - " FROM genes\n" - ' WHERE peaks."chrom" = genes."chrom" ' - 'AND peaks."strand" = genes."strand"\n' - " ORDER BY ABS(" - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT * FROM (SELECT genes.*, " + 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' 'WHEN peaks."strand" IS NULL OR genes."strand" IS NULL THEN NULL ' "WHEN peaks.\"strand\" = '.' OR peaks.\"strand\" = '?' THEN NULL " "WHEN genes.\"strand\" = '.' OR genes.\"strand\" = '?' THEN NULL " - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' "THEN CASE WHEN peaks.\"strand\" = '-' " 'THEN -(genes."start" - peaks."end" + 1) ' 'ELSE (genes."start" - peaks."end" + 1) END ' "ELSE CASE WHEN peaks.\"strand\" = '-' " 'THEN -(peaks."start" - genes."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END END)\n' - " LIMIT 3\n" - " )" + 'ELSE (peaks."start" - genes."end" + 1) END END AS distance ' + 'FROM genes WHERE peaks."chrom" = genes."chrom" ' + 'AND peaks."strand" = genes."strand") AS __giql_x_0 ' + 'ORDER BY ABS(__giql_x_0."distance"), ' + '__giql_x_0."start", __giql_x_0."end" LIMIT 3)' ) assert output == expected - def test_giqlnearest_sql_implicit_outer_without_strand_column(self): + def test_expand_nearest_should_skip_strand_when_outer_has_no_strand_column(self): """ GIVEN a stranded NEAREST whose implicit-outer table declares no strand column @@ -582,10 +566,12 @@ def test_giqlnearest_sql_implicit_outer_without_strand_column(self): assert "strand" not in output assert 'nostr."chrom" = genes."chrom"' in output - def test_giqlnearest_sql_signed(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_emit_signed_distance_when_signed( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest with signed := true - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN Distance expression includes signed calculation. """ sql = ( @@ -596,57 +582,37 @@ def test_giqlnearest_sql_signed(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) + # Reserialized by the #142 expander as a two-level subquery; the signed + # distance CASE (negated ELSE branch for upstream) is semantically + # unchanged in the inner SELECT, with the outer wrapper ordering on the + # materialized ``__giql_x_0."distance"`` plus the ``(start, end)`` + # tiebreaker (#142 A5). expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (\n" - " SELECT genes.*, " - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' - 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE -(peaks."start" - genes."end" + 1) END AS distance\n' - " FROM genes\n" - ' WHERE peaks."chrom" = genes."chrom"\n' - " ORDER BY ABS(" - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT * FROM (SELECT genes.*, " + 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE -(peaks."start" - genes."end" + 1) END)\n' - " LIMIT 3\n" - " )" + 'ELSE -(peaks."start" - genes."end" + 1) END AS distance ' + 'FROM genes WHERE peaks."chrom" = genes."chrom") AS __giql_x_0 ' + 'ORDER BY ABS(__giql_x_0."distance"), ' + '__giql_x_0."start", __giql_x_0."end" LIMIT 3)' ) assert output == expected - def test_giqlnearest_sql_no_lateral_support(self, tables_with_peaks_and_genes): - """ - GIVEN a GIQLNearest on a generator with SUPPORTS_LATERAL=False - WHEN giqlnearest_sql is called in correlated mode - THEN ValueError is raised with helpful message. - """ - - # Create a generator subclass without LATERAL support - class NoLateralGenerator(BaseGIQLGenerator): - SUPPORTS_LATERAL = False - - # Use query without explicit reference to trigger correlated mode - sql = "SELECT * FROM peaks CROSS JOIN LATERAL NEAREST(genes, k := 3)" - ast = parse_one(sql, dialect=GIQLDialect) - ast = resolve_operator_refs(ast, tables_with_peaks_and_genes) - ast = canonicalize_coordinates(ast) - - generator = NoLateralGenerator(tables=tables_with_peaks_and_genes) - - with pytest.raises(ValueError, match="LATERAL"): - generator.generate(ast) + # The legacy ``SUPPORTS_LATERAL=False`` generator-level error path was removed + # with ``giqlnearest_sql`` (#142): lateral support is now a target capability, + # and a target without it (DataFusion) gets the decorrelated window-function + # fallback rather than a hard error. That fallback's result-identity with the + # LATERAL form is verified by the cross-target oracle + # (tests/integration/datafusion/test_cross_target_oracle.py). @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) @given( k=st.integers(min_value=1, max_value=100), max_distance=st.integers(min_value=1, max_value=10_000_000), ) - def test_giqlnearest_sql_parameter_handling_property( + def test_expand_nearest_should_carry_k_and_max_distance_property( self, tables_with_peaks_and_genes, k, max_distance ): """ @@ -864,12 +830,12 @@ def test_select_sql_join_without_alias(self, tables_with_two_tables): ) assert output == expected - def test_giqlnearest_sql_stranded_literal_with_strand( + def test_expand_nearest_should_use_literal_strand_when_stranded( self, tables_with_peaks_and_genes ): """ GIVEN a GIQLNearest with stranded := true and literal reference containing strand - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN Strand from literal range is parsed and used in filtering. """ sql = ( @@ -883,12 +849,12 @@ def test_giqlnearest_sql_stranded_literal_with_strand( assert "'+'" in output assert 'genes."strand"' in output - def test_giqlnearest_sql_stranded_implicit_reference( + def test_expand_nearest_should_resolve_outer_strand_when_implicit_reference( self, tables_with_peaks_and_genes ): """ GIVEN a GIQLNearest in correlated mode with implicit reference and stranded := true - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN Strand column is resolved from outer table and used. """ sql = "SELECT * FROM peaks CROSS JOIN LATERAL NEAREST(genes, k := 3, stranded := true)" @@ -987,29 +953,22 @@ def test_giqldistance_sql_literal_second_arg_error(self, tables_with_two_tables) with pytest.raises(ValueError, match="Literal range as second argument"): generator.generate(ast) - def test_giqlnearest_sql_missing_outer_table_error( + def test_expand_nearest_should_raise_when_outer_table_unresolvable( self, tables_with_peaks_and_genes ): """ - GIVEN a GIQLNearest in correlated mode without reference where outer table - cannot be found - WHEN giqlnearest_sql is called - THEN ValueError is raised with helpful message about specifying reference. + GIVEN a GIQLNearest without a reference and no resolvable outer table + WHEN the NEAREST expander runs + THEN ValueError is raised with a helpful message about specifying reference. """ + # Arrange — no reference and no LATERAL outer relation to infer one from. + sql = "SELECT * FROM NEAREST(genes, k := 3)" - nearest = GIQLNearest( - this=exp.Table(this=exp.Identifier(this="genes")), - k=exp.Literal.number(3), - ) - resolve_operator_refs(nearest, tables_with_peaks_and_genes) - canonicalize_coordinates(nearest) - - generator = BaseGIQLGenerator(tables=tables_with_peaks_and_genes) - + # Act & assert with pytest.raises(ValueError, match="Could not find outer table"): - generator.giqlnearest_sql(nearest) + _generate_through_passes(sql, tables_with_peaks_and_genes) - def test_giqlnearest_sql_outer_table_not_in_tables(self): + def test_expand_nearest_should_raise_when_outer_table_unregistered(self): """ GIVEN a NEAREST whose implicit-outer relation is found but not registered WHEN the query is generated @@ -1025,10 +984,12 @@ def test_giqlnearest_sql_outer_table_not_in_tables(self): with pytest.raises(ValueError, match="not found in tables"): _generate_through_passes(sql, tables) - def test_giqlnearest_sql_invalid_reference_range(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_raise_when_reference_range_unparseable( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest with invalid/unparseable reference range string - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN ValueError is raised with parse error details. """ sql = "SELECT * FROM NEAREST(genes, reference := 'invalid_range', k := 3)" @@ -1036,36 +997,35 @@ def test_giqlnearest_sql_invalid_reference_range(self, tables_with_peaks_and_gen with pytest.raises(ValueError, match="Could not parse reference genomic range"): _generate_through_passes(sql, tables_with_peaks_and_genes) - def test_giqlnearest_sql_no_tables_error(self): + def test_expand_nearest_should_raise_when_no_tables_registered(self): """ - GIVEN a GIQLNearest without tables registered - WHEN giqlnearest_sql is called - THEN ValueError is raised because target table cannot be resolved. + GIVEN a GIQLNearest with no tables registered + WHEN the NEAREST expander runs + THEN ValueError is raised because the target table cannot be resolved. """ + # Arrange sql = "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 3)" - ast = parse_one(sql, dialect=GIQLDialect) - - # Generator with empty tables - table won't be found - generator = BaseGIQLGenerator() + # Act & assert with pytest.raises(ValueError, match="not found in tables"): - generator.generate(ast) + _generate_through_passes(sql, Tables()) - def test_giqlnearest_sql_target_not_in_tables(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_raise_when_target_unregistered( + self, tables_with_peaks_and_genes + ): """ - GIVEN a GIQLNearest with target table not registered - WHEN giqlnearest_sql is called - THEN ValueError is raised listing available tables. + GIVEN a GIQLNearest whose target table is not registered + WHEN the NEAREST expander runs + THEN ValueError is raised listing the unresolved table. """ + # Arrange sql = ( "SELECT * FROM NEAREST(unknown_table, reference := 'chr1:1000-2000', k := 3)" ) - ast = parse_one(sql, dialect=GIQLDialect) - - generator = BaseGIQLGenerator(tables=tables_with_peaks_and_genes) + # Act & assert with pytest.raises(ValueError, match="not found in tables"): - generator.generate(ast) + _generate_through_passes(sql, tables_with_peaks_and_genes) def test_intersects_sql_unqualified_column(self): """ @@ -1083,57 +1043,47 @@ def test_intersects_sql_unqualified_column(self): ) assert output == expected - def test_giqlnearest_sql_stranded_unqualified_reference( + def test_expand_nearest_should_resolve_strand_when_reference_unqualified( self, tables_with_peaks_and_genes ): """ - GIVEN a GIQLNearest with stranded := true and unqualified column reference - WHEN giqlnearest_sql is called + GIVEN a GIQLNearest with stranded := true and an unqualified column reference + WHEN the NEAREST expander runs THEN Strand column is resolved without table prefix. """ - - # Create NEAREST with stranded=True and an unqualified column reference - # The reference is an unqualified column (no table prefix) - nearest = GIQLNearest( - this=exp.Table(this=exp.Identifier(this="genes")), - reference=exp.Column(this=exp.Identifier(this="interval")), - k=exp.Literal.number(3), - stranded=exp.Boolean(this=True), + # Arrange — the reference is an unqualified column (no table prefix). + sql = ( + "SELECT * FROM peaks CROSS JOIN LATERAL " + "NEAREST(genes, reference := interval, k := 3, stranded := true)" ) - resolve_operator_refs(nearest, tables_with_peaks_and_genes) - canonicalize_coordinates(nearest) - generator = BaseGIQLGenerator(tables=tables_with_peaks_and_genes) - output = generator.giqlnearest_sql(nearest) + # Act + output = _generate_through_passes(sql, tables_with_peaks_and_genes) - # Should produce valid output with unqualified strand column + # Assert assert "LIMIT 3" in output - # The strand column should be unqualified (no table prefix) assert '"strand"' in output - def test_giqlnearest_sql_identifier_target(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_emit_ordered_subquery_for_literal_reference( + self, tables_with_peaks_and_genes + ): """ - GIVEN a GIQLNearest where target is an Identifier (not Table or Column) - WHEN giqlnearest_sql is called - THEN Target is converted to string and lookup proceeds. + GIVEN a GIQLNearest with a standalone literal reference + WHEN the NEAREST expander runs + THEN it produces a standalone ordered, limited subquery over the target + table with no correlated LATERAL. """ + # Arrange + sql = "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 3)" - # Use exp.Identifier directly - not Table or Column - # This triggers the else branch at line 830 where str(target) is called - nearest = GIQLNearest( - this=exp.Identifier(this="genes"), - reference=exp.Literal.string("chr1:1000-2000"), - k=exp.Literal.number(3), - ) - resolve_operator_refs(nearest, tables_with_peaks_and_genes) - canonicalize_coordinates(nearest) - - generator = BaseGIQLGenerator(tables=tables_with_peaks_and_genes) - output = generator.giqlnearest_sql(nearest) + # Act + output = _generate_through_passes(sql, tables_with_peaks_and_genes) - # Should succeed and produce valid SQL + # Assert assert "genes" in output + assert "ORDER BY" in output assert "LIMIT 3" in output + assert "LATERAL" not in output @given( bool_repr=st.sampled_from(["true", "TRUE", "True", "1", "yes", "YES"]), @@ -1806,7 +1756,7 @@ def test_giqlnearest_should_canonicalize_reference_column_when_reference_is_one_ A 0-based half-open target table (bed_a) and an explicit reference column from a 1-based closed table (vcf_b). When: - giqlnearest_sql is called. + the NEAREST expander runs. Then: It should wrap the reference-side start as (start - 1), leave its end raw, and leave the target side raw. @@ -1839,7 +1789,7 @@ def test_giqlnearest_should_canonicalize_outer_table_columns_when_reference_is_i target table (bed_a) joined via CROSS JOIN LATERAL with no ``reference`` argument on NEAREST. When: - giqlnearest_sql is called. + the NEAREST expander runs. Then: It should canonicalize the outer table's columns based on vcf_b's convention — wrapping start as (vcf_b."start" - 1) and diff --git a/tests/integration/datafusion/test_cross_target_oracle.py b/tests/integration/datafusion/test_cross_target_oracle.py index 6b717a1..02db63f 100644 --- a/tests/integration/datafusion/test_cross_target_oracle.py +++ b/tests/integration/datafusion/test_cross_target_oracle.py @@ -14,13 +14,14 @@ genuinely divergent SQL across targets (the DuckDB IEJoin vs. the binned equi-join). -NEAREST's expansion uses a correlated ``LATERAL`` subquery, which DataFusion has -no physical plan for today; its generic-vs-duckdb equivalence case runs both on -DuckDB, and the full three-target oracle is pinned by a -``pytest.raises(match="OuterReferenceColumn")`` test (#142) that fails loudly on -an unrelated error and trips "DID NOT RAISE" — forcing conversion to a real -identity test — when DataFusion gains correlated LATERAL. DISJOIN has an -analogous pending-#153 gap (duplicate ``end`` output names). +NEAREST's correlated expansion is capability-driven (#142): lateral-capable +targets (generic, duckdb) emit the portable ``LATERAL`` subquery, while +DataFusion — which has no correlated-LATERAL physical plan — gets a decorrelated +window-function fallback. Both forms return identical rows, so the full +three-target identity oracle now runs on every target (the former +``_unsupported_pending_142`` ``pytest.raises`` pin has been promoted to a real +identity test). DISJOIN has an analogous pending-#153 gap (duplicate ``end`` +output names). """ import pytest @@ -204,32 +205,207 @@ def test_standalone_nearest_k1_agrees_generic_vs_duckdb_on_duckdb( engines={"generic": "duckdb"}, ) - def test_nearest_on_datafusion_unsupported_pending_142(self, cross_target_oracle): - """Test the full NEAREST oracle raises DataFusion's missing-LATERAL error. + def test_correlated_nearest_k1_agrees_across_all_targets(self, cross_target_oracle): + """Test correlated NEAREST k=1 returns identical rows on every target. Given: - The single-row NEAREST query and a candidate gene on chr1. + A single-row peaks table and three candidate genes at varying + distances on chr1. When: - The oracle runs all three targets — the datafusion target executes - the correlated LATERAL on DataFusion, which has no physical plan. + A correlated ``CROSS JOIN LATERAL NEAREST(..., k := 1)`` query runs + for every target — the generic and duckdb targets emit the portable + LATERAL form (executed on DuckDB, the lateral-capable engine), and + the datafusion target emits the decorrelated window-function fallback + the #142 expander produces (executed on DataFusion). Then: - DataFusion should raise its ``OuterReferenceColumn`` "not - implemented" error. This pins the known #142 gap: the ``match`` - narrows to the LATERAL signature so an unrelated/reworded DataFusion - error fails loudly, and a closed gap (no exception) trips pytest's - "DID NOT RAISE", forcing this to be converted into a real - cross-target identity test when DataFusion gains correlated LATERAL. + Every target should return the single nearest gene and agree. + + Promoted from the ``_unsupported_pending_142`` expected-failure pin: + DataFusion now plans correlated NEAREST through the capability-driven + window-function fallback, so the full three-target oracle is a real + identity test rather than a ``pytest.raises`` placeholder. The generic + target is routed to DuckDB because its portable SQL is the LATERAL form, + which only the datafusion-specific fallback decorrelates for DataFusion. """ # Arrange / Act / Assert - with pytest.raises(Exception, match="OuterReferenceColumn"): - cross_target_oracle( - "SELECT a.chrom, a.start AS a_start, b.start AS b_start " - "FROM peaks a " - "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1) b", - peaks=[("chr1", 200, 300)], - genes=[("chr1", 280, 290)], - expected=[("chr1", 200, 280)], - ) + cross_target_oracle( + "SELECT a.chrom, a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1) b", + peaks=[("chr1", 200, 300)], + genes=[ + ("chr1", 1000, 1100), + ("chr1", 50, 60), + ("chr1", 280, 290), + ], + expected=[("chr1", 200, 280)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_k2_returns_two_nearest_across_targets( + self, cross_target_oracle + ): + """Test correlated NEAREST k=2 picks the two nearest on every target. + + Given: + One peak and four candidate genes, more than k of them on the peak's + chromosome at distinct distances. + When: + A correlated ``NEAREST(..., k := 2)`` runs on every target — DuckDB + via the LATERAL form, DataFusion via the decorrelated window fallback. + Then: + Every target should return the two nearest genes and agree, pinning + the top-k fan-out of the fallback against the LATERAL form. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 2) b", + peaks=[("chr1", 200, 300)], + genes=[ + ("chr1", 1000, 1100), + ("chr1", 50, 60), + ("chr1", 280, 290), + ("chr1", 310, 320), + ], + expected=[(200, 280), (200, 310)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_duplicate_reference_rows_fan_out( + self, cross_target_oracle + ): + """Test correlated NEAREST fans the top-k out to duplicate reference rows. + + Given: + Two identical peak rows and two candidate genes. + When: + A correlated ``NEAREST(..., k := 1)`` runs on every target. + Then: + Every target should return the nearest gene once per duplicate peak + (two rows), pinning the fallback's DISTINCT-then-rejoin fan-out so a + duplicate outer row is not collapsed. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1) b", + peaks=[("chr1", 200, 300), ("chr1", 200, 300)], + genes=[("chr1", 280, 290), ("chr1", 50, 60)], + expected=[(200, 280), (200, 280)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_partitions_by_chromosome(self, cross_target_oracle): + """Test correlated NEAREST keys the nearest per outer chromosome. + + Given: + Peaks on chr1 and chr2 and candidate genes on both chromosomes. + When: + A correlated ``NEAREST(..., k := 1)`` runs on every target. + Then: + Each peak should pair with the nearest gene on its own chromosome and + all targets agree, pinning the fallback's PARTITION BY reference key + across distinct outer keys. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.chrom AS chrom, a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1) b", + peaks=[("chr1", 200, 300), ("chr2", 200, 300)], + genes=[("chr1", 280, 290), ("chr2", 500, 510), ("chr2", 205, 215)], + expected=[("chr1", 200, 280), ("chr2", 200, 205)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_max_distance_boundary(self, cross_target_oracle): + """Test correlated NEAREST drops candidates beyond max_distance everywhere. + + Given: + A peak and two genes, one just inside and one far beyond a + ``max_distance`` threshold. + When: + A correlated ``NEAREST(..., k := 5, max_distance := 100)`` runs on + every target. + Then: + Every target should return only the in-threshold gene, pinning the + ``max_distance`` filter through both the LATERAL and fallback forms. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(" + "genes, reference := a.interval, k := 5, max_distance := 100) b", + peaks=[("chr1", 200, 300)], + genes=[("chr1", 360, 400), ("chr1", 5000, 5100)], + expected=[(200, 360)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_stranded_matches_strand(self, cross_target_oracle): + """Test stranded correlated NEAREST matches strand on every target. + + Given: + A ``+`` peak and two genes — a slightly farther ``+`` gene and a + nearer ``-`` gene. + When: + A correlated ``NEAREST(..., k := 1, stranded := true)`` runs on every + target. + Then: + Every target should return the same-strand (``+``) gene even though + the opposite-strand gene is nearer, in agreement. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(" + "genes, reference := a.interval, k := 1, stranded := true) b", + tables=[Table("peaks"), Table("genes")], + columns=_STRANDED_COLUMNS, + peaks=[("chr1", 200, 300, "+")], + genes=[("chr1", 280, 290, "+"), ("chr1", 250, 260, "-")], + expected=[(200, 280)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_signed_distance_agrees(self, cross_target_oracle): + """Test signed correlated NEAREST reports signed distances everywhere. + + Given: + A peak with one upstream and one downstream candidate gene. + When: + A correlated ``NEAREST(..., k := 2, signed := true)`` projects the + ``distance`` column on every target. + Then: + Every target should report a negative distance for the upstream gene + and a positive one for the downstream gene, in agreement. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start, b.distance AS d " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(" + "genes, reference := a.interval, k := 2, signed := true) b", + peaks=[("chr1", 200, 300)], + genes=[("chr1", 50, 60), ("chr1", 360, 400)], + expected=[(200, 50, -141), (200, 360, 61)], + engines={"generic": "duckdb"}, + ) + + +#: A chrom/start/end/strand schema for the stranded NEAREST oracle cases (the +#: default oracle schema carries no strand column). +_STRANDED_COLUMNS = ( + ("chrom", "utf8"), + ("start", "int64"), + ("end", "int64"), + ("strand", "utf8"), +) class TestCrossTargetOracleIntersectsAnyAll: diff --git a/tests/test_expander.py b/tests/test_expander.py index ad625f1..4579359 100644 --- a/tests/test_expander.py +++ b/tests/test_expander.py @@ -53,53 +53,70 @@ def _expander(node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: return _expander +import giql.expanders # noqa: E402, F401 (side-effect: registers built-in expanders) + +#: The registry contents at import — the built-in expanders ``giql.expanders`` +#: registers for the already-migrated operators. The leak guards and +#: ``clean_registry`` treat this as the baseline rather than an empty registry, +#: so the real built-in registrations survive isolating fixtures and a leaking +#: test is still caught against the true baseline. +_REGISTRY_BASELINE = REGISTRY.snapshot() + + @pytest.fixture def clean_registry(): - """Isolate the process-wide REGISTRY, leaving it empty afterward. + """Isolate the process-wide REGISTRY, restoring its baseline afterward. - The registry is empty at import (the pass ships inert), so a test that opts - in clears on the way out; emptiness is asserted through the public - ``bool()``/``len()`` surface rather than private state. + Saves the import-time baseline (the built-in expanders), empties the registry + so a test sees only what it registers, and restores the baseline on the way + out through the public ``snapshot()``/``restore()`` seam — so a test that + registers a stand-in expander cannot leak it, and the built-in registrations + survive this fixture's isolation. """ - assert not REGISTRY, "REGISTRY was non-empty entering clean_registry" + saved = REGISTRY.snapshot() REGISTRY.clear() yield REGISTRY - REGISTRY.clear() + REGISTRY.restore(saved) @pytest.fixture(autouse=True) def _registry_leak_guard(): - """Assert the process-wide REGISTRY is empty at each test boundary. - - A leak guard: the registry is empty at import and must return to empty after - every test, so a test that registers without cleaning up (a leak that would - silently flip the no-op pass for a later test) fails loudly. Tests that - register on the process-wide REGISTRY do so through ``clean_registry``, which - clears on the way out; this guard catches anything that bypasses it. Both - checks go through the public ``bool()`` surface (A5), not private state. + """Assert the process-wide REGISTRY matches its baseline at each boundary. + + A leak guard: the registry holds the built-in expanders at import and must + return to exactly that baseline after every test, so a test that registers + without cleaning up (a leak that would silently change dispatch for a later + test) fails loudly. Tests that mutate the process-wide REGISTRY do so through + ``clean_registry``, which restores the baseline on the way out; this guard + catches anything that bypasses it. """ - assert not REGISTRY, "REGISTRY leaked into a test from a prior one" + assert REGISTRY.snapshot() == _REGISTRY_BASELINE, ( + "REGISTRY differed from its baseline entering a test" + ) yield - assert not REGISTRY, "a test leaked a registration into REGISTRY" + assert REGISTRY.snapshot() == _REGISTRY_BASELINE, ( + "a test leaked a registration into REGISTRY" + ) @pytest.fixture(autouse=True) def _expand_flag_leak_guard(): """Assert every operator's GIQL_EXPAND is restored at each test boundary. - The symmetric partner of the registry leak guard: each operator class ships - opted out (its own GIQL_EXPAND attribute is False), and a test that flips one - via ``_opted_in`` must restore it. A leaked opt-in would silently flip the - no-op pass for a later test, so this catches anything that bypasses the - exception-safe ``_opted_in`` manager. + The symmetric partner of the registry leak guard: each operator class ships a + GIQL_EXPAND default (``True`` for a migrated operator, ``False`` otherwise), + and a test that flips one via ``_opted_in`` must restore it. A leaked flip + would silently change the pass for a later test, so this catches anything that + bypasses the exception-safe ``_opted_in`` manager by comparing against each + operator's shipped default rather than a blanket False. """ for op in _OPERATOR_CLASSES: - assert op.__dict__.get("GIQL_EXPAND") is False, ( + assert op.__dict__.get("GIQL_EXPAND") is _SHIPPED_EXPAND_FLAGS[op], ( f"{op.__name__}.GIQL_EXPAND leaked into a test from a prior one" ) yield for op in _OPERATOR_CLASSES: - assert op.__dict__.get("GIQL_EXPAND") is False, ( + assert op.__dict__.get("GIQL_EXPAND") is _SHIPPED_EXPAND_FLAGS[op], ( f"a test leaked a GIQL_EXPAND opt-in on {op.__name__}" ) @@ -468,6 +485,55 @@ def _expander(node, ctx): assert REGISTRY.resolve(DuckDBTarget(), GIQLDisjoin) is None assert (DuckDBTarget(), GIQLDisjoin) not in REGISTRY + def test_snapshot_should_not_observe_later_registrations(self): + """Test that a snapshot does not observe registrations made after it. + + Given: + A registry with one entry, captured by snapshot. + When: + A second entry is registered after the snapshot is taken. + Then: + The snapshot should still hold only the first entry (it is a copy, + not a live view). + """ + # Arrange + registry = ExpanderRegistry() + registry.register(DuckDBTarget(), GIQLDisjoin, _record("first")) + + # Act + saved = registry.snapshot() + registry.register(GenericTarget(), Intersects, _record("second")) + + # Assert + assert (DuckDBTarget(), GIQLDisjoin) in saved + assert (GenericTarget(), Intersects) not in saved + + def test_restore_should_replace_entries_with_snapshot_contents(self): + """Test that restore returns the registry to a captured snapshot. + + Given: + A snapshot of a registry with one entry, after which the registry is + cleared and a different entry registered. + When: + Restoring the snapshot. + Then: + The original entry should resolve again and the post-snapshot entry + should be gone. + """ + # Arrange + registry = ExpanderRegistry() + registry.register(DuckDBTarget(), GIQLDisjoin, _record("original")) + saved = registry.snapshot() + registry.clear() + registry.register(GenericTarget(), Intersects, _record("transient")) + + # Act + registry.restore(saved) + + # Assert + assert (DuckDBTarget(), GIQLDisjoin) in registry + assert (GenericTarget(), Intersects) not in registry + class TestRegisterDecorator: """Tests for the @register extension-hook decorator.""" @@ -956,7 +1022,9 @@ def test_transform_skips_unflagged_operator(self, clean_registry): Given: An expander registered for (GenericTarget, GIQLDisjoin) but the - operator's GIQL_EXPAND flag left at its default False. + operator's GIQL_EXPAND flag held off (a migrated operator ships it on, + so the control opts the migrated operator out to isolate the per-type + gate from any shipped opt-in). When: Running the pass. Then: @@ -968,8 +1036,9 @@ def test_transform_skips_unflagged_operator(self, clean_registry): ast = _prepare("SELECT * FROM DISJOIN(variants)", tables) pass_ = ExpandOperators(GenericTarget(), tables, clean_registry) - # Act (GIQL_EXPAND is False by default — no opt-in context) - result = pass_.transform(ast) + # Act + with _opted_out(_A_MIGRATED_OPERATOR): + result = pass_.transform(ast) # Assert assert list(result.find_all(GIQLDisjoin)) @@ -1022,17 +1091,21 @@ def test_expand_operators_is_identity_when_registry_empty(self): assert list(result.find_all(GIQLDisjoin)) def test_transpile_sql_unchanged_with_pass_inert(self): - """Test that transpile output is byte-identical with the pass inert. + """Test that transpile output is byte-identical for an unmigrated operator. Given: - A DISJOIN query, the default (empty) registry, and no operator flagged. + A DISTANCE query (an operator not migrated onto the pass, so its + GIQL_EXPAND is False and no expander resolves), with the default + registry. When: - Transpiling with the wired-in pass versus a pass-bypassed reference. + Transpiling with the wired-in pass versus a pass-bypassed reference + (its legacy emitter run directly). Then: - The SQL should match exactly and carry no expander alias prefix. + The SQL should match exactly and carry no expander alias prefix — the + pass is inert for any operator that has not been migrated. """ # Arrange - query = "SELECT * FROM DISJOIN(variants)" + query = "SELECT DISTANCE(a.interval, b.interval) FROM variants a, variants b" tables = _tables() ast = _prepare(query, tables) from giql.generators import BaseGIQLGenerator @@ -1048,8 +1121,8 @@ def test_transpile_sql_unchanged_with_pass_inert(self): # The nine GIQL operator expression classes the ExpandOperators pass inspects. -# Every one must ship opted out (GIQL_EXPAND=False) at this step so the pass is a -# strict no-op until a later migration flips one alongside its expander. +# Each ships its own GIQL_EXPAND default: a migrated operator ships True (and has +# a built-in expander registered), the rest ship False until their migrations land. from giql.expressions import Contains # noqa: E402 from giql.expressions import GIQLCluster # noqa: E402 from giql.expressions import GIQLDistance # noqa: E402 @@ -1069,20 +1142,46 @@ def test_transpile_sql_unchanged_with_pass_inert(self): GIQLMerge, ) +#: Each operator's shipped GIQL_EXPAND default, captured from its own class dict +#: at import. A migrated operator ships ``True``; the rest ship ``False`` until +#: their migrations land. The flag leak guard restores to these shipped values +#: rather than a blanket ``False``. +_SHIPPED_EXPAND_FLAGS = {op: op.__dict__.get("GIQL_EXPAND") for op in _OPERATOR_CLASSES} + + +#: Operators migrated onto the ExpandOperators pass — they ship GIQL_EXPAND=True. +_MIGRATED_OPERATORS = tuple( + op for op in _OPERATOR_CLASSES if op.__dict__.get("GIQL_EXPAND") is True +) +#: Operators not yet migrated — they ship GIQL_EXPAND=False. +_UNMIGRATED_OPERATORS = tuple( + op for op in _OPERATOR_CLASSES if op not in _MIGRATED_OPERATORS +) + +# At least one operator has been migrated (the build of this branch ships one); +# the controls below pick an arbitrary migrated operator operator-agnostically. +assert _MIGRATED_OPERATORS, "expected at least one migrated operator" +#: An arbitrary migrated operator, used by controls that need an operator shipping +#: GIQL_EXPAND=True without naming a specific one. +_A_MIGRATED_OPERATOR = _MIGRATED_OPERATORS[0] + class TestOperatorOptOut: - """Every operator ships opted out of the ExpandOperators pass at this step.""" + """Migrated operators opt into the pass; the rest still ship opted out.""" - @pytest.mark.parametrize("operator", _OPERATOR_CLASSES, ids=lambda c: c.__name__) + @pytest.mark.parametrize( + "operator", _UNMIGRATED_OPERATORS, ids=lambda c: c.__name__ + ) def test_operator_class_ships_expand_disabled(self, operator): - """Test that each operator class ships GIQL_EXPAND=False. + """Test that each unmigrated operator class ships GIQL_EXPAND=False. Given: - One of the nine GIQL operator expression classes. + A GIQL operator expression class that has not been migrated onto the + ExpandOperators pass. When: Reading its GIQL_EXPAND class attribute. Then: - It should be False (no operator opts into expansion yet). + It should be False (the operator still uses the legacy emitter). """ # Arrange & act flag = operator.GIQL_EXPAND @@ -1090,6 +1189,27 @@ def test_operator_class_ships_expand_disabled(self, operator): # Assert assert flag is False + @pytest.mark.parametrize( + "operator", _MIGRATED_OPERATORS, ids=lambda c: c.__name__ + ) + def test_operator_class_ships_expand_enabled(self, operator): + """Test that each migrated operator class ships GIQL_EXPAND=True. + + Given: + A GIQL operator expression class migrated onto the ExpandOperators + pass. + When: + Reading its GIQL_EXPAND class attribute. + Then: + It should be True (the operator expands through its registered + expander instead of the deleted legacy emitter). + """ + # Arrange & act + flag = operator.GIQL_EXPAND + + # Assert + assert flag is True + class TestOptedInRestoresFlag: """The _opted_in helper restores GIQL_EXPAND even when its body raises.""" @@ -1672,8 +1792,9 @@ def test_walk_partial_opt_in_replaces_only_flagged_type(self, clean_registry): ) pass_ = ExpandOperators(GenericTarget(), tables, clean_registry) - # Act - with _opted_in(Intersects): + # Act (opt the migrated operator out so its shipped opt-in cannot + # interfere; DISJOIN here is the unflagged control whose node must remain) + with _opted_in(Intersects), _opted_out(_A_MIGRATED_OPERATOR): result = pass_.transform(ast) # Assert @@ -1888,3 +2009,25 @@ def __enter__(self): def __exit__(self, *exc): self._operator.GIQL_EXPAND = self._prior return False + + +class _opted_out: + """Context manager opting an operator class out of GIQL_EXPAND for a test. + + The complement of :class:`_opted_in`: used by a control test that needs a + *migrated* operator (one shipping GIQL_EXPAND=True) to behave as if unflagged, + so the test can prove the pass gates per-type without the operator's shipped + opt-in interfering. Restores the prior flag on exit. + """ + + def __init__(self, operator: type) -> None: + self._operator = operator + self._prior = operator.__dict__.get("GIQL_EXPAND", False) + + def __enter__(self): + self._operator.GIQL_EXPAND = False + return self._operator + + def __exit__(self, *exc): + self._operator.GIQL_EXPAND = self._prior + return False diff --git a/tests/test_nearest_transpilation.py b/tests/test_nearest_transpilation.py index 2488cb0..e63600d 100644 --- a/tests/test_nearest_transpilation.py +++ b/tests/test_nearest_transpilation.py @@ -7,26 +7,48 @@ import pytest from sqlglot import parse_one +import giql.expanders # noqa: F401 (side-effect: registers the NEAREST expander) from giql import Table from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect +from giql.expander import ExpandOperators from giql.generators import BaseGIQLGenerator from giql.resolver import resolve_operator_refs from giql.table import Tables +from giql.targets import DataFusionTarget +from giql.targets import GenericTarget + + +def _generate_for_target(sql: str, tables: Tables, target) -> str: + """Parse, run passes 1-3 against *target*, then generate SQL. + + Drives the expander for a specific :class:`~giql.targets.Target` so a + capability-dependent shape (e.g. DataFusion's decorrelated window fallback, + chosen because ``supports_lateral`` is False) can be asserted without an + engine. + """ + ast = parse_one(sql, dialect=GIQLDialect) + ast = resolve_operator_refs(ast, tables) + ast = canonicalize_coordinates(ast) + ast = ExpandOperators(target, tables).transform(ast) + return BaseGIQLGenerator(tables=tables).generate(ast) def _generate(sql: str, tables: Tables) -> str: - """Parse, run normalization passes 1 and 2, then generate SQL. + """Parse, run normalization passes 1-3, then generate SQL. - Operator resolution and coordinate canonicalization moved out of the emitter - and into the ResolveOperatorRefs / CanonicalizeCoordinates passes (epic #114, - issues #118-#123). Emitter-level tests must run both passes before generating, - exactly as :func:`giql.transpile.transpile` does, rather than calling - ``generate`` on a bare parsed AST. + Operator resolution, coordinate canonicalization, and operator expansion + moved out of the emitter into the ResolveOperatorRefs / CanonicalizeCoordinates + / ExpandOperators passes (epics #114, #137). NEAREST is now produced by its + registered expander (issue #142) rather than a ``giqlnearest_sql`` emitter, so + these tests must run pass 3 before generating, exactly as + :func:`giql.transpile.transpile` does, rather than calling ``generate`` on a + bare parsed AST. """ ast = parse_one(sql, dialect=GIQLDialect) ast = resolve_operator_refs(ast, tables) ast = canonicalize_coordinates(ast) + ast = ExpandOperators(GenericTarget(), tables).transform(ast) return BaseGIQLGenerator(tables=tables).generate(ast) @@ -160,3 +182,86 @@ def test_nearest_with_signed(self, tables_with_peaks_and_genes): assert "ELSE -(" in output, ( f"Expected signed distance with negation for upstream, got:\n{output}" ) + + +class TestNearestDataFusionFallbackShape: + """Engine-free transpile-shape checks for the DataFusion window fallback (A8). + + A correlated NEAREST on the DataFusion target (``supports_lateral`` is False) + expands to the decorrelated window-function form. These assert its structural + invariants without running an engine: the window is present, the top-k filter + is a ``<= k`` predicate, no correlated ``LATERAL`` survives, and the candidate + cross-join and the window live at separate query levels. + """ + + def test_fallback_emits_window_with_topk_and_no_lateral( + self, tables_with_peaks_and_genes + ): + """ + GIVEN a correlated NEAREST(genes, k := 1) on the DataFusion target + WHEN transpiling + THEN the decorrelated fallback emits a ROW_NUMBER() window, a `<= 1` + top-k predicate, no surviving LATERAL, and the cross-join and window + at separate query levels. + """ + sql = ( + "SELECT * FROM peaks " + "CROSS JOIN LATERAL NEAREST(genes, reference := peaks.interval, k := 1) AS b" + ) + + output = _generate_for_target( + sql, tables_with_peaks_and_genes, DataFusionTarget() + ) + + assert "ROW_NUMBER(" in output.upper() + assert "OVER (" in output.upper() + assert "<= 1" in output + assert "LATERAL" not in output.upper() + # The candidate cross-join sits one level below the window: the window's + # FROM is a parenthesized subquery, so a CROSS JOIN appears nested inside. + assert "CROSS JOIN" in output.upper() + + def test_fallback_stranded_emits_window_and_strand_match( + self, tables_with_peaks_and_genes + ): + """ + GIVEN a stranded correlated NEAREST on the DataFusion target + WHEN transpiling + THEN the fallback emits the window form, keeps a strand equality in the + candidate WHERE, and surfaces no LATERAL. + """ + sql = ( + "SELECT * FROM peaks CROSS JOIN LATERAL " + "NEAREST(genes, reference := peaks.interval, k := 1, stranded := true) AS b" + ) + + output = _generate_for_target( + sql, tables_with_peaks_and_genes, DataFusionTarget() + ) + + assert "ROW_NUMBER(" in output.upper() + assert "LATERAL" not in output.upper() + assert 'peaks."strand"' in output + assert 'genes."strand"' in output + + def test_fallback_k_greater_than_one_uses_k_in_topk_predicate( + self, tables_with_peaks_and_genes + ): + """ + GIVEN a correlated NEAREST(genes, k := 3) on the DataFusion target + WHEN transpiling + THEN the top-k predicate carries the requested k (`<= 3`) rather than a + LIMIT, and no LATERAL survives. + """ + sql = ( + "SELECT * FROM peaks " + "CROSS JOIN LATERAL NEAREST(genes, reference := peaks.interval, k := 3) AS b" + ) + + output = _generate_for_target( + sql, tables_with_peaks_and_genes, DataFusionTarget() + ) + + assert "ROW_NUMBER(" in output.upper() + assert "<= 3" in output + assert "LATERAL" not in output.upper()