diff --git a/src/giql/expander.py b/src/giql/expander.py index f93132a..3b5a900 100644 --- a/src/giql/expander.py +++ b/src/giql/expander.py @@ -39,10 +39,11 @@ i.e. a ``(target, op)`` or ``(generic, op)`` expander is registered. Otherwise it falls through to the legacy ``*_sql`` emitter on -:class:`giql.generators.base.BaseGIQLGenerator`. As of this issue **no operator -sets ``GIQL_EXPAND`` and the registry is empty, so the pass is a strict no-op**: -no node is touched and the emitted SQL is byte-identical. Each later migration PR -(epic #137 steps 4-9) registers a generic expander, flips one operator's +:class:`giql.generators.base.BaseGIQLGenerator`. The built-in expanders register +at import time via :mod:`giql.expanders`; the pass rewrites a node only when +``GIQL_EXPAND=True`` **and** an expander resolves for ``(active target, operator +type)``, and is a no-op for any operator that is unflagged or has no registered +expander. A migration PR registers an expander, flips one operator's ``GIQL_EXPAND`` flag, and deletes that operator's ``*_sql`` method. """ @@ -149,6 +150,14 @@ class OperatorExpander(Protocol): a registered object satisfies it. A plain function is *not* an ``OperatorExpander`` (it has no ``expand`` method); register one by wrapping it (see :func:`register`, which accepts either form). + + An expander is **node-local**: ``expand(node, ctx) -> exp.Expression`` sees + one operator node and returns the expression that replaces it in place. It + cannot express a whole-query rewrite such as the INTERSECTS IEJoin fold, + which restructures the surrounding query (joins, CTEs) rather than a single + node. That fold is therefore deferred — it would need a separate + query-level mechanism — and is handled by the pre-pass join transformers, not + by an expander. """ def expand(self, node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: ... @@ -215,6 +224,16 @@ def register( expander : OperatorExpander | ExpanderFn The expander object or function. A later registration for the same key replaces an earlier one (last-write-wins override). + + Notes + ----- + Registering a *non-generic* ``(target, operator)`` expander where the + operator has a built-in whole-query join rewrite (notably + :class:`~giql.expressions.Intersects`, whose binned equi-join / DuckDB + IEJoin transformers run before expansion) signals that this expander + assumes responsibility for that rewrite: the built-in join transformers + are bypassed for that target so the operator node flows untouched into + :class:`ExpandOperators`. See :meth:`has_override`. """ self._expanders[(target, operator)] = _as_callable(expander) @@ -223,6 +242,12 @@ def resolve(self, target: Target, operator: type) -> ExpanderFn | None: Tries the exact ``(target, op)`` entry, then the ``(GenericTarget(), op)`` fallback, then ``None`` (legacy emitter). + + A non-generic exact ``(target, op)`` entry is also a *join-rewrite + override* for operators with a built-in whole-query join rewrite (notably + :class:`~giql.expressions.Intersects`): registering one bypasses the + built-in binned / IEJoin transformers for that target (see + :meth:`register` and :meth:`has_override`). """ fn = self._expanders.get((target, operator)) if fn is not None: @@ -233,6 +258,19 @@ def resolve(self, target: Target, operator: type) -> ExpanderFn | None: return fn return None + def has_override(self, target: Target, operator: type) -> bool: + """Whether an exact non-generic override supersedes built-in handling. + + Returns ``True`` only when *target* is not :class:`~giql.targets.GenericTarget` + and an exact ``(target, operator)`` entry is registered. Such an entry is a + target-specific override that supersedes built-in handling for that target + (e.g. it takes responsibility for the whole-query join rewrite that the + built-in transformers would otherwise perform); the portable + ``(GenericTarget(), operator)`` fallback is *not* an override and does not + count here. + """ + return target != GenericTarget() and (target, operator) in self._expanders + def unregister(self, target: Target, operator: type) -> None: """Drop the ``(target, operator)`` entry if present. @@ -253,6 +291,33 @@ def clear(self) -> None: """ self._expanders.clear() + def snapshot(self) -> dict[tuple[Target, type], ExpanderFn]: + """Return a shallow copy of the current registrations. + + The save half of a save/restore seam used primarily for test baseline + isolation (and which may serve a plugin that mutates the process-wide + :data:`REGISTRY` around a body): capture the baseline with this and hand + it back to :meth:`restore` afterward, so the built-in expanders + registered at import survive an isolating fixture that would otherwise + :meth:`clear` them permanently. It is not a committed plugin API. + + The returned dict is a fresh mapping (mutating it does not affect the + registry), keyed by the same ``(target, operator)`` tuples. + """ + return dict(self._expanders) + + def restore(self, snapshot: dict[tuple[Target, type], ExpanderFn]) -> None: + """Replace all registrations with those captured by :meth:`snapshot`. + + The restore half of the save/restore seam (test baseline isolation; may + also serve a plugin, but is not a committed plugin API). Drops every + current entry and re-installs exactly the *snapshot* contents, so a + fixture can return the registry to a previously captured baseline + regardless of what its body registered or cleared. + """ + self._expanders.clear() + self._expanders.update(snapshot) + def __contains__(self, key: tuple[Target, type]) -> bool: """Whether an *exact* ``(target, operator)`` entry is registered. @@ -280,8 +345,9 @@ def __bool__(self) -> bool: #: The process-wide registry the :func:`register` decorator writes to and the -#: :class:`ExpandOperators` pass reads from. Empty as of this issue, so the pass -#: is a strict no-op. +#: :class:`ExpandOperators` pass reads from. The built-in expanders register into +#: it at import time via :mod:`giql.expanders`; the pass rewrites a node only when +#: an expander resolves here (and the operator is flagged ``GIQL_EXPAND``). REGISTRY = ExpanderRegistry() @@ -380,9 +446,10 @@ class sets ``GIQL_EXPAND = True`` *and* the registry resolves an expander for ``(target, operator type)`` through its fallback chain; otherwise the node is left untouched and the legacy ``*_sql`` emitter handles it. - The pass mutates and returns *expression* in place. **With no operator - flagged and an empty registry it is a strict no-op** and the emitted SQL is - byte-identical, so the existing suite is the migration oracle. + The pass mutates and returns *expression* in place. It touches only nodes + whose operator is flagged ``GIQL_EXPAND`` and resolves an expander; for every + other operator it is a no-op, leaving the emitted SQL byte-identical, so the + existing suite is the migration oracle. Parameters ---------- @@ -399,9 +466,9 @@ class sets ``GIQL_EXPAND = True`` *and* the registry resolves an expander for Returns ------- exp.Expression - The same *expression*, with opted-in operator nodes replaced by their - target-specific expansions (none, while every flag is off / the registry - is empty). + The same *expression*, with each opted-in operator node that resolves an + expander replaced by its target-specific expansion; nodes that are + unflagged or resolve no expander are left untouched. """ reg = registry if registry is not None else REGISTRY operators = _giql_operators() diff --git a/src/giql/expanders/__init__.py b/src/giql/expanders/__init__.py new file mode 100644 index 0000000..042a07d --- /dev/null +++ b/src/giql/expanders/__init__.py @@ -0,0 +1,27 @@ +"""Built-in operator expanders for epic #137. + +Importing this package registers every built-in expander as a side effect: +each submodule decorates its expander(s) with ``@register(...)`` at import +time, and this package imports all of them. The import is wired once (in +:mod:`giql.transpile`) so the process-wide ``REGISTRY`` is populated before the +first transpile. + +New operator modules are picked up automatically: drop a ``.py`` into +this package and it is imported here without editing this file. + +Modules whose name starts with ``_`` are skipped (private helpers, not +expanders). Submodules import in :func:`pkgutil.iter_modules` order, which sets +last-write-wins resolution-order precedence for overlapping registrations; an +import error here aborts the whole package import by design (a broken built-in +expander must not be silently skipped). +""" + +from __future__ import annotations + +import importlib +import pkgutil + +for _module_info in pkgutil.iter_modules(__path__): + if _module_info.name.startswith("_"): + continue + importlib.import_module(f"{__name__}.{_module_info.name}") diff --git a/src/giql/expanders/intersects.py b/src/giql/expanders/intersects.py new file mode 100644 index 0000000..ed7a798 --- /dev/null +++ b/src/giql/expanders/intersects.py @@ -0,0 +1,236 @@ +"""Generic expanders for the spatial predicates and set predicates (epic #137). + +Migrates INTERSECTS / CONTAINS / WITHIN and the ``ANY`` / ``ALL`` set predicates +off the legacy ``*_sql`` emitters on :class:`giql.generators.base.BaseGIQLGenerator` +and onto the operator-expander registry. Each expander turns one predicate node +into standard sqlglot AST built from the pass-1 :class:`~giql.resolver.ResolvedColumn` +metadata (already canonicalized to 0-based half-open by pass 2), so the emitted SQL +is byte-identical to the strings the legacy emitter produced. + +These are *node-local* predicate rewrites: an INTERSECTS / CONTAINS / WITHIN node +expands to a boolean ``(chrom = ... AND start < ... AND end > ...)`` expression that +replaces it in place. The whole-query column-to-column **join** rewrites (the binned +equi-join and the DuckDB IEJoin) remain capability-gated pre-pass transformers in +:mod:`giql.transformer` keyed on ``capabilities.range_join_strategy`` — they consume +a column-to-column INTERSECTS *join* before this pass runs, so by the time a +column-to-column INTERSECTS reaches an expander it is a residual predicate (e.g. +inside an ``OR``, or a join shape the transformer declined) that the legacy emitter +also rendered as a plain predicate. The expander handles that residual the same way. + +Only :class:`~giql.targets.GenericTarget` expanders are registered: spatial-predicate +*emission* is portable SQL-92 and does not vary by engine, so one generic expander +covers every target via the registry's ``(generic, op)`` fallback. +""" + +from __future__ import annotations + +from sqlglot import exp +from sqlglot import maybe_parse + +from giql.dialect import GIQLDialect +from giql.expander import ExpansionContext +from giql.expander import register +from giql.expressions import Contains +from giql.expressions import Intersects +from giql.expressions import SpatialSetPredicate +from giql.expressions import Within +from giql.range_parser import ParsedRange +from giql.range_parser import RangeParser +from giql.resolver import ResolvedColumn +from giql.targets import GenericTarget + + +def _fragment(fragment: str) -> exp.Expression: + """Parse a resolved SQL fragment (e.g. ``a."end"`` / ``'chr1'``) into AST. + + The pass-1 :class:`~giql.resolver.ResolvedColumn` carries column references as + pre-canonicalized SQL string fragments; parse them through the GIQL dialect so + the rebuilt predicate reserializes identically to the legacy emitter's string. + """ + return maybe_parse(fragment, dialect=GIQLDialect) + + +def _predicate_column(ctx: ExpansionContext, arg: str) -> ResolvedColumn: + """Return the :class:`ResolvedColumn` for predicate operand *arg*. + + Mirrors :meth:`giql.generators.base.BaseGIQLGenerator._predicate_operand`: the + expander consumes only the pass-1 resolution; a missing column means pass 1 did + not run (an internal invariant violation), so raise the historical message. + """ + resolution = ctx.resolution + if resolution is not None: + resolved = resolution.column(arg) + if resolved is not None: + return resolved + raise ValueError( + f"Spatial predicate operand {arg!r} was not resolved; run the " + "ResolveOperatorRefs pass (transpile pipeline) before generation." + ) + + +def _range_predicate( + column: ResolvedColumn, parsed: ParsedRange, op_type: str +) -> exp.Expression: + """Build the boolean AST for ``column ``. + + Reproduces :meth:`BaseGIQLGenerator._generate_range_predicate` as AST. The + column fragments are already canonical 0-based half-open (pass 2); the parsed + range is canonicalized by the caller. Returns a parenthesized boolean. + """ + chrom = _fragment(column.chrom) + start = _fragment(column.start) + end = _fragment(column.end) + chrom_lit = exp.Literal.string(parsed.chromosome) + r_start = exp.Literal.number(parsed.start) + r_end = exp.Literal.number(parsed.end) + + if op_type == "intersects": + # Ranges overlap if: start1 < end2 AND end1 > start2 + cond = exp.and_( + exp.EQ(this=chrom, expression=chrom_lit), + exp.LT(this=start, expression=r_end), + exp.GT(this=end, expression=r_start), + ) + elif op_type == "contains": + if parsed.end == parsed.start + 1: + # Point query: start1 <= point < end1 + cond = exp.and_( + exp.EQ(this=chrom, expression=chrom_lit), + exp.LTE(this=start, expression=r_start), + exp.GT(this=end, expression=r_start), + ) + else: + # Range query: start1 <= start2 AND end1 >= end2 + cond = exp.and_( + exp.EQ(this=chrom, expression=chrom_lit), + exp.LTE(this=start, expression=r_start), + exp.GTE(this=end, expression=r_end), + ) + else: + # op_type == "within": left within right: start1 >= start2 AND end1 <= end2 + cond = exp.and_( + exp.EQ(this=chrom, expression=chrom_lit), + exp.GTE(this=start, expression=r_start), + exp.LTE(this=end, expression=r_end), + ) + + return exp.paren(cond) + + +def _column_join( + left: ResolvedColumn, right: ResolvedColumn, op_type: str +) -> exp.Expression: + """Build the boolean AST for a column-to-column spatial predicate. + + Reproduces :meth:`BaseGIQLGenerator._generate_column_join` as AST. Both + operands' fragments are pre-canonicalized (pass 2). Returns a parenthesized + boolean. + """ + l_chrom, r_chrom = _fragment(left.chrom), _fragment(right.chrom) + l_start, r_start = _fragment(left.start), _fragment(right.start) + l_end, r_end = _fragment(left.end), _fragment(right.end) + + if op_type == "intersects": + cond = exp.and_( + exp.EQ(this=l_chrom, expression=r_chrom), + exp.LT(this=l_start, expression=r_end), + exp.GT(this=l_end, expression=r_start), + ) + elif op_type == "contains": + cond = exp.and_( + exp.EQ(this=l_chrom, expression=r_chrom), + exp.LTE(this=l_start, expression=r_start), + exp.GTE(this=l_end, expression=r_end), + ) + else: + # op_type == "within" + cond = exp.and_( + exp.EQ(this=l_chrom, expression=r_chrom), + exp.GTE(this=l_start, expression=r_start), + exp.LTE(this=l_end, expression=r_end), + ) + + return exp.paren(cond) + + +def _expand_spatial_op( + node: exp.Expression, ctx: ExpansionContext, op_type: str +) -> exp.Expression: + """Expand one INTERSECTS / CONTAINS / WITHIN node to a boolean predicate. + + Dispatches on the right operand exactly as the legacy emitter did: the + presence of a resolved right *column* — keyed off + ``ctx.resolution.column("expression")``, the slot pass 1 attaches a + :class:`ResolvedColumn` to when the right operand is a column reference — + selects the column-to-column path; its absence means the right operand is a + literal range, parsed in place. + """ + resolution = ctx.resolution + right_column = resolution.column("expression") if resolution is not None else None + left = _predicate_column(ctx, "this") + + if right_column is not None: + return _column_join(left, right_column, op_type) + + # Literal range string (e.g. interval INTERSECTS 'chr1:1000-2000'). Reproduce + # the legacy emitter's parse-and-wrap-error behavior verbatim: any parse + # failure (including the RangeParser's own ValueError) is wrapped in the + # historical "Could not parse genomic range" message. + right_expr = node.args.get("expression") + raw = right_expr.sql(dialect=GIQLDialect) if right_expr is not None else "" + try: + range_str = raw.strip("'\"") + parsed = RangeParser.parse(range_str).to_zero_based_half_open() + return _range_predicate(left, parsed, op_type) + except Exception as e: + raise ValueError(f"Could not parse genomic range: {raw}. Error: {e}") from e + + +@register(GenericTarget, Intersects) +def expand_intersects(node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: + """Expand an INTERSECTS predicate to standard boolean SQL AST.""" + return _expand_spatial_op(node, ctx, "intersects") + + +@register(GenericTarget, Contains) +def expand_contains(node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: + """Expand a CONTAINS predicate to standard boolean SQL AST.""" + return _expand_spatial_op(node, ctx, "contains") + + +@register(GenericTarget, Within) +def expand_within(node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: + """Expand a WITHIN predicate to standard boolean SQL AST.""" + return _expand_spatial_op(node, ctx, "within") + + +@register(GenericTarget, SpatialSetPredicate) +def expand_spatial_set( + node: exp.Expression, ctx: ExpansionContext +) -> exp.Expression: + """Expand a quantified set predicate (``ANY`` / ``ALL``) to boolean SQL AST. + + Reproduces :meth:`BaseGIQLGenerator._generate_spatial_set`: the single left + column is compared against every literal range, and the per-range conditions + are OR-combined for ``ANY`` / AND-combined for ``ALL``, all wrapped in one + outer paren. + """ + operator = node.args["operator"] + quantifier = node.args["quantifier"] + ranges = node.args["ranges"] + + column = _predicate_column(ctx, "this") + op_type = operator.lower() + + conditions: list[exp.Expression] = [] + for range_expr in ranges: + range_str = range_expr.sql(dialect=GIQLDialect).strip("'\"") + parsed = RangeParser.parse(range_str).to_zero_based_half_open() + conditions.append(_range_predicate(column, parsed, op_type)) + + if quantifier.upper() == "ANY": + combined = exp.or_(*conditions) + else: + combined = exp.and_(*conditions) + + return exp.paren(combined) diff --git a/src/giql/expressions.py b/src/giql/expressions.py index ce939dc..fa1cc37 100644 --- a/src/giql/expressions.py +++ b/src/giql/expressions.py @@ -126,7 +126,12 @@ class Intersects(SpatialPredicate): """ GIQL_CANONICALIZE = _CANONICALIZE - GIQL_EXPAND = _EXPAND + #: Migrated to the ExpandOperators registry (#141). A literal-range or + #: residual column-to-column INTERSECTS *predicate* expands through + #: ``giql.expanders.intersects``; a column-to-column INTERSECTS *join* is + #: consumed by the capability-gated binned / IEJoin pre-pass transformers + #: before this pass runs, so the predicate expander never sees it. + GIQL_EXPAND = True GIQL_SLOTS = ( SlotSpec("this", frozenset({"column"}), required=True), @@ -141,7 +146,9 @@ class Contains(SpatialPredicate): """ GIQL_CANONICALIZE = _CANONICALIZE - GIQL_EXPAND = _EXPAND + #: Migrated to the ExpandOperators registry (#141); expands through + #: ``giql.expanders.intersects``. + GIQL_EXPAND = True GIQL_SLOTS = ( SlotSpec("this", frozenset({"column"}), required=True), @@ -156,7 +163,9 @@ class Within(SpatialPredicate): """ GIQL_CANONICALIZE = _CANONICALIZE - GIQL_EXPAND = _EXPAND + #: Migrated to the ExpandOperators registry (#141); expands through + #: ``giql.expanders.intersects``. + GIQL_EXPAND = True GIQL_SLOTS = ( SlotSpec("this", frozenset({"column"}), required=True), @@ -180,7 +189,9 @@ class SpatialSetPredicate(exp.Expression): } GIQL_CANONICALIZE = _CANONICALIZE - GIQL_EXPAND = _EXPAND + #: Migrated to the ExpandOperators registry (#141); expands through + #: ``giql.expanders.intersects``. + GIQL_EXPAND = True GIQL_SLOTS = ( SlotSpec("this", frozenset({"column"}), required=True), diff --git a/src/giql/generators/base.py b/src/giql/generators/base.py index 9038369..3af4ba6 100644 --- a/src/giql/generators/base.py +++ b/src/giql/generators/base.py @@ -3,14 +3,9 @@ from giql.canonical import decanonical_end from giql.canonical import decanonical_start -from giql.expressions import Contains from giql.expressions import GIQLDisjoin from giql.expressions import GIQLDistance from giql.expressions import GIQLNearest -from giql.expressions import Intersects -from giql.expressions import SpatialSetPredicate -from giql.expressions import Within -from giql.range_parser import ParsedRange from giql.range_parser import RangeParser from giql.resolver import META_KEY from giql.resolver import OperatorResolution @@ -42,45 +37,12 @@ def __init__(self, tables: Tables | None = None, **kwargs): super().__init__(**kwargs) self.tables = tables or Tables() - def intersects_sql(self, expression: Intersects) -> str: - """Generate standard SQL for INTERSECTS. - - :param expression: - INTERSECTS expression node - :return: - SQL predicate string - """ - return self._generate_spatial_op(expression, "intersects") - - def contains_sql(self, expression: Contains) -> str: - """Generate standard SQL for CONTAINS. - - :param expression: - CONTAINS expression node - :return: - SQL predicate string - """ - return self._generate_spatial_op(expression, "contains") - - def within_sql(self, expression: Within) -> str: - """Generate standard SQL for WITHIN. - - :param expression: - WITHIN expression node - :return: - SQL predicate string - """ - return self._generate_spatial_op(expression, "within") - - def spatialsetpredicate_sql(self, expression: SpatialSetPredicate) -> str: - """Generate SQL for spatial set predicates (ANY/ALL). - - :param expression: - SpatialSetPredicate expression node - :return: - SQL predicate string - """ - return self._generate_spatial_set(expression) + # INTERSECTS / CONTAINS / WITHIN and the ANY/ALL set predicates are migrated + # to the ExpandOperators registry (#141): they expand to standard boolean AST + # in ``giql.expanders.intersects`` before generation, so the generator no + # longer carries ``intersects_sql`` / ``contains_sql`` / ``within_sql`` / + # ``spatialsetpredicate_sql`` emitters or their ``_generate_spatial_*`` / + # ``_predicate_operand`` helpers. def giqlnearest_sql(self, expression: GIQLNearest) -> str: """Generate SQL for NEAREST function. @@ -670,217 +632,6 @@ def _generate_distance_case( f"ELSE ({start_a} - {end_b} + 1) END END" ) - def _predicate_operand(self, expression: exp.Expression, arg: str) -> ResolvedColumn: - """Return the :class:`ResolvedColumn` for a spatial predicate operand. - - Reads the column resolution attached to *expression* by the - ``ResolveOperatorRefs`` pass (pass 1). The emitter consumes only the - resolved metadata; all name/column resolution lives in the pass. - - :param expression: - The spatial predicate node carrying the resolution metadata. - :param arg: - The operand slot key (``"this"`` or ``"expression"``). - :return: - The resolved column metadata. - """ - resolution = expression.meta.get(META_KEY) - if isinstance(resolution, OperatorResolution): - resolved = resolution.column(arg) - if resolved is not None: - return resolved - - raise ValueError( - f"Spatial predicate operand {arg!r} was not resolved; run the " - "ResolveOperatorRefs pass (transpile pipeline) before generation." - ) - - def _generate_spatial_op(self, expression: exp.Binary, op_type: str) -> str: - """Generate SQL for a spatial operation. - - :param expression: - AST node (Intersects, Contains, or Within) - :param op_type: - 'intersects', 'contains', or 'within' - :return: - SQL predicate string - """ - right_raw = self.sql(expression, "expression") - - # Check if right side is a column reference or a literal range string - if "." in right_raw and not right_raw.startswith("'"): - # Column-to-column join (e.g., a.interval INTERSECTS b.interval) - left = self._predicate_operand(expression, "this") - right = self._predicate_operand(expression, "expression") - return self._generate_column_join(left, right, op_type) - else: - # Literal range string (e.g., interval INTERSECTS 'chr1:1000-2000') - try: - range_str = right_raw.strip("'\"") - parsed_range = RangeParser.parse(range_str).to_zero_based_half_open() - left = self._predicate_operand(expression, "this") - return self._generate_range_predicate(left, parsed_range, op_type) - except Exception as e: - raise ValueError( - f"Could not parse genomic range: {right_raw}. Error: {e}" - ) - - def _generate_range_predicate( - self, - column: ResolvedColumn, - parsed_range: ParsedRange, - op_type: str, - ) -> str: - """Generate SQL predicate for a range operation. - - :param column: - Resolved column operand (physical chrom/start/end fragments plus the - backing :class:`~giql.table.Table` config for canonicalization). - :param parsed_range: - Parsed genomic range - :param op_type: - 'intersects', 'contains', or 'within' - :return: - SQL predicate string - """ - # The alias-qualified column fragments come pre-resolved on the - # ResolvedColumn, already canonicalized to 0-based half-open by - # CanonicalizeCoordinates (pass 2, issue #123). The predicate returns a - # boolean, which is encoding-invariant, so no output de-canonicalization - # is needed. - chrom_col = column.chrom - start_col = column.start - end_col = column.end - - chrom = parsed_range.chromosome - start = parsed_range.start - end = parsed_range.end - - if op_type == "intersects": - # Ranges overlap if: start1 < end2 AND end1 > start2 - return ( - f"({chrom_col} = '{chrom}' " - f"AND {start_col} < {end} " - f"AND {end_col} > {start})" - ) - - elif op_type == "contains": - # Point query: start1 <= point < end1 - if end == start + 1: - return ( - f"({chrom_col} = '{chrom}' " - f"AND {start_col} <= {start} " - f"AND {end_col} > {start})" - ) - # Range query: start1 <= start2 AND end1 >= end2 - else: - return ( - f"({chrom_col} = '{chrom}' " - f"AND {start_col} <= {start} " - f"AND {end_col} >= {end})" - ) - - elif op_type == "within": - # Left within right: start1 >= start2 AND end1 <= end2 - return ( - f"({chrom_col} = '{chrom}' " - f"AND {start_col} >= {start} " - f"AND {end_col} <= {end})" - ) - - raise ValueError(f"Unknown operation: {op_type}") - - def _generate_column_join( - self, left: ResolvedColumn, right: ResolvedColumn, op_type: str - ) -> str: - """Generate SQL for column-to-column spatial joins. - - :param left: - Resolved left column operand (e.g., for 'a.interval'). - :param right: - Resolved right column operand (e.g., for 'b.interval'). - :param op_type: - 'intersects', 'contains', or 'within' - :return: - SQL predicate string - """ - # The alias-qualified chrom/start/end fragments come pre-resolved on the - # ResolvedColumns, already canonicalized to 0-based half-open by - # CanonicalizeCoordinates (pass 2, issue #123). The predicate returns a - # boolean (encoding-invariant), so no output de-canonicalization is needed. - l_chrom = left.chrom - r_chrom = right.chrom - l_start = left.start - l_end = left.end - r_start = right.start - r_end = right.end - - if op_type == "intersects": - # Ranges overlap if: chrom1 = chrom2 AND start1 < end2 AND end1 > start2 - return ( - f"({l_chrom} = {r_chrom} " - f"AND {l_start} < {r_end} " - f"AND {l_end} > {r_start})" - ) - - elif op_type == "contains": - # Left contains right: chrom1 = chrom2 AND start1 <= start2 AND end1 >= end2 - return ( - f"({l_chrom} = {r_chrom} " - f"AND {l_start} <= {r_start} " - f"AND {l_end} >= {r_end})" - ) - - elif op_type == "within": - # Left within right: chrom1 = chrom2 AND start1 >= start2 AND end1 <= end2 - return ( - f"({l_chrom} = {r_chrom} " - f"AND {l_start} >= {r_start} " - f"AND {l_end} <= {r_end})" - ) - - raise ValueError(f"Unknown operation: {op_type}") - - def _generate_spatial_set(self, expression: SpatialSetPredicate) -> str: - """Generate SQL for spatial set predicates (ANY/ALL). - - Examples: - column INTERSECTS ANY(...) -> (condition1 OR condition2 OR ...) - column INTERSECTS ALL(...) -> (condition1 AND condition2 AND ...) - - :param expression: - SpatialSetPredicate expression node - :return: - SQL predicate string - """ - operator = expression.args["operator"] - quantifier = expression.args["quantifier"] - ranges = expression.args["ranges"] - - # Resolve the (single) left column operand once; every range condition - # compares against the same column. The set predicate's ranges are - # always literals, so only this operand needs resolution. - column = self._predicate_operand(expression, "this") - - # Parse all ranges - parsed_ranges = [] - for range_expr in ranges: - range_str = self.sql(range_expr).strip("'\"") - parsed_range = RangeParser.parse(range_str).to_zero_based_half_open() - parsed_ranges.append(parsed_range) - - op_type = operator.lower() - - # Generate conditions for each range - conditions = [] - for parsed_range in parsed_ranges: - condition = self._generate_range_predicate(column, parsed_range, op_type) - conditions.append(condition) - - # Combine with AND (for ALL) or OR (for ANY) - combinator = " OR " if quantifier.upper() == "ANY" else " AND " - return "(" + combinator.join(conditions) + ")" - def _detect_nearest_mode( self, expression: GIQLNearest, parent_expression: exp.Expression | None = None ) -> str: diff --git a/src/giql/transpile.py b/src/giql/transpile.py index 9ef2100..d75ea9b 100644 --- a/src/giql/transpile.py +++ b/src/giql/transpile.py @@ -11,9 +11,12 @@ from sqlglot import parse_one +import giql.expanders # noqa: F401 from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect +from giql.expander import REGISTRY from giql.expander import ExpandOperators +from giql.expressions import Intersects from giql.generators import BaseGIQLGenerator from giql.resolver import resolve_operator_refs from giql.table import Table @@ -150,41 +153,73 @@ def transpile( "of the binned equi-join. Pass one or the other, not both." ) + # A *target-specific* ``(target, Intersects)`` registry entry (the public + # extension hook) takes over the INTERSECTS join rewrite entirely, so the + # built-in binned/IEJoin transformers are skipped for that target (see + # ExpanderRegistry.has_override). ``intersects_bin_size`` only configures the + # built-in binned transformer, so under such an override it would be silently + # dropped — reject it up front, parallel to the iejoin rejection above. + target_overrides_intersects = REGISTRY.has_override(target, Intersects) + if target_overrides_intersects and intersects_bin_size is not None: + raise ValueError( + "intersects_bin_size has no effect when a target-specific " + f"(target={target.name!r}, Intersects) expander is registered; that " + "expander supersedes the built-in binned join transformer the bin " + "size configures. Pass one or the other, not both." + ) + tables_container = _build_tables(tables) with _reraise_as_value_error("Parse error", query=giql): ast = parse_one(giql, dialect=GIQLDialect) + # The column-to-column INTERSECTS *join* rewrites are capability-gated + # pre-pass transformers (epic #137, #141): the target's + # ``range_join_strategy`` selects the DuckDB IEJoin plan or the generic + # binned equi-join. They run on the raw parsed AST (before resolution, which + # rewrites the genomic column name) and consume a column-to-column INTERSECTS + # *join* so it never reaches the predicate expander; a literal-range or + # residual column-to-column INTERSECTS *predicate* survives to pass 3. + # + # ``target_overrides_intersects`` (computed above) records whether a + # *target-specific* ``(target, Intersects)`` registry entry — the public + # extension hook — has taken over the join rewrite. When it has, the built-in + # join rewrite below is skipped so the INTERSECTS node flows untouched into + # ExpandOperators, which dispatches it to that expander. This is the + # registry-deferral the IEJoin early-return used to preclude (#141). + # ``has_override`` deliberately matches only an *exact non-generic* entry: the + # built-in ``(GenericTarget(), Intersects)`` predicate expander is NOT a + # join-strategy override (it only renders residual / literal-range predicates + # the join transformers leave behind), so it must not disable the join rewrite. + # Falls back to the binned plan for unsupported shapes — see # IntersectsDuckDBIEJoinTransformer.transform_to_sql for the complete - # fallback set. - if uses_iejoin: + # fallback set. The IEJoin transformer emits a whole-query string, so when it + # produces output it must short-circuit the AST pipeline; this is safe for + # expansion because an IEJoin-eligible query carries exactly one INTERSECTS + # (its join), leaving no residual predicate operator for pass 3 to expand. + if uses_iejoin and not target_overrides_intersects: duckdb_transformer = IntersectsDuckDBIEJoinTransformer(tables_container) with _reraise_as_value_error("Transformation error"): duckdb_sql = duckdb_transformer.transform_to_sql(ast) if duckdb_sql is not None: - # WARNING: this early return emits the legacy IEJoin SQL directly and - # SKIPS the normalization pipeline below — pass 1 (resolution), pass 2 - # (canonicalization), and pass 3 (ExpandOperators, constructed ~40 - # lines down). The ExpandOperators registry is therefore NOT consulted - # on this path: a flagged operator on an IEJoin-eligible duckdb query - # is left un-expanded. This is benign today (the registry is empty and - # no operator opts in), but any DuckDB-pathed operator migration (#141) - # must either run expansion BEFORE this early return or have the IEJoin - # transformer defer to the registry. See the strict-xfail - # characterization test pinning this gap in tests/test_expander.py. return duckdb_sql - intersects_transformer = IntersectsBinnedJoinTransformer( - tables_container, - bin_size=intersects_bin_size, - ) merge_transformer = MergeTransformer(tables_container) cluster_transformer = ClusterTransformer(tables_container) generator = BaseGIQLGenerator(tables=tables_container) with _reraise_as_value_error("Transformation error"): - ast = intersects_transformer.transform(ast) + # Reaching here with an iejoin target means the IEJoin transformer + # declined the query (returned None) and fell back to the binned plan, + # exactly as before. ``intersects_bin_size`` is rejected up front for + # iejoin targets, so the binned transformer always sees its default there. + if not target_overrides_intersects: + intersects_transformer = IntersectsBinnedJoinTransformer( + tables_container, + bin_size=intersects_bin_size, + ) + ast = intersects_transformer.transform(ast) ast = merge_transformer.transform(ast) ast = cluster_transformer.transform(ast) @@ -196,20 +231,20 @@ def transpile( ast = resolve_operator_refs(ast, tables_container) # Pass 2 of the normalization pipeline (epic #114): synthesize canonical - # __giql_canon_* wrapper CTEs for non-canonical interval operands of - # opted-in operators (GIQL_CANONICALIZE). No operator opts in yet, so this - # is a strict no-op until the per-operator port issues (#122, #123) land. + # __giql_canon_* wrapper CTEs for non-canonical interval operands of operators + # that opt in via GIQL_CANONICALIZE; those operators are rewritten here, and + # operators that do not opt in are left untouched. with _reraise_as_value_error("Canonicalization error"): ast = canonicalize_coordinates(ast) - # Pass 3 of the normalization pipeline (epic #137): replace each opted-in - # GIQL operator node with the AST its registered expander produces for the - # active target. No operator sets GIQL_EXPAND and the registry is empty, so - # this is a strict no-op until the per-operator migration issues land; the - # legacy *_sql emitters on the generator remain the fallback. - expand_operators = ExpandOperators(target, tables_container) + # Pass 3 of the normalization pipeline (epic #137): replace each GIQL operator + # node that opts in (GIQL_EXPAND) and resolves a registered expander with the + # AST that expander produces for the active target. Operators that are + # unflagged or resolve no expander are left untouched and the generator renders + # them via their legacy ``*_sql`` emitter as before. + expand_pass = ExpandOperators(target, tables_container) with _reraise_as_value_error("Expansion error"): - ast = expand_operators.transform(ast) + ast = expand_pass.transform(ast) with _reraise_as_value_error("Transpilation error"): sql = generator.generate(ast) diff --git a/tests/generators/test_base.py b/tests/generators/test_base.py index f95e91b..0b8777a 100644 --- a/tests/generators/test_base.py +++ b/tests/generators/test_base.py @@ -11,31 +11,36 @@ from sqlglot import exp from sqlglot import parse_one +import giql.expanders as _expanders # noqa: F401 (registers built-in expanders) from giql import Table from giql import transpile from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect +from giql.expander import ExpandOperators from giql.expressions import GIQLNearest from giql.generators import BaseGIQLGenerator from giql.resolver import resolve_operator_refs from giql.table import Tables +from giql.targets import GenericTarget def _generate_through_passes(sql: str, tables: Tables) -> str: - """Parse, run normalization passes 1 and 2, then generate SQL. + """Parse, run normalization passes 1-3, then generate SQL. Coordinate canonicalization for operator operands moved out of the emitter and - into the CanonicalizeCoordinates pass (issue #123). Emitter-level tests that - pin canonicalized output must therefore run both passes before generating, - exactly as :func:`giql.transpile.transpile` does, rather than calling - ``generate`` on a bare parsed AST (which would skip canonicalization). This - helper is used where the full ``transpile`` pipeline would otherwise rewrite - the node away (a column-to-column ``INTERSECTS`` is turned into a binned - equi-join before the predicate emitter runs). + into the CanonicalizeCoordinates pass (issue #123), and spatial / set + predicate emission moved into the ExpandOperators pass (issue #141). Emitter- + level tests that pin canonicalized / expanded output must therefore run all + three passes before generating, exactly as :func:`giql.transpile.transpile` + does, rather than calling ``generate`` on a bare parsed AST (which would skip + them). This helper is used where the full ``transpile`` pipeline would + otherwise rewrite the node away (a column-to-column ``INTERSECTS`` is turned + into a binned equi-join before the predicate expander runs). """ ast = parse_one(sql, dialect=GIQLDialect) ast = resolve_operator_refs(ast, tables) ast = canonicalize_coordinates(ast) + ast = ExpandOperators(GenericTarget(), tables).transform(ast) return BaseGIQLGenerator(tables=tables).generate(ast) @@ -817,35 +822,30 @@ def test_giqldistance_canonicalizes_closed_ends_apart_from_gap_parity( def test_error_handling_invalid_range(self): """ GIVEN invalid genomic range string in Intersects - WHEN intersects_sql is called + WHEN the INTERSECTS predicate is expanded THEN ValueError with descriptive message is raised. """ sql = "SELECT * FROM variants WHERE interval INTERSECTS 'invalid'" - ast = parse_one(sql, dialect=GIQLDialect) - - generator = BaseGIQLGenerator() with pytest.raises(ValueError, match="Could not parse genomic range"): - generator.generate(ast) + _generate_through_passes(sql, Tables()) - def test_error_handling_unknown_operation(self): + def test_error_handling_nonnumeric_range_bounds(self): """ - GIVEN unknown operation type in spatial operations - WHEN a spatial operation with unknown op_type is attempted - THEN ValueError is raised. + GIVEN an INTERSECTS range whose start/end bounds are non-numeric + WHEN the INTERSECTS predicate is expanded + THEN ValueError is raised from the range parse failure. - Note: This test verifies internal error handling by directly calling - a method with invalid input, which would only occur through code errors. + Note: 'chr:a-b' parses as a range shape but its bounds are not integers, + so the underlying RangeParser raises and the expander wraps it. (The + former "unknown operation" guard this exercised is now unreachable — + dispatch is closed over the three known op types — so this pins the + remaining reachable failure: a parse error on the literal range.) """ - # This is an indirect test - we verify the generator raises ValueError - # when given malformed range strings as that's how errors surface sql = "SELECT * FROM variants WHERE interval INTERSECTS 'chr:a-b'" - ast = parse_one(sql, dialect=GIQLDialect) - - generator = BaseGIQLGenerator() with pytest.raises(ValueError): - generator.generate(ast) + _generate_through_passes(sql, Tables()) def test_select_sql_join_without_alias(self, tables_with_two_tables): """ diff --git a/tests/test_expander.py b/tests/test_expander.py index ad625f1..913caf5 100644 --- a/tests/test_expander.py +++ b/tests/test_expander.py @@ -53,34 +53,50 @@ def _expander(node: exp.Expression, ctx: ExpansionContext) -> exp.Expression: return _expander +import giql.expanders # noqa: E402, F401 + +#: The registry contents at import — the built-in expanders registered by +#: ``giql.expanders`` for the already-migrated operators. The leak guards and +#: ``clean_registry`` treat this as the baseline rather than an empty registry, so +#: the real built-in registrations survive isolating fixtures and a leaking test +#: is still caught against the true baseline. +_REGISTRY_BASELINE = REGISTRY.snapshot() + + @pytest.fixture def clean_registry(): - """Isolate the process-wide REGISTRY, leaving it empty afterward. + """Isolate the process-wide REGISTRY, restoring its baseline afterward. - The registry is empty at import (the pass ships inert), so a test that opts - in clears on the way out; emptiness is asserted through the public - ``bool()``/``len()`` surface rather than private state. + Saves the import-time baseline (the built-in expanders), empties the registry + so a test sees only what it registers, and restores the baseline on the way + out through the public ``snapshot()``/``restore()`` seam — so a test that + registers a stand-in expander cannot leak it, and the built-in registrations + survive this fixture's isolation. """ - assert not REGISTRY, "REGISTRY was non-empty entering clean_registry" + saved = REGISTRY.snapshot() REGISTRY.clear() yield REGISTRY - REGISTRY.clear() + REGISTRY.restore(saved) @pytest.fixture(autouse=True) def _registry_leak_guard(): - """Assert the process-wide REGISTRY is empty at each test boundary. - - A leak guard: the registry is empty at import and must return to empty after - every test, so a test that registers without cleaning up (a leak that would - silently flip the no-op pass for a later test) fails loudly. Tests that - register on the process-wide REGISTRY do so through ``clean_registry``, which - clears on the way out; this guard catches anything that bypasses it. Both - checks go through the public ``bool()`` surface (A5), not private state. + """Assert the process-wide REGISTRY matches its baseline at each boundary. + + A leak guard: the registry holds the built-in expanders at import and must + return to exactly that baseline after every test, so a test that registers + without cleaning up (a leak that would silently change dispatch for a later + test) fails loudly. Tests that mutate the process-wide REGISTRY do so through + ``clean_registry``, which restores the baseline on the way out; this guard + catches anything that bypasses it. """ - assert not REGISTRY, "REGISTRY leaked into a test from a prior one" + assert REGISTRY.snapshot() == _REGISTRY_BASELINE, ( + "REGISTRY differed from its baseline entering a test" + ) yield - assert not REGISTRY, "a test leaked a registration into REGISTRY" + assert REGISTRY.snapshot() == _REGISTRY_BASELINE, ( + "a test leaked a registration into REGISTRY" + ) @pytest.fixture(autouse=True) @@ -88,18 +104,19 @@ def _expand_flag_leak_guard(): """Assert every operator's GIQL_EXPAND is restored at each test boundary. The symmetric partner of the registry leak guard: each operator class ships - opted out (its own GIQL_EXPAND attribute is False), and a test that flips one - via ``_opted_in`` must restore it. A leaked opt-in would silently flip the - no-op pass for a later test, so this catches anything that bypasses the - exception-safe ``_opted_in`` manager. + a shipped GIQL_EXPAND default (``True`` for a migrated operator, ``False`` + otherwise), and a test that flips one via ``_opted_in`` must restore it. A + leaked flip would silently change the pass for a later test, so this catches + anything that bypasses the exception-safe ``_opted_in`` manager by comparing + against each operator's shipped default rather than a blanket False. """ for op in _OPERATOR_CLASSES: - assert op.__dict__.get("GIQL_EXPAND") is False, ( + assert op.__dict__.get("GIQL_EXPAND") is _SHIPPED_EXPAND_FLAGS[op], ( f"{op.__name__}.GIQL_EXPAND leaked into a test from a prior one" ) yield for op in _OPERATOR_CLASSES: - assert op.__dict__.get("GIQL_EXPAND") is False, ( + assert op.__dict__.get("GIQL_EXPAND") is _SHIPPED_EXPAND_FLAGS[op], ( f"a test leaked a GIQL_EXPAND opt-in on {op.__name__}" ) @@ -468,6 +485,55 @@ def _expander(node, ctx): assert REGISTRY.resolve(DuckDBTarget(), GIQLDisjoin) is None assert (DuckDBTarget(), GIQLDisjoin) not in REGISTRY + def test_snapshot_should_not_observe_later_registrations(self): + """Test that a snapshot does not observe registrations made after it. + + Given: + A registry with one entry, captured by snapshot. + When: + A second entry is registered after the snapshot is taken. + Then: + The snapshot should still hold only the first entry (it is a copy, + not a live view). + """ + # Arrange + registry = ExpanderRegistry() + registry.register(DuckDBTarget(), GIQLDisjoin, _record("first")) + + # Act + saved = registry.snapshot() + registry.register(GenericTarget(), Intersects, _record("second")) + + # Assert + assert (DuckDBTarget(), GIQLDisjoin) in saved + assert (GenericTarget(), Intersects) not in saved + + def test_restore_should_replace_entries_with_snapshot_contents(self): + """Test that restore returns the registry to a captured snapshot. + + Given: + A snapshot of a registry with one entry, after which the registry is + cleared and a different entry registered. + When: + Restoring the snapshot. + Then: + The original entry should resolve again and the post-snapshot entry + should be gone. + """ + # Arrange + registry = ExpanderRegistry() + registry.register(DuckDBTarget(), GIQLDisjoin, _record("original")) + saved = registry.snapshot() + registry.clear() + registry.register(GenericTarget(), Intersects, _record("transient")) + + # Act + registry.restore(saved) + + # Assert + assert (DuckDBTarget(), GIQLDisjoin) in registry + assert (GenericTarget(), Intersects) not in registry + class TestRegisterDecorator: """Tests for the @register extension-hook decorator.""" @@ -955,24 +1021,27 @@ def test_transform_skips_unflagged_operator(self, clean_registry): """Test that an unflagged operator is left untouched even when registered. Given: - An expander registered for (GenericTarget, GIQLDisjoin) but the - operator's GIQL_EXPAND flag left at its default False. + An expander registered for (GenericTarget, op) but the operator's + GIQL_EXPAND flag held off (a migrated operator ships it on, so the + control opts it out to isolate the per-type gate). When: Running the pass. Then: - The DISJOIN node should remain in the tree (gate requires both). + The operator node should remain in the tree (gate requires both). """ # Arrange - clean_registry.register(GenericTarget(), GIQLDisjoin, _record("expanded")) - tables = _tables() - ast = _prepare("SELECT * FROM DISJOIN(variants)", tables) + operator = _A_MIGRATED_OPERATOR + clean_registry.register(GenericTarget(), operator, _record("expanded")) + tables = _tables(("variants", "peaks")) + ast = _prepare_operator(operator, tables) pass_ = ExpandOperators(GenericTarget(), tables, clean_registry) - # Act (GIQL_EXPAND is False by default — no opt-in context) - result = pass_.transform(ast) + # Act + with _opted_out(operator): + result = pass_.transform(ast) # Assert - assert list(result.find_all(GIQLDisjoin)) + assert list(result.find_all(operator)) def test_transform_skips_flagged_operator_with_no_expander(self, clean_registry): """Test that a flagged operator with no expander is left untouched. @@ -1022,17 +1091,21 @@ def test_expand_operators_is_identity_when_registry_empty(self): assert list(result.find_all(GIQLDisjoin)) def test_transpile_sql_unchanged_with_pass_inert(self): - """Test that transpile output is byte-identical with the pass inert. + """Test that transpile output is byte-identical for an unmigrated operator. Given: - A DISJOIN query, the default (empty) registry, and no operator flagged. + A DISTANCE query (an operator not migrated onto the pass, so its + GIQL_EXPAND is False and no expander resolves), with the default + registry. When: - Transpiling with the wired-in pass versus a pass-bypassed reference. + Transpiling with the wired-in pass versus a pass-bypassed reference + (its legacy emitter run directly). Then: - The SQL should match exactly and carry no expander alias prefix. + The SQL should match exactly and carry no expander alias prefix — the + pass is inert for any operator that has not been migrated. """ # Arrange - query = "SELECT * FROM DISJOIN(variants)" + query = "SELECT DISTANCE(a.interval, b.interval) FROM variants a, variants b" tables = _tables() ast = _prepare(query, tables) from giql.generators import BaseGIQLGenerator @@ -1048,8 +1121,9 @@ def test_transpile_sql_unchanged_with_pass_inert(self): # The nine GIQL operator expression classes the ExpandOperators pass inspects. -# Every one must ship opted out (GIQL_EXPAND=False) at this step so the pass is a -# strict no-op until a later migration flips one alongside its expander. +# Each ships GIQL_EXPAND=True once migrated (alongside its registered expander) +# and False otherwise; the flags are read dynamically below rather than asserted +# per branch. from giql.expressions import Contains # noqa: E402 from giql.expressions import GIQLCluster # noqa: E402 from giql.expressions import GIQLDistance # noqa: E402 @@ -1069,20 +1143,47 @@ def test_transpile_sql_unchanged_with_pass_inert(self): GIQLMerge, ) +#: Each operator's shipped GIQL_EXPAND default, captured from its own class dict +#: at import. A migrated operator ships ``True``; the rest ship ``False`` until +#: their migrations land. The flag leak guard restores to these shipped values +#: rather than a blanket ``False``. +_SHIPPED_EXPAND_FLAGS = {op: op.__dict__.get("GIQL_EXPAND") for op in _OPERATOR_CLASSES} + + +#: Operators migrated onto the ExpandOperators pass — they ship GIQL_EXPAND=True. +#: Derived dynamically from each operator's shipped flag so this line is +#: byte-identical across branches (auto-merges) regardless of which operators a +#: branch migrated. +_MIGRATED_OPERATORS = tuple( + op for op in _OPERATOR_CLASSES if op.__dict__.get("GIQL_EXPAND") is True +) +assert _MIGRATED_OPERATORS, "expected at least one migrated operator" +#: A representative migrated operator for operator-agnostic control tests (a +#: migrated operator ships GIQL_EXPAND=True, so a control that needs one to act +#: unflagged must explicitly opt it out). +_A_MIGRATED_OPERATOR = _MIGRATED_OPERATORS[0] +#: Operators not yet migrated — they ship GIQL_EXPAND=False. +_UNMIGRATED_OPERATORS = tuple( + op for op in _OPERATOR_CLASSES if op not in _MIGRATED_OPERATORS +) + class TestOperatorOptOut: - """Every operator ships opted out of the ExpandOperators pass at this step.""" + """Migrated operators opt into the pass; the rest still ship opted out.""" - @pytest.mark.parametrize("operator", _OPERATOR_CLASSES, ids=lambda c: c.__name__) + @pytest.mark.parametrize( + "operator", _UNMIGRATED_OPERATORS, ids=lambda c: c.__name__ + ) def test_operator_class_ships_expand_disabled(self, operator): - """Test that each operator class ships GIQL_EXPAND=False. + """Test that each unmigrated operator class ships GIQL_EXPAND=False. Given: - One of the nine GIQL operator expression classes. + A GIQL operator expression class that has not been migrated onto the + ExpandOperators pass. When: Reading its GIQL_EXPAND class attribute. Then: - It should be False (no operator opts into expansion yet). + It should be False (the operator still uses the legacy emitter). """ # Arrange & act flag = operator.GIQL_EXPAND @@ -1090,6 +1191,53 @@ def test_operator_class_ships_expand_disabled(self, operator): # Assert assert flag is False + @pytest.mark.parametrize( + "operator", _MIGRATED_OPERATORS, ids=lambda c: c.__name__ + ) + def test_operator_class_ships_expand_enabled(self, operator): + """Test that each migrated operator class ships GIQL_EXPAND=True. + + Given: + A GIQL operator expression class migrated onto the ExpandOperators + pass. + When: + Reading its GIQL_EXPAND class attribute. + Then: + It should be True (the operator expands through its registered + expander instead of the deleted legacy emitter). + """ + # Arrange & act + flag = operator.GIQL_EXPAND + + # Assert + assert flag is True + + +class TestMigratedOperatorsRegistered: + """Every migrated operator resolves a built-in expander in the process REGISTRY.""" + + @pytest.mark.parametrize( + "operator", _MIGRATED_OPERATORS, ids=lambda c: c.__name__ + ) + def test_migrated_operator_resolves_in_process_registry(self, operator): + """Test that each migrated operator resolves an expander in REGISTRY. + + Given: + A GIQL operator class that ships GIQL_EXPAND=True (migrated onto the + ExpandOperators pass). + When: + Resolving it against the import-populated process-wide REGISTRY for + the generic target. + Then: + A built-in expander should resolve — a migrated operator always has a + registered expander, so the pass never leaves it on a deleted emitter. + """ + # Arrange & act + resolved = REGISTRY.resolve(GenericTarget(), operator) + + # Assert + assert resolved is not None + class TestOptedInRestoresFlag: """The _opted_in helper restores GIQL_EXPAND even when its body raises.""" @@ -1118,27 +1266,28 @@ def test_opted_in_restores_flag_after_exception(self): assert GIQLMerge.GIQL_EXPAND is False -class TestIEJoinEarlyReturnSkipsExpansion: - """Pin Finding 2: the duckdb IEJoin early return skips the ExpandOperators pass.""" +class TestIEJoinRegistryDeferral: + """The duckdb IEJoin path defers to a target-specific Intersects expander (#141). - @pytest.mark.xfail( - strict=True, - reason="#141: the duckdb IEJoin early return in transpile() emits before " - "ExpandOperators runs, so a flagged operator on an IEJoin-eligible query " - "is not expanded. Flips to pass when #141 runs expansion before the " - "early return (or defers the IEJoin transformer to the registry).", - ) - def test_iejoin_query_expands_flagged_operator(self, clean_registry): - """Test that an IEJoin-eligible duckdb query expands a flagged operator. + Resolves Finding 2: the IEJoin early return used to emit before the + ExpandOperators pass, so a flagged operator on an IEJoin-eligible query was + never expanded. Now a *target-specific* ``(DuckDBTarget, Intersects)`` + registry entry overrides the built-in join strategy entirely (the public + extension hook), while the default duckdb path — with no such override — + still emits the built-in IEJoin SQL. + """ + + def test_iejoin_query_expands_target_override_expander(self, clean_registry): + """Test that a target-specific Intersects override fires on an IEJoin query. Given: A column-to-column INTERSECTS join eligible for the duckdb IEJoin - path, with Intersects flagged GIQL_EXPAND and an expander registered. + path, with a (DuckDBTarget, Intersects) expander registered. When: Transpiling with dialect='duckdb'. Then: - The expander's sentinel should appear (currently it does NOT — the - IEJoin early return skips the pass; this xfail flips when #141 lands). + The override expander's sentinel should appear — the IEJoin path + defers to the registry rather than short-circuiting expansion. """ # Arrange clean_registry.register( @@ -1150,41 +1299,35 @@ def test_iejoin_query_expands_flagged_operator(self, clean_registry): ) # Act - with _opted_in(Intersects): - sql = transpile(query, tables=["peaks", "genes"], dialect="duckdb") + sql = transpile(query, tables=["peaks", "genes"], dialect="duckdb") # Assert assert "__giql_iejoin_sentinel" in sql + assert "SET VARIABLE __giql_iejoin_" not in sql - def test_iejoin_query_emits_legacy_sql_unchanged(self, clean_registry): - """Test that the legacy IEJoin SQL is emitted regardless of a flagged op. + def test_iejoin_query_emits_builtin_iejoin_without_override(self): + """Test that the default duckdb path emits the built-in IEJoin SQL. Given: - The same IEJoin-eligible duckdb query with Intersects flagged and an - expander registered. + The same IEJoin-eligible duckdb query and no target-specific + Intersects override registered (only the built-in generic expander). When: Transpiling with dialect='duckdb'. Then: - The legacy IEJoin SET VARIABLE SQL is emitted and the expander's - sentinel is absent (characterizing the current skip; the companion - xfail surfaces when #141 fixes it). + The built-in IEJoin SET VARIABLE SQL is emitted (the generic + predicate expander does not disable the join strategy). """ # Arrange - clean_registry.register( - DuckDBTarget(), Intersects, lambda n, c: exp.column("__giql_iejoin_sentinel") - ) query = ( "SELECT a.start FROM peaks a " "JOIN genes b ON a.interval INTERSECTS b.interval" ) # Act - with _opted_in(Intersects): - sql = transpile(query, tables=["peaks", "genes"], dialect="duckdb") + sql = transpile(query, tables=["peaks", "genes"], dialect="duckdb") # Assert assert "SET VARIABLE __giql_iejoin_" in sql - assert "__giql_iejoin_sentinel" not in sql class TestTranspileExpanderDispatch: @@ -1649,36 +1792,39 @@ def test_walk_partial_opt_in_replaces_only_flagged_type(self, clean_registry): """Test that only the flagged operator type is replaced when both registered. Given: - A DISJOIN and an INTERSECTS, both with registered expanders, but only - INTERSECTS flagged GIQL_EXPAND. + A DISJOIN and a migrated operator, both with registered expanders, but + only DISJOIN flagged GIQL_EXPAND — the migrated operator is held as the + control, opted out so its shipped opt-in cannot interfere. When: Running the pass. Then: - The INTERSECTS is replaced while the DISJOIN node remains (the gate is - per-type). + The DISJOIN is replaced while the control operator node remains (the + gate is per-type). """ # Arrange + control = _A_MIGRATED_OPERATOR clean_registry.register( GenericTarget(), GIQLDisjoin, lambda n, c: exp.column("DJ") ) clean_registry.register( - GenericTarget(), Intersects, lambda n, c: exp.column("IX") + GenericTarget(), control, lambda n, c: exp.column("IX") ) tables = _tables(("variants", "peaks")) ast = _prepare( - "SELECT * FROM DISJOIN(variants) " - "WHERE EXISTS (SELECT * FROM peaks WHERE interval INTERSECTS 'chr1:1-100')", + "SELECT * FROM DISJOIN(variants) WHERE EXISTS (" + + _OPERATOR_QUERIES[control] + + ")", tables, ) pass_ = ExpandOperators(GenericTarget(), tables, clean_registry) - # Act - with _opted_in(Intersects): + # Act (a migrated operator ships flagged, so opt it out as the control) + with _opted_in(GIQLDisjoin), _opted_out(control): result = pass_.transform(ast) # Assert - assert list(result.find_all(GIQLDisjoin)) - assert not list(result.find_all(Intersects)) + assert not list(result.find_all(GIQLDisjoin)) + assert list(result.find_all(control)) def test_walk_shares_alias_sequence_across_sibling_expanders(self, clean_registry): """Test that sibling expanders draw from one alias sequence (no collision). @@ -1874,6 +2020,32 @@ def _prepare(query: str, tables: Tables) -> exp.Expression: return resolve_operator_refs(ast, tables) +#: A minimal GIQL query yielding a node of each operator class, so a control test +#: can build an AST for whichever operator a branch happens to have migrated +#: (``_A_MIGRATED_OPERATOR``) without hard-coding one. Covers every operator class +#: rather than a per-branch migrated subset. +_OPERATOR_QUERIES = { + Intersects: "SELECT * FROM peaks WHERE interval INTERSECTS 'chr1:1-100'", + Contains: "SELECT * FROM peaks WHERE interval CONTAINS 'chr1:1-100'", + Within: "SELECT * FROM peaks WHERE interval WITHIN 'chr1:1-100'", + SpatialSetPredicate: ( + "SELECT * FROM peaks WHERE interval INTERSECTS ANY ('chr1:1-100', 'chr1:5-9')" + ), + GIQLDisjoin: "SELECT * FROM DISJOIN(variants)", + GIQLDistance: ( + "SELECT DISTANCE(a.interval, b.interval) FROM variants a, variants b" + ), + GIQLNearest: "SELECT * FROM NEAREST(variants, reference := 'chr1:1-100')", + GIQLCluster: "SELECT * FROM CLUSTER(variants)", + GIQLMerge: "SELECT * FROM MERGE(variants)", +} + + +def _prepare_operator(operator: type, tables: Tables) -> exp.Expression: + """Parse and resolve a minimal query containing one *operator* node.""" + return _prepare(_OPERATOR_QUERIES[operator], tables) + + class _opted_in: """Context manager opting an operator class into GIQL_EXPAND for a test.""" @@ -1888,3 +2060,25 @@ def __enter__(self): def __exit__(self, *exc): self._operator.GIQL_EXPAND = self._prior return False + + +class _opted_out: + """Context manager opting an operator class out of GIQL_EXPAND for a test. + + The complement of :class:`_opted_in`: used by a control test that needs a + *migrated* operator (one that ships GIQL_EXPAND=True) to behave as if + unflagged, so the test can prove the pass gates per-type without the + operator's shipped opt-in interfering. Restores the prior flag on exit. + """ + + def __init__(self, operator: type) -> None: + self._operator = operator + self._prior = operator.__dict__.get("GIQL_EXPAND", False) + + def __enter__(self): + self._operator.GIQL_EXPAND = False + return self._operator + + def __exit__(self, *exc): + self._operator.GIQL_EXPAND = self._prior + return False diff --git a/tests/test_spatial_expanders.py b/tests/test_spatial_expanders.py new file mode 100644 index 0000000..068c05f --- /dev/null +++ b/tests/test_spatial_expanders.py @@ -0,0 +1,386 @@ +"""Direct unit tests for the spatial / set predicate expanders (#141). + +These call ``expand_intersects`` / ``expand_contains`` / ``expand_within`` / +``expand_spatial_set`` directly with a hand-built :class:`ExpansionContext`, +characterizing each dispatch branch (column-to-column vs literal range; CONTAINS +point vs range; ANY/OR vs ALL/AND) and pinning the chosen error messages on +invalid input. They sit outside ``tests/test_expander.py`` so they do not touch +that file's shared, operator-agnostic fixture/infra region. +""" + +import pytest +from sqlglot import exp +from sqlglot import parse_one + +from giql.dialect import GIQLDialect +from giql.expander import REGISTRY +from giql.expander import ExpansionContext +from giql.expanders.intersects import expand_contains +from giql.expanders.intersects import expand_intersects +from giql.expanders.intersects import expand_spatial_set +from giql.expanders.intersects import expand_within +from giql.expressions import Contains +from giql.expressions import Intersects +from giql.expressions import SpatialSetPredicate +from giql.expressions import Within +from giql.resolver import OperatorResolution +from giql.resolver import ResolvedColumn +from giql.table import Tables +from giql.targets import DataFusionTarget +from giql.targets import GenericTarget +from giql.transpile import transpile + +_LEFT = ResolvedColumn( + chrom='a."chrom"', start='a."start"', end='a."end"', strand=None, table=None +) +_RIGHT = ResolvedColumn( + chrom='b."chrom"', start='b."start"', end='b."end"', strand=None, table=None +) +_OPERATOR_TYPES = (Intersects, Contains, Within, SpatialSetPredicate) + + +def _context(query: str, columns: dict[str, ResolvedColumn]) -> tuple: + """Find the spatial operator in *query* and build a context with *columns*.""" + root = parse_one(query, dialect=GIQLDialect) + node = next(n for n in root.walk() if isinstance(n, _OPERATOR_TYPES)) + resolution = OperatorResolution( + operator=type(node).__name__, slots={}, columns=columns + ) + ctx = ExpansionContext(node, resolution, GenericTarget(), Tables()) + return node, ctx + + +def _sql(expression: exp.Expression) -> str: + """Serialize a built expression through the GIQL dialect.""" + return expression.sql(dialect=GIQLDialect) + + +class TestSpatialExpanders: + """Direct expansion of the spatial / set predicate expanders (#141).""" + + def test_intersects_literal_range_expands_to_overlap_predicate(self): + """Test that a literal-range INTERSECTS expands to the overlap predicate. + + Given: + An INTERSECTS node whose right operand is a literal range and a + context resolving only the left column. + When: + Expanding it. + Then: + It should build the overlap boolean (chrom = lit AND start < end2 AND + end > start2) with the right operand as numeric literals. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval INTERSECTS 'chr1:1000-2000'", + {"this": _LEFT}, + ) + + # Act + result = expand_intersects(node, ctx) + + # Assert + assert _sql(result) == ( + '(a."chrom" = \'chr1\' AND a."start" < 2000 AND a."end" > 1000)' + ) + + def test_intersects_column_to_column_expands_to_join_predicate(self): + """Test that a column-to-column INTERSECTS expands to a join predicate. + + Given: + An INTERSECTS node with a resolved right *column* (the dispatch keys on + ctx.resolution.column("expression")). + When: + Expanding it. + Then: + It should compare the two columns' endpoints rather than literals. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a JOIN b ON a.interval INTERSECTS b.interval", + {"this": _LEFT, "expression": _RIGHT}, + ) + + # Act + result = expand_intersects(node, ctx) + + # Assert + assert _sql(result) == ( + '(a."chrom" = b."chrom" AND a."start" < b."end" AND a."end" > b."start")' + ) + + def test_contains_point_query_expands_to_point_predicate(self): + """Test that a single-base CONTAINS expands to the point-containment form. + + Given: + A CONTAINS node whose literal range is a single base (end == start+1). + When: + Expanding it. + Then: + It should use the point form (start <= point AND end > point), not the + range form. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval CONTAINS 'chr1:1000'", {"this": _LEFT} + ) + + # Act + result = expand_contains(node, ctx) + + # Assert + assert _sql(result) == ( + '(a."chrom" = \'chr1\' AND a."start" <= 1000 AND a."end" > 1000)' + ) + + def test_contains_range_query_expands_to_range_predicate(self): + """Test that a multi-base CONTAINS expands to the range-containment form. + + Given: + A CONTAINS node whose literal range spans more than one base. + When: + Expanding it. + Then: + It should use the range form (start <= start2 AND end >= end2). + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval CONTAINS 'chr1:1000-2000'", + {"this": _LEFT}, + ) + + # Act + result = expand_contains(node, ctx) + + # Assert + assert _sql(result) == ( + '(a."chrom" = \'chr1\' AND a."start" <= 1000 AND a."end" >= 2000)' + ) + + def test_within_expands_to_containment_predicate(self): + """Test that WITHIN expands to the left-within-right containment form. + + Given: + A WITHIN node with a literal range. + When: + Expanding it. + Then: + It should build start >= start2 AND end <= end2. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval WITHIN 'chr1:1000-2000'", {"this": _LEFT} + ) + + # Act + result = expand_within(node, ctx) + + # Assert + assert _sql(result) == ( + '(a."chrom" = \'chr1\' AND a."start" >= 1000 AND a."end" <= 2000)' + ) + + def test_set_any_or_combines_per_range_conditions(self): + """Test that an ANY set predicate OR-combines its per-range conditions. + + Given: + An INTERSECTS ANY node over two literal ranges. + When: + Expanding it. + Then: + The two per-range overlap predicates should be OR-combined inside one + outer paren. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval " + "INTERSECTS ANY ('chr1:1-100', 'chr1:200-300')", + {"this": _LEFT}, + ) + + # Act + result = expand_spatial_set(node, ctx) + + # Assert + assert _sql(result) == ( + '((a."chrom" = \'chr1\' AND a."start" < 100 AND a."end" > 1) OR ' + '(a."chrom" = \'chr1\' AND a."start" < 300 AND a."end" > 200))' + ) + + def test_set_all_and_combines_per_range_conditions(self): + """Test that an ALL set predicate AND-combines its per-range conditions. + + Given: + An INTERSECTS ALL node over two literal ranges. + When: + Expanding it. + Then: + The two per-range overlap predicates should be AND-combined inside one + outer paren. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval " + "INTERSECTS ALL ('chr1:1-100', 'chr1:200-300')", + {"this": _LEFT}, + ) + + # Act + result = expand_spatial_set(node, ctx) + + # Assert + assert _sql(result) == ( + '((a."chrom" = \'chr1\' AND a."start" < 100 AND a."end" > 1) AND ' + '(a."chrom" = \'chr1\' AND a."start" < 300 AND a."end" > 200))' + ) + + +@pytest.fixture +def isolated_registry(): + """Snapshot/restore the process REGISTRY so a test can register an override.""" + saved = REGISTRY.snapshot() + try: + yield REGISTRY + finally: + REGISTRY.restore(saved) + + +class TestBinnedTargetOverrideDeferral: + """A target-specific Intersects override defers the binned join rewrite (#141).""" + + def test_binned_target_override_skips_join_rewrite(self, isolated_registry): + """Test that a (target, Intersects) override bypasses the binned transformer. + + Given: + A column-to-column INTERSECTS join on the generic binned path + (dialect='datafusion') with a (DataFusionTarget, Intersects) override + registered. + When: + Transpiling. + Then: + The override's sentinel reaches the SQL and no binned equi-join + artifact is emitted — the override takes over the join rewrite that the + built-in binned transformer would otherwise perform. + """ + # Arrange + isolated_registry.register( + DataFusionTarget(), + Intersects, + lambda n, c: exp.column("BINNED_OVERRIDE_SENTINEL"), + ) + query = ( + "SELECT a.start FROM peaks a " + "JOIN genes b ON a.interval INTERSECTS b.interval" + ) + + # Act + sql = transpile(query, tables=["peaks", "genes"], dialect="datafusion") + + # Assert + assert "BINNED_OVERRIDE_SENTINEL" in sql + assert "_bins" not in sql + + def test_binned_target_override_rejects_bin_size(self, isolated_registry): + """Test that bin size is rejected under a binned-target Intersects override. + + Given: + A (DataFusionTarget, Intersects) override registered. + When: + Transpiling with intersects_bin_size set (which only configures the + built-in binned transformer the override supersedes). + Then: + transpile() raises ValueError rather than silently dropping the bin + size, parallel to the iejoin rejection. + """ + # Arrange + isolated_registry.register( + DataFusionTarget(), + Intersects, + lambda n, c: exp.column("BINNED_OVERRIDE_SENTINEL"), + ) + query = ( + "SELECT a.start FROM peaks a " + "JOIN genes b ON a.interval INTERSECTS b.interval" + ) + + # Act & assert + with pytest.raises(ValueError, match=r"intersects_bin_size has no effect"): + transpile( + query, + tables=["peaks", "genes"], + dialect="datafusion", + intersects_bin_size=5000, + ) + + +class TestSpatialExpanderErrors: + """Characterization tests pinning the chosen error messages on invalid input.""" + + def test_invalid_literal_range_wraps_parse_error(self): + """Test that an unparseable literal range raises the wrapped diagnostic. + + Given: + An INTERSECTS node whose literal range string cannot be parsed. + When: + Expanding it. + Then: + It should raise ValueError with the historical "Could not parse + genomic range" wrapper, chained from the underlying parse error. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval INTERSECTS 'invalid'", {"this": _LEFT} + ) + + # Act & assert + with pytest.raises(ValueError, match=r"Could not parse genomic range") as exc: + expand_intersects(node, ctx) + assert exc.value.__cause__ is not None + + def test_unresolved_left_operand_raises_internal_invariant(self): + """Test that a missing left-operand resolution raises the invariant error. + + Given: + An INTERSECTS node whose context resolved no "this" column (pass 1 did + not run). + When: + Expanding it. + Then: + It should raise ValueError naming the unresolved operand and pointing + at the ResolveOperatorRefs pass. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval INTERSECTS 'chr1:1-100'", {} + ) + + # Act & assert + with pytest.raises( + ValueError, match=r"Spatial predicate operand 'this' was not resolved" + ): + expand_intersects(node, ctx) + + def test_set_predicate_parse_error_is_unwrapped(self): + """Test that a set-predicate bad range surfaces the raw parser error. + + Given: + An INTERSECTS ANY node with one unparseable range. + When: + Expanding it. + Then: + The raw RangeParser ValueError propagates *unwrapped* — the set- + predicate path does NOT apply the "Could not parse genomic range" + wrapper the single-operand path does. This pins the current + (pre-existing) asymmetry so any future unification is a conscious + change, not an accident. + """ + # Arrange + node, ctx = _context( + "SELECT * FROM a WHERE interval INTERSECTS ANY ('bad', 'chr1:1-2')", + {"this": _LEFT}, + ) + + # Act & assert + with pytest.raises(ValueError, match=r"Invalid genomic range format") as exc: + expand_spatial_set(node, ctx) + assert "Could not parse genomic range" not in str(exc.value)