From 218b5cccc0120856bf1fc2fbdca84a53b559b057 Mon Sep 17 00:00:00 2001 From: Conrad Date: Sun, 28 Jun 2026 00:11:29 -0400 Subject: [PATCH 1/4] refactor: Add snapshot and restore seam to ExpanderRegistry Add public snapshot and restore methods to ExpanderRegistry, a save/restore seam over the process-wide registry. A test fixture or a plugin that mutates the registry around a body can capture the baseline and reinstate it afterward, so the built-in expanders registered at import survive an isolating fixture that would otherwise clear them permanently. This is the infrastructure the migrated test fixtures depend on to treat the import-time built-in registrations as their baseline rather than an empty registry. --- src/giql/expander.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/giql/expander.py b/src/giql/expander.py index f93132a..c6558ac 100644 --- a/src/giql/expander.py +++ b/src/giql/expander.py @@ -253,6 +253,31 @@ def clear(self) -> None: """ self._expanders.clear() + def snapshot(self) -> dict[tuple[Target, type], ExpanderFn]: + """Return a shallow copy of the current registrations. + + The save half of the registry's **public save/restore seam**: a test + fixture (or a plugin) that mutates the process-wide :data:`REGISTRY` + around a body — registering or clearing entries — captures the baseline + with this and hands it back to :meth:`restore` afterward, so the + built-in expanders registered at import survive an isolating fixture + that would otherwise :meth:`clear` them permanently. + + The returned dict is a fresh mapping (mutating it does not affect the + registry), keyed by the same ``(target, operator)`` tuples. + """ + return dict(self._expanders) + + def restore(self, snapshot: dict[tuple[Target, type], ExpanderFn]) -> None: + """Replace all registrations with those captured by :meth:`snapshot`. + + The restore half of the save/restore seam. Drops every current entry and + re-installs exactly the *snapshot* contents, so a fixture can return the + registry to a previously captured baseline regardless of what its body + registered or cleared. + """ + self._expanders = dict(snapshot) + def __contains__(self, key: tuple[Target, type]) -> bool: """Whether an *exact* ``(target, operator)`` entry is registered. From adec14cdbd505729ef162939864b768e432477c9 Mon Sep 17 00:00:00 2001 From: Conrad Date: Sun, 28 Jun 2026 00:11:46 -0400 Subject: [PATCH 2/4] feat: Migrate NEAREST to capability-driven expander Move NEAREST expansion off the legacy giqlnearest_sql emitter and onto the ExpandOperators pass as a capability-driven expander. Lateral-capable targets and every standalone literal-reference placement get the portable correlated LATERAL subquery, byte-identical to the legacy emitter. A correlated NEAREST on a target without LATERAL support now gets a decorrelated ROW_NUMBER() window-function fallback that returns identical rows: it ranks candidates once per distinct reference key and re-joins the top k back to every outer row sharing that key. This adds DataFusion support for correlated NEAREST, which previously had no physical plan for the LATERAL form and failed outright. Add the giql.expanders package, whose import registers every built-in expander as a side effect and auto-discovers new operator modules, and wire that import into transpile so the registry is populated before the first transpile. Flip GIQLNearest.GIQL_EXPAND on so the pass owns NEAREST, and delete the giqlnearest_sql emitter. The shared _generate_distance_case and _nearest_* resolution helpers are retained and reused by the expander, keeping the distance, passthrough, and encoding logic unchanged. --- src/giql/expanders/__init__.py | 19 ++ src/giql/expanders/nearest.py | 347 +++++++++++++++++++++++++++++++++ src/giql/expressions.py | 6 +- src/giql/generators/base.py | 165 ---------------- src/giql/transpile.py | 1 + 5 files changed, 372 insertions(+), 166 deletions(-) create mode 100644 src/giql/expanders/__init__.py create mode 100644 src/giql/expanders/nearest.py diff --git a/src/giql/expanders/__init__.py b/src/giql/expanders/__init__.py new file mode 100644 index 0000000..86bf43e --- /dev/null +++ b/src/giql/expanders/__init__.py @@ -0,0 +1,19 @@ +"""Built-in operator expanders for epic #137. + +Importing this package registers every built-in expander as a side effect: +each submodule decorates its expander(s) with ``@register(...)`` at import +time, and this package imports all of them. The import is wired once (in +:mod:`giql.transpile`) so the process-wide ``REGISTRY`` is populated before the +first transpile. + +New operator modules are picked up automatically: drop a ``.py`` into +this package and it is imported here without editing this file. +""" + +from __future__ import annotations + +import importlib +import pkgutil + +for _module_info in pkgutil.iter_modules(__path__): + importlib.import_module(f"{__name__}.{_module_info.name}") diff --git a/src/giql/expanders/nearest.py b/src/giql/expanders/nearest.py new file mode 100644 index 0000000..3934a8e --- /dev/null +++ b/src/giql/expanders/nearest.py @@ -0,0 +1,347 @@ +"""The NEAREST operator expander (epic #137, issue #142). + +NEAREST is the first operator whose expansion is genuinely capability-driven. +The portable form is a correlated ``LATERAL`` subquery: each outer row drives a +``SELECT ... FROM WHERE ORDER BY ABS(distance) +LIMIT k`` whose reference endpoints are outer-table columns. DuckDB and the +generic target plan that directly (``supports_lateral == True``). + +Apache DataFusion has no correlated-``LATERAL`` physical plan +(``supports_lateral == False``). For it the same k-nearest / ``max_distance`` / +``stranded`` / ``signed`` semantics are reproduced with a **decorrelated +window-function fallback**: the target is cross-joined against the outer +relation, each candidate is ranked with +``ROW_NUMBER() OVER (PARTITION BY ORDER BY ABS(distance))``, and +the surrounding ``CROSS JOIN LATERAL`` is rewritten into a plain join that +re-associates the top-``k`` ranked candidates back to every outer row sharing +that reference key. Ranking depends only on the reference value, so ranking once +per distinct reference value and re-joining is row-for-row identical to the +per-row LATERAL form (verified by the cross-target result oracle). + +A literal-reference (standalone) NEAREST is already an uncorrelated subquery, so +every target — DataFusion included — uses the LATERAL/standalone form unchanged; +only the *correlated* shape needs the fallback. + +The expander reuses :class:`giql.generators.base.BaseGIQLGenerator`'s +``_generate_distance_case`` (shared with DISTANCE, #140) and ``_nearest_*`` +resolution/passthrough helpers, then parses the assembled SQL fragments into AST +so the emitted SQL is reserialized by the active target's serializer. +""" + +from __future__ import annotations + +from sqlglot import exp +from sqlglot import parse_one + +from giql.dialect import GIQLDialect +from giql.expander import ExpansionContext +from giql.expander import register +from giql.expressions import GIQLNearest +from giql.generators.base import BaseGIQLGenerator +from giql.resolver import ResolvedInterval +from giql.resolver import ResolvedRef +from giql.targets import GenericTarget + +#: Reserved column names the window-function fallback synthesizes inside its +#: ranked subquery. They share the expander's ``__giql_x_`` prefix so they stay +#: clear of user identifiers, mirroring the other reserved internal prefixes. +_RANK_COL = "__giql_x_rn" +_REF_KEY_PREFIX = "__giql_x_rk_" + + +def _emitter() -> BaseGIQLGenerator: + """A throwaway generator used only for its (self-free) NEAREST helpers. + + ``_generate_distance_case``, ``_nearest_*`` and ``_extract_bool_param`` carry + no instance state, so a default-constructed generator is a safe host for + them. Reusing them keeps the expander's distance/passthrough/encoding logic + byte-for-byte identical to the legacy emitter it replaces. + """ + return BaseGIQLGenerator() + + +def _nearest_params(expression: GIQLNearest, gen: BaseGIQLGenerator): + """Unpack the (k, max_distance, stranded, signed) parameters of a NEAREST.""" + k = expression.args.get("k") + k_value = int(str(k)) if k else 1 + + max_distance = expression.args.get("max_distance") + max_dist_value = int(str(max_distance)) if max_distance else None + + is_stranded = gen._extract_bool_param(expression.args.get("stranded")) + is_signed = gen._extract_bool_param(expression.args.get("signed")) + return k_value, max_dist_value, is_stranded, is_signed + + +def _distance_and_filters( + expression, ctx, gen, table_name, target_ref, ref, ref_fragments=None +): + """Build the shared distance SQL, the qualified target columns, and WHERE. + + Returns ``(distance_expr, abs_distance_expr, where_clauses, passthrough)`` — + the fragments common to the LATERAL/standalone form and the decorrelated + fallback. Distance math, the chromosome pre-filter, the optional strand + match, and the optional ``max_distance`` filter all reproduce the legacy + ``giqlnearest_sql`` emitter exactly. + + ``ref_fragments`` optionally overrides the reference ``(chrom, start, end, + strand)`` SQL fragments. The LATERAL form consumes the resolution's + outer-qualified fragments verbatim; the fallback passes fragments pointing at + its renamed, pre-projected reference relation so the cross-joined columns + carry names distinct from the target's (DataFusion's planner cannot resolve a + window ordering over a join with duplicate column names). + """ + target_chrom, target_start, target_end = target_ref.cols + k_value, max_dist_value, is_stranded, is_signed = _nearest_params(expression, gen) + + output_table = gen._nearest_output_encoding(expression, target_ref) + passthrough = gen._nearest_passthrough( + table_name, target_start, target_end, output_table + ) + + if ref_fragments is not None: + ref_chrom, ref_start, ref_end, ref_strand_frag = ref_fragments + else: + ref_chrom, ref_start, ref_end, ref_strand_frag = ( + ref.chrom, + ref.start, + ref.end, + ref.strand, + ) + + ref_strand = None + target_strand = None + if is_stranded: + ref_strand = ref_strand_frag + if output_table and output_table.strand_col: + target_strand = f'{table_name}."{output_table.strand_col}"' + + target_chrom_expr = f'{table_name}."{target_chrom}"' + target_start_expr = f'{table_name}."{target_start}"' + target_end_expr = f'{table_name}."{target_end}"' + + distance_expr = gen._generate_distance_case( + ref_chrom, + ref_start, + ref_end, + ref_strand, + target_chrom_expr, + target_start_expr, + target_end_expr, + target_strand, + stranded=is_stranded, + signed=is_signed, + ) + abs_distance_expr = f"ABS({distance_expr})" + + where_clauses = [f"{ref_chrom} = {target_chrom_expr}"] + if is_stranded and ref_strand and target_strand: + where_clauses.append(f"{ref_strand} = {target_strand}") + if max_dist_value is not None: + where_clauses.append(f"({abs_distance_expr}) <= {max_dist_value}") + + return distance_expr, abs_distance_expr, where_clauses, passthrough + + +def _lateral_form(expression, ctx, gen, table_name, target_ref, ref): + """The portable LATERAL/standalone subquery — identical to the legacy emitter. + + Builds the ``(SELECT , AS distance FROM + WHERE ... ORDER BY ABS(distance) LIMIT k)`` subquery the legacy + ``giqlnearest_sql`` produced and parses it into AST. For a correlated + placement the parent ``LATERAL`` correlates it to the outer row; for a + standalone (literal-reference) placement it stands alone. + """ + k_value, *_ = _nearest_params(expression, gen) + distance_expr, abs_distance_expr, where_clauses, passthrough = _distance_and_filters( + expression, ctx, gen, table_name, target_ref, ref + ) + where_sql = " AND ".join(where_clauses) + sql = ( + f"(SELECT {passthrough}, {distance_expr} AS distance " + f"FROM {table_name} WHERE {where_sql} " + f"ORDER BY {abs_distance_expr} LIMIT {k_value})" + ) + return parse_one(sql, dialect=GIQLDialect) + + +def _outer_relation(ref: ResolvedInterval) -> tuple[str, str]: + """Return ``(physical_relation, alias)`` for the correlated reference table. + + The reference endpoints are alias-qualified fragments (``a."chrom"``). The + alias is the outer table's correlation name in the query; the physical + relation comes from the reference's backing :class:`~giql.table.Table`. Both + are needed to re-introduce the outer relation inside the decorrelated + subquery the fallback builds. + """ + parsed = parse_one(ref.chrom, dialect=GIQLDialect) + alias = parsed.table if isinstance(parsed, exp.Column) else "" + relation = ref.table.name if ref.table is not None else alias + return relation, alias + + +def _fallback_form(expression, ctx, gen, table_name, target_ref, ref): + """The decorrelated window-function fallback for non-LATERAL targets. + + Rewrites the surrounding `` AS a CROSS JOIN LATERAL (nearest) AS b`` + into `` AS a JOIN () AS b ON AND + b. <= k``. The ranked subquery cross-joins the target against the outer + relation and ranks candidates per distinct reference key with + ``ROW_NUMBER()``; the join re-associates the top-k back to every outer row + sharing that key, reproducing the per-row LATERAL semantics. Swaps the parent + ``LATERAL`` for the decorrelated subquery in place and returns the (now + detached) NEAREST node, so the pass's own ``node.replace`` is a no-op. + """ + lateral = expression.parent + join = lateral.parent + alias = lateral.args["alias"].name + + relation, outer_alias = _outer_relation(ref) + k_value, _max, is_stranded, _signed = _nearest_params(expression, gen) + + # Pre-project the outer relation's reference columns under fresh, non-target + # names into a renamed derived relation. Cross-joining *this* (rather than the + # raw outer table) keeps every reference column distinct from the target's + # columns: DataFusion's planner cannot resolve a window ordering over a join + # whose two sides share column names (e.g. both expose ``start`` / ``end``). + # + # The reference key identifies one distinct reference interval, which the + # ranking partitions by and the join re-associates on. Position + # (chrom/start/end) alone identifies it in the unstranded case; in stranded + # mode strand joins the key too, because two outer rows at the same position + # but opposite strands must each get their own strand-filtered nearest. The + # ref relation is de-duplicated on the key with DISTINCT so ranking happens + # once per distinct reference and the join fans the top-k back out to every + # outer row sharing it — exactly the per-row LATERAL semantics, even when the + # outer table holds duplicate reference rows. + ref_relation_alias = "__giql_x_ref" + strand_name = f"{_REF_KEY_PREFIX}strand" + stranded_key = is_stranded and ref.strand is not None + + key_names = [f"{_REF_KEY_PREFIX}chrom", f"{_REF_KEY_PREFIX}start", + f"{_REF_KEY_PREFIX}end"] + source_frags = [ref.chrom, ref.start, ref.end] + if stranded_key: + key_names.append(strand_name) + source_frags.append(ref.strand) + + ref_projection = ", ".join( + f'{frag} AS "{name}"' for name, frag in zip(key_names, source_frags) + ) + ref_relation = ( + f"(SELECT DISTINCT {ref_projection} FROM {relation} AS {outer_alias})" + f" AS {ref_relation_alias}" + ) + + # Reference fragments now point at the renamed relation's safe columns. + renamed = [f'{ref_relation_alias}."{name}"' for name in key_names] + renamed_strand = ( + f'{ref_relation_alias}."{strand_name}"' if stranded_key else None + ) + ref_fragments = (renamed[0], renamed[1], renamed[2], renamed_strand) + + distance_expr, abs_distance_expr, where_clauses, passthrough = _distance_and_filters( + expression, ctx, gen, table_name, target_ref, ref, ref_fragments=ref_fragments + ) + + # Surface the reference-key columns so the rewritten join can match each + # ranked candidate back to its outer row(s). Ranking depends only on these + # values, so partitioning by them and re-joining is identical to the per-row + # LATERAL form even when outer rows share a reference value. + key_cols = list(zip(key_names, renamed)) + key_projection = ", ".join(f'{frag} AS "{name}"' for name, frag in key_cols) + where_sql = " AND ".join(where_clauses) + + # Compute the candidate set (cross join + distance + reference keys) in an + # inner subquery, then add ROW_NUMBER() in the enclosing one. Keeping the + # join and the window in *separate* query levels is load-bearing on + # DataFusion: fused into one level its optimizer mis-derives the window's sort + # order from the chromosome-equality prefilter and trips ``SanityCheckPlan``. + candidate = "__giql_x_cand" + inner = ( + f"SELECT {passthrough}, {distance_expr} AS distance, {key_projection} " + f"FROM {table_name} CROSS JOIN {ref_relation} " + f"WHERE {where_sql}" + ) + partition = ", ".join(f'{candidate}."{name}"' for name, _ in key_cols) + ranked = ( + f"(SELECT {candidate}.*, " + f"ROW_NUMBER() OVER (PARTITION BY {partition} " + f"ORDER BY ABS({candidate}.distance)) AS \"{_RANK_COL}\" " + f"FROM ({inner}) AS {candidate})" + ) + ranked_subquery = parse_one(ranked, dialect=GIQLDialect) + + # Match each ranked candidate back to the *outer* relation by its reference + # value (the original outer-qualified fragments, e.g. ``a."chrom"``), not the + # renamed inner columns which exist only inside the subquery. + on_parts = [ + f'{alias}."{name}" = {src}' for name, src in zip(key_names, source_frags) + ] + on_parts.append(f'{alias}."{_RANK_COL}" <= {k_value}') + on_sql = " AND ".join(on_parts) + on_expr = parse_one(on_sql, dialect=GIQLDialect) + + subquery = exp.Subquery( + this=ranked_subquery.this, alias=lateral.args["alias"].copy() + ) + + # Convert `` CROSS JOIN LATERAL (nearest) AS b`` into + # `` JOIN (ranked) AS b ON AND b.rn <= k``. Swap the + # whole LATERAL out for the decorrelated subquery and drop the CROSS kind so + # the ON clause attaches as a plain (inner) join. + lateral.replace(subquery) + join.set("kind", None) + join.set("side", None) + join.set("on", on_expr) + + # The LATERAL (and the NEAREST node within it) is now detached; returning the + # node unchanged makes the pass's ``node.replace`` a no-op. + return expression + + +def expand_nearest(node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: + """Expand a NEAREST node to LATERAL or the decorrelated window-function form. + + Selects on ``ctx.capabilities.supports_lateral`` and whether the node is + correlated (its parent is a ``LATERAL``). Lateral-capable targets and every + standalone (literal-reference) placement get the portable LATERAL/standalone + subquery; a correlated NEAREST on a target without LATERAL support gets the + decorrelated window-function fallback. + """ + assert isinstance(node, GIQLNearest) + gen = _emitter() + resolution = ctx.resolution + + target_ref = resolution.slot("this") if resolution is not None else None + if not isinstance(target_ref, ResolvedRef): + # An unresolved target means it is not a registered table; raise the + # historical diagnostic (verbatim from the removed giqlnearest_sql). + target = node.this + if isinstance(target, exp.Table): + target_name = target.name + elif isinstance(target, exp.Column): + target_name = target.table if target.table else str(target.this) + else: + target_name = str(target) + raise ValueError( + f"Target table '{target_name}' not found in tables. " + "Register the table before transpiling." + ) + table_name = target_ref.name + + ref = resolution.slot("reference") + if not isinstance(ref, ResolvedInterval): + mode = gen._detect_nearest_mode(node) + gen._raise_nearest_reference_error(node, mode, resolution) + + correlated = isinstance(node.parent, exp.Lateral) + if correlated and not ctx.capabilities.supports_lateral: + return _fallback_form(node, ctx, gen, table_name, target_ref, ref) + return _lateral_form(node, ctx, gen, table_name, target_ref, ref) + + +# The generic registration covers every target through the registry's fallback +# chain; the expander branches on ctx.capabilities.supports_lateral internally, +# so no per-target override is needed. +register(GenericTarget, GIQLNearest)(expand_nearest) diff --git a/src/giql/expressions.py b/src/giql/expressions.py index ce939dc..fe8d397 100644 --- a/src/giql/expressions.py +++ b/src/giql/expressions.py @@ -369,7 +369,11 @@ class GIQLNearest(exp.Func): #: half-open) operands are left untouched and the emitted SQL stays #: byte-identical. GIQL_CANONICALIZE = _CANONICALIZE - GIQL_EXPAND = _EXPAND + #: Migrated to the ExpandOperators pass (epic #137, issue #142): NEAREST is + #: expanded by ``giql.expanders.nearest`` — the portable correlated LATERAL + #: subquery where ``supports_lateral`` holds, a decorrelated window-function + #: form otherwise. The legacy ``giqlnearest_sql`` emitter has been removed. + GIQL_EXPAND = True GIQL_SLOTS = ( SlotSpec("this", frozenset({"registered_table"}), required=True), diff --git a/src/giql/generators/base.py b/src/giql/generators/base.py index 9038369..5e6f3f6 100644 --- a/src/giql/generators/base.py +++ b/src/giql/generators/base.py @@ -15,7 +15,6 @@ from giql.resolver import META_KEY from giql.resolver import OperatorResolution from giql.resolver import ResolvedColumn -from giql.resolver import ResolvedInterval from giql.resolver import ResolvedRef from giql.table import Table from giql.table import Tables @@ -82,170 +81,6 @@ def spatialsetpredicate_sql(self, expression: SpatialSetPredicate) -> str: """ return self._generate_spatial_set(expression) - def giqlnearest_sql(self, expression: GIQLNearest) -> str: - """Generate SQL for NEAREST function. - - Detects mode (standalone vs correlated) and generates appropriate SQL: - - Standalone: Direct query with ORDER BY + LIMIT - - Correlated (LATERAL): Subquery for k-nearest neighbors - - :param expression: - GIQLNearest expression node - :return: - SQL string for NEAREST operation - """ - # Detect mode - mode = self._detect_nearest_mode(expression) - - # Unpack the resolution metadata attached by ResolveOperatorRefs (pass 1). - resolution = self._nearest_resolution(expression) - - # Target (already a registered-table ResolvedRef from the pass). An - # unresolved target means it is not a registered table; raise the - # historical diagnostic. - target_ref = resolution.slot("this") if resolution is not None else None - if not isinstance(target_ref, ResolvedRef): - target = expression.this - if isinstance(target, exp.Table): - target_name = target.name - elif isinstance(target, exp.Column): - target_name = target.table if target.table else str(target.this) - else: - target_name = str(target) - raise ValueError( - f"Target table '{target_name}' not found in tables. " - "Register the table before transpiling." - ) - table_name = target_ref.name - target_chrom, target_start, target_end = target_ref.cols - - # The target's *declared* encoding, which the passed-through target row - # (SELECT {table_name}.*) must round-trip back into. CanonicalizeCoordinates - # (pass 2) preserves it on the resolution when it wraps a non-canonical - # target in a __giql_canon_* CTE (the slot's own Table is then None); a - # canonical target is left unwrapped and its slot Table carries the - # (identity) encoding. The synthesized `distance` column is encoding- - # invariant (a count of bases) and needs no round-trip. - output_table = self._nearest_output_encoding(expression, target_ref) - passthrough = self._nearest_passthrough( - table_name, target_start, target_end, output_table - ) - - # Reference interval (a ResolvedInterval from the pass). An unresolved - # reference re-raises the generator's historical diagnostic. Input - # canonicalization is owned by CanonicalizeCoordinates (pass 2, issue - # #123): a literal range is already canonical, and a column / implicit- - # outer reference's endpoints are canonicalized in place by the pass, so - # the emitter consumes the fragments verbatim with no canonicalization. - ref = resolution.slot("reference") - if not isinstance(ref, ResolvedInterval): - self._raise_nearest_reference_error(expression, mode, resolution) - ref_chrom, ref_start, ref_end = ref.chrom, ref.start, ref.end - - # Extract parameters - k = expression.args.get("k") - k_value = int(str(k)) if k else 1 # Default k=1 - - max_distance = expression.args.get("max_distance") - max_dist_value = int(str(max_distance)) if max_distance else None - - is_stranded = self._extract_bool_param(expression.args.get("stranded")) - is_signed = self._extract_bool_param(expression.args.get("signed")) - - # Resolve strand columns if stranded mode. The reference strand is - # carried on the resolved interval (a literal's strand, an explicit - # column's strand, or the outer table's strand for an implicit - # reference — already gated to preserve the historical divergence). - ref_strand = None - target_strand = None - if is_stranded: - ref_strand = ref.strand - # When pass 2 wraps a non-canonical target its slot Table is blanked, - # so the strand column name comes from the *declared* encoding the - # pass preserved (output_table). The canon CTE's SELECT * REPLACE - # passes the strand column through unchanged under its physical name, - # so the qualifier stays the relation NEAREST selects from. - if output_table and output_table.strand_col: - target_strand = f'{table_name}."{output_table.strand_col}"' - - # Distance math below assumes 0-based half-open. Input canonicalization - # is owned by CanonicalizeCoordinates (pass 2, issue #123): a - # non-canonical target is rewritten to a canonical __giql_canon_* CTE - # before generation (table_name then names the CTE), so the target - # endpoints are consumed verbatim with no in-emitter canonicalization. The - # output round-trip of the passed-through target row stays here (see the - # SELECT projection below). - target_start_expr = f'{table_name}."{target_start}"' - target_end_expr = f'{table_name}."{target_end}"' - - # Build distance calculation using CASE expression - # For NEAREST: ORDER BY absolute distance, but RETURN signed distance - distance_expr = self._generate_distance_case( - ref_chrom, - ref_start, - ref_end, - ref_strand, - f'{table_name}."{target_chrom}"', - target_start_expr, - target_end_expr, - target_strand, - stranded=is_stranded, - signed=is_signed, - ) - - # Use absolute distance for ordering and filtering - abs_distance_expr = f"ABS({distance_expr})" - - # Build WHERE clauses - where_clauses = [ - f'{ref_chrom} = {table_name}."{target_chrom}"' # Chromosome pre-filter - ] - - # Add strand matching for stranded mode - if is_stranded and ref_strand and target_strand: - where_clauses.append(f"{ref_strand} = {target_strand}") - - if max_dist_value is not None: - where_clauses.append(f"({abs_distance_expr}) <= {max_dist_value}") - - where_sql = " AND ".join(where_clauses) - - # Generate SQL based on mode - if mode == "standalone": - # Standalone mode: direct ORDER BY + LIMIT - # Return signed distance, but order by absolute distance - sql = f"""( - SELECT {passthrough}, {distance_expr} AS distance - FROM {table_name} - WHERE {where_sql} - ORDER BY {abs_distance_expr} - LIMIT {k_value} - )""" - else: - # Correlated mode: requires LATERAL join support - if not self.SUPPORTS_LATERAL: - raise ValueError( - "NEAREST in correlated mode (CROSS JOIN LATERAL) is not supported " - "in SQLite. SQLite does not support LATERAL joins. " - "\n\nAlternatives:" - "\n1. Use standalone mode: SELECT * FROM NEAREST(table, " - "reference='chr1:100-200', k=3)" - "\n2. Use DuckDB for queries requiring LATERAL joins" - "\n3. Manually write equivalent window function query" - ) - - # LATERAL mode: subquery for k-nearest neighbors - # Return signed distance, but order by absolute distance - sql = f"""( - SELECT {passthrough}, {distance_expr} AS distance - FROM {table_name} - WHERE {where_sql} - ORDER BY {abs_distance_expr} - LIMIT {k_value} - )""" - - return sql.strip() - def _nearest_output_encoding( self, expression: GIQLNearest, target_ref: ResolvedRef ) -> Table | None: diff --git a/src/giql/transpile.py b/src/giql/transpile.py index 9ef2100..dd8bb6c 100644 --- a/src/giql/transpile.py +++ b/src/giql/transpile.py @@ -11,6 +11,7 @@ from sqlglot import parse_one +import giql.expanders # noqa: F401 (side-effect: registers built-in expanders) from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect from giql.expander import ExpandOperators From 6485047239d4353134275131a444e7693c2a9008 Mon Sep 17 00:00:00 2001 From: Conrad Date: Sun, 28 Jun 2026 00:11:58 -0400 Subject: [PATCH 3/4] test: Update tests for NEAREST expander migration Rework the NEAREST and registry tests for the capability-driven expander. Drive the emitter-level NEAREST tests through the ExpandOperators pass instead of calling the deleted giqlnearest_sql directly, and update the pinned SQL to the expander's reserialized output (semantically unchanged from the legacy emitter). Drop the obsolete SUPPORTS_LATERAL=False hard-error test, since lateral support is now a target capability with a window-function fallback rather than a generator-level error. Promote the cross-target oracle's _unsupported_pending_142 expected- failure into a real three-target identity test: DataFusion now plans correlated NEAREST through the decorrelated fallback, so the LATERAL and window forms are verified to return identical rows on every target. Update the registry leak guards and clean_registry fixture to treat the import-time built-in registrations as the baseline through the new snapshot/restore seam, add coverage for snapshot/restore, and account for operators that now ship GIQL_EXPAND=True via an _opted_out helper. --- tests/generators/test_base.py | 319 +++++++----------- .../datafusion/test_cross_target_oracle.py | 65 ++-- tests/test_expander.py | 206 +++++++++-- tests/test_nearest_transpilation.py | 20 +- 4 files changed, 351 insertions(+), 259 deletions(-) diff --git a/tests/generators/test_base.py b/tests/generators/test_base.py index f95e91b..648f770 100644 --- a/tests/generators/test_base.py +++ b/tests/generators/test_base.py @@ -8,34 +8,37 @@ from hypothesis import given from hypothesis import settings from hypothesis import strategies as st -from sqlglot import exp from sqlglot import parse_one +import giql.expanders # noqa: F401 (side-effect: registers the NEAREST expander) from giql import Table from giql import transpile from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect -from giql.expressions import GIQLNearest +from giql.expander import ExpandOperators from giql.generators import BaseGIQLGenerator from giql.resolver import resolve_operator_refs from giql.table import Tables +from giql.targets import GenericTarget def _generate_through_passes(sql: str, tables: Tables) -> str: - """Parse, run normalization passes 1 and 2, then generate SQL. - - Coordinate canonicalization for operator operands moved out of the emitter and - into the CanonicalizeCoordinates pass (issue #123). Emitter-level tests that - pin canonicalized output must therefore run both passes before generating, - exactly as :func:`giql.transpile.transpile` does, rather than calling - ``generate`` on a bare parsed AST (which would skip canonicalization). This - helper is used where the full ``transpile`` pipeline would otherwise rewrite - the node away (a column-to-column ``INTERSECTS`` is turned into a binned - equi-join before the predicate emitter runs). + """Parse, run normalization passes 1-3, then generate SQL. + + Coordinate canonicalization (issue #123) and operator expansion (epic #137) + moved out of the emitter into the CanonicalizeCoordinates / ExpandOperators + passes. Emitter-level tests that pin canonicalized output — or an operator now + produced by its registered expander, such as NEAREST (#142) — must run those + passes before generating, exactly as :func:`giql.transpile.transpile` does, + rather than calling ``generate`` on a bare parsed AST. This helper is used + where the full ``transpile`` pipeline would otherwise rewrite the node away (a + column-to-column ``INTERSECTS`` is turned into a binned equi-join before the + predicate emitter runs). """ ast = parse_one(sql, dialect=GIQLDialect) ast = resolve_operator_refs(ast, tables) ast = canonicalize_coordinates(ast) + ast = ExpandOperators(GenericTarget(), tables).transform(ast) return BaseGIQLGenerator(tables=tables).generate(ast) @@ -398,38 +401,36 @@ def test_spatialsetpredicate_sql_all(self): def test_giqlnearest_sql_standalone(self, tables_with_peaks_and_genes): """ GIVEN a GIQLNearest in standalone mode with literal reference - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN Subquery with ORDER BY distance LIMIT k is generated. """ sql = "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 3)" output = _generate_through_passes(sql, tables_with_peaks_and_genes) + # NEAREST now expands via its registered expander (#142): the subquery is + # reserialized by sqlglot (single line, ``<>`` for ``!=``) but the + # distance CASE, WHERE prefilter, ORDER BY ABS, and LIMIT are unchanged + # from the legacy emitter (verified equivalent in test_nearest_transpilation + # and the cross-target oracle). expected = ( - "SELECT * FROM (\n" - " SELECT genes.*, " - "CASE WHEN 'chr1' != genes.\"chrom\" THEN NULL " + "SELECT * FROM (SELECT genes.*, " + "CASE WHEN 'chr1' <> genes.\"chrom\" THEN NULL " 'WHEN 1000 < genes."end" AND 2000 > genes."start" THEN 0 ' - 'WHEN 2000 <= genes."start" ' - 'THEN (genes."start" - 2000 + 1) ' - 'ELSE (1000 - genes."end" + 1) END AS distance\n' - " FROM genes\n" - " WHERE 'chr1' = genes.\"chrom\"\n" - " ORDER BY ABS(" - "CASE WHEN 'chr1' != genes.\"chrom\" THEN NULL " + 'WHEN 2000 <= genes."start" THEN (genes."start" - 2000 + 1) ' + 'ELSE (1000 - genes."end" + 1) END AS distance ' + "FROM genes WHERE 'chr1' = genes.\"chrom\" " + "ORDER BY ABS(CASE WHEN 'chr1' <> genes.\"chrom\" THEN NULL " 'WHEN 1000 < genes."end" AND 2000 > genes."start" THEN 0 ' - 'WHEN 2000 <= genes."start" ' - 'THEN (genes."start" - 2000 + 1) ' - 'ELSE (1000 - genes."end" + 1) END)\n' - " LIMIT 3\n" - " )" + 'WHEN 2000 <= genes."start" THEN (genes."start" - 2000 + 1) ' + 'ELSE (1000 - genes."end" + 1) END) LIMIT 3)' ) assert output == expected def test_giqlnearest_sql_correlated(self, tables_with_peaks_and_genes): """ GIVEN a GIQLNearest in correlated mode (LATERAL join context) - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs on a lateral-capable target THEN LATERAL-compatible subquery is generated. """ sql = ( @@ -439,26 +440,22 @@ def test_giqlnearest_sql_correlated(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) + # Reserialized by the #142 expander; LATERAL placement, distance CASE, + # WHERE, ORDER BY, and LIMIT are semantically unchanged from the legacy + # emitter. expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (\n" - " SELECT genes.*, " - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT genes.*, " + 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END AS distance\n' - " FROM genes\n" - ' WHERE peaks."chrom" = genes."chrom"\n' - " ORDER BY ABS(" - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + 'ELSE (peaks."start" - genes."end" + 1) END AS distance ' + 'FROM genes WHERE peaks."chrom" = genes."chrom" ' + 'ORDER BY ABS(CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END)\n' - " LIMIT 3\n" - " )" + 'ELSE (peaks."start" - genes."end" + 1) END) LIMIT 3)' ) assert output == expected @@ -476,33 +473,26 @@ def test_giqlnearest_sql_with_max_distance(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) + # Reserialized by the #142 expander; the max_distance filter on ABS of the + # distance CASE is semantically unchanged from the legacy emitter. expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (\n" - " SELECT genes.*, " - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT genes.*, " + 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END AS distance\n' - " FROM genes\n" - ' WHERE peaks."chrom" = genes."chrom" ' - "AND (ABS(" - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + 'ELSE (peaks."start" - genes."end" + 1) END AS distance ' + 'FROM genes WHERE peaks."chrom" = genes."chrom" ' + 'AND (ABS(CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END)) <= 100000\n' - " ORDER BY ABS(" - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + 'ELSE (peaks."start" - genes."end" + 1) END)) <= 100000 ' + 'ORDER BY ABS(CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END)\n' - " LIMIT 5\n" - " )" + 'ELSE (peaks."start" - genes."end" + 1) END) LIMIT 5)' ) assert output == expected @@ -520,41 +510,37 @@ def test_giqlnearest_sql_stranded(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) + # Reserialized by the #142 expander; the stranded distance CASE and the + # ``peaks.strand = genes.strand`` match in WHERE are semantically + # unchanged from the legacy emitter. expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (\n" - " SELECT genes.*, " - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT genes.*, " + 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' 'WHEN peaks."strand" IS NULL OR genes."strand" IS NULL THEN NULL ' "WHEN peaks.\"strand\" = '.' OR peaks.\"strand\" = '?' THEN NULL " "WHEN genes.\"strand\" = '.' OR genes.\"strand\" = '?' THEN NULL " - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' "THEN CASE WHEN peaks.\"strand\" = '-' " 'THEN -(genes."start" - peaks."end" + 1) ' 'ELSE (genes."start" - peaks."end" + 1) END ' "ELSE CASE WHEN peaks.\"strand\" = '-' " 'THEN -(peaks."start" - genes."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END END AS distance\n' - " FROM genes\n" - ' WHERE peaks."chrom" = genes."chrom" ' - 'AND peaks."strand" = genes."strand"\n' - " ORDER BY ABS(" - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' + 'ELSE (peaks."start" - genes."end" + 1) END END AS distance ' + 'FROM genes WHERE peaks."chrom" = genes."chrom" ' + 'AND peaks."strand" = genes."strand" ' + 'ORDER BY ABS(CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' 'WHEN peaks."strand" IS NULL OR genes."strand" IS NULL THEN NULL ' "WHEN peaks.\"strand\" = '.' OR peaks.\"strand\" = '?' THEN NULL " "WHEN genes.\"strand\" = '.' OR genes.\"strand\" = '?' THEN NULL " - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' "THEN CASE WHEN peaks.\"strand\" = '-' " 'THEN -(genes."start" - peaks."end" + 1) ' 'ELSE (genes."start" - peaks."end" + 1) END ' "ELSE CASE WHEN peaks.\"strand\" = '-' " 'THEN -(peaks."start" - genes."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END END)\n' - " LIMIT 3\n" - " )" + 'ELSE (peaks."start" - genes."end" + 1) END END) LIMIT 3)' ) assert output == expected @@ -585,7 +571,7 @@ def test_giqlnearest_sql_implicit_outer_without_strand_column(self): def test_giqlnearest_sql_signed(self, tables_with_peaks_and_genes): """ GIVEN a GIQLNearest with signed := true - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN Distance expression includes signed calculation. """ sql = ( @@ -596,50 +582,31 @@ def test_giqlnearest_sql_signed(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) + # Reserialized by the #142 expander; the signed distance CASE (negated + # ELSE branch for upstream) is semantically unchanged from the legacy + # emitter. expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (\n" - " SELECT genes.*, " - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT genes.*, " + 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE -(peaks."start" - genes."end" + 1) END AS distance\n' - " FROM genes\n" - ' WHERE peaks."chrom" = genes."chrom"\n' - " ORDER BY ABS(" - 'CASE WHEN peaks."chrom" != genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" ' - 'AND peaks."end" > genes."start" THEN 0 ' - 'WHEN peaks."end" <= genes."start" ' + 'ELSE -(peaks."start" - genes."end" + 1) END AS distance ' + 'FROM genes WHERE peaks."chrom" = genes."chrom" ' + 'ORDER BY ABS(CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' + 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' + 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE -(peaks."start" - genes."end" + 1) END)\n' - " LIMIT 3\n" - " )" + 'ELSE -(peaks."start" - genes."end" + 1) END) LIMIT 3)' ) assert output == expected - def test_giqlnearest_sql_no_lateral_support(self, tables_with_peaks_and_genes): - """ - GIVEN a GIQLNearest on a generator with SUPPORTS_LATERAL=False - WHEN giqlnearest_sql is called in correlated mode - THEN ValueError is raised with helpful message. - """ - - # Create a generator subclass without LATERAL support - class NoLateralGenerator(BaseGIQLGenerator): - SUPPORTS_LATERAL = False - - # Use query without explicit reference to trigger correlated mode - sql = "SELECT * FROM peaks CROSS JOIN LATERAL NEAREST(genes, k := 3)" - ast = parse_one(sql, dialect=GIQLDialect) - ast = resolve_operator_refs(ast, tables_with_peaks_and_genes) - ast = canonicalize_coordinates(ast) - - generator = NoLateralGenerator(tables=tables_with_peaks_and_genes) - - with pytest.raises(ValueError, match="LATERAL"): - generator.generate(ast) + # The legacy ``SUPPORTS_LATERAL=False`` generator-level error path was removed + # with ``giqlnearest_sql`` (#142): lateral support is now a target capability, + # and a target without it (DataFusion) gets the decorrelated window-function + # fallback rather than a hard error. That fallback's result-identity with the + # LATERAL form is verified by the cross-target oracle + # (tests/integration/datafusion/test_cross_target_oracle.py). @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) @given( @@ -991,23 +958,16 @@ def test_giqlnearest_sql_missing_outer_table_error( self, tables_with_peaks_and_genes ): """ - GIVEN a GIQLNearest in correlated mode without reference where outer table - cannot be found - WHEN giqlnearest_sql is called - THEN ValueError is raised with helpful message about specifying reference. + GIVEN a GIQLNearest without a reference and no resolvable outer table + WHEN the NEAREST expander runs + THEN ValueError is raised with a helpful message about specifying reference. """ + # Arrange — no reference and no LATERAL outer relation to infer one from. + sql = "SELECT * FROM NEAREST(genes, k := 3)" - nearest = GIQLNearest( - this=exp.Table(this=exp.Identifier(this="genes")), - k=exp.Literal.number(3), - ) - resolve_operator_refs(nearest, tables_with_peaks_and_genes) - canonicalize_coordinates(nearest) - - generator = BaseGIQLGenerator(tables=tables_with_peaks_and_genes) - + # Act & assert with pytest.raises(ValueError, match="Could not find outer table"): - generator.giqlnearest_sql(nearest) + _generate_through_passes(sql, tables_with_peaks_and_genes) def test_giqlnearest_sql_outer_table_not_in_tables(self): """ @@ -1038,34 +998,31 @@ def test_giqlnearest_sql_invalid_reference_range(self, tables_with_peaks_and_gen def test_giqlnearest_sql_no_tables_error(self): """ - GIVEN a GIQLNearest without tables registered - WHEN giqlnearest_sql is called - THEN ValueError is raised because target table cannot be resolved. + GIVEN a GIQLNearest with no tables registered + WHEN the NEAREST expander runs + THEN ValueError is raised because the target table cannot be resolved. """ + # Arrange sql = "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 3)" - ast = parse_one(sql, dialect=GIQLDialect) - - # Generator with empty tables - table won't be found - generator = BaseGIQLGenerator() + # Act & assert with pytest.raises(ValueError, match="not found in tables"): - generator.generate(ast) + _generate_through_passes(sql, Tables()) def test_giqlnearest_sql_target_not_in_tables(self, tables_with_peaks_and_genes): """ - GIVEN a GIQLNearest with target table not registered - WHEN giqlnearest_sql is called - THEN ValueError is raised listing available tables. + GIVEN a GIQLNearest whose target table is not registered + WHEN the NEAREST expander runs + THEN ValueError is raised listing the unresolved table. """ + # Arrange sql = ( "SELECT * FROM NEAREST(unknown_table, reference := 'chr1:1000-2000', k := 3)" ) - ast = parse_one(sql, dialect=GIQLDialect) - - generator = BaseGIQLGenerator(tables=tables_with_peaks_and_genes) + # Act & assert with pytest.raises(ValueError, match="not found in tables"): - generator.generate(ast) + _generate_through_passes(sql, tables_with_peaks_and_genes) def test_intersects_sql_unqualified_column(self): """ @@ -1087,53 +1044,39 @@ def test_giqlnearest_sql_stranded_unqualified_reference( self, tables_with_peaks_and_genes ): """ - GIVEN a GIQLNearest with stranded := true and unqualified column reference - WHEN giqlnearest_sql is called + GIVEN a GIQLNearest with stranded := true and an unqualified column reference + WHEN the NEAREST expander runs THEN Strand column is resolved without table prefix. """ - - # Create NEAREST with stranded=True and an unqualified column reference - # The reference is an unqualified column (no table prefix) - nearest = GIQLNearest( - this=exp.Table(this=exp.Identifier(this="genes")), - reference=exp.Column(this=exp.Identifier(this="interval")), - k=exp.Literal.number(3), - stranded=exp.Boolean(this=True), + # Arrange — the reference is an unqualified column (no table prefix). + sql = ( + "SELECT * FROM peaks CROSS JOIN LATERAL " + "NEAREST(genes, reference := interval, k := 3, stranded := true)" ) - resolve_operator_refs(nearest, tables_with_peaks_and_genes) - canonicalize_coordinates(nearest) - generator = BaseGIQLGenerator(tables=tables_with_peaks_and_genes) - output = generator.giqlnearest_sql(nearest) + # Act + output = _generate_through_passes(sql, tables_with_peaks_and_genes) - # Should produce valid output with unqualified strand column + # Assert assert "LIMIT 3" in output - # The strand column should be unqualified (no table prefix) assert '"strand"' in output - def test_giqlnearest_sql_identifier_target(self, tables_with_peaks_and_genes): + def test_giqlnearest_sql_literal_target(self, tables_with_peaks_and_genes): """ - GIVEN a GIQLNearest where target is an Identifier (not Table or Column) - WHEN giqlnearest_sql is called - THEN Target is converted to string and lookup proceeds. + GIVEN a GIQLNearest with a standalone literal reference + WHEN the NEAREST expander runs + THEN it produces valid SQL selecting from the target table. """ + # Arrange + sql = "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 3)" - # Use exp.Identifier directly - not Table or Column - # This triggers the else branch at line 830 where str(target) is called - nearest = GIQLNearest( - this=exp.Identifier(this="genes"), - reference=exp.Literal.string("chr1:1000-2000"), - k=exp.Literal.number(3), - ) - resolve_operator_refs(nearest, tables_with_peaks_and_genes) - canonicalize_coordinates(nearest) - - generator = BaseGIQLGenerator(tables=tables_with_peaks_and_genes) - output = generator.giqlnearest_sql(nearest) + # Act + output = _generate_through_passes(sql, tables_with_peaks_and_genes) - # Should succeed and produce valid SQL + # Assert assert "genes" in output assert "LIMIT 3" in output + assert "LIMIT 3" in output @given( bool_repr=st.sampled_from(["true", "TRUE", "True", "1", "yes", "YES"]), diff --git a/tests/integration/datafusion/test_cross_target_oracle.py b/tests/integration/datafusion/test_cross_target_oracle.py index 6b717a1..f270fff 100644 --- a/tests/integration/datafusion/test_cross_target_oracle.py +++ b/tests/integration/datafusion/test_cross_target_oracle.py @@ -14,13 +14,14 @@ genuinely divergent SQL across targets (the DuckDB IEJoin vs. the binned equi-join). -NEAREST's expansion uses a correlated ``LATERAL`` subquery, which DataFusion has -no physical plan for today; its generic-vs-duckdb equivalence case runs both on -DuckDB, and the full three-target oracle is pinned by a -``pytest.raises(match="OuterReferenceColumn")`` test (#142) that fails loudly on -an unrelated error and trips "DID NOT RAISE" — forcing conversion to a real -identity test — when DataFusion gains correlated LATERAL. DISJOIN has an -analogous pending-#153 gap (duplicate ``end`` output names). +NEAREST's correlated expansion is capability-driven (#142): lateral-capable +targets (generic, duckdb) emit the portable ``LATERAL`` subquery, while +DataFusion — which has no correlated-LATERAL physical plan — gets a decorrelated +window-function fallback. Both forms return identical rows, so the full +three-target identity oracle now runs on every target (the former +``_unsupported_pending_142`` ``pytest.raises`` pin has been promoted to a real +identity test). DISJOIN has an analogous pending-#153 gap (duplicate ``end`` +output names). """ import pytest @@ -204,32 +205,42 @@ def test_standalone_nearest_k1_agrees_generic_vs_duckdb_on_duckdb( engines={"generic": "duckdb"}, ) - def test_nearest_on_datafusion_unsupported_pending_142(self, cross_target_oracle): - """Test the full NEAREST oracle raises DataFusion's missing-LATERAL error. + def test_correlated_nearest_k1_agrees_across_all_targets(self, cross_target_oracle): + """Test correlated NEAREST k=1 returns identical rows on every target. Given: - The single-row NEAREST query and a candidate gene on chr1. + A single-row peaks table and three candidate genes at varying + distances on chr1. When: - The oracle runs all three targets — the datafusion target executes - the correlated LATERAL on DataFusion, which has no physical plan. + A correlated ``CROSS JOIN LATERAL NEAREST(..., k := 1)`` query runs + for every target — the generic and duckdb targets emit the portable + LATERAL form (executed on DuckDB, the lateral-capable engine), and + the datafusion target emits the decorrelated window-function fallback + the #142 expander produces (executed on DataFusion). Then: - DataFusion should raise its ``OuterReferenceColumn`` "not - implemented" error. This pins the known #142 gap: the ``match`` - narrows to the LATERAL signature so an unrelated/reworded DataFusion - error fails loudly, and a closed gap (no exception) trips pytest's - "DID NOT RAISE", forcing this to be converted into a real - cross-target identity test when DataFusion gains correlated LATERAL. + Every target should return the single nearest gene and agree. + + Promoted from the ``_unsupported_pending_142`` expected-failure pin: + DataFusion now plans correlated NEAREST through the capability-driven + window-function fallback, so the full three-target oracle is a real + identity test rather than a ``pytest.raises`` placeholder. The generic + target is routed to DuckDB because its portable SQL is the LATERAL form, + which only the datafusion-specific fallback decorrelates for DataFusion. """ # Arrange / Act / Assert - with pytest.raises(Exception, match="OuterReferenceColumn"): - cross_target_oracle( - "SELECT a.chrom, a.start AS a_start, b.start AS b_start " - "FROM peaks a " - "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1) b", - peaks=[("chr1", 200, 300)], - genes=[("chr1", 280, 290)], - expected=[("chr1", 200, 280)], - ) + cross_target_oracle( + "SELECT a.chrom, a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1) b", + peaks=[("chr1", 200, 300)], + genes=[ + ("chr1", 1000, 1100), + ("chr1", 50, 60), + ("chr1", 280, 290), + ], + expected=[("chr1", 200, 280)], + engines={"generic": "duckdb"}, + ) class TestCrossTargetOracleIntersectsAnyAll: diff --git a/tests/test_expander.py b/tests/test_expander.py index ad625f1..31a22f1 100644 --- a/tests/test_expander.py +++ b/tests/test_expander.py @@ -53,34 +53,48 @@ def _expander(node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: return _expander +#: The registry contents at import — the built-in expanders registered by +#: ``giql.expanders`` for already-migrated operators (DISJOIN as of #143). The +#: leak guards and ``clean_registry`` treat this as the baseline rather than an +#: empty registry, so the real built-in registrations survive isolating fixtures +#: and a leaking test is still caught against the true baseline. +_REGISTRY_BASELINE = REGISTRY.snapshot() + + @pytest.fixture def clean_registry(): - """Isolate the process-wide REGISTRY, leaving it empty afterward. + """Isolate the process-wide REGISTRY, restoring its baseline afterward. - The registry is empty at import (the pass ships inert), so a test that opts - in clears on the way out; emptiness is asserted through the public - ``bool()``/``len()`` surface rather than private state. + Saves the import-time baseline (the built-in expanders), empties the registry + so a test sees only what it registers, and restores the baseline on the way + out through the public ``snapshot()``/``restore()`` seam — so a test that + registers a stand-in expander cannot leak it, and the built-in registrations + survive this fixture's isolation. """ - assert not REGISTRY, "REGISTRY was non-empty entering clean_registry" + saved = REGISTRY.snapshot() REGISTRY.clear() yield REGISTRY - REGISTRY.clear() + REGISTRY.restore(saved) @pytest.fixture(autouse=True) def _registry_leak_guard(): - """Assert the process-wide REGISTRY is empty at each test boundary. - - A leak guard: the registry is empty at import and must return to empty after - every test, so a test that registers without cleaning up (a leak that would - silently flip the no-op pass for a later test) fails loudly. Tests that - register on the process-wide REGISTRY do so through ``clean_registry``, which - clears on the way out; this guard catches anything that bypasses it. Both - checks go through the public ``bool()`` surface (A5), not private state. + """Assert the process-wide REGISTRY matches its baseline at each boundary. + + A leak guard: the registry holds the built-in expanders at import and must + return to exactly that baseline after every test, so a test that registers + without cleaning up (a leak that would silently change dispatch for a later + test) fails loudly. Tests that mutate the process-wide REGISTRY do so through + ``clean_registry``, which restores the baseline on the way out; this guard + catches anything that bypasses it. """ - assert not REGISTRY, "REGISTRY leaked into a test from a prior one" + assert REGISTRY.snapshot() == _REGISTRY_BASELINE, ( + "REGISTRY differed from its baseline entering a test" + ) yield - assert not REGISTRY, "a test leaked a registration into REGISTRY" + assert REGISTRY.snapshot() == _REGISTRY_BASELINE, ( + "a test leaked a registration into REGISTRY" + ) @pytest.fixture(autouse=True) @@ -88,18 +102,19 @@ def _expand_flag_leak_guard(): """Assert every operator's GIQL_EXPAND is restored at each test boundary. The symmetric partner of the registry leak guard: each operator class ships - opted out (its own GIQL_EXPAND attribute is False), and a test that flips one - via ``_opted_in`` must restore it. A leaked opt-in would silently flip the - no-op pass for a later test, so this catches anything that bypasses the - exception-safe ``_opted_in`` manager. + a shipped GIQL_EXPAND default (``True`` for a migrated operator like DISJOIN, + ``False`` otherwise), and a test that flips one via ``_opted_in`` must restore + it. A leaked flip would silently change the pass for a later test, so this + catches anything that bypasses the exception-safe ``_opted_in`` manager by + comparing against each operator's shipped default rather than a blanket False. """ for op in _OPERATOR_CLASSES: - assert op.__dict__.get("GIQL_EXPAND") is False, ( + assert op.__dict__.get("GIQL_EXPAND") is _SHIPPED_EXPAND_FLAGS[op], ( f"{op.__name__}.GIQL_EXPAND leaked into a test from a prior one" ) yield for op in _OPERATOR_CLASSES: - assert op.__dict__.get("GIQL_EXPAND") is False, ( + assert op.__dict__.get("GIQL_EXPAND") is _SHIPPED_EXPAND_FLAGS[op], ( f"a test leaked a GIQL_EXPAND opt-in on {op.__name__}" ) @@ -468,6 +483,55 @@ def _expander(node, ctx): assert REGISTRY.resolve(DuckDBTarget(), GIQLDisjoin) is None assert (DuckDBTarget(), GIQLDisjoin) not in REGISTRY + def test_snapshot_is_independent_of_later_registrations(self): + """Test that a snapshot does not observe registrations made after it. + + Given: + A registry with one entry, captured by snapshot. + When: + A second entry is registered after the snapshot is taken. + Then: + The snapshot should still hold only the first entry (it is a copy, + not a live view). + """ + # Arrange + registry = ExpanderRegistry() + registry.register(DuckDBTarget(), GIQLDisjoin, _record("first")) + + # Act + saved = registry.snapshot() + registry.register(GenericTarget(), Intersects, _record("second")) + + # Assert + assert (DuckDBTarget(), GIQLDisjoin) in saved + assert (GenericTarget(), Intersects) not in saved + + def test_restore_replaces_entries_with_snapshot_contents(self): + """Test that restore returns the registry to a captured snapshot. + + Given: + A snapshot of a registry with one entry, after which the registry is + cleared and a different entry registered. + When: + Restoring the snapshot. + Then: + The original entry should resolve again and the post-snapshot entry + should be gone. + """ + # Arrange + registry = ExpanderRegistry() + registry.register(DuckDBTarget(), GIQLDisjoin, _record("original")) + saved = registry.snapshot() + registry.clear() + registry.register(GenericTarget(), Intersects, _record("transient")) + + # Act + registry.restore(saved) + + # Assert + assert (DuckDBTarget(), GIQLDisjoin) in registry + assert (GenericTarget(), Intersects) not in registry + class TestRegisterDecorator: """Tests for the @register extension-hook decorator.""" @@ -956,7 +1020,8 @@ def test_transform_skips_unflagged_operator(self, clean_registry): Given: An expander registered for (GenericTarget, GIQLDisjoin) but the - operator's GIQL_EXPAND flag left at its default False. + operator's GIQL_EXPAND flag held off (DISJOIN ships it on, so the + control opts it out to isolate the per-type gate). When: Running the pass. Then: @@ -968,8 +1033,9 @@ def test_transform_skips_unflagged_operator(self, clean_registry): ast = _prepare("SELECT * FROM DISJOIN(variants)", tables) pass_ = ExpandOperators(GenericTarget(), tables, clean_registry) - # Act (GIQL_EXPAND is False by default — no opt-in context) - result = pass_.transform(ast) + # Act + with _opted_out(GIQLDisjoin): + result = pass_.transform(ast) # Assert assert list(result.find_all(GIQLDisjoin)) @@ -1022,17 +1088,21 @@ def test_expand_operators_is_identity_when_registry_empty(self): assert list(result.find_all(GIQLDisjoin)) def test_transpile_sql_unchanged_with_pass_inert(self): - """Test that transpile output is byte-identical with the pass inert. + """Test that transpile output is byte-identical for an unmigrated operator. Given: - A DISJOIN query, the default (empty) registry, and no operator flagged. + A DISTANCE query (an operator not migrated onto the pass, so its + GIQL_EXPAND is False and no expander resolves), with the default + registry. When: - Transpiling with the wired-in pass versus a pass-bypassed reference. + Transpiling with the wired-in pass versus a pass-bypassed reference + (its legacy emitter run directly). Then: - The SQL should match exactly and carry no expander alias prefix. + The SQL should match exactly and carry no expander alias prefix — the + pass is inert for any operator that has not been migrated. """ # Arrange - query = "SELECT * FROM DISJOIN(variants)" + query = "SELECT DISTANCE(a.interval, b.interval) FROM variants a, variants b" tables = _tables() ast = _prepare(query, tables) from giql.generators import BaseGIQLGenerator @@ -1069,20 +1139,39 @@ def test_transpile_sql_unchanged_with_pass_inert(self): GIQLMerge, ) +#: Each operator's shipped GIQL_EXPAND default, captured from its own class dict +#: at import. A migrated operator (DISJOIN, #143) ships ``True``; the rest ship +#: ``False`` until their migrations land. The flag leak guard restores to these +#: shipped values rather than a blanket ``False``. +_SHIPPED_EXPAND_FLAGS = {op: op.__dict__.get("GIQL_EXPAND") for op in _OPERATOR_CLASSES} + + +#: Operators migrated onto the ExpandOperators pass — they ship GIQL_EXPAND=True. +_MIGRATED_OPERATORS = tuple( + op for op in _OPERATOR_CLASSES if op.__dict__.get("GIQL_EXPAND") is True +) +#: Operators not yet migrated — they ship GIQL_EXPAND=False. +_UNMIGRATED_OPERATORS = tuple( + op for op in _OPERATOR_CLASSES if op not in _MIGRATED_OPERATORS +) + class TestOperatorOptOut: - """Every operator ships opted out of the ExpandOperators pass at this step.""" + """Migrated operators opt into the pass; the rest still ship opted out.""" - @pytest.mark.parametrize("operator", _OPERATOR_CLASSES, ids=lambda c: c.__name__) + @pytest.mark.parametrize( + "operator", _UNMIGRATED_OPERATORS, ids=lambda c: c.__name__ + ) def test_operator_class_ships_expand_disabled(self, operator): - """Test that each operator class ships GIQL_EXPAND=False. + """Test that each unmigrated operator class ships GIQL_EXPAND=False. Given: - One of the nine GIQL operator expression classes. + A GIQL operator expression class that has not been migrated onto the + ExpandOperators pass. When: Reading its GIQL_EXPAND class attribute. Then: - It should be False (no operator opts into expansion yet). + It should be False (the operator still uses the legacy emitter). """ # Arrange & act flag = operator.GIQL_EXPAND @@ -1090,6 +1179,27 @@ def test_operator_class_ships_expand_disabled(self, operator): # Assert assert flag is False + @pytest.mark.parametrize( + "operator", _MIGRATED_OPERATORS, ids=lambda c: c.__name__ + ) + def test_operator_class_ships_expand_enabled(self, operator): + """Test that each migrated operator class ships GIQL_EXPAND=True. + + Given: + A GIQL operator expression class migrated onto the ExpandOperators + pass (DISJOIN, #143). + When: + Reading its GIQL_EXPAND class attribute. + Then: + It should be True (the operator expands through its registered + expander instead of the deleted legacy emitter). + """ + # Arrange & act + flag = operator.GIQL_EXPAND + + # Assert + assert flag is True + class TestOptedInRestoresFlag: """The _opted_in helper restores GIQL_EXPAND even when its body raises.""" @@ -1672,8 +1782,8 @@ def test_walk_partial_opt_in_replaces_only_flagged_type(self, clean_registry): ) pass_ = ExpandOperators(GenericTarget(), tables, clean_registry) - # Act - with _opted_in(Intersects): + # Act (DISJOIN ships flagged, so opt it out to hold it as the control) + with _opted_in(Intersects), _opted_out(GIQLDisjoin): result = pass_.transform(ast) # Assert @@ -1888,3 +1998,25 @@ def __enter__(self): def __exit__(self, *exc): self._operator.GIQL_EXPAND = self._prior return False + + +class _opted_out: + """Context manager opting an operator class out of GIQL_EXPAND for a test. + + The complement of :class:`_opted_in`: used by a control test that needs a + *migrated* operator (DISJOIN ships GIQL_EXPAND=True) to behave as if + unflagged, so the test can prove the pass gates per-type without the + operator's shipped opt-in interfering. Restores the prior flag on exit. + """ + + def __init__(self, operator: type) -> None: + self._operator = operator + self._prior = operator.__dict__.get("GIQL_EXPAND", False) + + def __enter__(self): + self._operator.GIQL_EXPAND = False + return self._operator + + def __exit__(self, *exc): + self._operator.GIQL_EXPAND = self._prior + return False diff --git a/tests/test_nearest_transpilation.py b/tests/test_nearest_transpilation.py index 2488cb0..caaed98 100644 --- a/tests/test_nearest_transpilation.py +++ b/tests/test_nearest_transpilation.py @@ -7,26 +7,32 @@ import pytest from sqlglot import parse_one +import giql.expanders # noqa: F401 (side-effect: registers the NEAREST expander) from giql import Table from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect +from giql.expander import ExpandOperators from giql.generators import BaseGIQLGenerator from giql.resolver import resolve_operator_refs from giql.table import Tables +from giql.targets import GenericTarget def _generate(sql: str, tables: Tables) -> str: - """Parse, run normalization passes 1 and 2, then generate SQL. - - Operator resolution and coordinate canonicalization moved out of the emitter - and into the ResolveOperatorRefs / CanonicalizeCoordinates passes (epic #114, - issues #118-#123). Emitter-level tests must run both passes before generating, - exactly as :func:`giql.transpile.transpile` does, rather than calling - ``generate`` on a bare parsed AST. + """Parse, run normalization passes 1-3, then generate SQL. + + Operator resolution, coordinate canonicalization, and operator expansion + moved out of the emitter into the ResolveOperatorRefs / CanonicalizeCoordinates + / ExpandOperators passes (epics #114, #137). NEAREST is now produced by its + registered expander (issue #142) rather than a ``giqlnearest_sql`` emitter, so + these tests must run pass 3 before generating, exactly as + :func:`giql.transpile.transpile` does, rather than calling ``generate`` on a + bare parsed AST. """ ast = parse_one(sql, dialect=GIQLDialect) ast = resolve_operator_refs(ast, tables) ast = canonicalize_coordinates(ast) + ast = ExpandOperators(GenericTarget(), tables).transform(ast) return BaseGIQLGenerator(tables=tables).generate(ast) From 0215c44d55b686264c52ed207ab84a26051adff8 Mon Sep 17 00:00:00 2001 From: Conrad Date: Sun, 28 Jun 2026 19:25:37 -0400 Subject: [PATCH 4/4] fix: Address PR #158 review findings for NEAREST Fix the literal-reference NEAREST crash on DataFusion by gating the decorrelated fallback on genuine correlation and materializing the distance in a two-level subquery. Add executing cross-target oracle cases (k>1, duplicate references, multi-key, max_distance, stranded, signed) and a deterministic tiebreaker so the LATERAL and window forms are set-equivalent. Delete dead helpers and SUPPORTS_LATERAL, make borrowed helpers static, mint fallback aliases via ctx.alias, add invariant asserts, and document DataFusion support. Apply the shared registry-docstring, restore-in-place, and auto-discovery fixes. --- docs/dialect/distance-operators.rst | 5 + src/giql/expander.py | 88 ++++++--- src/giql/expanders/__init__.py | 8 + src/giql/expanders/nearest.py | 174 +++++++++++------ src/giql/generators/base.py | 36 +--- src/giql/targets.py | 12 +- src/giql/transpile.py | 11 +- tests/generators/test_base.py | 175 +++++++++--------- .../datafusion/test_cross_target_oracle.py | 165 +++++++++++++++++ tests/test_expander.py | 65 ++++--- tests/test_nearest_transpilation.py | 99 ++++++++++ 11 files changed, 614 insertions(+), 224 deletions(-) diff --git a/docs/dialect/distance-operators.rst b/docs/dialect/distance-operators.rst index d3ebe89..10455e1 100644 --- a/docs/dialect/distance-operators.rst +++ b/docs/dialect/distance-operators.rst @@ -309,6 +309,11 @@ Find nearby same-strand features within distance constraints: WHERE nearest.distance BETWEEN -10000 AND 10000 ORDER BY peaks.name, ABS(nearest.distance) +Target support +~~~~~~~~~~~~~~ + +A correlated ``NEAREST`` (its reference is an outer-row column) runs on lateral-capable engines — DuckDB and the generic target — via a correlated ``LATERAL`` subquery, and on Apache DataFusion, which has no correlated-``LATERAL`` physical plan, via a decorrelated window-function rewrite. Both forms return identical results (a deterministic tiebreaker orders rows tied at the k-th distance the same way on every engine). A standalone ``NEAREST`` with a literal reference is uncorrelated and uses the same ordered, limited subquery on every target. + Notes ~~~~~ diff --git a/src/giql/expander.py b/src/giql/expander.py index c6558ac..3b5a900 100644 --- a/src/giql/expander.py +++ b/src/giql/expander.py @@ -39,10 +39,11 @@ i.e. a ``(target, op)`` or ``(generic, op)`` expander is registered. Otherwise it falls through to the legacy ``*_sql`` emitter on -:class:`giql.generators.base.BaseGIQLGenerator`. As of this issue **no operator -sets ``GIQL_EXPAND`` and the registry is empty, so the pass is a strict no-op**: -no node is touched and the emitted SQL is byte-identical. Each later migration PR -(epic #137 steps 4-9) registers a generic expander, flips one operator's +:class:`giql.generators.base.BaseGIQLGenerator`. The built-in expanders register +at import time via :mod:`giql.expanders`; the pass rewrites a node only when +``GIQL_EXPAND=True`` **and** an expander resolves for ``(active target, operator +type)``, and is a no-op for any operator that is unflagged or has no registered +expander. A migration PR registers an expander, flips one operator's ``GIQL_EXPAND`` flag, and deletes that operator's ``*_sql`` method. """ @@ -149,6 +150,14 @@ class OperatorExpander(Protocol): a registered object satisfies it. A plain function is *not* an ``OperatorExpander`` (it has no ``expand`` method); register one by wrapping it (see :func:`register`, which accepts either form). + + An expander is **node-local**: ``expand(node, ctx) -> exp.Expression`` sees + one operator node and returns the expression that replaces it in place. It + cannot express a whole-query rewrite such as the INTERSECTS IEJoin fold, + which restructures the surrounding query (joins, CTEs) rather than a single + node. That fold is therefore deferred — it would need a separate + query-level mechanism — and is handled by the pre-pass join transformers, not + by an expander. """ def expand(self, node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: ... @@ -215,6 +224,16 @@ def register( expander : OperatorExpander | ExpanderFn The expander object or function. A later registration for the same key replaces an earlier one (last-write-wins override). + + Notes + ----- + Registering a *non-generic* ``(target, operator)`` expander where the + operator has a built-in whole-query join rewrite (notably + :class:`~giql.expressions.Intersects`, whose binned equi-join / DuckDB + IEJoin transformers run before expansion) signals that this expander + assumes responsibility for that rewrite: the built-in join transformers + are bypassed for that target so the operator node flows untouched into + :class:`ExpandOperators`. See :meth:`has_override`. """ self._expanders[(target, operator)] = _as_callable(expander) @@ -223,6 +242,12 @@ def resolve(self, target: Target, operator: type) -> ExpanderFn | None: Tries the exact ``(target, op)`` entry, then the ``(GenericTarget(), op)`` fallback, then ``None`` (legacy emitter). + + A non-generic exact ``(target, op)`` entry is also a *join-rewrite + override* for operators with a built-in whole-query join rewrite (notably + :class:`~giql.expressions.Intersects`): registering one bypasses the + built-in binned / IEJoin transformers for that target (see + :meth:`register` and :meth:`has_override`). """ fn = self._expanders.get((target, operator)) if fn is not None: @@ -233,6 +258,19 @@ def resolve(self, target: Target, operator: type) -> ExpanderFn | None: return fn return None + def has_override(self, target: Target, operator: type) -> bool: + """Whether an exact non-generic override supersedes built-in handling. + + Returns ``True`` only when *target* is not :class:`~giql.targets.GenericTarget` + and an exact ``(target, operator)`` entry is registered. Such an entry is a + target-specific override that supersedes built-in handling for that target + (e.g. it takes responsibility for the whole-query join rewrite that the + built-in transformers would otherwise perform); the portable + ``(GenericTarget(), operator)`` fallback is *not* an override and does not + count here. + """ + return target != GenericTarget() and (target, operator) in self._expanders + def unregister(self, target: Target, operator: type) -> None: """Drop the ``(target, operator)`` entry if present. @@ -256,12 +294,12 @@ def clear(self) -> None: def snapshot(self) -> dict[tuple[Target, type], ExpanderFn]: """Return a shallow copy of the current registrations. - The save half of the registry's **public save/restore seam**: a test - fixture (or a plugin) that mutates the process-wide :data:`REGISTRY` - around a body — registering or clearing entries — captures the baseline - with this and hands it back to :meth:`restore` afterward, so the - built-in expanders registered at import survive an isolating fixture - that would otherwise :meth:`clear` them permanently. + The save half of a save/restore seam used primarily for test baseline + isolation (and which may serve a plugin that mutates the process-wide + :data:`REGISTRY` around a body): capture the baseline with this and hand + it back to :meth:`restore` afterward, so the built-in expanders + registered at import survive an isolating fixture that would otherwise + :meth:`clear` them permanently. It is not a committed plugin API. The returned dict is a fresh mapping (mutating it does not affect the registry), keyed by the same ``(target, operator)`` tuples. @@ -271,12 +309,14 @@ def snapshot(self) -> dict[tuple[Target, type], ExpanderFn]: def restore(self, snapshot: dict[tuple[Target, type], ExpanderFn]) -> None: """Replace all registrations with those captured by :meth:`snapshot`. - The restore half of the save/restore seam. Drops every current entry and - re-installs exactly the *snapshot* contents, so a fixture can return the - registry to a previously captured baseline regardless of what its body - registered or cleared. + The restore half of the save/restore seam (test baseline isolation; may + also serve a plugin, but is not a committed plugin API). Drops every + current entry and re-installs exactly the *snapshot* contents, so a + fixture can return the registry to a previously captured baseline + regardless of what its body registered or cleared. """ - self._expanders = dict(snapshot) + self._expanders.clear() + self._expanders.update(snapshot) def __contains__(self, key: tuple[Target, type]) -> bool: """Whether an *exact* ``(target, operator)`` entry is registered. @@ -305,8 +345,9 @@ def __bool__(self) -> bool: #: The process-wide registry the :func:`register` decorator writes to and the -#: :class:`ExpandOperators` pass reads from. Empty as of this issue, so the pass -#: is a strict no-op. +#: :class:`ExpandOperators` pass reads from. The built-in expanders register into +#: it at import time via :mod:`giql.expanders`; the pass rewrites a node only when +#: an expander resolves here (and the operator is flagged ``GIQL_EXPAND``). REGISTRY = ExpanderRegistry() @@ -405,9 +446,10 @@ class sets ``GIQL_EXPAND = True`` *and* the registry resolves an expander for ``(target, operator type)`` through its fallback chain; otherwise the node is left untouched and the legacy ``*_sql`` emitter handles it. - The pass mutates and returns *expression* in place. **With no operator - flagged and an empty registry it is a strict no-op** and the emitted SQL is - byte-identical, so the existing suite is the migration oracle. + The pass mutates and returns *expression* in place. It touches only nodes + whose operator is flagged ``GIQL_EXPAND`` and resolves an expander; for every + other operator it is a no-op, leaving the emitted SQL byte-identical, so the + existing suite is the migration oracle. Parameters ---------- @@ -424,9 +466,9 @@ class sets ``GIQL_EXPAND = True`` *and* the registry resolves an expander for Returns ------- exp.Expression - The same *expression*, with opted-in operator nodes replaced by their - target-specific expansions (none, while every flag is off / the registry - is empty). + The same *expression*, with each opted-in operator node that resolves an + expander replaced by its target-specific expansion; nodes that are + unflagged or resolve no expander are left untouched. """ reg = registry if registry is not None else REGISTRY operators = _giql_operators() diff --git a/src/giql/expanders/__init__.py b/src/giql/expanders/__init__.py index 86bf43e..042a07d 100644 --- a/src/giql/expanders/__init__.py +++ b/src/giql/expanders/__init__.py @@ -8,6 +8,12 @@ New operator modules are picked up automatically: drop a ``.py`` into this package and it is imported here without editing this file. + +Modules whose name starts with ``_`` are skipped (private helpers, not +expanders). Submodules import in :func:`pkgutil.iter_modules` order, which sets +last-write-wins resolution-order precedence for overlapping registrations; an +import error here aborts the whole package import by design (a broken built-in +expander must not be silently skipped). """ from __future__ import annotations @@ -16,4 +22,6 @@ import pkgutil for _module_info in pkgutil.iter_modules(__path__): + if _module_info.name.startswith("_"): + continue importlib.import_module(f"{__name__}.{_module_info.name}") diff --git a/src/giql/expanders/nearest.py b/src/giql/expanders/nearest.py index 3934a8e..9bfafbc 100644 --- a/src/giql/expanders/nearest.py +++ b/src/giql/expanders/nearest.py @@ -24,8 +24,9 @@ The expander reuses :class:`giql.generators.base.BaseGIQLGenerator`'s ``_generate_distance_case`` (shared with DISTANCE, #140) and ``_nearest_*`` -resolution/passthrough helpers, then parses the assembled SQL fragments into AST -so the emitted SQL is reserialized by the active target's serializer. +passthrough/diagnostic helpers — all static, so they are called on the class with +no generator instance — then parses the assembled SQL fragments into AST so the +emitted SQL is reserialized by the active target's serializer. """ from __future__ import annotations @@ -49,18 +50,7 @@ _REF_KEY_PREFIX = "__giql_x_rk_" -def _emitter() -> BaseGIQLGenerator: - """A throwaway generator used only for its (self-free) NEAREST helpers. - - ``_generate_distance_case``, ``_nearest_*`` and ``_extract_bool_param`` carry - no instance state, so a default-constructed generator is a safe host for - them. Reusing them keeps the expander's distance/passthrough/encoding logic - byte-for-byte identical to the legacy emitter it replaces. - """ - return BaseGIQLGenerator() - - -def _nearest_params(expression: GIQLNearest, gen: BaseGIQLGenerator): +def _nearest_params(expression: GIQLNearest): """Unpack the (k, max_distance, stranded, signed) parameters of a NEAREST.""" k = expression.args.get("k") k_value = int(str(k)) if k else 1 @@ -68,21 +58,22 @@ def _nearest_params(expression: GIQLNearest, gen: BaseGIQLGenerator): max_distance = expression.args.get("max_distance") max_dist_value = int(str(max_distance)) if max_distance else None - is_stranded = gen._extract_bool_param(expression.args.get("stranded")) - is_signed = gen._extract_bool_param(expression.args.get("signed")) + is_stranded = BaseGIQLGenerator._extract_bool_param(expression.args.get("stranded")) + is_signed = BaseGIQLGenerator._extract_bool_param(expression.args.get("signed")) return k_value, max_dist_value, is_stranded, is_signed def _distance_and_filters( - expression, ctx, gen, table_name, target_ref, ref, ref_fragments=None + expression, table_name, target_ref, ref, ref_fragments=None ): """Build the shared distance SQL, the qualified target columns, and WHERE. Returns ``(distance_expr, abs_distance_expr, where_clauses, passthrough)`` — the fragments common to the LATERAL/standalone form and the decorrelated - fallback. Distance math, the chromosome pre-filter, the optional strand - match, and the optional ``max_distance`` filter all reproduce the legacy - ``giqlnearest_sql`` emitter exactly. + fallback. Distance math, the chromosome pre-filter, the optional strand match, + and the optional ``max_distance`` filter all reproduce the legacy + ``giqlnearest_sql`` emitter exactly. Each form derives its deterministic + ORDER BY tiebreaker from the target columns itself. ``ref_fragments`` optionally overrides the reference ``(chrom, start, end, strand)`` SQL fragments. The LATERAL form consumes the resolution's @@ -92,10 +83,10 @@ def _distance_and_filters( window ordering over a join with duplicate column names). """ target_chrom, target_start, target_end = target_ref.cols - k_value, max_dist_value, is_stranded, is_signed = _nearest_params(expression, gen) + _k_value, max_dist_value, is_stranded, is_signed = _nearest_params(expression) - output_table = gen._nearest_output_encoding(expression, target_ref) - passthrough = gen._nearest_passthrough( + output_table = BaseGIQLGenerator._nearest_output_encoding(expression, target_ref) + passthrough = BaseGIQLGenerator._nearest_passthrough( table_name, target_start, target_end, output_table ) @@ -120,7 +111,7 @@ def _distance_and_filters( target_start_expr = f'{table_name}."{target_start}"' target_end_expr = f'{table_name}."{target_end}"' - distance_expr = gen._generate_distance_case( + distance_expr = BaseGIQLGenerator._generate_distance_case( ref_chrom, ref_start, ref_end, @@ -143,24 +134,56 @@ def _distance_and_filters( return distance_expr, abs_distance_expr, where_clauses, passthrough -def _lateral_form(expression, ctx, gen, table_name, target_ref, ref): - """The portable LATERAL/standalone subquery — identical to the legacy emitter. - - Builds the ``(SELECT , AS distance FROM - WHERE ... ORDER BY ABS(distance) LIMIT k)`` subquery the legacy - ``giqlnearest_sql`` produced and parses it into AST. For a correlated - placement the parent ``LATERAL`` correlates it to the outer row; for a - standalone (literal-reference) placement it stands alone. +def _lateral_form(expression, ctx, table_name, target_ref, ref): + """The portable LATERAL/standalone subquery. + + Builds a two-level subquery: an inner ``SELECT , AS + distance FROM WHERE ...`` that materializes the distance, wrapped by + an outer ``SELECT * FROM () AS x ORDER BY ABS(x.distance), x., + x. LIMIT k`` that orders on the *precomputed* ``distance`` column. For a + correlated placement the parent ``LATERAL`` correlates it to the outer row; + for a standalone (literal-reference) placement it stands alone. + + Splitting the distance computation (inner) from the ordering (outer) is + load-bearing for cross-engine support: + + * DuckDB's correlated-``LATERAL`` binder will not resolve a SELECT-list alias + named ``distance`` from inside an ``ORDER BY`` that also projects + ``.*``, so the order key must reference a *materialized* column + (``x.distance``) from the wrapping level rather than an alias in the same + SELECT. + * DataFusion's planner, given the distance ``CASE`` re-emitted inline in the + ``ORDER BY`` over the chromosome-equality prefiltered scan, rewrites the + filtered ``chrom`` to a self-comparison in one copy of the CASE but not the + other and trips ``SanityCheckPlan``; ordering on the materialized column + avoids re-deriving the key. + + A deterministic ``(start, end)`` tiebreaker follows ``ABS(distance)`` so rows + tied at the k-th distance order identically across engines and against the + decorrelated fallback's ranking (#142 A5). """ - k_value, *_ = _nearest_params(expression, gen) - distance_expr, abs_distance_expr, where_clauses, passthrough = _distance_and_filters( - expression, ctx, gen, table_name, target_ref, ref - ) + k_value, *_ = _nearest_params(expression) + ( + distance_expr, + _abs_distance_expr, + where_clauses, + passthrough, + ) = _distance_and_filters(expression, table_name, target_ref, ref) where_sql = " AND ".join(where_clauses) + # The wrapping level reads the inner row's *bare* column names (the passthrough + # projected ``.*``), so the tiebreaker qualifies them by the wrapper + # alias, not the original ``table_name."col"``. + _chrom, target_start_col, target_end_col = target_ref.cols + wrapper = ctx.alias() + inner = ( + f"SELECT {passthrough}, {distance_expr} AS distance " + f"FROM {table_name} WHERE {where_sql}" + ) sql = ( - f"(SELECT {passthrough}, {distance_expr} AS distance " - f"FROM {table_name} WHERE {where_sql} " - f"ORDER BY {abs_distance_expr} LIMIT {k_value})" + f"(SELECT * FROM ({inner}) AS {wrapper} " + f'ORDER BY ABS({wrapper}."distance"), ' + f'{wrapper}."{target_start_col}", {wrapper}."{target_end_col}" ' + f"LIMIT {k_value})" ) return parse_one(sql, dialect=GIQLDialect) @@ -180,7 +203,7 @@ def _outer_relation(ref: ResolvedInterval) -> tuple[str, str]: return relation, alias -def _fallback_form(expression, ctx, gen, table_name, target_ref, ref): +def _fallback_form(expression, ctx, table_name, target_ref, ref): """The decorrelated window-function fallback for non-LATERAL targets. Rewrites the surrounding `` AS a CROSS JOIN LATERAL (nearest) AS b`` @@ -191,13 +214,38 @@ def _fallback_form(expression, ctx, gen, table_name, target_ref, ref): sharing that key, reproducing the per-row LATERAL semantics. Swaps the parent ``LATERAL`` for the decorrelated subquery in place and returns the (now detached) NEAREST node, so the pass's own ``node.replace`` is a no-op. + + The no-op return relies on NEAREST having no nestable inner GIQL operator: a + detached node carrying a still-pending descendant would strand that + descendant's later ``node.replace``. NEAREST's only operands are a registered + target table and an interval reference, neither of which is an expandable + operator, so nothing pending hangs off the node this detaches. """ lateral = expression.parent + # Invariants the surrounding-AST rewrite depends on. The fallback only runs + # for a correlated NEAREST, whose pass-1 placement is always a CROSS JOIN + # LATERAL carrying an alias; a violation is an internal pipeline bug, not user + # error, so fail loudly with a clear message rather than dereferencing None. + assert isinstance(lateral, exp.Lateral), ( + "correlated NEAREST fallback expected its parent to be a LATERAL, got " + f"{type(lateral).__name__}" + ) join = lateral.parent + assert isinstance(join, exp.Join), ( + "correlated NEAREST fallback expected the LATERAL to sit under a JOIN, got " + f"{type(join).__name__}" + ) + assert lateral.args.get("alias") is not None and lateral.args["alias"].name, ( + "correlated NEAREST fallback expected the LATERAL to carry a table alias" + ) alias = lateral.args["alias"].name relation, outer_alias = _outer_relation(ref) - k_value, _max, is_stranded, _signed = _nearest_params(expression, gen) + k_value, _max, is_stranded, _signed = _nearest_params(expression) + # Bare target column names: the candidate subquery exposes the target row via + # ``target.*``, so its tiebreaker columns are referenced by name, not by the + # ``table_name."col"`` qualifier the distance math uses. + _target_chrom, target_start_col, target_end_col = target_ref.cols # Pre-project the outer relation's reference columns under fresh, non-target # names into a renamed derived relation. Cross-joining *this* (rather than the @@ -214,7 +262,12 @@ def _fallback_form(expression, ctx, gen, table_name, target_ref, ref): # once per distinct reference and the join fans the top-k back out to every # outer row sharing it — exactly the per-row LATERAL semantics, even when the # outer table holds duplicate reference rows. - ref_relation_alias = "__giql_x_ref" + # Mint the synthetic relation aliases from the run's collision-safe sequence + # (rather than hardcoding ``__giql_x_ref`` / ``__giql_x_cand``) so two NEAREST + # operators in one query never reuse the same derived-relation name. The + # reserved *column* names below stay derived from EXPAND_ALIAS_PREFIX. + ref_relation_alias = ctx.alias() + candidate = ctx.alias() strand_name = f"{_REF_KEY_PREFIX}strand" stranded_key = is_stranded and ref.strand is not None @@ -240,8 +293,13 @@ def _fallback_form(expression, ctx, gen, table_name, target_ref, ref): ) ref_fragments = (renamed[0], renamed[1], renamed[2], renamed_strand) - distance_expr, abs_distance_expr, where_clauses, passthrough = _distance_and_filters( - expression, ctx, gen, table_name, target_ref, ref, ref_fragments=ref_fragments + ( + distance_expr, + _abs_distance_expr, + where_clauses, + passthrough, + ) = _distance_and_filters( + expression, table_name, target_ref, ref, ref_fragments=ref_fragments ) # Surface the reference-key columns so the rewritten join can match each @@ -257,17 +315,21 @@ def _fallback_form(expression, ctx, gen, table_name, target_ref, ref): # join and the window in *separate* query levels is load-bearing on # DataFusion: fused into one level its optimizer mis-derives the window's sort # order from the chromosome-equality prefilter and trips ``SanityCheckPlan``. - candidate = "__giql_x_cand" inner = ( f"SELECT {passthrough}, {distance_expr} AS distance, {key_projection} " f"FROM {table_name} CROSS JOIN {ref_relation} " f"WHERE {where_sql}" ) partition = ", ".join(f'{candidate}."{name}"' for name, _ in key_cols) + # A deterministic ``(start, end)`` tiebreaker follows ``ABS(distance)`` so rows + # tied at the k-th distance rank identically here and in the LATERAL form, + # making the two provably set-equivalent (no engine-dependent tie ordering). ranked = ( f"(SELECT {candidate}.*, " f"ROW_NUMBER() OVER (PARTITION BY {partition} " - f"ORDER BY ABS({candidate}.distance)) AS \"{_RANK_COL}\" " + f"ORDER BY ABS({candidate}.distance), " + f'{candidate}."{target_start_col}", {candidate}."{target_end_col}") ' + f'AS "{_RANK_COL}" ' f"FROM ({inner}) AS {candidate})" ) ranked_subquery = parse_one(ranked, dialect=GIQLDialect) @@ -310,7 +372,6 @@ def expand_nearest(node: exp.Expression, ctx: ExpansionContext) -> exp.Expressio decorrelated window-function fallback. """ assert isinstance(node, GIQLNearest) - gen = _emitter() resolution = ctx.resolution target_ref = resolution.slot("this") if resolution is not None else None @@ -332,13 +393,20 @@ def expand_nearest(node: exp.Expression, ctx: ExpansionContext) -> exp.Expressio ref = resolution.slot("reference") if not isinstance(ref, ResolvedInterval): - mode = gen._detect_nearest_mode(node) - gen._raise_nearest_reference_error(node, mode, resolution) - - correlated = isinstance(node.parent, exp.Lateral) + mode = BaseGIQLGenerator._detect_nearest_mode(node) + BaseGIQLGenerator._raise_nearest_reference_error(node, mode, resolution) + + # A literal-range reference is uncorrelated even under CROSS JOIN LATERAL: its + # endpoints are constants, not outer-row columns, so the subquery stands alone + # and every target — DataFusion included — takes the LATERAL/standalone form. + # Only a genuinely correlated reference (a column / implicit-outer endpoint) + # needs the decorrelated window-function fallback on a lateral-incapable + # target. Gating on parentage alone would mis-route a literal range into + # ``_fallback_form``, which dereferences a non-existent outer relation. + correlated = isinstance(node.parent, exp.Lateral) and ref.kind != "literal_range" if correlated and not ctx.capabilities.supports_lateral: - return _fallback_form(node, ctx, gen, table_name, target_ref, ref) - return _lateral_form(node, ctx, gen, table_name, target_ref, ref) + return _fallback_form(node, ctx, table_name, target_ref, ref) + return _lateral_form(node, ctx, table_name, target_ref, ref) # The generic registration covers every target through the registry's fallback diff --git a/src/giql/generators/base.py b/src/giql/generators/base.py index 5e6f3f6..0138137 100644 --- a/src/giql/generators/base.py +++ b/src/giql/generators/base.py @@ -3,6 +3,7 @@ from giql.canonical import decanonical_end from giql.canonical import decanonical_start +from giql.dialect import GIQLDialect from giql.expressions import Contains from giql.expressions import GIQLDisjoin from giql.expressions import GIQLDistance @@ -33,10 +34,6 @@ class BaseGIQLGenerator(Generator): compatibility with virtually all SQL databases. """ - # Most databases support LATERAL joins (PostgreSQL 9.3+, DuckDB 0.7.0+) - # SQLite does not support LATERAL, so it overrides this to False - SUPPORTS_LATERAL = True - def __init__(self, tables: Tables | None = None, **kwargs): super().__init__(**kwargs) self.tables = tables or Tables() @@ -81,8 +78,9 @@ def spatialsetpredicate_sql(self, expression: SpatialSetPredicate) -> str: """ return self._generate_spatial_set(expression) + @staticmethod def _nearest_output_encoding( - self, expression: GIQLNearest, target_ref: ResolvedRef + expression: GIQLNearest, target_ref: ResolvedRef ) -> Table | None: """Return the target's declared encoding for NEAREST's row passthrough. @@ -107,8 +105,8 @@ def _nearest_output_encoding( return preserved return target_ref.table + @staticmethod def _nearest_passthrough( - self, table_name: str, target_start: str, target_end: str, @@ -407,8 +405,8 @@ def _distance_operand( raise ValueError(f"Literal range as {position} argument not yet supported") + @staticmethod def _generate_distance_case( - self, chrom_a: str, start_a: str, end_a: str, @@ -716,8 +714,9 @@ def _generate_spatial_set(self, expression: SpatialSetPredicate) -> str: combinator = " OR " if quantifier.upper() == "ANY" else " AND " return "(" + combinator.join(conditions) + ")" + @staticmethod def _detect_nearest_mode( - self, expression: GIQLNearest, parent_expression: exp.Expression | None = None + expression: GIQLNearest, parent_expression: exp.Expression | None = None ) -> str: """Detect whether NEAREST is in standalone or correlated mode. @@ -741,25 +740,8 @@ def _detect_nearest_mode( # (validation will catch missing reference errors later) return "correlated" - def _nearest_resolution(self, expression: GIQLNearest) -> OperatorResolution | None: - """Return the NEAREST resolution attached by ResolveOperatorRefs (pass 1). - - The transpile pipeline attaches an - :class:`~giql.resolver.OperatorResolution` before generation, and it - survives the generator's defensive tree copy. The emitter reads only the - attached metadata; resolution lives entirely in the pass. - - :param expression: - GIQLNearest expression node - :return: - The attached :class:`~giql.resolver.OperatorResolution`, or ``None`` - if resolution did not produce one. - """ - resolution = expression.meta.get(META_KEY) - return resolution if isinstance(resolution, OperatorResolution) else None - + @staticmethod def _raise_nearest_reference_error( - self, expression: GIQLNearest, mode: str, resolution: OperatorResolution | None, @@ -807,7 +789,7 @@ def _raise_nearest_reference_error( # An explicit reference that deferred is a literal range that failed to # parse (column references always resolve). Re-parse to surface the # original parse error in the historical message. - reference_sql = self.sql(reference) + reference_sql = reference.sql(dialect=GIQLDialect) range_str = reference_sql.strip("'\"") try: RangeParser.parse(range_str).to_zero_based_half_open() diff --git a/src/giql/targets.py b/src/giql/targets.py index 6fd29d3..825a6d2 100644 --- a/src/giql/targets.py +++ b/src/giql/targets.py @@ -29,11 +29,13 @@ class Capabilities: Parameters ---------- supports_lateral : bool - Whether the engine supports ``LATERAL`` / correlated joins. Will - drive the NEAREST LATERAL-vs-window-function strategy (#142). Until - then, :attr:`giql.generators.base.BaseGIQLGenerator.SUPPORTS_LATERAL` - remains the live source of truth at generation time; #142 reconciles - the two. + Whether the engine supports ``LATERAL`` / correlated joins. Drives the + NEAREST LATERAL-vs-window-function strategy (#142): a correlated NEAREST + expands to a portable correlated ``LATERAL`` subquery where this holds + and to a decorrelated window-function form where it does not. This + capability is the single source of truth — the former + ``BaseGIQLGenerator.SUPPORTS_LATERAL`` generator attribute has been + removed. supports_star_replace : bool Whether the engine supports ``SELECT * REPLACE (...)``. Drives the coordinate-canonicalization output: ``* REPLACE`` where supported, diff --git a/src/giql/transpile.py b/src/giql/transpile.py index dd8bb6c..4817a2d 100644 --- a/src/giql/transpile.py +++ b/src/giql/transpile.py @@ -198,16 +198,17 @@ def transpile( # Pass 2 of the normalization pipeline (epic #114): synthesize canonical # __giql_canon_* wrapper CTEs for non-canonical interval operands of - # opted-in operators (GIQL_CANONICALIZE). No operator opts in yet, so this - # is a strict no-op until the per-operator port issues (#122, #123) land. + # operators that opt into GIQL_CANONICALIZE; those operators' non-canonical + # operands are rewritten here, and identity (0-based half-open) operands are + # left untouched. with _reraise_as_value_error("Canonicalization error"): ast = canonicalize_coordinates(ast) # Pass 3 of the normalization pipeline (epic #137): replace each opted-in # GIQL operator node with the AST its registered expander produces for the - # active target. No operator sets GIQL_EXPAND and the registry is empty, so - # this is a strict no-op until the per-operator migration issues land; the - # legacy *_sql emitters on the generator remain the fallback. + # active target. An operator is rewritten here only when it both sets + # GIQL_EXPAND and resolves to a registered expander; otherwise its node is + # left untouched and the legacy *_sql emitter on the generator handles it. expand_operators = ExpandOperators(target, tables_container) with _reraise_as_value_error("Expansion error"): ast = expand_operators.transform(ast) diff --git a/tests/generators/test_base.py b/tests/generators/test_base.py index 648f770..d57c0b2 100644 --- a/tests/generators/test_base.py +++ b/tests/generators/test_base.py @@ -106,13 +106,12 @@ def test_instantiation_defaults(self): """ GIVEN no tables provided WHEN Generator is instantiated with defaults - THEN Generator has empty Tables and SUPPORTS_LATERAL is True. + THEN Generator has empty Tables. """ generator = BaseGIQLGenerator() assert generator.tables is not None assert "variants" not in generator.tables - assert generator.SUPPORTS_LATERAL is True def test_instantiation_with_tables(self, tables_info): """ @@ -398,7 +397,9 @@ def test_spatialsetpredicate_sql_all(self): ) assert output == expected - def test_giqlnearest_sql_standalone(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_emit_ordered_limit_subquery_when_standalone( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest in standalone mode with literal reference WHEN the NEAREST expander runs @@ -408,26 +409,28 @@ def test_giqlnearest_sql_standalone(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) - # NEAREST now expands via its registered expander (#142): the subquery is - # reserialized by sqlglot (single line, ``<>`` for ``!=``) but the - # distance CASE, WHERE prefilter, ORDER BY ABS, and LIMIT are unchanged - # from the legacy emitter (verified equivalent in test_nearest_transpilation - # and the cross-target oracle). + # NEAREST now expands via its registered expander (#142): a two-level + # subquery — an inner SELECT that materializes ``distance`` and an outer + # wrapper that orders on the materialized ``__giql_x_0."distance"`` plus a + # deterministic ``(start, end)`` tiebreaker (#142 A5). Splitting the + # distance computation from the ordering keeps DuckDB's correlated-LATERAL + # binder and DataFusion's planner both happy while staying result- + # equivalent to the legacy single-level emitter. expected = ( - "SELECT * FROM (SELECT genes.*, " + "SELECT * FROM (SELECT * FROM (SELECT genes.*, " "CASE WHEN 'chr1' <> genes.\"chrom\" THEN NULL " 'WHEN 1000 < genes."end" AND 2000 > genes."start" THEN 0 ' 'WHEN 2000 <= genes."start" THEN (genes."start" - 2000 + 1) ' 'ELSE (1000 - genes."end" + 1) END AS distance ' - "FROM genes WHERE 'chr1' = genes.\"chrom\" " - "ORDER BY ABS(CASE WHEN 'chr1' <> genes.\"chrom\" THEN NULL " - 'WHEN 1000 < genes."end" AND 2000 > genes."start" THEN 0 ' - 'WHEN 2000 <= genes."start" THEN (genes."start" - 2000 + 1) ' - 'ELSE (1000 - genes."end" + 1) END) LIMIT 3)' + "FROM genes WHERE 'chr1' = genes.\"chrom\") AS __giql_x_0 " + 'ORDER BY ABS(__giql_x_0."distance"), ' + '__giql_x_0."start", __giql_x_0."end" LIMIT 3)' ) assert output == expected - def test_giqlnearest_sql_correlated(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_emit_lateral_subquery_when_correlated( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest in correlated mode (LATERAL join context) WHEN the NEAREST expander runs on a lateral-capable target @@ -440,29 +443,30 @@ def test_giqlnearest_sql_correlated(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) - # Reserialized by the #142 expander; LATERAL placement, distance CASE, - # WHERE, ORDER BY, and LIMIT are semantically unchanged from the legacy - # emitter. + # Reserialized by the #142 expander as a two-level subquery: the inner + # SELECT materializes ``distance`` (CASE and WHERE semantically unchanged) + # and the outer wrapper orders on the materialized + # ``__giql_x_0."distance"`` plus a deterministic ``(start, end)`` + # tiebreaker (#142 A5). expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT genes.*, " + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT * FROM (SELECT genes.*, " 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' 'ELSE (peaks."start" - genes."end" + 1) END AS distance ' - 'FROM genes WHERE peaks."chrom" = genes."chrom" ' - 'ORDER BY ABS(CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' - 'THEN 0 WHEN peaks."end" <= genes."start" ' - 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END) LIMIT 3)' + 'FROM genes WHERE peaks."chrom" = genes."chrom") AS __giql_x_0 ' + 'ORDER BY ABS(__giql_x_0."distance"), ' + '__giql_x_0."start", __giql_x_0."end" LIMIT 3)' ) assert output == expected - def test_giqlnearest_sql_with_max_distance(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_filter_on_max_distance( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest with max_distance parameter - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN WHERE clause includes distance filter. """ sql = ( @@ -473,10 +477,13 @@ def test_giqlnearest_sql_with_max_distance(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) - # Reserialized by the #142 expander; the max_distance filter on ABS of the - # distance CASE is semantically unchanged from the legacy emitter. + # Reserialized by the #142 expander as a two-level subquery; the + # max_distance filter on ABS of the distance CASE stays in the inner + # SELECT's WHERE and the outer wrapper orders on the materialized + # ``__giql_x_0."distance"`` plus the deterministic ``(start, end)`` + # tiebreaker (#142 A5). expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT genes.*, " + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT * FROM (SELECT genes.*, " 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' 'THEN 0 WHEN peaks."end" <= genes."start" ' @@ -487,19 +494,18 @@ def test_giqlnearest_sql_with_max_distance(self, tables_with_peaks_and_genes): 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END)) <= 100000 ' - 'ORDER BY ABS(CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' - 'THEN 0 WHEN peaks."end" <= genes."start" ' - 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END) LIMIT 5)' + 'ELSE (peaks."start" - genes."end" + 1) END)) <= 100000) ' + 'AS __giql_x_0 ORDER BY ABS(__giql_x_0."distance"), ' + '__giql_x_0."start", __giql_x_0."end" LIMIT 5)' ) assert output == expected - def test_giqlnearest_sql_stranded(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_match_strand_when_stranded( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest with stranded := true - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN Strand matching is included in WHERE clause. """ sql = ( @@ -510,11 +516,13 @@ def test_giqlnearest_sql_stranded(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) - # Reserialized by the #142 expander; the stranded distance CASE and the - # ``peaks.strand = genes.strand`` match in WHERE are semantically - # unchanged from the legacy emitter. + # Reserialized by the #142 expander as a two-level subquery; the stranded + # distance CASE and the ``peaks.strand = genes.strand`` match in the inner + # WHERE are semantically unchanged, with the outer wrapper ordering on the + # materialized ``__giql_x_0."distance"`` plus the ``(start, end)`` + # tiebreaker (#142 A5). expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT genes.*, " + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT * FROM (SELECT genes.*, " 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' 'WHEN peaks."strand" IS NULL OR genes."strand" IS NULL THEN NULL ' "WHEN peaks.\"strand\" = '.' OR peaks.\"strand\" = '?' THEN NULL " @@ -528,23 +536,13 @@ def test_giqlnearest_sql_stranded(self, tables_with_peaks_and_genes): 'THEN -(peaks."start" - genes."end" + 1) ' 'ELSE (peaks."start" - genes."end" + 1) END END AS distance ' 'FROM genes WHERE peaks."chrom" = genes."chrom" ' - 'AND peaks."strand" = genes."strand" ' - 'ORDER BY ABS(CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' - 'WHEN peaks."strand" IS NULL OR genes."strand" IS NULL THEN NULL ' - "WHEN peaks.\"strand\" = '.' OR peaks.\"strand\" = '?' THEN NULL " - "WHEN genes.\"strand\" = '.' OR genes.\"strand\" = '?' THEN NULL " - 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' - 'THEN 0 WHEN peaks."end" <= genes."start" ' - "THEN CASE WHEN peaks.\"strand\" = '-' " - 'THEN -(genes."start" - peaks."end" + 1) ' - 'ELSE (genes."start" - peaks."end" + 1) END ' - "ELSE CASE WHEN peaks.\"strand\" = '-' " - 'THEN -(peaks."start" - genes."end" + 1) ' - 'ELSE (peaks."start" - genes."end" + 1) END END) LIMIT 3)' + 'AND peaks."strand" = genes."strand") AS __giql_x_0 ' + 'ORDER BY ABS(__giql_x_0."distance"), ' + '__giql_x_0."start", __giql_x_0."end" LIMIT 3)' ) assert output == expected - def test_giqlnearest_sql_implicit_outer_without_strand_column(self): + def test_expand_nearest_should_skip_strand_when_outer_has_no_strand_column(self): """ GIVEN a stranded NEAREST whose implicit-outer table declares no strand column @@ -568,7 +566,9 @@ def test_giqlnearest_sql_implicit_outer_without_strand_column(self): assert "strand" not in output assert 'nostr."chrom" = genes."chrom"' in output - def test_giqlnearest_sql_signed(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_emit_signed_distance_when_signed( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest with signed := true WHEN the NEAREST expander runs @@ -582,22 +582,21 @@ def test_giqlnearest_sql_signed(self, tables_with_peaks_and_genes): output = _generate_through_passes(sql, tables_with_peaks_and_genes) - # Reserialized by the #142 expander; the signed distance CASE (negated - # ELSE branch for upstream) is semantically unchanged from the legacy - # emitter. + # Reserialized by the #142 expander as a two-level subquery; the signed + # distance CASE (negated ELSE branch for upstream) is semantically + # unchanged in the inner SELECT, with the outer wrapper ordering on the + # materialized ``__giql_x_0."distance"`` plus the ``(start, end)`` + # tiebreaker (#142 A5). expected = ( - "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT genes.*, " + "SELECT * FROM peaks CROSS JOIN LATERAL (SELECT * FROM (SELECT genes.*, " 'CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' 'THEN 0 WHEN peaks."end" <= genes."start" ' 'THEN (genes."start" - peaks."end" + 1) ' 'ELSE -(peaks."start" - genes."end" + 1) END AS distance ' - 'FROM genes WHERE peaks."chrom" = genes."chrom" ' - 'ORDER BY ABS(CASE WHEN peaks."chrom" <> genes."chrom" THEN NULL ' - 'WHEN peaks."start" < genes."end" AND peaks."end" > genes."start" ' - 'THEN 0 WHEN peaks."end" <= genes."start" ' - 'THEN (genes."start" - peaks."end" + 1) ' - 'ELSE -(peaks."start" - genes."end" + 1) END) LIMIT 3)' + 'FROM genes WHERE peaks."chrom" = genes."chrom") AS __giql_x_0 ' + 'ORDER BY ABS(__giql_x_0."distance"), ' + '__giql_x_0."start", __giql_x_0."end" LIMIT 3)' ) assert output == expected @@ -613,7 +612,7 @@ def test_giqlnearest_sql_signed(self, tables_with_peaks_and_genes): k=st.integers(min_value=1, max_value=100), max_distance=st.integers(min_value=1, max_value=10_000_000), ) - def test_giqlnearest_sql_parameter_handling_property( + def test_expand_nearest_should_carry_k_and_max_distance_property( self, tables_with_peaks_and_genes, k, max_distance ): """ @@ -831,12 +830,12 @@ def test_select_sql_join_without_alias(self, tables_with_two_tables): ) assert output == expected - def test_giqlnearest_sql_stranded_literal_with_strand( + def test_expand_nearest_should_use_literal_strand_when_stranded( self, tables_with_peaks_and_genes ): """ GIVEN a GIQLNearest with stranded := true and literal reference containing strand - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN Strand from literal range is parsed and used in filtering. """ sql = ( @@ -850,12 +849,12 @@ def test_giqlnearest_sql_stranded_literal_with_strand( assert "'+'" in output assert 'genes."strand"' in output - def test_giqlnearest_sql_stranded_implicit_reference( + def test_expand_nearest_should_resolve_outer_strand_when_implicit_reference( self, tables_with_peaks_and_genes ): """ GIVEN a GIQLNearest in correlated mode with implicit reference and stranded := true - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN Strand column is resolved from outer table and used. """ sql = "SELECT * FROM peaks CROSS JOIN LATERAL NEAREST(genes, k := 3, stranded := true)" @@ -954,7 +953,7 @@ def test_giqldistance_sql_literal_second_arg_error(self, tables_with_two_tables) with pytest.raises(ValueError, match="Literal range as second argument"): generator.generate(ast) - def test_giqlnearest_sql_missing_outer_table_error( + def test_expand_nearest_should_raise_when_outer_table_unresolvable( self, tables_with_peaks_and_genes ): """ @@ -969,7 +968,7 @@ def test_giqlnearest_sql_missing_outer_table_error( with pytest.raises(ValueError, match="Could not find outer table"): _generate_through_passes(sql, tables_with_peaks_and_genes) - def test_giqlnearest_sql_outer_table_not_in_tables(self): + def test_expand_nearest_should_raise_when_outer_table_unregistered(self): """ GIVEN a NEAREST whose implicit-outer relation is found but not registered WHEN the query is generated @@ -985,10 +984,12 @@ def test_giqlnearest_sql_outer_table_not_in_tables(self): with pytest.raises(ValueError, match="not found in tables"): _generate_through_passes(sql, tables) - def test_giqlnearest_sql_invalid_reference_range(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_raise_when_reference_range_unparseable( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest with invalid/unparseable reference range string - WHEN giqlnearest_sql is called + WHEN the NEAREST expander runs THEN ValueError is raised with parse error details. """ sql = "SELECT * FROM NEAREST(genes, reference := 'invalid_range', k := 3)" @@ -996,7 +997,7 @@ def test_giqlnearest_sql_invalid_reference_range(self, tables_with_peaks_and_gen with pytest.raises(ValueError, match="Could not parse reference genomic range"): _generate_through_passes(sql, tables_with_peaks_and_genes) - def test_giqlnearest_sql_no_tables_error(self): + def test_expand_nearest_should_raise_when_no_tables_registered(self): """ GIVEN a GIQLNearest with no tables registered WHEN the NEAREST expander runs @@ -1009,7 +1010,9 @@ def test_giqlnearest_sql_no_tables_error(self): with pytest.raises(ValueError, match="not found in tables"): _generate_through_passes(sql, Tables()) - def test_giqlnearest_sql_target_not_in_tables(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_raise_when_target_unregistered( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest whose target table is not registered WHEN the NEAREST expander runs @@ -1040,7 +1043,7 @@ def test_intersects_sql_unqualified_column(self): ) assert output == expected - def test_giqlnearest_sql_stranded_unqualified_reference( + def test_expand_nearest_should_resolve_strand_when_reference_unqualified( self, tables_with_peaks_and_genes ): """ @@ -1061,11 +1064,14 @@ def test_giqlnearest_sql_stranded_unqualified_reference( assert "LIMIT 3" in output assert '"strand"' in output - def test_giqlnearest_sql_literal_target(self, tables_with_peaks_and_genes): + def test_expand_nearest_should_emit_ordered_subquery_for_literal_reference( + self, tables_with_peaks_and_genes + ): """ GIVEN a GIQLNearest with a standalone literal reference WHEN the NEAREST expander runs - THEN it produces valid SQL selecting from the target table. + THEN it produces a standalone ordered, limited subquery over the target + table with no correlated LATERAL. """ # Arrange sql = "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 3)" @@ -1075,8 +1081,9 @@ def test_giqlnearest_sql_literal_target(self, tables_with_peaks_and_genes): # Assert assert "genes" in output + assert "ORDER BY" in output assert "LIMIT 3" in output - assert "LIMIT 3" in output + assert "LATERAL" not in output @given( bool_repr=st.sampled_from(["true", "TRUE", "True", "1", "yes", "YES"]), @@ -1749,7 +1756,7 @@ def test_giqlnearest_should_canonicalize_reference_column_when_reference_is_one_ A 0-based half-open target table (bed_a) and an explicit reference column from a 1-based closed table (vcf_b). When: - giqlnearest_sql is called. + the NEAREST expander runs. Then: It should wrap the reference-side start as (start - 1), leave its end raw, and leave the target side raw. @@ -1782,7 +1789,7 @@ def test_giqlnearest_should_canonicalize_outer_table_columns_when_reference_is_i target table (bed_a) joined via CROSS JOIN LATERAL with no ``reference`` argument on NEAREST. When: - giqlnearest_sql is called. + the NEAREST expander runs. Then: It should canonicalize the outer table's columns based on vcf_b's convention — wrapping start as (vcf_b."start" - 1) and diff --git a/tests/integration/datafusion/test_cross_target_oracle.py b/tests/integration/datafusion/test_cross_target_oracle.py index f270fff..02db63f 100644 --- a/tests/integration/datafusion/test_cross_target_oracle.py +++ b/tests/integration/datafusion/test_cross_target_oracle.py @@ -242,6 +242,171 @@ def test_correlated_nearest_k1_agrees_across_all_targets(self, cross_target_orac engines={"generic": "duckdb"}, ) + def test_correlated_nearest_k2_returns_two_nearest_across_targets( + self, cross_target_oracle + ): + """Test correlated NEAREST k=2 picks the two nearest on every target. + + Given: + One peak and four candidate genes, more than k of them on the peak's + chromosome at distinct distances. + When: + A correlated ``NEAREST(..., k := 2)`` runs on every target — DuckDB + via the LATERAL form, DataFusion via the decorrelated window fallback. + Then: + Every target should return the two nearest genes and agree, pinning + the top-k fan-out of the fallback against the LATERAL form. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 2) b", + peaks=[("chr1", 200, 300)], + genes=[ + ("chr1", 1000, 1100), + ("chr1", 50, 60), + ("chr1", 280, 290), + ("chr1", 310, 320), + ], + expected=[(200, 280), (200, 310)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_duplicate_reference_rows_fan_out( + self, cross_target_oracle + ): + """Test correlated NEAREST fans the top-k out to duplicate reference rows. + + Given: + Two identical peak rows and two candidate genes. + When: + A correlated ``NEAREST(..., k := 1)`` runs on every target. + Then: + Every target should return the nearest gene once per duplicate peak + (two rows), pinning the fallback's DISTINCT-then-rejoin fan-out so a + duplicate outer row is not collapsed. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1) b", + peaks=[("chr1", 200, 300), ("chr1", 200, 300)], + genes=[("chr1", 280, 290), ("chr1", 50, 60)], + expected=[(200, 280), (200, 280)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_partitions_by_chromosome(self, cross_target_oracle): + """Test correlated NEAREST keys the nearest per outer chromosome. + + Given: + Peaks on chr1 and chr2 and candidate genes on both chromosomes. + When: + A correlated ``NEAREST(..., k := 1)`` runs on every target. + Then: + Each peak should pair with the nearest gene on its own chromosome and + all targets agree, pinning the fallback's PARTITION BY reference key + across distinct outer keys. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.chrom AS chrom, a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(genes, reference := a.interval, k := 1) b", + peaks=[("chr1", 200, 300), ("chr2", 200, 300)], + genes=[("chr1", 280, 290), ("chr2", 500, 510), ("chr2", 205, 215)], + expected=[("chr1", 200, 280), ("chr2", 200, 205)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_max_distance_boundary(self, cross_target_oracle): + """Test correlated NEAREST drops candidates beyond max_distance everywhere. + + Given: + A peak and two genes, one just inside and one far beyond a + ``max_distance`` threshold. + When: + A correlated ``NEAREST(..., k := 5, max_distance := 100)`` runs on + every target. + Then: + Every target should return only the in-threshold gene, pinning the + ``max_distance`` filter through both the LATERAL and fallback forms. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(" + "genes, reference := a.interval, k := 5, max_distance := 100) b", + peaks=[("chr1", 200, 300)], + genes=[("chr1", 360, 400), ("chr1", 5000, 5100)], + expected=[(200, 360)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_stranded_matches_strand(self, cross_target_oracle): + """Test stranded correlated NEAREST matches strand on every target. + + Given: + A ``+`` peak and two genes — a slightly farther ``+`` gene and a + nearer ``-`` gene. + When: + A correlated ``NEAREST(..., k := 1, stranded := true)`` runs on every + target. + Then: + Every target should return the same-strand (``+``) gene even though + the opposite-strand gene is nearer, in agreement. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(" + "genes, reference := a.interval, k := 1, stranded := true) b", + tables=[Table("peaks"), Table("genes")], + columns=_STRANDED_COLUMNS, + peaks=[("chr1", 200, 300, "+")], + genes=[("chr1", 280, 290, "+"), ("chr1", 250, 260, "-")], + expected=[(200, 280)], + engines={"generic": "duckdb"}, + ) + + def test_correlated_nearest_signed_distance_agrees(self, cross_target_oracle): + """Test signed correlated NEAREST reports signed distances everywhere. + + Given: + A peak with one upstream and one downstream candidate gene. + When: + A correlated ``NEAREST(..., k := 2, signed := true)`` projects the + ``distance`` column on every target. + Then: + Every target should report a negative distance for the upstream gene + and a positive one for the downstream gene, in agreement. + """ + # Arrange / Act / Assert + cross_target_oracle( + "SELECT a.start AS a_start, b.start AS b_start, b.distance AS d " + "FROM peaks a " + "CROSS JOIN LATERAL NEAREST(" + "genes, reference := a.interval, k := 2, signed := true) b", + peaks=[("chr1", 200, 300)], + genes=[("chr1", 50, 60), ("chr1", 360, 400)], + expected=[(200, 50, -141), (200, 360, 61)], + engines={"generic": "duckdb"}, + ) + + +#: A chrom/start/end/strand schema for the stranded NEAREST oracle cases (the +#: default oracle schema carries no strand column). +_STRANDED_COLUMNS = ( + ("chrom", "utf8"), + ("start", "int64"), + ("end", "int64"), + ("strand", "utf8"), +) + class TestCrossTargetOracleIntersectsAnyAll: """INTERSECTS ANY/ALL identity across all targets (T2).""" diff --git a/tests/test_expander.py b/tests/test_expander.py index 31a22f1..4579359 100644 --- a/tests/test_expander.py +++ b/tests/test_expander.py @@ -53,11 +53,13 @@ def _expander(node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: return _expander -#: The registry contents at import — the built-in expanders registered by -#: ``giql.expanders`` for already-migrated operators (DISJOIN as of #143). The -#: leak guards and ``clean_registry`` treat this as the baseline rather than an -#: empty registry, so the real built-in registrations survive isolating fixtures -#: and a leaking test is still caught against the true baseline. +import giql.expanders # noqa: E402, F401 (side-effect: registers built-in expanders) + +#: The registry contents at import — the built-in expanders ``giql.expanders`` +#: registers for the already-migrated operators. The leak guards and +#: ``clean_registry`` treat this as the baseline rather than an empty registry, +#: so the real built-in registrations survive isolating fixtures and a leaking +#: test is still caught against the true baseline. _REGISTRY_BASELINE = REGISTRY.snapshot() @@ -101,12 +103,12 @@ def _registry_leak_guard(): def _expand_flag_leak_guard(): """Assert every operator's GIQL_EXPAND is restored at each test boundary. - The symmetric partner of the registry leak guard: each operator class ships - a shipped GIQL_EXPAND default (``True`` for a migrated operator like DISJOIN, - ``False`` otherwise), and a test that flips one via ``_opted_in`` must restore - it. A leaked flip would silently change the pass for a later test, so this - catches anything that bypasses the exception-safe ``_opted_in`` manager by - comparing against each operator's shipped default rather than a blanket False. + The symmetric partner of the registry leak guard: each operator class ships a + GIQL_EXPAND default (``True`` for a migrated operator, ``False`` otherwise), + and a test that flips one via ``_opted_in`` must restore it. A leaked flip + would silently change the pass for a later test, so this catches anything that + bypasses the exception-safe ``_opted_in`` manager by comparing against each + operator's shipped default rather than a blanket False. """ for op in _OPERATOR_CLASSES: assert op.__dict__.get("GIQL_EXPAND") is _SHIPPED_EXPAND_FLAGS[op], ( @@ -483,7 +485,7 @@ def _expander(node, ctx): assert REGISTRY.resolve(DuckDBTarget(), GIQLDisjoin) is None assert (DuckDBTarget(), GIQLDisjoin) not in REGISTRY - def test_snapshot_is_independent_of_later_registrations(self): + def test_snapshot_should_not_observe_later_registrations(self): """Test that a snapshot does not observe registrations made after it. Given: @@ -506,7 +508,7 @@ def test_snapshot_is_independent_of_later_registrations(self): assert (DuckDBTarget(), GIQLDisjoin) in saved assert (GenericTarget(), Intersects) not in saved - def test_restore_replaces_entries_with_snapshot_contents(self): + def test_restore_should_replace_entries_with_snapshot_contents(self): """Test that restore returns the registry to a captured snapshot. Given: @@ -1020,8 +1022,9 @@ def test_transform_skips_unflagged_operator(self, clean_registry): Given: An expander registered for (GenericTarget, GIQLDisjoin) but the - operator's GIQL_EXPAND flag held off (DISJOIN ships it on, so the - control opts it out to isolate the per-type gate). + operator's GIQL_EXPAND flag held off (a migrated operator ships it on, + so the control opts the migrated operator out to isolate the per-type + gate from any shipped opt-in). When: Running the pass. Then: @@ -1034,7 +1037,7 @@ def test_transform_skips_unflagged_operator(self, clean_registry): pass_ = ExpandOperators(GenericTarget(), tables, clean_registry) # Act - with _opted_out(GIQLDisjoin): + with _opted_out(_A_MIGRATED_OPERATOR): result = pass_.transform(ast) # Assert @@ -1118,8 +1121,8 @@ def test_transpile_sql_unchanged_with_pass_inert(self): # The nine GIQL operator expression classes the ExpandOperators pass inspects. -# Every one must ship opted out (GIQL_EXPAND=False) at this step so the pass is a -# strict no-op until a later migration flips one alongside its expander. +# Each ships its own GIQL_EXPAND default: a migrated operator ships True (and has +# a built-in expander registered), the rest ship False until their migrations land. from giql.expressions import Contains # noqa: E402 from giql.expressions import GIQLCluster # noqa: E402 from giql.expressions import GIQLDistance # noqa: E402 @@ -1140,9 +1143,9 @@ def test_transpile_sql_unchanged_with_pass_inert(self): ) #: Each operator's shipped GIQL_EXPAND default, captured from its own class dict -#: at import. A migrated operator (DISJOIN, #143) ships ``True``; the rest ship -#: ``False`` until their migrations land. The flag leak guard restores to these -#: shipped values rather than a blanket ``False``. +#: at import. A migrated operator ships ``True``; the rest ship ``False`` until +#: their migrations land. The flag leak guard restores to these shipped values +#: rather than a blanket ``False``. _SHIPPED_EXPAND_FLAGS = {op: op.__dict__.get("GIQL_EXPAND") for op in _OPERATOR_CLASSES} @@ -1155,6 +1158,13 @@ def test_transpile_sql_unchanged_with_pass_inert(self): op for op in _OPERATOR_CLASSES if op not in _MIGRATED_OPERATORS ) +# At least one operator has been migrated (the build of this branch ships one); +# the controls below pick an arbitrary migrated operator operator-agnostically. +assert _MIGRATED_OPERATORS, "expected at least one migrated operator" +#: An arbitrary migrated operator, used by controls that need an operator shipping +#: GIQL_EXPAND=True without naming a specific one. +_A_MIGRATED_OPERATOR = _MIGRATED_OPERATORS[0] + class TestOperatorOptOut: """Migrated operators opt into the pass; the rest still ship opted out.""" @@ -1187,7 +1197,7 @@ def test_operator_class_ships_expand_enabled(self, operator): Given: A GIQL operator expression class migrated onto the ExpandOperators - pass (DISJOIN, #143). + pass. When: Reading its GIQL_EXPAND class attribute. Then: @@ -1782,8 +1792,9 @@ def test_walk_partial_opt_in_replaces_only_flagged_type(self, clean_registry): ) pass_ = ExpandOperators(GenericTarget(), tables, clean_registry) - # Act (DISJOIN ships flagged, so opt it out to hold it as the control) - with _opted_in(Intersects), _opted_out(GIQLDisjoin): + # Act (opt the migrated operator out so its shipped opt-in cannot + # interfere; DISJOIN here is the unflagged control whose node must remain) + with _opted_in(Intersects), _opted_out(_A_MIGRATED_OPERATOR): result = pass_.transform(ast) # Assert @@ -2004,9 +2015,9 @@ class _opted_out: """Context manager opting an operator class out of GIQL_EXPAND for a test. The complement of :class:`_opted_in`: used by a control test that needs a - *migrated* operator (DISJOIN ships GIQL_EXPAND=True) to behave as if - unflagged, so the test can prove the pass gates per-type without the - operator's shipped opt-in interfering. Restores the prior flag on exit. + *migrated* operator (one shipping GIQL_EXPAND=True) to behave as if unflagged, + so the test can prove the pass gates per-type without the operator's shipped + opt-in interfering. Restores the prior flag on exit. """ def __init__(self, operator: type) -> None: diff --git a/tests/test_nearest_transpilation.py b/tests/test_nearest_transpilation.py index caaed98..e63600d 100644 --- a/tests/test_nearest_transpilation.py +++ b/tests/test_nearest_transpilation.py @@ -15,9 +15,25 @@ from giql.generators import BaseGIQLGenerator from giql.resolver import resolve_operator_refs from giql.table import Tables +from giql.targets import DataFusionTarget from giql.targets import GenericTarget +def _generate_for_target(sql: str, tables: Tables, target) -> str: + """Parse, run passes 1-3 against *target*, then generate SQL. + + Drives the expander for a specific :class:`~giql.targets.Target` so a + capability-dependent shape (e.g. DataFusion's decorrelated window fallback, + chosen because ``supports_lateral`` is False) can be asserted without an + engine. + """ + ast = parse_one(sql, dialect=GIQLDialect) + ast = resolve_operator_refs(ast, tables) + ast = canonicalize_coordinates(ast) + ast = ExpandOperators(target, tables).transform(ast) + return BaseGIQLGenerator(tables=tables).generate(ast) + + def _generate(sql: str, tables: Tables) -> str: """Parse, run normalization passes 1-3, then generate SQL. @@ -166,3 +182,86 @@ def test_nearest_with_signed(self, tables_with_peaks_and_genes): assert "ELSE -(" in output, ( f"Expected signed distance with negation for upstream, got:\n{output}" ) + + +class TestNearestDataFusionFallbackShape: + """Engine-free transpile-shape checks for the DataFusion window fallback (A8). + + A correlated NEAREST on the DataFusion target (``supports_lateral`` is False) + expands to the decorrelated window-function form. These assert its structural + invariants without running an engine: the window is present, the top-k filter + is a ``<= k`` predicate, no correlated ``LATERAL`` survives, and the candidate + cross-join and the window live at separate query levels. + """ + + def test_fallback_emits_window_with_topk_and_no_lateral( + self, tables_with_peaks_and_genes + ): + """ + GIVEN a correlated NEAREST(genes, k := 1) on the DataFusion target + WHEN transpiling + THEN the decorrelated fallback emits a ROW_NUMBER() window, a `<= 1` + top-k predicate, no surviving LATERAL, and the cross-join and window + at separate query levels. + """ + sql = ( + "SELECT * FROM peaks " + "CROSS JOIN LATERAL NEAREST(genes, reference := peaks.interval, k := 1) AS b" + ) + + output = _generate_for_target( + sql, tables_with_peaks_and_genes, DataFusionTarget() + ) + + assert "ROW_NUMBER(" in output.upper() + assert "OVER (" in output.upper() + assert "<= 1" in output + assert "LATERAL" not in output.upper() + # The candidate cross-join sits one level below the window: the window's + # FROM is a parenthesized subquery, so a CROSS JOIN appears nested inside. + assert "CROSS JOIN" in output.upper() + + def test_fallback_stranded_emits_window_and_strand_match( + self, tables_with_peaks_and_genes + ): + """ + GIVEN a stranded correlated NEAREST on the DataFusion target + WHEN transpiling + THEN the fallback emits the window form, keeps a strand equality in the + candidate WHERE, and surfaces no LATERAL. + """ + sql = ( + "SELECT * FROM peaks CROSS JOIN LATERAL " + "NEAREST(genes, reference := peaks.interval, k := 1, stranded := true) AS b" + ) + + output = _generate_for_target( + sql, tables_with_peaks_and_genes, DataFusionTarget() + ) + + assert "ROW_NUMBER(" in output.upper() + assert "LATERAL" not in output.upper() + assert 'peaks."strand"' in output + assert 'genes."strand"' in output + + def test_fallback_k_greater_than_one_uses_k_in_topk_predicate( + self, tables_with_peaks_and_genes + ): + """ + GIVEN a correlated NEAREST(genes, k := 3) on the DataFusion target + WHEN transpiling + THEN the top-k predicate carries the requested k (`<= 3`) rather than a + LIMIT, and no LATERAL survives. + """ + sql = ( + "SELECT * FROM peaks " + "CROSS JOIN LATERAL NEAREST(genes, reference := peaks.interval, k := 3) AS b" + ) + + output = _generate_for_target( + sql, tables_with_peaks_and_genes, DataFusionTarget() + ) + + assert "ROW_NUMBER(" in output.upper() + assert "<= 3" in output + assert "LATERAL" not in output.upper()