diff --git a/src/giql/expander.py b/src/giql/expander.py index f93132a..ae9ca3e 100644 --- a/src/giql/expander.py +++ b/src/giql/expander.py @@ -39,10 +39,11 @@ i.e. a ``(target, op)`` or ``(generic, op)`` expander is registered. Otherwise it falls through to the legacy ``*_sql`` emitter on -:class:`giql.generators.base.BaseGIQLGenerator`. As of this issue **no operator -sets ``GIQL_EXPAND`` and the registry is empty, so the pass is a strict no-op**: -no node is touched and the emitted SQL is byte-identical. Each later migration PR -(epic #137 steps 4-9) registers a generic expander, flips one operator's +:class:`giql.generators.base.BaseGIQLGenerator`. The built-in expanders register +at import time via :mod:`giql.expanders`; the pass rewrites a node only when +``GIQL_EXPAND=True`` **and** an expander resolves for ``(active target, operator +type)``, and is a no-op for any operator that is unflagged or has no registered +expander. A migration PR registers an expander, flips one operator's ``GIQL_EXPAND`` flag, and deletes that operator's ``*_sql`` method. """ @@ -149,6 +150,14 @@ class OperatorExpander(Protocol): a registered object satisfies it. A plain function is *not* an ``OperatorExpander`` (it has no ``expand`` method); register one by wrapping it (see :func:`register`, which accepts either form). + + An expander is **node-local**: ``expand(node, ctx) -> exp.Expression`` sees + one operator node and returns the expression that replaces it in place. It + cannot express a whole-query rewrite such as the INTERSECTS IEJoin fold, + which restructures the surrounding query (joins, CTEs) rather than a single + node. That fold is therefore deferred — it would need a separate + query-level mechanism — and is handled by the pre-pass join transformers, not + by an expander. """ def expand(self, node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: ... @@ -215,6 +224,17 @@ def register( expander : OperatorExpander | ExpanderFn The expander object or function. A later registration for the same key replaces an earlier one (last-write-wins override). + + Notes + ----- + A *non-generic* ``(target, operator)`` entry is intended to also act as a + join-rewrite override for operators with a built-in whole-query join + rewrite (notably :class:`~giql.expressions.Intersects`, whose binned + equi-join / DuckDB IEJoin transformers run before expansion), letting a + per-target expander assume responsibility for that rewrite. That bypass + is intended for a future INTERSECTS consumer and is **not wired by any + caller yet** — no transformer consults :meth:`has_override` here (see + #141). """ self._expanders[(target, operator)] = _as_callable(expander) @@ -223,6 +243,14 @@ def resolve(self, target: Target, operator: type) -> ExpanderFn | None: Tries the exact ``(target, op)`` entry, then the ``(GenericTarget(), op)`` fallback, then ``None`` (legacy emitter). + + A non-generic exact ``(target, op)`` entry is also intended to act as a + *join-rewrite override* for operators with a built-in whole-query join + rewrite (notably :class:`~giql.expressions.Intersects`). That override is + intended for a future INTERSECTS consumer and is **not wired by any + caller yet** — resolution does not itself bypass the built-in + binned / IEJoin transformers (see :meth:`register`, :meth:`has_override`, + and #141). """ fn = self._expanders.get((target, operator)) if fn is not None: @@ -233,6 +261,23 @@ def resolve(self, target: Target, operator: type) -> ExpanderFn | None: return fn return None + def has_override(self, target: Target, operator: type) -> bool: + """Whether an exact non-generic ``(target, operator)`` entry is registered. + + Returns ``True`` only when *target* is not :class:`~giql.targets.GenericTarget` + and an exact ``(target, operator)`` entry is registered; the portable + ``(GenericTarget(), operator)`` fallback is *not* an override and does not + count here. + + Such an entry is intended to mark a target-specific override that + supersedes built-in handling (e.g. taking responsibility for the + whole-query join rewrite the built-in transformers would otherwise + perform). That mechanism is intended for a future INTERSECTS consumer and + is **not wired by any caller yet** — no transformer consults this method + in the current pipeline (see #141). + """ + return target != GenericTarget() and (target, operator) in self._expanders + def unregister(self, target: Target, operator: type) -> None: """Drop the ``(target, operator)`` entry if present. @@ -253,6 +298,31 @@ def clear(self) -> None: """ self._expanders.clear() + def snapshot(self) -> dict[tuple[Target, type], ExpanderFn]: + """Return a shallow copy of the current registrations. + + The save half of a save/restore seam that supports test + baseline-isolation: capture the baseline with this and hand it back to + :meth:`restore` afterward, so the built-in expanders registered at import + survive an isolating fixture that would otherwise :meth:`clear` them + permanently. + + The returned dict is a fresh mapping (mutating it does not affect the + registry), keyed by the same ``(target, operator)`` tuples. + """ + return dict(self._expanders) + + def restore(self, snapshot: dict[tuple[Target, type], ExpanderFn]) -> None: + """Replace all registrations with those captured by :meth:`snapshot`. + + The restore half of the save/restore seam that supports test + baseline-isolation. Drops every current entry and re-installs exactly the + *snapshot* contents, so a fixture can return the registry to a previously + captured baseline regardless of what its body registered or cleared. + """ + self._expanders.clear() + self._expanders.update(snapshot) + def __contains__(self, key: tuple[Target, type]) -> bool: """Whether an *exact* ``(target, operator)`` entry is registered. @@ -280,8 +350,9 @@ def __bool__(self) -> bool: #: The process-wide registry the :func:`register` decorator writes to and the -#: :class:`ExpandOperators` pass reads from. Empty as of this issue, so the pass -#: is a strict no-op. +#: :class:`ExpandOperators` pass reads from. The built-in expanders register into +#: it at import time via :mod:`giql.expanders`; the pass rewrites a node only when +#: an expander resolves here (and the operator is flagged ``GIQL_EXPAND``). REGISTRY = ExpanderRegistry() @@ -380,9 +451,10 @@ class sets ``GIQL_EXPAND = True`` *and* the registry resolves an expander for ``(target, operator type)`` through its fallback chain; otherwise the node is left untouched and the legacy ``*_sql`` emitter handles it. - The pass mutates and returns *expression* in place. **With no operator - flagged and an empty registry it is a strict no-op** and the emitted SQL is - byte-identical, so the existing suite is the migration oracle. + The pass mutates and returns *expression* in place. It touches only nodes + whose operator is flagged ``GIQL_EXPAND`` and resolves an expander; for every + other operator it is a no-op, leaving the emitted SQL byte-identical, so the + existing suite is the migration oracle. Parameters ---------- @@ -399,9 +471,9 @@ class sets ``GIQL_EXPAND = True`` *and* the registry resolves an expander for Returns ------- exp.Expression - The same *expression*, with opted-in operator nodes replaced by their - target-specific expansions (none, while every flag is off / the registry - is empty). + The same *expression*, with each opted-in operator node that resolves an + expander replaced by its target-specific expansion; nodes that are + unflagged or resolve no expander are left untouched. """ reg = registry if registry is not None else REGISTRY operators = _giql_operators() diff --git a/src/giql/expanders/__init__.py b/src/giql/expanders/__init__.py new file mode 100644 index 0000000..2628289 --- /dev/null +++ b/src/giql/expanders/__init__.py @@ -0,0 +1,37 @@ +"""Built-in operator expanders for epic #137. + +Importing this package registers every built-in expander as a side effect: +each submodule decorates its expander(s) with ``@register(...)`` at import +time, and this package imports all of them. The import is wired once (in +:mod:`giql.transpile`) so the process-wide ``REGISTRY`` is populated before the +first transpile. + +New operator modules are picked up automatically: drop a ``.py`` into +this package and it is imported here without editing this file. + +Modules whose name starts with ``_`` are skipped (private helpers, not +expanders). Submodules import in :func:`pkgutil.iter_modules` order, which sets +last-write-wins resolution-order precedence for overlapping registrations; an +import error here aborts the whole package import by design (a broken built-in +expander must not be silently skipped). +""" + +from __future__ import annotations + +import importlib +import pkgutil + +from giql.expander import REGISTRY + +for _module_info in pkgutil.iter_modules(__path__): + if _module_info.name.startswith("_"): + continue + importlib.import_module(f"{__name__}.{_module_info.name}") + +# Fail loudly if discovery registered nothing. Under zipimport or PEP-420 +# namespace-package layouts ``pkgutil.iter_modules`` can yield no submodules, +# silently leaving the registry unpopulated; assert at least one expander landed +# so zero-discovery surfaces here rather than as a mystery legacy-path +# fallthrough. The check is branch-agnostic: it names no one operator, so it +# holds in every wave-3 worktree regardless of which expanders ship. +assert len(REGISTRY) > 0, "giql.expanders auto-discovery registered no expanders" diff --git a/src/giql/expanders/distance.py b/src/giql/expanders/distance.py new file mode 100644 index 0000000..acd2c81 --- /dev/null +++ b/src/giql/expanders/distance.py @@ -0,0 +1,361 @@ +"""The generic DISTANCE operator expander (epic #137, step / issue #140). + +DISTANCE is the proof-of-concept that validates the expander protocol, the +registry dispatch, and the cross-target result-oracle workflow before the +harder operators migrate. It is the simplest operator — a single ``CASE`` +expression with no joins, CTEs, or per-target divergence — so a single +*generic* expander registered for :class:`~giql.targets.GenericTarget` serves +every target. DISTANCE emits identical SQL on DuckDB, DataFusion, and the +generic baseline, so no per-target override is needed. + +The CASE this expander builds matches +:meth:`giql.generators.base.BaseGIQLGenerator._generate_distance_case` exactly +(bedtools ``closest -d`` semantics): overlapping intervals report ``0``, +book-ended (adjacent) intervals report ``1``, and a raw half-open gap of ``N`` +bases reports ``N + 1``. The ``+ 1`` is applied to the absolute gap magnitude +before any directional sign, so a downstream book-ended pair reports ``+1`` and +an upstream one ``-1`` in signed mode. There are four shapes — the cartesian +product of unsigned/signed and non-stranded/stranded — preserved verbatim from +the legacy emitter. + +Because the returned CASE is reserialized by the active target's serializer +(rather than spliced in as a raw string, as the legacy ``giqldistance_sql`` +emitter did), the emitted text changes cosmetically — most visibly ``!=`` +renders as the SQL-standard ``<>``. The two are semantically identical. +""" + +from __future__ import annotations + +from sqlglot import exp +from sqlglot import parse_one + +from giql.dialect import GIQLDialect +from giql.expander import ExpansionContext +from giql.expander import register +from giql.expressions import GIQLDistance +from giql.resolver import ResolvedColumn +from giql.targets import GenericTarget + +__all__ = ["expand_distance"] + + +def _frag(fragment: str) -> exp.Expression: + """Parse one canonicalized SQL fragment into AST. + + The pass-1 :class:`~giql.resolver.ResolvedColumn` endpoints are SQL string + fragments (e.g. ``a."end"``, or ``'chr1'`` for a literal range) already + canonicalized in place by pass 2, so they are parsed — not rebuilt — back + into AST under the GIQL dialect. + + Parameters + ---------- + fragment : str + A canonicalized SQL fragment (a column reference or literal). + + Returns + ------- + exp.Expression + The parsed fragment as a sqlglot AST node. + """ + return parse_one(fragment, dialect=GIQLDialect) + + +def _gap(minuend: str, subtrahend: str) -> exp.Expression: + """Build ``(minuend - subtrahend + 1)`` — the bedtools-parity gap magnitude. + + Mirrors the legacy ``({start} - {end} + 1)`` fragment: the ``+ 1`` lifts a + book-ended (adjacent) pair from a raw half-open gap of ``0`` to a reported + distance of ``1``. + + Parameters + ---------- + minuend : str + The SQL fragment subtracted *from* (the left operand of the ``-``). + subtrahend : str + The SQL fragment subtracted (the right operand of the ``-``). + + Returns + ------- + exp.Expression + The parenthesized ``(minuend - subtrahend + 1)`` AST. + """ + diff = exp.Sub(this=_frag(minuend), expression=_frag(subtrahend)) + return exp.paren(exp.Add(this=diff, expression=exp.Literal.number(1))) + + +def _bool_param(param: exp.Expression | None) -> bool: + """Coerce an optional DISTANCE boolean argument to a Python ``bool``. + + Mirrors ``BaseGIQLGenerator._extract_bool_param`` in how it reads the + ``stranded`` / ``signed`` keyword arguments, with one *intentional* + divergence: for an :class:`sqlglot.exp.Boolean` node this returns + ``bool(param.this)`` (a true Python ``bool``), whereas the legacy + ``_extract_bool_param`` returns the raw ``param_expr.this`` (already a + ``bool`` in practice, but unhardened). The coercion is strictly safer — it + guarantees a ``bool`` regardless of what the parser stored — and never + changes the observed branch selection, so the two stay behaviorally + equivalent. See TODO(#146): folding these two readers together. + + Parameters + ---------- + param : exp.Expression | None + The ``stranded`` or ``signed`` argument node, or ``None`` if absent. + + Returns + ------- + bool + The coerced Python boolean (``False`` when the argument is absent). + """ + if not param: + return False + if isinstance(param, exp.Boolean): + return bool(param.this) + return str(param).upper() in ("TRUE", "1", "YES") + + +def _operand(ctx: ExpansionContext, arg: str, position: str) -> ResolvedColumn: + """Return the resolved column for one DISTANCE interval operand. + + Reads the pass-1 metadata attached to the node. A deferred operand (a + literal range, or an unqualified column the resolver could not resolve) has + no column attached; this raises the historical literal-range diagnostic so + the public error contract is preserved. + + Parameters + ---------- + ctx : ExpansionContext + The expansion context carrying the node's pass-1 resolution. + arg : str + The operand slot key (``"this"`` or ``"expression"``). + position : str + Human-readable operand position (``"first"`` / ``"second"``) for the + diagnostic message. + + Returns + ------- + ResolvedColumn + The resolved column metadata for the operand. + + Raises + ------ + ValueError + If the operand was deferred (a literal range or unresolved column). + """ + # TODO(#146): this read-required-column-or-raise pattern is duplicated across + # expanders; hoist it to a shared ``ExpansionContext.require_column`` helper + # once a second expander needs it. + resolution = ctx.resolution + if resolution is not None: + resolved = resolution.column(arg) + if resolved is not None: + return resolved + raise ValueError(f"Literal range as {position} argument not yet supported") + + +def _unsigned_distance(col_a: ResolvedColumn, col_b: ResolvedColumn) -> exp.Expression: + """Branch 1: unsigned (absolute) non-stranded distance, returning ``|gap| + 1``.""" + return _wrap_overlap_case( + col_a, + col_b, + downstream=_gap(col_b.start, col_a.end), + upstream=_gap(col_a.start, col_b.end), + ) + + +def _signed_distance(col_a: ResolvedColumn, col_b: ResolvedColumn) -> exp.Expression: + """Branch 2: signed non-stranded distance. + + ``+`` downstream (B after A), ``-`` upstream (B before A). + """ + return _wrap_overlap_case( + col_a, + col_b, + downstream=_gap(col_b.start, col_a.end), + upstream=exp.Neg(this=_gap(col_a.start, col_b.end)), + ) + + +def _stranded_distance( + col_a: ResolvedColumn, col_b: ResolvedColumn, signed: bool +) -> exp.Expression: + """Branches 3 & 4: stranded distance, flipping sign on A's ``-`` strand. + + The downstream and upstream gaps each become a nested ``CASE`` keyed on + ``strand_a``. The ``signed`` flag additionally layers the directional sign + on top of the strand flip, exactly as the legacy emitter's two stranded + branches do. + """ + strand_a = col_a.strand + strand_b = col_b.strand + assert strand_a is not None and strand_b is not None # gated by caller + + down_gap = _gap(col_b.start, col_a.end) + up_gap = _gap(col_a.start, col_b.end) + + # Downstream (B after A) is identical across the signed and unsigned arms: + # positive by default, flipped negative on A's '-' strand. Only *upstream* + # differs between the arms, so hoist downstream and compute upstream per arm. + downstream = _strand_flip_case( + strand_a, neg=exp.Neg(this=down_gap), pos=down_gap.copy() + ) + if signed: + # Stranded + signed: upstream (B before A) is negative by default but + # flips positive on A's '-' strand (the directional sign layered on top + # of the strand flip). + upstream = _strand_flip_case( + strand_a, neg=up_gap, pos=exp.Neg(this=up_gap.copy()) + ) + else: + # Stranded but not signed: upstream carries the strand flip only. + upstream = _strand_flip_case( + strand_a, neg=exp.Neg(this=up_gap), pos=up_gap.copy() + ) + + case = _wrap_overlap_case(col_a, col_b, downstream=downstream, upstream=upstream) + # Prepend the strand-validity guards ahead of the overlap guards. The WHEN + # order matters: chrom mismatch, then strand NULL/'.'/'?', then overlap. + return _prepend_strand_guards(case, strand_a, strand_b) + + +def _strand_flip_case( + strand_a: str, neg: exp.Expression, pos: exp.Expression +) -> exp.Expression: + """Build ``CASE WHEN strand_a = '-' THEN neg ELSE pos END``.""" + return ( + exp.Case() + .when(exp.EQ(this=_frag(strand_a), expression=exp.Literal.string("-")), neg) + .else_(pos) + ) + + +def _wrap_overlap_case( + col_a: ResolvedColumn, + col_b: ResolvedColumn, + downstream: exp.Expression, + upstream: exp.Expression, +) -> exp.Expression: + """Build the shared distance CASE skeleton common to all four branches. + + ``CASE WHEN chrom_a != chrom_b THEN NULL WHEN THEN 0 WHEN + end_a <= start_b THEN ELSE END``. + + Parameters + ---------- + col_a, col_b : ResolvedColumn + The resolved A and B interval operands. + downstream : exp.Expression + The B-after-A gap expression (already carrying any sign / strand flip). + upstream : exp.Expression + The B-before-A gap expression (already carrying any sign / strand flip). + + Returns + ------- + exp.Expression + The assembled distance ``CASE`` expression. + """ + overlap = exp.and_( + exp.LT(this=_frag(col_a.start), expression=_frag(col_b.end)), + exp.GT(this=_frag(col_a.end), expression=_frag(col_b.start)), + ) + chrom_mismatch = exp.NEQ(this=_frag(col_a.chrom), expression=_frag(col_b.chrom)) + end_a_le_start_b = exp.LTE(this=_frag(col_a.end), expression=_frag(col_b.start)) + return ( + exp.Case() + .when(chrom_mismatch, exp.Null()) + .when(overlap, exp.Literal.number(0)) + .when(end_a_le_start_b, downstream) + .else_(upstream) + ) + + +def _prepend_strand_guards( + case: exp.Case, strand_a: str, strand_b: str +) -> exp.Case: + """Insert the strand-validity WHEN guards after the chrom guard. + + Distance is undefined for an unstranded (``'.'``/``'?'``) or missing strand, + matching the legacy emitter. + + Parameters + ---------- + case : exp.Case + The overlap/gap ``CASE`` to prepend the strand guards onto (mutated in + place). + strand_a, strand_b : str + The A and B strand-column SQL fragments. + + Returns + ------- + exp.Case + The same *case*, whose WHEN order is now: chrom mismatch -> NULL, either + strand NULL -> NULL, ``strand_a`` is ``'.'``/``'?'`` -> NULL, + ``strand_b`` is ``'.'``/``'?'`` -> NULL, then the original overlap/gap + branches. + """ + sa = _frag(strand_a) + sb = _frag(strand_b) + null_guard = exp.condition(exp.Is(this=sa, expression=exp.Null())).or_( + exp.Is(this=sb.copy(), expression=exp.Null()) + ) + a_unstranded = exp.condition( + exp.EQ(this=sa.copy(), expression=exp.Literal.string(".")) + ).or_(exp.EQ(this=sa.copy(), expression=exp.Literal.string("?"))) + b_unstranded = exp.condition( + exp.EQ(this=sb.copy(), expression=exp.Literal.string(".")) + ).or_(exp.EQ(this=sb.copy(), expression=exp.Literal.string("?"))) + + guards = [ + exp.If(this=null_guard, true=exp.Null()), + exp.If(this=a_unstranded, true=exp.Null()), + exp.If(this=b_unstranded, true=exp.Null()), + ] + # The existing WHENs keep their order; the strand guards slot in right after + # the leading chrom-mismatch guard. + existing = case.args["ifs"] + case.set("ifs", existing[:1] + guards + existing[1:]) + return case + + +# KEEP IN SYNC: this expander and +# ``BaseGIQLGenerator._generate_distance_case`` (base.py) build the *same* +# distance CASE by two routes. The legacy method is retained only because +# NEAREST still calls it for its ORDER BY / filter math; once NEAREST migrates +# to the expander path that method can be deleted and this duplication retired. +# Until then, any change to the distance math here must be mirrored there (and +# vice versa). The parity test in tests/test_distance_udf.py guards the drift. +@register(GenericTarget, GIQLDistance) +def expand_distance(node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: + """Expand a ``GIQLDistance`` node into a standard SQL ``CASE`` expression. + + The portable expander registered for every target. Reads ``stranded`` / + ``signed`` from the node and the pass-1 resolved interval operands, then + builds the matching one of the four CASE shapes + (unsigned/signed x non-stranded/stranded). + + Parameters + ---------- + node : exp.Expression + The ``GIQLDistance`` operator node being expanded. + ctx : ExpansionContext + The expansion context carrying the node's pass-1 resolution. + + Returns + ------- + exp.Expression + The distance ``CASE`` that replaces the operator node and is rendered by + the active target's serializer. + """ + stranded = _bool_param(node.args.get("stranded")) + signed = _bool_param(node.args.get("signed")) + + col_a = _operand(ctx, "this", "first") + col_b = _operand(ctx, "expression", "second") + + # Strand columns are consumed only in stranded mode, and only when both + # operands actually carry a strand fragment — mirroring the legacy emitter's + # `strand_a is None or strand_b is None` fall-through to the unstranded path. + if stranded and col_a.strand is not None and col_b.strand is not None: + return _stranded_distance(col_a, col_b, signed=signed) + if signed: + return _signed_distance(col_a, col_b) + return _unsigned_distance(col_a, col_b) diff --git a/src/giql/expressions.py b/src/giql/expressions.py index ce939dc..0c6bc80 100644 --- a/src/giql/expressions.py +++ b/src/giql/expressions.py @@ -321,7 +321,10 @@ class GIQLDistance(exp.Func): } GIQL_CANONICALIZE = _CANONICALIZE - GIQL_EXPAND = _EXPAND + # Migrated to the registry's AST-expansion path (epic #137, issue #140): the + # generic expander in giql.expanders.distance builds the CASE; the legacy + # giqldistance_sql emitter is gone. + GIQL_EXPAND = True GIQL_SLOTS = ( SlotSpec("this", frozenset({"column"}), required=True), diff --git a/src/giql/generators/base.py b/src/giql/generators/base.py index 9038369..a4be873 100644 --- a/src/giql/generators/base.py +++ b/src/giql/generators/base.py @@ -5,7 +5,6 @@ from giql.canonical import decanonical_start from giql.expressions import Contains from giql.expressions import GIQLDisjoin -from giql.expressions import GIQLDistance from giql.expressions import GIQLNearest from giql.expressions import Intersects from giql.expressions import SpatialSetPredicate @@ -493,85 +492,6 @@ def _disjoin_passthrough( f't.* REPLACE ({pt_start} AS "{target_start}", {pt_end} AS "{target_end}")' ) - def giqldistance_sql(self, expression: GIQLDistance) -> str: - """Generate SQL CASE expression for DISTANCE function. - - Reads the :class:`~giql.resolver.ResolvedColumn` metadata that - ``ResolveOperatorRefs`` (pass 1) attaches to each interval operand. When - the pass deferred an operand (a literal range, or an unqualified column) - the emitter raises the historical literal-range diagnostic. - - Coordinate canonicalization is owned by ``CanonicalizeCoordinates`` - (pass 2, issue #123): the resolved metadata's endpoints are already - canonicalized in place, so the emitter consumes them verbatim. - - :param expression: - GIQLDistance expression node - :return: - SQL CASE expression string calculating genomic distance - """ - stranded = self._extract_bool_param(expression.args.get("stranded")) - signed = self._extract_bool_param(expression.args.get("signed")) - - col_a = self._distance_operand(expression, "this", "first") - col_b = self._distance_operand(expression, "expression", "second") - - # Strand columns are consumed only in stranded mode (matching the - # historical 3-tuple vs 4-tuple branching in the legacy emitter). - strand_a = col_a.strand if stranded else None - strand_b = col_b.strand if stranded else None - - # Distance math below assumes 0-based half-open. Input canonicalization is - # owned by CanonicalizeCoordinates (pass 2, issue #123): each operand's - # start/end fragments are canonicalized in place by the pass, so the - # emitter consumes them verbatim with no in-emitter canonicalization. The - # returned distance is an encoding-invariant base count, so it needs no - # output de-canonicalization. - - # Generate CASE expression - return self._generate_distance_case( - col_a.chrom, - col_a.start, - col_a.end, - strand_a, - col_b.chrom, - col_b.start, - col_b.end, - strand_b, - stranded=stranded, - signed=signed, - ) - - def _distance_operand( - self, expression: GIQLDistance, arg: str, position: str - ) -> ResolvedColumn: - """Return the :class:`ResolvedColumn` for one DISTANCE interval operand. - - Reads the metadata attached by ``ResolveOperatorRefs`` (pass 1). When the - pass deferred the operand — a literal range or an unqualified column it - could not resolve — no column is attached and this raises the historical - literal-range diagnostic. - - :param expression: - GIQLDistance expression node - :param arg: - The operand arg key (``"this"`` or ``"expression"``) - :param position: - Human-readable operand position for the error message (``"first"`` - or ``"second"``) - :return: - The resolved column operand - :raises ValueError: - If the operand is a literal range rather than a column reference - """ - resolution = expression.meta.get(META_KEY) - if isinstance(resolution, OperatorResolution): - resolved = resolution.column(arg) - if resolved is not None: - return resolved - - raise ValueError(f"Literal range as {position} argument not yet supported") - def _generate_distance_case( self, chrom_a: str, @@ -587,6 +507,17 @@ def _generate_distance_case( ) -> str: """Generate SQL CASE expression for distance calculation. + .. note:: + + KEEP IN SYNC: this method and the AST builder in + ``giql.expanders.distance`` (``expand_distance``) produce the *same* + distance CASE by two routes. DISTANCE itself migrated to the expander + (epic #137, issue #140); this method survives only because NEAREST + still calls it for its ORDER BY / filter math. Once NEAREST migrates, + delete this method and retire the duplication. Until then, any change + to the distance math here must be mirrored in the expander (and vice + versa); the parity test in tests/test_distance_udf.py guards drift. + Distances follow bedtools ``closest -d`` semantics: overlapping intervals report ``0``, book-ended (adjacent) intervals where ``A.end == B.start`` in half-open coordinates report ``1``, and a raw diff --git a/src/giql/transpile.py b/src/giql/transpile.py index 9ef2100..81ee8e1 100644 --- a/src/giql/transpile.py +++ b/src/giql/transpile.py @@ -11,6 +11,7 @@ from sqlglot import parse_one +import giql.expanders # noqa: F401 from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect from giql.expander import ExpandOperators @@ -195,18 +196,19 @@ def transpile( with _reraise_as_value_error("Resolution error"): ast = resolve_operator_refs(ast, tables_container) - # Pass 2 of the normalization pipeline (epic #114): synthesize canonical - # __giql_canon_* wrapper CTEs for non-canonical interval operands of - # opted-in operators (GIQL_CANONICALIZE). No operator opts in yet, so this - # is a strict no-op until the per-operator port issues (#122, #123) land. + # Pass 2 of the normalization pipeline (epic #114): for each operator that + # opts into GIQL_CANONICALIZE, rewrite its non-canonical interval operands — + # synthesizing canonical __giql_canon_* wrapper CTEs — so downstream passes + # and emitters see canonical 0-based half-open coordinates. with _reraise_as_value_error("Canonicalization error"): ast = canonicalize_coordinates(ast) # Pass 3 of the normalization pipeline (epic #137): replace each opted-in # GIQL operator node with the AST its registered expander produces for the - # active target. No operator sets GIQL_EXPAND and the registry is empty, so - # this is a strict no-op until the per-operator migration issues land; the - # legacy *_sql emitters on the generator remain the fallback. + # active target. Each operator that opts in (GIQL_EXPAND) with a registered + # expander is rewritten here; any operator that is unflagged or has no + # registered expander falls through to its legacy *_sql emitter on the + # generator. expand_operators = ExpandOperators(target, tables_container) with _reraise_as_value_error("Expansion error"): ast = expand_operators.transform(ast) diff --git a/tests/generators/test_base.py b/tests/generators/test_base.py index f95e91b..a7141c4 100644 --- a/tests/generators/test_base.py +++ b/tests/generators/test_base.py @@ -11,31 +11,39 @@ from sqlglot import exp from sqlglot import parse_one +import giql # noqa: F401 (ensures the built-in expanders are registered) from giql import Table from giql import transpile from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect +from giql.expander import ExpandOperators from giql.expressions import GIQLNearest from giql.generators import BaseGIQLGenerator from giql.resolver import resolve_operator_refs from giql.table import Tables +from giql.targets import GenericTarget def _generate_through_passes(sql: str, tables: Tables) -> str: - """Parse, run normalization passes 1 and 2, then generate SQL. + """Parse, run normalization passes 1-3, then generate SQL. Coordinate canonicalization for operator operands moved out of the emitter and - into the CanonicalizeCoordinates pass (issue #123). Emitter-level tests that - pin canonicalized output must therefore run both passes before generating, - exactly as :func:`giql.transpile.transpile` does, rather than calling - ``generate`` on a bare parsed AST (which would skip canonicalization). This - helper is used where the full ``transpile`` pipeline would otherwise rewrite - the node away (a column-to-column ``INTERSECTS`` is turned into a binned - equi-join before the predicate emitter runs). + into the CanonicalizeCoordinates pass (issue #123), and DISTANCE generation + itself moved onto the registry's AST-expansion pass (epic #137, issue #140). + Emitter-level tests that pin canonicalized / expanded output must therefore + run all three passes before generating, exactly as + :func:`giql.transpile.transpile` does, rather than calling ``generate`` on a + bare parsed AST. The expansion pass only touches operators that opt in + (``GIQL_EXPAND``); operators still on the legacy emitter (NEAREST, the + spatial predicates) pass through untouched. This helper is used where the + full ``transpile`` pipeline would otherwise rewrite the node away (a + column-to-column ``INTERSECTS`` is turned into a binned equi-join before the + predicate emitter runs). """ ast = parse_one(sql, dialect=GIQLDialect) ast = resolve_operator_refs(ast, tables) ast = canonicalize_coordinates(ast) + ast = ExpandOperators(GenericTarget(), tables).transform(ast) return BaseGIQLGenerator(tables=tables).generate(ast) @@ -669,7 +677,7 @@ def test_giqlnearest_sql_parameter_handling_property( def test_giqldistance_sql_basic(self, tables_with_two_tables): """ GIVEN a GIQLDistance with two column references - WHEN giqldistance_sql is called + WHEN the DISTANCE node is expanded THEN CASE expression for distance calculation is generated. """ sql = ( @@ -680,7 +688,7 @@ def test_giqldistance_sql_basic(self, tables_with_two_tables): output = _generate_through_passes(sql, tables_with_two_tables) expected = ( - 'SELECT CASE WHEN a."chrom" != b."chrom" THEN NULL ' + 'SELECT CASE WHEN a."chrom" <> b."chrom" THEN NULL ' 'WHEN a."start" < b."end" AND a."end" > b."start" ' 'THEN 0 WHEN a."end" <= b."start" ' 'THEN (b."start" - a."end" + 1) ' @@ -692,7 +700,7 @@ def test_giqldistance_sql_basic(self, tables_with_two_tables): def test_giqldistance_sql_stranded(self, tables_with_two_tables): """ GIVEN a GIQLDistance with stranded := true - WHEN giqldistance_sql is called + WHEN the DISTANCE node is expanded THEN Strand-aware distance CASE expression is generated. """ sql = ( @@ -703,7 +711,7 @@ def test_giqldistance_sql_stranded(self, tables_with_two_tables): output = _generate_through_passes(sql, tables_with_two_tables) expected = ( - 'SELECT CASE WHEN a."chrom" != b."chrom" THEN NULL ' + 'SELECT CASE WHEN a."chrom" <> b."chrom" THEN NULL ' 'WHEN a."strand" IS NULL OR b."strand" IS NULL THEN NULL ' "WHEN a.\"strand\" = '.' OR a.\"strand\" = '?' THEN NULL " "WHEN b.\"strand\" = '.' OR b.\"strand\" = '?' THEN NULL " @@ -723,7 +731,7 @@ def test_giqldistance_sql_stranded(self, tables_with_two_tables): def test_giqldistance_sql_signed(self, tables_with_two_tables): """ GIVEN a GIQLDistance with signed := true - WHEN giqldistance_sql is called + WHEN the DISTANCE node is expanded THEN Signed distance CASE expression is generated. """ sql = ( @@ -734,7 +742,7 @@ def test_giqldistance_sql_signed(self, tables_with_two_tables): output = _generate_through_passes(sql, tables_with_two_tables) expected = ( - 'SELECT CASE WHEN a."chrom" != b."chrom" THEN NULL ' + 'SELECT CASE WHEN a."chrom" <> b."chrom" THEN NULL ' 'WHEN a."start" < b."end" AND a."end" > b."start" ' 'THEN 0 WHEN a."end" <= b."start" ' 'THEN (b."start" - a."end" + 1) ' @@ -746,7 +754,7 @@ def test_giqldistance_sql_signed(self, tables_with_two_tables): def test_giqldistance_sql_stranded_and_signed(self, tables_with_two_tables): """ GIVEN a GIQLDistance with both stranded and signed := true - WHEN giqldistance_sql is called + WHEN the DISTANCE node is expanded THEN Combined stranded+signed distance expression is generated. """ sql = ( @@ -758,7 +766,7 @@ def test_giqldistance_sql_stranded_and_signed(self, tables_with_two_tables): output = _generate_through_passes(sql, tables_with_two_tables) expected = ( - 'SELECT CASE WHEN a."chrom" != b."chrom" THEN NULL ' + 'SELECT CASE WHEN a."chrom" <> b."chrom" THEN NULL ' 'WHEN a."strand" IS NULL OR b."strand" IS NULL THEN NULL ' "WHEN a.\"strand\" = '.' OR a.\"strand\" = '?' THEN NULL " "WHEN b.\"strand\" = '.' OR b.\"strand\" = '?' THEN NULL " @@ -783,7 +791,7 @@ def test_giqldistance_canonicalizes_closed_ends_apart_from_gap_parity( Given: Two 0-based closed-interval tables and DISTANCE. When: - giqldistance_sql is called. + the DISTANCE node is expanded. Then: It should canonicalize each table-side end as (end + 1) for the closed->half-open conversion, distinct from the bedtools-parity @@ -804,7 +812,7 @@ def test_giqldistance_canonicalizes_closed_ends_apart_from_gap_parity( # Assert expected = ( - 'SELECT CASE WHEN a."chrom" != b."chrom" THEN NULL ' + 'SELECT CASE WHEN a."chrom" <> b."chrom" THEN NULL ' 'WHEN a."start" < (b."end" + 1) ' 'AND (a."end" + 1) > b."start" THEN 0 ' 'WHEN (a."end" + 1) <= b."start" ' @@ -958,7 +966,7 @@ def test_giqlnearest_closed_interval_does_not_double_count_plus_one(self): def test_giqldistance_sql_literal_first_arg_error(self, tables_with_two_tables): """ GIVEN a GIQLDistance with literal range as first argument - WHEN giqldistance_sql is called + WHEN the DISTANCE node is expanded THEN ValueError is raised indicating literals not supported. """ sql = "SELECT DISTANCE('chr1:1000-2000', b.interval) as dist FROM features_b b" @@ -966,15 +974,15 @@ def test_giqldistance_sql_literal_first_arg_error(self, tables_with_two_tables): ast = resolve_operator_refs(ast, tables_with_two_tables) ast = canonicalize_coordinates(ast) - generator = BaseGIQLGenerator(tables=tables_with_two_tables) + expander = ExpandOperators(GenericTarget(), tables_with_two_tables) with pytest.raises(ValueError, match="Literal range as first argument"): - generator.generate(ast) + expander.transform(ast) def test_giqldistance_sql_literal_second_arg_error(self, tables_with_two_tables): """ GIVEN a GIQLDistance with literal range as second argument - WHEN giqldistance_sql is called + WHEN the DISTANCE node is expanded THEN ValueError is raised indicating literals not supported. """ sql = "SELECT DISTANCE(a.interval, 'chr1:1000-2000') as dist FROM features_a a" @@ -982,10 +990,10 @@ def test_giqldistance_sql_literal_second_arg_error(self, tables_with_two_tables) ast = resolve_operator_refs(ast, tables_with_two_tables) ast = canonicalize_coordinates(ast) - generator = BaseGIQLGenerator(tables=tables_with_two_tables) + expander = ExpandOperators(GenericTarget(), tables_with_two_tables) with pytest.raises(ValueError, match="Literal range as second argument"): - generator.generate(ast) + expander.transform(ast) def test_giqlnearest_sql_missing_outer_table_error( self, tables_with_peaks_and_genes @@ -1144,7 +1152,7 @@ def test_giqldistance_stranded_param_truthy_values_property( ): """ GIVEN a GIQLDistance with stranded parameter in various truthy representations - WHEN giqldistance_sql is called + WHEN the DISTANCE node is expanded THEN The parameter is parsed as True and strand-aware distance is calculated. """ sql = ( @@ -1167,7 +1175,7 @@ def test_giqldistance_stranded_param_falsy_values_property( ): """ GIVEN a GIQLDistance with stranded parameter in various falsy representations - WHEN giqldistance_sql is called + WHEN the DISTANCE node is expanded THEN The parameter is parsed as False and basic distance is calculated. """ sql = ( @@ -1189,7 +1197,7 @@ def test_giqldistance_signed_param_truthy_values_property( ): """ GIVEN a GIQLDistance with signed parameter in various truthy representations - WHEN giqldistance_sql is called + WHEN the DISTANCE node is expanded THEN The parameter is parsed as True and signed distance is calculated. """ sql = ( @@ -1211,7 +1219,7 @@ def test_giqldistance_signed_param_falsy_values_property( ): """ GIVEN a GIQLDistance with signed parameter in various falsy representations - WHEN giqldistance_sql is called + WHEN the DISTANCE node is expanded THEN The parameter is parsed as False and unsigned distance is calculated. """ sql = ( @@ -1630,7 +1638,7 @@ def test_giqldistance_should_canonicalize_table_columns_for_each_convention( # Assert expected = ( - 'SELECT CASE WHEN a."chrom" != b."chrom" THEN NULL ' + 'SELECT CASE WHEN a."chrom" <> b."chrom" THEN NULL ' f"WHEN {start_a} < {end_b} AND {end_a} > {start_b} THEN 0 " f"WHEN {end_a} <= {start_b} THEN ({start_b} - {end_a} + 1) " f"ELSE ({start_a} - {end_b} + 1) END AS dist " @@ -1664,7 +1672,7 @@ def test_giqldistance_should_canonicalize_each_side_when_conventions_differ( # Assert expected = ( - 'SELECT CASE WHEN a."chrom" != b."chrom" THEN NULL ' + 'SELECT CASE WHEN a."chrom" <> b."chrom" THEN NULL ' 'WHEN a."start" < b."end" AND a."end" > (b."start" - 1) THEN 0 ' 'WHEN a."end" <= (b."start" - 1) ' 'THEN ((b."start" - 1) - a."end" + 1) ' diff --git a/tests/test_distance_transpilation.py b/tests/test_distance_transpilation.py index 944a3db..7169d87 100644 --- a/tests/test_distance_transpilation.py +++ b/tests/test_distance_transpilation.py @@ -5,27 +5,33 @@ from sqlglot import parse_one +import giql # noqa: F401 (ensures the built-in expanders are registered) from giql import transpile from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect +from giql.expander import ExpandOperators from giql.generators import BaseGIQLGenerator from giql.resolver import resolve_operator_refs from giql.table import Tables +from giql.targets import GenericTarget def _generate(sql: str, tables: Tables | None = None) -> str: - """Parse, run normalization passes 1 and 2, then generate SQL. + """Parse, run normalization passes 1-3, then generate SQL. DISTANCE operand resolution and coordinate canonicalization moved out of the emitter and into the ResolveOperatorRefs / CanonicalizeCoordinates passes - (epic #114, issues #119 / #123). Emitter-level tests must run both passes - before generating, exactly as :func:`giql.transpile.transpile` does, rather - than calling ``generate`` on a bare parsed AST. + (epic #114, issues #119 / #123). DISTANCE generation itself then moved onto + the registry's AST-expansion pass (epic #137, issue #140), so the operator + node must be expanded before generation too. Emitter-level tests run all + three passes, exactly as :func:`giql.transpile.transpile` does, rather than + calling ``generate`` on a bare parsed AST. """ tables = tables or Tables() ast = parse_one(sql, dialect=GIQLDialect) ast = resolve_operator_refs(ast, tables) ast = canonicalize_coordinates(ast) + ast = ExpandOperators(GenericTarget(), tables).transform(ast) return BaseGIQLGenerator(tables=tables).generate(ast) @@ -45,7 +51,7 @@ def test_distance_transpilation_duckdb(self): output = _generate(sql) - expected = """SELECT CASE WHEN a."chrom" != b."chrom" THEN NULL WHEN a."start" < b."end" AND a."end" > b."start" THEN 0 WHEN a."end" <= b."start" THEN (b."start" - a."end" + 1) ELSE (a."start" - b."end" + 1) END AS dist FROM features_a AS a CROSS JOIN features_b AS b""" + expected = """SELECT CASE WHEN a."chrom" <> b."chrom" THEN NULL WHEN a."start" < b."end" AND a."end" > b."start" THEN 0 WHEN a."end" <= b."start" THEN (b."start" - a."end" + 1) ELSE (a."start" - b."end" + 1) END AS dist FROM features_a AS a CROSS JOIN features_b AS b""" assert output == expected, f"Expected:\n{expected}\n\nGot:\n{output}" @@ -62,7 +68,7 @@ def test_distance_transpilation_sqlite(self): output = _generate(sql) - expected = """SELECT CASE WHEN a."chrom" != b."chrom" THEN NULL WHEN a."start" < b."end" AND a."end" > b."start" THEN 0 WHEN a."end" <= b."start" THEN (b."start" - a."end" + 1) ELSE (a."start" - b."end" + 1) END AS dist FROM features_a AS a, features_b AS b""" + expected = """SELECT CASE WHEN a."chrom" <> b."chrom" THEN NULL WHEN a."start" < b."end" AND a."end" > b."start" THEN 0 WHEN a."end" <= b."start" THEN (b."start" - a."end" + 1) ELSE (a."start" - b."end" + 1) END AS dist FROM features_a AS a, features_b AS b""" assert output == expected, f"Expected:\n{expected}\n\nGot:\n{output}" @@ -79,7 +85,7 @@ def test_distance_transpilation_postgres(self): output = _generate(sql) - expected = """SELECT CASE WHEN a."chrom" != b."chrom" THEN NULL WHEN a."start" < b."end" AND a."end" > b."start" THEN 0 WHEN a."end" <= b."start" THEN (b."start" - a."end" + 1) ELSE (a."start" - b."end" + 1) END AS dist FROM features_a AS a CROSS JOIN features_b AS b""" + expected = """SELECT CASE WHEN a."chrom" <> b."chrom" THEN NULL WHEN a."start" < b."end" AND a."end" > b."start" THEN 0 WHEN a."end" <= b."start" THEN (b."start" - a."end" + 1) ELSE (a."start" - b."end" + 1) END AS dist FROM features_a AS a CROSS JOIN features_b AS b""" assert output == expected, f"Expected:\n{expected}\n\nGot:\n{output}" @@ -119,7 +125,7 @@ def test_distance_transpilation_signed_duckdb(self): # Signed distance: upstream (B before A) returns negative, # downstream (B after A) returns positive expected = ( - 'SELECT CASE WHEN a."chrom" != b."chrom" THEN NULL ' + 'SELECT CASE WHEN a."chrom" <> b."chrom" THEN NULL ' 'WHEN a."start" < b."end" AND a."end" > b."start" THEN 0 ' 'WHEN a."end" <= b."start" THEN (b."start" - a."end" + 1) ' 'ELSE -(a."start" - b."end" + 1) END AS dist ' diff --git a/tests/test_distance_udf.py b/tests/test_distance_udf.py index 65f9c19..a4a4351 100644 --- a/tests/test_distance_udf.py +++ b/tests/test_distance_udf.py @@ -6,27 +6,42 @@ import duckdb import pytest +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st +from sqlglot import exp from sqlglot import parse_one +import giql # noqa: F401 (ensures the built-in expanders are registered) from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect +from giql.expander import ExpandOperators from giql.generators import BaseGIQLGenerator from giql.resolver import resolve_operator_refs from giql.table import Tables +from giql.targets import GenericTarget + +#: This module executes generated SQL against a real in-memory DuckDB, so every +#: test here is an integration test (the marker is registered in pyproject). +pytestmark = pytest.mark.integration def _generate(sql: str) -> str: - """Parse, run normalization passes 1 and 2, then generate SQL. + """Parse, run normalization passes 1-3, then generate SQL. DISTANCE operand resolution and coordinate canonicalization moved out of the emitter and into the ResolveOperatorRefs / CanonicalizeCoordinates passes - (epic #114, issues #119 / #123). These behavioral tests must run both passes - before generating, exactly as :func:`giql.transpile.transpile` does, rather - than calling ``generate`` on a bare parsed AST. + (epic #114, issues #119 / #123), and DISTANCE generation itself moved onto + the registry's AST-expansion pass (epic #137, issue #140). These behavioral + tests must run all three passes before generating, exactly as + :func:`giql.transpile.transpile` does, rather than calling ``generate`` on a + bare parsed AST. """ + tables = Tables() ast = parse_one(sql, dialect=GIQLDialect) - ast = resolve_operator_refs(ast, Tables()) + ast = resolve_operator_refs(ast, tables) ast = canonicalize_coordinates(ast) + ast = ExpandOperators(GenericTarget(), tables).transform(ast) return BaseGIQLGenerator().generate(ast) @@ -696,3 +711,216 @@ def test_stranded_signed_null_strand_returns_null(self): # Assert assert result is None, f"Expected NULL for '.' strand, got {result}" + + +# --- Drift guard: expand_distance vs the legacy _generate_distance_case ------- +# +# DISTANCE moved onto the AST-expansion pass (expand_distance), but +# BaseGIQLGenerator._generate_distance_case is retained because NEAREST still +# calls it. The two compute the same distance by different routes; these tests +# pin that they stay semantically equivalent until NEAREST migrates and the +# legacy method can be deleted. + +#: Column expressions both routes are evaluated over. ``a``/``b`` are the two +#: operand relations supplied by the parity harness's VALUES row. +_CHROM_A, _START_A, _END_A, _STRAND_A = 'a."chrom"', 'a."start"', 'a."end"', 'a."strand"' +_CHROM_B, _START_B, _END_B, _STRAND_B = 'b."chrom"', 'b."start"', 'b."end"', 'b."strand"' + +#: The four DISTANCE shapes: (id, stranded, signed). +_SHAPES = [ + ("unsigned_nonstranded", False, False), + ("signed_nonstranded", False, True), + ("unsigned_stranded", True, False), + ("signed_stranded", True, True), +] + +#: Rows exercised by the parity test: ordinary downstream/upstream gaps, +#: book-ended pairs, overlaps, a chrom mismatch, and strand-invalid ('.'/'?'/ +#: NULL) rows so every WHEN branch of both routes is covered. +_PARITY_ROWS = [ + ("chr1", 100, 200, "+", "chr1", 300, 400, "+"), # downstream gap + ("chr1", 300, 400, "+", "chr1", 100, 200, "+"), # upstream gap + ("chr1", 100, 200, "+", "chr1", 200, 300, "+"), # book-ended downstream + ("chr1", 200, 300, "+", "chr1", 100, 200, "+"), # book-ended upstream + ("chr1", 100, 200, "+", "chr1", 150, 250, "+"), # overlap + ("chr1", 100, 200, "-", "chr1", 300, 400, "+"), # '-' strand A, downstream + ("chr1", 300, 400, "-", "chr1", 100, 200, "+"), # '-' strand A, upstream + ("chr1", 100, 200, "+", "chr2", 300, 400, "+"), # chrom mismatch -> NULL + ("chr1", 100, 200, ".", "chr1", 300, 400, "+"), # strand A '.' -> NULL + ("chr1", 100, 200, "?", "chr1", 300, 400, "+"), # strand A '?' -> NULL + ("chr1", 100, 200, "+", "chr1", 300, 400, "."), # strand B '.' -> NULL + ("chr1", 100, 200, None, "chr1", 300, 400, "+"), # strand A NULL -> NULL +] + + +def _row_values_cte(row) -> str: + """Render one parity row as ``a``/``b`` relations via SELECT subqueries. + + *row* is ``(chrom_a, start_a, end_a, strand_a, chrom_b, ...)``; strands may + be ``None`` (rendered as SQL ``NULL``). + """ + ca, sa, ea, ta, cb, sb, eb, tb = row + + def _strand(value): + return "NULL" if value is None else f"'{value}'" + + a = ( + f"(SELECT '{ca}' AS chrom, {sa} AS start, {ea} AS \"end\", " + f"{_strand(ta)} AS strand) a" + ) + b = ( + f"(SELECT '{cb}' AS chrom, {sb} AS start, {eb} AS \"end\", " + f"{_strand(tb)} AS strand) b" + ) + return f"{a} CROSS JOIN {b}" + + +def _expander_distance_case(stranded: bool, signed: bool) -> str: + """Return the DISTANCE CASE the expander builds, isolated from its SELECT. + + Runs the real transpile passes over a DISTANCE query whose operands resolve + to ``a``/``b`` default columns, then lifts the generated CASE expression so + it can be re-embedded over arbitrary VALUES rows. + """ + args = ["a.interval", "b.interval"] + if stranded: + args.append("stranded := true") + if signed: + args.append("signed := true") + sql = ( + f"SELECT DISTANCE({', '.join(args)}) AS d " + "FROM (SELECT 'x' AS chrom, 0 AS start, 0 AS \"end\", '+' AS strand) a " + "CROSS JOIN (SELECT 'x' AS chrom, 0 AS start, 0 AS \"end\", '+' AS strand) b" + ) + tables = Tables() + ast = parse_one(sql, dialect=GIQLDialect) + ast = resolve_operator_refs(ast, tables) + ast = canonicalize_coordinates(ast) + ast = ExpandOperators(GenericTarget(), tables).transform(ast) + # The single projected expression is the expander's CASE. + return ast.find(exp.Select).expressions[0].this.sql() + + +def _legacy_distance_case(stranded: bool, signed: bool) -> str: + """Return the CASE the legacy _generate_distance_case builds for ``a``/``b``.""" + return BaseGIQLGenerator()._generate_distance_case( + _CHROM_A, + _START_A, + _END_A, + _STRAND_A if stranded else None, + _CHROM_B, + _START_B, + _END_B, + _STRAND_B if stranded else None, + stranded=stranded, + signed=signed, + ) + + +def _eval_case(conn, case_sql: str, row) -> object: + """Execute one distance CASE over one parity row, returning the scalar. + + Reuses the caller-supplied *conn* (a module-scoped in-memory DuckDB) rather + than opening a fresh connection per row — every row is a standalone + ``SELECT ... FROM (VALUES)`` against no persistent state, so one connection + serves the whole parity sweep. + """ + query = f"SELECT {case_sql} AS d FROM {_row_values_cte(row)}" + return conn.execute(query).fetchone()[0] + + +@pytest.fixture(scope="module") +def parity_conn(): + """A module-scoped in-memory DuckDB connection for the parity/property sweep. + + Opened once and shared across every parity row and Hypothesis example + instead of reconnecting per row; each evaluation is a self-contained + ``SELECT`` over inline ``VALUES``, so no per-row isolation is needed. + """ + conn = duckdb.connect(":memory:") + try: + yield conn + finally: + conn.close() + + +class TestDistanceExpanderLegacyParity: + """expand_distance and the retained _generate_distance_case agree row-for-row.""" + + @pytest.mark.parametrize( + "shape_id, stranded, signed", _SHAPES, ids=[s[0] for s in _SHAPES] + ) + def test_expander_matches_legacy_distance_case( + self, parity_conn, shape_id, stranded, signed + ): + """ + GIVEN the four DISTANCE shapes (unsigned/signed x non-stranded/stranded) + plus overlap, chrom-mismatch, and strand-invalid input rows + WHEN the same inputs run through expand_distance and the retained + _generate_distance_case + THEN both routes return the identical scalar for every row, pinning the + two distance implementations against drift until NEAREST migrates. + """ + # Arrange + expander_case = _expander_distance_case(stranded, signed) + legacy_case = _legacy_distance_case(stranded, signed) + + # Act & assert + for row in _PARITY_ROWS: + expander_result = _eval_case(parity_conn, expander_case, row) + legacy_result = _eval_case(parity_conn, legacy_case, row) + assert expander_result == legacy_result, ( + f"{shape_id}: expander {expander_result!r} != " + f"legacy {legacy_result!r} for row {row}" + ) + + +class TestDistanceExpanderProperties: + """Property-based invariants of the expander's distance CASE.""" + + @settings(max_examples=200, deadline=None) + @given( + start_a=st.integers(min_value=0, max_value=10_000), + len_a=st.integers(min_value=1, max_value=5_000), + start_b=st.integers(min_value=0, max_value=10_000), + len_b=st.integers(min_value=1, max_value=5_000), + ) + def test_distance_invariants_hold( + self, parity_conn, start_a, len_a, start_b, len_b + ): + """ + GIVEN random A and B intervals (start + positive length) + WHEN DISTANCE is evaluated unsigned, signed, and cross-chromosome + THEN unsigned == abs(signed), overlapping intervals report 0, a + non-overlapping same-chrom pair reports the half-open gap + 1 + (bedtools parity ground truth), and a cross-chromosome pair reports + NULL. + """ + # Arrange + end_a = start_a + len_a + end_b = start_b + len_b + same_chrom = ("chr1", start_a, end_a, "+", "chr1", start_b, end_b, "+") + cross_chrom = ("chr1", start_a, end_a, "+", "chr2", start_b, end_b, "+") + unsigned_case = _expander_distance_case(stranded=False, signed=False) + signed_case = _expander_distance_case(stranded=False, signed=True) + + # Act + unsigned = _eval_case(parity_conn, unsigned_case, same_chrom) + signed = _eval_case(parity_conn, signed_case, same_chrom) + cross = _eval_case(parity_conn, unsigned_case, cross_chrom) + + # Assert + assert unsigned == abs(signed) + overlaps = start_a < end_b and end_a > start_b + if overlaps: + assert unsigned == 0 + else: + # Ground truth (bedtools closest -d): the unsigned distance of two + # non-overlapping same-chrom intervals is the raw half-open gap plus + # one. This ties the property test to the +1 offset directly, so + # dropping the +1 from the expander would fail here (not only in the + # legacy-parity test). The gap is end_a..start_b downstream or + # end_b..start_a upstream, whichever is non-negative. + gap_in_bases = max(start_b - end_a, start_a - end_b) + assert unsigned == gap_in_bases + 1 + assert cross is None diff --git a/tests/test_expander.py b/tests/test_expander.py index ad625f1..196571b 100644 --- a/tests/test_expander.py +++ b/tests/test_expander.py @@ -53,34 +53,50 @@ def _expander(node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: return _expander +#: The registry contents at import — the built-in expanders registered by +#: ``giql.expanders`` for the already-migrated operators. The leak guards and +#: ``clean_registry`` treat this as the baseline rather than an empty registry, +#: so the real built-in registrations survive isolating fixtures and a leaking +#: test is still caught against the true baseline. +import giql.expanders # noqa: F401, E402 + +_REGISTRY_BASELINE = REGISTRY.snapshot() + + @pytest.fixture def clean_registry(): - """Isolate the process-wide REGISTRY, leaving it empty afterward. + """Isolate the process-wide REGISTRY, restoring its baseline afterward. - The registry is empty at import (the pass ships inert), so a test that opts - in clears on the way out; emptiness is asserted through the public - ``bool()``/``len()`` surface rather than private state. + Saves the import-time baseline (the built-in expanders), empties the registry + so a test sees only what it registers, and restores the baseline on the way + out through the public ``snapshot()``/``restore()`` seam — so a test that + registers a stand-in expander cannot leak it, and the built-in registrations + survive this fixture's isolation. """ - assert not REGISTRY, "REGISTRY was non-empty entering clean_registry" + saved = REGISTRY.snapshot() REGISTRY.clear() yield REGISTRY - REGISTRY.clear() + REGISTRY.restore(saved) @pytest.fixture(autouse=True) def _registry_leak_guard(): - """Assert the process-wide REGISTRY is empty at each test boundary. - - A leak guard: the registry is empty at import and must return to empty after - every test, so a test that registers without cleaning up (a leak that would - silently flip the no-op pass for a later test) fails loudly. Tests that - register on the process-wide REGISTRY do so through ``clean_registry``, which - clears on the way out; this guard catches anything that bypasses it. Both - checks go through the public ``bool()`` surface (A5), not private state. + """Assert the process-wide REGISTRY matches its baseline at each boundary. + + A leak guard: the registry holds the built-in expanders at import and must + return to exactly that baseline after every test, so a test that registers + without cleaning up (a leak that would silently change dispatch for a later + test) fails loudly. Tests that mutate the process-wide REGISTRY do so through + ``clean_registry``, which restores the baseline on the way out; this guard + catches anything that bypasses it. """ - assert not REGISTRY, "REGISTRY leaked into a test from a prior one" + assert REGISTRY.snapshot() == _REGISTRY_BASELINE, ( + "REGISTRY differed from its baseline entering a test" + ) yield - assert not REGISTRY, "a test leaked a registration into REGISTRY" + assert REGISTRY.snapshot() == _REGISTRY_BASELINE, ( + "a test leaked a registration into REGISTRY" + ) @pytest.fixture(autouse=True) @@ -88,18 +104,19 @@ def _expand_flag_leak_guard(): """Assert every operator's GIQL_EXPAND is restored at each test boundary. The symmetric partner of the registry leak guard: each operator class ships - opted out (its own GIQL_EXPAND attribute is False), and a test that flips one - via ``_opted_in`` must restore it. A leaked opt-in would silently flip the - no-op pass for a later test, so this catches anything that bypasses the - exception-safe ``_opted_in`` manager. + a shipped GIQL_EXPAND default (``True`` for a migrated operator, + ``False`` otherwise), and a test that flips one via ``_opted_in`` must restore + it. A leaked flip would silently change the pass for a later test, so this + catches anything that bypasses the exception-safe ``_opted_in`` manager by + comparing against each operator's shipped default rather than a blanket False. """ for op in _OPERATOR_CLASSES: - assert op.__dict__.get("GIQL_EXPAND") is False, ( + assert op.__dict__.get("GIQL_EXPAND") is _SHIPPED_EXPAND_FLAGS[op], ( f"{op.__name__}.GIQL_EXPAND leaked into a test from a prior one" ) yield for op in _OPERATOR_CLASSES: - assert op.__dict__.get("GIQL_EXPAND") is False, ( + assert op.__dict__.get("GIQL_EXPAND") is _SHIPPED_EXPAND_FLAGS[op], ( f"a test leaked a GIQL_EXPAND opt-in on {op.__name__}" ) @@ -335,6 +352,55 @@ class _BadExpander: with pytest.raises(TypeError): registry.register(GenericTarget(), GIQLDisjoin, _BadExpander()) + def test_snapshot_should_not_observe_later_registrations(self): + """Test that a snapshot does not observe registrations made after it. + + Given: + A registry with one entry, captured by snapshot. + When: + A second entry is registered after the snapshot is taken. + Then: + The snapshot should still hold only the first entry (it is a copy, + not a live view). + """ + # Arrange + registry = ExpanderRegistry() + registry.register(DuckDBTarget(), GIQLDisjoin, _record("first")) + + # Act + saved = registry.snapshot() + registry.register(GenericTarget(), Intersects, _record("second")) + + # Assert + assert (DuckDBTarget(), GIQLDisjoin) in saved + assert (GenericTarget(), Intersects) not in saved + + def test_restore_should_replace_entries_with_snapshot_contents(self): + """Test that restore returns the registry to a captured snapshot. + + Given: + A snapshot of a registry with one entry, after which the registry is + cleared and a different entry registered. + When: + Restoring the snapshot. + Then: + The original entry should resolve again and the post-snapshot entry + should be gone. + """ + # Arrange + registry = ExpanderRegistry() + registry.register(DuckDBTarget(), GIQLDisjoin, _record("original")) + saved = registry.snapshot() + registry.clear() + registry.register(GenericTarget(), Intersects, _record("transient")) + + # Act + registry.restore(saved) + + # Assert + assert (DuckDBTarget(), GIQLDisjoin) in registry + assert (GenericTarget(), Intersects) not in registry + class TestExpanderRegistryFallbackGaps: """Edge cases of the registry fallback and op-scoped keying.""" @@ -468,6 +534,61 @@ def _expander(node, ctx): assert REGISTRY.resolve(DuckDBTarget(), GIQLDisjoin) is None assert (DuckDBTarget(), GIQLDisjoin) not in REGISTRY + def test_has_override_should_return_true_when_exact_nongeneric_entry_registered( + self, + ): + """Test that has_override is True for an exact non-generic entry. + + Given: + A registry with an exact (DuckDBTarget, op) entry registered. + When: + Querying has_override for that exact key. + Then: + It should return True (a non-generic exact entry is an override). + """ + # Arrange + registry = ExpanderRegistry() + registry.register(DuckDBTarget(), GIQLDisjoin, _record("duckdb")) + + # Act & assert + assert registry.has_override(DuckDBTarget(), GIQLDisjoin) is True + + def test_has_override_should_return_false_for_generic_fallback_entry(self): + """Test that has_override ignores the generic fallback entry. + + Given: + A registry with only a (GenericTarget, op) entry registered. + When: + Querying has_override for a non-generic target's key. + Then: + It should return False (the portable generic fallback is not an + override, even though resolve() would route to it). + """ + # Arrange + registry = ExpanderRegistry() + registry.register(GenericTarget(), GIQLDisjoin, _record("generic")) + + # Act & assert + assert registry.has_override(DuckDBTarget(), GIQLDisjoin) is False + # And a generic target queried against its own entry is not an override. + assert registry.has_override(GenericTarget(), GIQLDisjoin) is False + + def test_has_override_should_return_false_when_unregistered(self): + """Test that has_override is False when nothing is registered. + + Given: + An empty registry. + When: + Querying has_override for any key. + Then: + It should return False (no entry, so no override). + """ + # Arrange + registry = ExpanderRegistry() + + # Act & assert + assert registry.has_override(DuckDBTarget(), GIQLDisjoin) is False + class TestRegisterDecorator: """Tests for the @register extension-hook decorator.""" @@ -952,27 +1073,57 @@ def _expander(node, ctx): assert ctx.resolution.operator == "GIQLDisjoin" def test_transform_skips_unflagged_operator(self, clean_registry): - """Test that an unflagged operator is left untouched even when registered. + """Test that an opted-out operator is left untouched even when registered. Given: - An expander registered for (GenericTarget, GIQLDisjoin) but the - operator's GIQL_EXPAND flag left at its default False. + A migrated operator with an expander registered for it, but its + GIQL_EXPAND flag opted out for the test — so the *same* operator is + registered, queried, and opted out, isolating the per-type gate. When: Running the pass. Then: - The DISJOIN node should remain in the tree (gate requires both). + The operator node should remain in the tree (the gate requires the + flag, so opting it out alone holds expansion off). """ # Arrange - clean_registry.register(GenericTarget(), GIQLDisjoin, _record("expanded")) - tables = _tables() - ast = _prepare("SELECT * FROM DISJOIN(variants)", tables) + operator = _A_MIGRATED_OPERATOR + clean_registry.register(GenericTarget(), operator, _record("expanded")) + ast, tables = _prepare_operator(operator) pass_ = ExpandOperators(GenericTarget(), tables, clean_registry) - # Act (GIQL_EXPAND is False by default — no opt-in context) - result = pass_.transform(ast) + # Act + with _opted_out(operator): + result = pass_.transform(ast) # Assert - assert list(result.find_all(GIQLDisjoin)) + assert list(result.find_all(operator)) + + def test_transform_expands_flagged_operator(self, clean_registry): + """Test that the same operator expands once flagged (the paired positive). + + Given: + The same migrated operator and registered expander as the opt-out + control, but with its GIQL_EXPAND flag opted in. + When: + Running the pass. + Then: + The operator node should be replaced by the expander's output — so the + contrast with the opt-out control pins the per-type gate as + load-bearing, not vacuous. + """ + # Arrange + operator = _A_MIGRATED_OPERATOR + clean_registry.register(GenericTarget(), operator, _record("expanded")) + ast, tables = _prepare_operator(operator) + pass_ = ExpandOperators(GenericTarget(), tables, clean_registry) + + # Act + with _opted_in(operator): + result = pass_.transform(ast) + + # Assert + assert not list(result.find_all(operator)) + assert result.find(exp.Literal).this == "expanded" def test_transform_skips_flagged_operator_with_no_expander(self, clean_registry): """Test that a flagged operator with no expander is left untouched. @@ -1022,34 +1173,43 @@ def test_expand_operators_is_identity_when_registry_empty(self): assert list(result.find_all(GIQLDisjoin)) def test_transpile_sql_unchanged_with_pass_inert(self): - """Test that transpile output is byte-identical with the pass inert. + """Test that transpile output is byte-identical for an unmigrated operator. Given: - A DISJOIN query, the default (empty) registry, and no operator flagged. + A CLUSTER query (an operator not migrated onto the pass in any wave-3 + branch, so its GIQL_EXPAND is False and no expander resolves), with + the default registry. When: - Transpiling with the wired-in pass versus a pass-bypassed reference. + Running the ExpandOperators pass (default REGISTRY) over the resolved + AST and serializing both the original and the pass-run AST. Then: - The SQL should match exactly and carry no expander alias prefix. + The pass leaves the operator node in place, the serialized SQL is + byte-identical, and no expander alias prefix appears — the pass is + inert for any operator that has not been migrated. """ # Arrange - query = "SELECT * FROM DISJOIN(variants)" - tables = _tables() + query = "SELECT *, CLUSTER(interval) AS cluster_id FROM peaks" + tables = _tables(("peaks",)) ast = _prepare(query, tables) from giql.generators import BaseGIQLGenerator - expected = BaseGIQLGenerator(tables=tables).generate(ast) + before = BaseGIQLGenerator(tables=tables).generate(ast) + before_ops = len(list(ast.find_all(GIQLCluster))) - # Act - actual = transpile(query, tables=[Table("variants")]) + # Act — the wired-in pass over the default REGISTRY must be a no-op here. + result = expand_operators(ast, GenericTarget(), tables) + after = BaseGIQLGenerator(tables=tables).generate(result) # Assert - assert actual == expected - assert EXPAND_ALIAS_PREFIX not in actual + assert after == before + assert len(list(result.find_all(GIQLCluster))) == before_ops + assert EXPAND_ALIAS_PREFIX not in after # The nine GIQL operator expression classes the ExpandOperators pass inspects. -# Every one must ship opted out (GIQL_EXPAND=False) at this step so the pass is a -# strict no-op until a later migration flips one alongside its expander. +# Each migrated operator ships opted in (GIQL_EXPAND=True) alongside its +# registered expander; the rest ship opted out (False) and fall through to the +# legacy emitter. from giql.expressions import Contains # noqa: E402 from giql.expressions import GIQLCluster # noqa: E402 from giql.expressions import GIQLDistance # noqa: E402 @@ -1069,20 +1229,107 @@ def test_transpile_sql_unchanged_with_pass_inert(self): GIQLMerge, ) +#: Each operator's shipped GIQL_EXPAND default, captured from its own class dict +#: at import. A migrated operator ships ``True``; the rest ship ``False`` until +#: their migrations land. The flag leak guard restores to these shipped values +#: rather than a blanket ``False``. +# Pin the leak guard's assumption that every operator declares GIQL_EXPAND on its +# own class (so __dict__.get() reads the operator's shipped value rather than an +# inherited one). A future operator that omits its own flag would make the guard +# read ``None`` and silently lose its leak coverage; this one-time check closes +# that hole. +for _op in _OPERATOR_CLASSES: + assert "GIQL_EXPAND" in _op.__dict__, ( + f"{_op.__name__} must declare GIQL_EXPAND on its own class" + ) +_SHIPPED_EXPAND_FLAGS = {op: op.__dict__.get("GIQL_EXPAND") for op in _OPERATOR_CLASSES} + + +#: Operators migrated onto the ExpandOperators pass — they ship GIQL_EXPAND=True. +_MIGRATED_OPERATORS = tuple( + op for op in _OPERATOR_CLASSES if op.__dict__.get("GIQL_EXPAND") is True +) +assert _MIGRATED_OPERATORS, "expected at least one migrated operator" +#: An arbitrary migrated operator the operator-agnostic control tests target. +_A_MIGRATED_OPERATOR = _MIGRATED_OPERATORS[0] +#: Operators not yet migrated — they ship GIQL_EXPAND=False. +_UNMIGRATED_OPERATORS = tuple( + op for op in _OPERATOR_CLASSES if op not in _MIGRATED_OPERATORS +) +assert _UNMIGRATED_OPERATORS, "expected at least one unmigrated operator" + + +#: A minimal GIQL query producing one node of each operator class, keyed by the +#: class. Lets the control tests build a node for *any* operator — whichever the +#: branch ships migrated/unmigrated — so they stay operator-agnostic rather than +#: hard-wiring a particular operator. The second element is the table names the +#: query references (registered before pass 1). +_OPERATOR_QUERIES: dict[type, tuple[str, tuple[str, ...]]] = { + Intersects: ( + "SELECT * FROM variants WHERE interval INTERSECTS 'chr1:1000-2000'", + ("variants",), + ), + Contains: ( + "SELECT * FROM variants WHERE interval CONTAINS 'chr1:1500-1600'", + ("variants",), + ), + Within: ( + "SELECT * FROM variants WHERE interval WITHIN 'chr1:1000-5000'", + ("variants",), + ), + SpatialSetPredicate: ( + "SELECT * FROM variants " + "WHERE interval INTERSECTS ANY('chr1:1000-2000', 'chr1:5000-6000')", + ("variants",), + ), + GIQLDistance: ( + "SELECT DISTANCE(a.interval, b.interval) AS d " + "FROM features_a a CROSS JOIN features_b b", + ("features_a", "features_b"), + ), + GIQLNearest: ( + "SELECT * FROM peaks CROSS JOIN LATERAL NEAREST(genes, k := 3)", + ("peaks", "genes"), + ), + GIQLDisjoin: ("SELECT * FROM DISJOIN(variants)", ("variants",)), + GIQLCluster: ( + "SELECT *, CLUSTER(interval) AS cluster_id FROM peaks", + ("peaks",), + ), + GIQLMerge: ("SELECT MERGE(interval) AS m FROM peaks", ("peaks",)), +} + + +def _prepare_operator(operator: type) -> tuple[exp.Expression, Tables]: + """Build a pass-1-resolved AST containing one node of *operator*. + + Operator-agnostic: looks the query up in :data:`_OPERATOR_QUERIES` so a + control test can exercise whichever operator a branch ships migrated or + unmigrated, rather than hard-wiring one operator class. Returns the resolved + AST and the Tables container it was resolved against (the same container must + be threaded into the pass). + """ + query, names = _OPERATOR_QUERIES[operator] + tables = _tables(names) + return _prepare(query, tables), tables + class TestOperatorOptOut: - """Every operator ships opted out of the ExpandOperators pass at this step.""" + """Migrated operators opt into the pass; the rest still ship opted out.""" - @pytest.mark.parametrize("operator", _OPERATOR_CLASSES, ids=lambda c: c.__name__) + @pytest.mark.parametrize( + "operator", _UNMIGRATED_OPERATORS, ids=lambda c: c.__name__ + ) def test_operator_class_ships_expand_disabled(self, operator): - """Test that each operator class ships GIQL_EXPAND=False. + """Test that each unmigrated operator class ships GIQL_EXPAND=False. Given: - One of the nine GIQL operator expression classes. + A GIQL operator expression class that has not been migrated onto the + ExpandOperators pass. When: Reading its GIQL_EXPAND class attribute. Then: - It should be False (no operator opts into expansion yet). + It should be False (the operator still uses the legacy emitter). """ # Arrange & act flag = operator.GIQL_EXPAND @@ -1090,6 +1337,53 @@ def test_operator_class_ships_expand_disabled(self, operator): # Assert assert flag is False + @pytest.mark.parametrize( + "operator", _MIGRATED_OPERATORS, ids=lambda c: c.__name__ + ) + def test_operator_class_ships_expand_enabled(self, operator): + """Test that each migrated operator class ships GIQL_EXPAND=True. + + Given: + A GIQL operator expression class migrated onto the ExpandOperators + pass. + When: + Reading its GIQL_EXPAND class attribute. + Then: + It should be True (the operator expands through its registered + expander instead of the deleted legacy emitter). + """ + # Arrange & act + flag = operator.GIQL_EXPAND + + # Assert + assert flag is True + + +class TestMigratedOperatorsRegistered: + """Every migrated operator resolves a built-in expander in the process REGISTRY.""" + + @pytest.mark.parametrize( + "operator", _MIGRATED_OPERATORS, ids=lambda c: c.__name__ + ) + def test_migrated_operator_resolves_in_process_registry(self, operator): + """Test that each migrated operator resolves an expander in REGISTRY. + + Given: + A GIQL operator class that ships GIQL_EXPAND=True (migrated onto the + ExpandOperators pass). + When: + Resolving it against the import-populated process-wide REGISTRY for + the generic target. + Then: + A built-in expander should resolve — a migrated operator always has a + registered expander, so the pass never leaves it on a deleted emitter. + """ + # Arrange & act + resolved = REGISTRY.resolve(GenericTarget(), operator) + + # Assert + assert resolved is not None + class TestOptedInRestoresFlag: """The _opted_in helper restores GIQL_EXPAND even when its body raises.""" @@ -1649,35 +1943,42 @@ def test_walk_partial_opt_in_replaces_only_flagged_type(self, clean_registry): """Test that only the flagged operator type is replaced when both registered. Given: - A DISJOIN and an INTERSECTS, both with registered expanders, but only - INTERSECTS flagged GIQL_EXPAND. + A genuinely-unmigrated operator (GIQLCluster, shipping + GIQL_EXPAND=False in every wave-3 branch) and an INTERSECTS, both with + registered expanders, but only INTERSECTS flagged GIQL_EXPAND for the + test. When: Running the pass. Then: - The INTERSECTS is replaced while the DISJOIN node remains (the gate is - per-type). + The INTERSECTS is replaced while the unmigrated operator node remains + on its own shipped ``False`` flag — no opt-out ceremony needed (the + gate is per-type). """ - # Arrange + # Arrange — the held-off subject is genuinely unmigrated: it survives on + # its own shipped GIQL_EXPAND=False, not on a test opt-out. + held_off = GIQLCluster + assert held_off in _UNMIGRATED_OPERATORS + assert held_off.GIQL_EXPAND is False clean_registry.register( - GenericTarget(), GIQLDisjoin, lambda n, c: exp.column("DJ") + GenericTarget(), held_off, lambda n, c: exp.column("CL") ) clean_registry.register( GenericTarget(), Intersects, lambda n, c: exp.column("IX") ) - tables = _tables(("variants", "peaks")) + tables = _tables(("peaks",)) ast = _prepare( - "SELECT * FROM DISJOIN(variants) " + "SELECT *, CLUSTER(interval) AS cluster_id FROM peaks " "WHERE EXISTS (SELECT * FROM peaks WHERE interval INTERSECTS 'chr1:1-100')", tables, ) pass_ = ExpandOperators(GenericTarget(), tables, clean_registry) - # Act + # Act — only INTERSECTS is opted in; the unmigrated operator stays off. with _opted_in(Intersects): result = pass_.transform(ast) # Assert - assert list(result.find_all(GIQLDisjoin)) + assert list(result.find_all(held_off)) assert not list(result.find_all(Intersects)) def test_walk_shares_alias_sequence_across_sibling_expanders(self, clean_registry): @@ -1888,3 +2189,25 @@ def __enter__(self): def __exit__(self, *exc): self._operator.GIQL_EXPAND = self._prior return False + + +class _opted_out: + """Context manager opting an operator class out of GIQL_EXPAND for a test. + + The complement of :class:`_opted_in`: used by a control test that needs a + *migrated* operator (one shipping GIQL_EXPAND=True) to behave as if + unflagged, so the test can prove the pass gates per-type without the + operator's shipped opt-in interfering. Restores the prior flag on exit. + """ + + def __init__(self, operator: type) -> None: + self._operator = operator + self._prior = operator.__dict__.get("GIQL_EXPAND", False) + + def __enter__(self): + self._operator.GIQL_EXPAND = False + return self._operator + + def __exit__(self, *exc): + self._operator.GIQL_EXPAND = self._prior + return False diff --git a/tests/test_transpile.py b/tests/test_transpile.py index ea770a8..5bf6138 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -458,8 +458,21 @@ def test_transpile_datafusion_accepted(self): "ANY('chr1:1000-2000', 'chr1:5000-6000')", ["peaks"], ), + ( + "SELECT DISTANCE(a.interval, b.interval) AS d " + "FROM peaks a, genes b", + ["peaks", "genes"], + ), + ], + ids=[ + "intersects_literal", + "contains", + "within", + "nearest", + "join", + "any", + "distance", ], - ids=["intersects_literal", "contains", "within", "nearest", "join", "any"], ) def test_transpile_datafusion_matches_generic_output(self, query, tables): """Test that datafusion is currently a pure alias for the generic target. @@ -483,6 +496,52 @@ def test_transpile_datafusion_matches_generic_output(self, query, tables): # Assert assert datafusion_sql == generic_sql + @pytest.mark.parametrize( + "query, tables", + [ + ( + "SELECT DISTANCE(a.interval, b.interval) AS d FROM peaks a, genes b", + ["peaks", "genes"], + ), + ( + "SELECT DISTANCE(a.interval, b.interval, stranded := true) AS d " + "FROM peaks a, genes b", + ["peaks", "genes"], + ), + ( + "SELECT DISTANCE(a.interval, b.interval, signed := true) AS d " + "FROM peaks a, genes b", + ["peaks", "genes"], + ), + ( + "SELECT DISTANCE(a.interval, b.interval, stranded := true, " + "signed := true) AS d FROM peaks a, genes b", + ["peaks", "genes"], + ), + ], + ids=["unsigned", "stranded", "signed", "stranded_signed"], + ) + def test_transpile_distance_is_byte_identical_across_targets(self, query, tables): + """Test that the migrated DISTANCE operator emits identical SQL everywhere. + + Given: + A DISTANCE query in each of the four shapes (unsigned/signed x + non-stranded/stranded), which migrated onto the AST-expansion pass + with a single generic expander. + When: + Transpiling it with dialect=None, "duckdb", and "datafusion". + Then: + The three outputs should be byte-identical — the generic expander + covers every target and DISTANCE has no engine-specific divergence. + """ + # Act + generic_sql = transpile(query, tables=tables, dialect=None) + duckdb_sql = transpile(query, tables=tables, dialect="duckdb") + datafusion_sql = transpile(query, tables=tables, dialect="datafusion") + + # Assert + assert generic_sql == duckdb_sql == datafusion_sql + def test_transpile_datafusion_accepts_intersects_bin_size(self): """Test that datafusion honours the binned-join bin size identically.