diff --git a/src/giql/canonicalizer.py b/src/giql/canonicalizer.py index 714a007..78e1948 100644 --- a/src/giql/canonicalizer.py +++ b/src/giql/canonicalizer.py @@ -65,14 +65,40 @@ A migrated operator's *output* columns must land back in the target relation's declared encoding. Epic #114 step 6 envisioned a rewrite of the outermost ``SELECT`` projection, but that placement is wrong for a table function: DISJOIN -synthesizes its ``disjoin_*`` output and its passed-through interval at +and NEAREST synthesize their output columns and pass whole source rows through at *generation* time, so those columns do not exist as AST in this pass, and a ``SELECT *`` consumer hides them from any outer-projection rewrite. So :func:`_decanonicalize_outputs` instead records each wrapped slot's *original* :class:`~giql.table.Table` on the operator's :class:`~giql.resolver.OperatorResolution`, and the operator's emitter reads it to de-canonicalize those synthesized columns where it generates them (DISJOIN, -issue #122). +issue #122; NEAREST's target row passthrough, issue #123). + +Column / interval operands (epic #114, step 8 / issue #123) +----------------------------------------------------------- +A reference slot (DISJOIN/NEAREST target, DISJOIN reference) *owns* the relation +it names, so the pass can wrap that whole relation in a ``__giql_canon_*`` CTE and +redirect the slot's AST node to it. A *column* operand cannot be wrapped that way: +``DISTANCE(a.interval, b.interval)`` and ``a.interval INTERSECTS b.interval`` +reference an alias bound in the *enclosing* query's ``FROM`` / ``JOIN``, shared +with the user's own projection (``SELECT a.start, DISTANCE(...)``). Rewriting that +``FROM`` source to a canonical CTE would silently canonicalize the user's own +``a.start`` too — a behavior change. NEAREST's column / implicit-outer +``reference`` slot is the same shape (an alias from the outer LATERAL relation). + +For those operands the pass therefore canonicalizes the resolution metadata *in +place* rather than synthesizing a CTE: :func:`_canonicalize_column_operands` +rewrites each :class:`~giql.resolver.ResolvedColumn` (DISTANCE's two operands and +the spatial predicates' column operands) and each non-table +:class:`~giql.resolver.ResolvedInterval` (NEAREST's ``column`` / +``implicit_outer`` reference) so its ``start`` / ``end`` fragments carry the +canonical 0-based half-open arithmetic and its ``table`` is blanked. The emitter +then consumes the fragments verbatim — no in-emitter +:func:`giql.canonical.canonical_start` / ``canonical_end``. The arithmetic is the +same the emitter used to emit inline, so the SQL stays byte-identical for these +operands; only the *owner* of the arithmetic moves from the generator to the pass. +A ``literal_range`` interval is already canonical and is left untouched, and an +operand whose ``table`` is already canonical (or ``None``) is a no-op. """ from __future__ import annotations @@ -94,6 +120,8 @@ from giql.expressions import Within from giql.resolver import META_KEY from giql.resolver import OperatorResolution +from giql.resolver import ResolvedColumn +from giql.resolver import ResolvedInterval from giql.resolver import ResolvedRef from giql.table import Table @@ -146,9 +174,16 @@ def canonicalize_coordinates(expression: exp.Expression) -> exp.Expression: The same *expression*, with canonical wrapper CTEs inserted and migrated operator slots rewritten (none, while every flag is off). """ + # Column / interval operands (DISTANCE, predicates, NEAREST's non-table + # reference) canonicalize their metadata in place; this is independent of the + # ref-slot CTE synthesis below and runs for every opted-in operator. + _canonicalize_column_operands(expression) + targets = _collect_targets(expression) if not targets: - # No opted-in operator carries a non-canonical operand: strict no-op. + # No opted-in operator carries a non-canonical *reference-slot* operand: + # no wrapper CTE is synthesized (the in-place column canonicalization + # above already ran). return expression taken = _collect_taken_names(expression) @@ -223,6 +258,102 @@ def _is_canonical(table: Table | None) -> bool: return table.coordinate_system == "0based" and table.interval_type == "half_open" +def _canonicalize_column_operands(expression: exp.Expression) -> None: + """Canonicalize column / interval operand metadata in place for opted-in ops. + + For every opted-in operator (``GIQL_CANONICALIZE``) carrying column operands + (DISTANCE's two operands and the spatial predicates' column operands, in the + :attr:`OperatorResolution.columns` channel) or a non-table interval reference + (NEAREST's ``column`` / ``implicit_outer`` slot, in + :attr:`OperatorResolution.slots`), rewrite the operand's ``start`` / ``end`` + SQL fragments to carry the canonical 0-based half-open arithmetic for its + declared encoding and blank its ``table``. + + This replaces the in-emitter :func:`giql.canonical.canonical_start` / + ``canonical_end`` wrapping for those operands (epic #114, step 8). Unlike the + ref-slot CTE synthesis, no relation is wrapped: the operand references an + alias bound in the enclosing query's ``FROM`` shared with the user's own + projection, so canonicalizing the whole relation would change unrelated + columns. Operands already canonical (``table`` is ``None`` or 0-based + half-open) are left untouched, keeping their SQL byte-identical; a + ``literal_range`` interval is already canonical and is skipped. + """ + for node in expression.walk(): + if not isinstance(node, _OPERATORS): + continue + if not getattr(node, "GIQL_CANONICALIZE", False): + continue + resolution = node.meta.get(META_KEY) + if not isinstance(resolution, OperatorResolution): + continue + for arg, column in list(resolution.columns.items()): + resolution.columns[arg] = _canonicalize_column(column) + for arg, slot in list(resolution.slots.items()): + if isinstance(slot, ResolvedInterval): + resolution.slots[arg] = _canonicalize_interval(slot) + + +def _canonicalize_column(column: ResolvedColumn) -> ResolvedColumn: + """Return *column* with canonical start/end fragments and a blanked table. + + A no-op (returns *column* unchanged) when its backing table is already + canonical or ``None``. + """ + if _is_canonical(column.table): + return column + return replace( + column, + start=_canonical_start_sql(column.start, column.table), + end=_canonical_end_sql(column.end, column.table), + table=None, + ) + + +def _canonicalize_interval(interval: ResolvedInterval) -> ResolvedInterval: + """Return *interval* with canonical start/end fragments and a blanked table. + + A no-op for a ``literal_range`` (already canonical, no table) and for any + interval whose backing table is already canonical or ``None``. + """ + if interval.kind == "literal_range" or _is_canonical(interval.table): + return interval + return replace( + interval, + start=_canonical_start_sql(interval.start, interval.table), + end=_canonical_end_sql(interval.end, interval.table), + table=None, + ) + + +def _canonical_start_sql(start: str, table: Table | None) -> str: + """SQL-fragment analog of :func:`giql.canonical.canonical_start`. + + - ``0based``: ``start`` (identity) + - ``1based``: ``(start - 1)`` + """ + if table is None or table.coordinate_system == "0based": + return start + return f"({start} - 1)" + + +def _canonical_end_sql(end: str, table: Table | None) -> str: + """SQL-fragment analog of :func:`giql.canonical.canonical_end`. + + - ``0based`` / ``half_open``: ``end`` (identity) + - ``0based`` / ``closed``: ``(end + 1)`` + - ``1based`` / ``half_open``: ``(end - 1)`` + - ``1based`` / ``closed``: ``end`` (identity) + """ + if table is None: + return end + key = (table.coordinate_system, table.interval_type) + if key == ("0based", "closed"): + return f"({end} + 1)" + if key == ("1based", "half_open"): + return f"({end} - 1)" + return end + + def _collect_taken_names(expression: exp.Expression) -> set[str]: """Collect every relation name already in use across all scopes. diff --git a/src/giql/expressions.py b/src/giql/expressions.py index 52e5fce..e8de483 100644 --- a/src/giql/expressions.py +++ b/src/giql/expressions.py @@ -99,12 +99,26 @@ class SpatialPredicate(exp.Binary): pass +#: Opt the spatial predicates, DISTANCE, and NEAREST into the +#: CanonicalizeCoordinates pass (epic #114 step 8, issue #123). With this flag +#: set, pass 2 canonicalizes each operator's interval operands: a non-table column +#: operand (DISTANCE / predicate operand, NEAREST's column / implicit-outer +#: reference) has its resolution metadata rewritten in place to canonical 0-based +#: half-open arithmetic, and a registered-table reference slot (NEAREST's target) +#: is wrapped in a ``__giql_canon_*`` CTE. The emitter then consumes already- +#: canonical fragments with no in-emitter canonicalization. Identity (0-based +#: half-open) operands are left untouched and the emitted SQL stays byte-identical. +_CANONICALIZE = True + + class Intersects(SpatialPredicate): """INTERSECTS spatial predicate. Example: column INTERSECTS 'chr1:1000-2000' """ + GIQL_CANONICALIZE = _CANONICALIZE + GIQL_SLOTS = ( SlotSpec("this", frozenset({"column"}), required=True), SlotSpec("expression", frozenset({"literal_range", "column"}), required=True), @@ -117,6 +131,8 @@ class Contains(SpatialPredicate): Example: column CONTAINS 'chr1:1500' """ + GIQL_CANONICALIZE = _CANONICALIZE + GIQL_SLOTS = ( SlotSpec("this", frozenset({"column"}), required=True), SlotSpec("expression", frozenset({"literal_range", "column"}), required=True), @@ -129,6 +145,8 @@ class Within(SpatialPredicate): Example: column WITHIN 'chr1:1000-5000' """ + GIQL_CANONICALIZE = _CANONICALIZE + GIQL_SLOTS = ( SlotSpec("this", frozenset({"column"}), required=True), SlotSpec("expression", frozenset({"literal_range", "column"}), required=True), @@ -150,6 +168,8 @@ class SpatialSetPredicate(exp.Expression): "ranges": True, } + GIQL_CANONICALIZE = _CANONICALIZE + GIQL_SLOTS = ( SlotSpec("this", frozenset({"column"}), required=True), SlotSpec( @@ -259,6 +279,8 @@ class GIQLDistance(exp.Func): "signed": False, # Optional: boolean for directional distance } + GIQL_CANONICALIZE = _CANONICALIZE + GIQL_SLOTS = ( SlotSpec("this", frozenset({"column"}), required=True), SlotSpec("expression", frozenset({"column"}), required=True), @@ -296,6 +318,16 @@ class GIQLNearest(exp.Func): "signed": False, # Optional: directional distance } + #: Opt NEAREST into the CanonicalizeCoordinates pass (epic #114 step 8, issue + #: #123). Pass 2 wraps a non-canonical *target* (the ``this`` registered-table + #: ref slot) in a ``__giql_canon_*`` CTE — so the emitter reads canonical + #: target columns and de-canonicalizes its ``*`` row passthrough back to the + #: declared encoding — and canonicalizes a non-table ``column`` / + #: ``implicit_outer`` reference's metadata in place. Identity (0-based + #: half-open) operands are left untouched and the emitted SQL stays + #: byte-identical. + GIQL_CANONICALIZE = _CANONICALIZE + GIQL_SLOTS = ( SlotSpec("this", frozenset({"registered_table"}), required=True), SlotSpec( diff --git a/src/giql/generators/base.py b/src/giql/generators/base.py index 2bbf918..1874094 100644 --- a/src/giql/generators/base.py +++ b/src/giql/generators/base.py @@ -1,8 +1,6 @@ from sqlglot import exp from sqlglot.generator import Generator -from giql.canonical import canonical_end -from giql.canonical import canonical_start from giql.canonical import decanonical_end from giql.canonical import decanonical_start from giql.constants import DEFAULT_CHROM_COL @@ -159,21 +157,29 @@ def giqlnearest_sql(self, expression: GIQLNearest) -> str: ) table_name = target_ref.name target_chrom, target_start, target_end = target_ref.cols - target_table = target_ref.table + + # The target's *declared* encoding, which the passed-through target row + # (SELECT {table_name}.*) must round-trip back into. CanonicalizeCoordinates + # (pass 2) preserves it on the resolution when it wraps a non-canonical + # target in a __giql_canon_* CTE (the slot's own Table is then None); a + # canonical target is left unwrapped and its slot Table carries the + # (identity) encoding. The synthesized `distance` column is encoding- + # invariant (a count of bases) and needs no round-trip. + output_table = self._nearest_output_encoding(expression, target_ref) + passthrough = self._nearest_passthrough( + table_name, target_start, target_end, output_table + ) # Reference interval (a ResolvedInterval from the pass). An unresolved - # reference re-raises the generator's historical diagnostic. + # reference re-raises the generator's historical diagnostic. Input + # canonicalization is owned by CanonicalizeCoordinates (pass 2, issue + # #123): a literal range is already canonical, and a column / implicit- + # outer reference's endpoints are canonicalized in place by the pass, so + # the emitter consumes the fragments verbatim with no canonicalization. ref = resolution.slot("reference") if not isinstance(ref, ResolvedInterval): self._raise_nearest_reference_error(expression, mode, resolution) - if ref.kind == "literal_range": - # Literal endpoints are already canonical 0-based half-open. - ref_chrom, ref_start, ref_end = ref.chrom, ref.start, ref.end - else: - # Column / implicit-outer endpoints are raw; canonicalize here. - ref_chrom = ref.chrom - ref_start = canonical_start(ref.start, ref.table) - ref_end = canonical_end(ref.end, ref.table) + ref_chrom, ref_start, ref_end = ref.chrom, ref.start, ref.end # Extract parameters k = expression.args.get("k") @@ -193,14 +199,23 @@ def giqlnearest_sql(self, expression: GIQLNearest) -> str: target_strand = None if is_stranded: ref_strand = ref.strand - if target_table and target_table.strand_col: - target_strand = f'{table_name}."{target_table.strand_col}"' - - # Distance math below assumes 0-based half-open. - target_start_expr = canonical_start( - f'{table_name}."{target_start}"', target_table - ) - target_end_expr = canonical_end(f'{table_name}."{target_end}"', target_table) + # When pass 2 wraps a non-canonical target its slot Table is blanked, + # so the strand column name comes from the *declared* encoding the + # pass preserved (output_table). The canon CTE's SELECT * REPLACE + # passes the strand column through unchanged under its physical name, + # so the qualifier stays the relation NEAREST selects from. + if output_table and output_table.strand_col: + target_strand = f'{table_name}."{output_table.strand_col}"' + + # Distance math below assumes 0-based half-open. Input canonicalization + # is owned by CanonicalizeCoordinates (pass 2, issue #123): a + # non-canonical target is rewritten to a canonical __giql_canon_* CTE + # before generation (table_name then names the CTE), so the target + # endpoints are consumed verbatim with no in-emitter canonicalization. The + # output round-trip of the passed-through target row stays here (see the + # SELECT projection below). + target_start_expr = f'{table_name}."{target_start}"' + target_end_expr = f'{table_name}."{target_end}"' # Build distance calculation using CASE expression # For NEAREST: ORDER BY absolute distance, but RETURN signed distance @@ -239,7 +254,7 @@ def giqlnearest_sql(self, expression: GIQLNearest) -> str: # Standalone mode: direct ORDER BY + LIMIT # Return signed distance, but order by absolute distance sql = f"""( - SELECT {table_name}.*, {distance_expr} AS distance + SELECT {passthrough}, {distance_expr} AS distance FROM {table_name} WHERE {where_sql} ORDER BY {abs_distance_expr} @@ -261,7 +276,7 @@ def giqlnearest_sql(self, expression: GIQLNearest) -> str: # LATERAL mode: subquery for k-nearest neighbors # Return signed distance, but order by absolute distance sql = f"""( - SELECT {table_name}.*, {distance_expr} AS distance + SELECT {passthrough}, {distance_expr} AS distance FROM {table_name} WHERE {where_sql} ORDER BY {abs_distance_expr} @@ -270,6 +285,77 @@ def giqlnearest_sql(self, expression: GIQLNearest) -> str: return sql.strip() + def _nearest_output_encoding( + self, expression: GIQLNearest, target_ref: ResolvedRef + ) -> Table | None: + """Return the target's declared encoding for NEAREST's row passthrough. + + ``CanonicalizeCoordinates`` (pass 2) records the original + :class:`~giql.table.Table` on the resolution when it wraps a non-canonical + target in a ``__giql_canon_*`` CTE (blanking the slot's own ``table``). + For an unwrapped target — a canonical registered table, or any target when + the pass did not run — the slot's own ``table`` carries the (identity) + encoding. + + :param expression: + GIQLNearest expression node + :param target_ref: + The resolved target reference (post pass 2) + :return: + The target's declared :class:`~giql.table.Table`, or ``None`` + """ + resolution = expression.meta.get(META_KEY) + if isinstance(resolution, OperatorResolution): + preserved = resolution.output_tables.get("this") + if preserved is not None: + return preserved + return target_ref.table + + def _nearest_passthrough( + self, + table_name: str, + target_start: str, + target_end: str, + output_table: Table | None, + ) -> str: + """Project the target's full row, de-canonicalizing the interval columns. + + NEAREST passes the whole target row through (``SELECT {table_name}.*``) + alongside the synthesized, encoding-invariant ``distance`` column. When the + target's declared encoding is canonical 0-based half-open the row passes + through as a plain ``{table_name}.*`` — the byte-identical identity fast + path. When it is non-canonical the interval columns, canonical inside the + ``__giql_canon_*`` CTE the target was rewritten to, are de-canonicalized + back into that encoding via a star ``REPLACE`` so the passed-through + interval matches the target's own convention. (Only non-canonical targets + are wrapped, so the ``REPLACE`` appears only where a canonical CTE already + shapes the SQL.) + + :param table_name: + The relation the row is selected from (the canon CTE name when wrapped, + else the registered table name) — also the column qualifier. + :param target_start: + Physical start column name + :param target_end: + Physical end column name + :param output_table: + The target's declared :class:`~giql.table.Table`, or ``None`` + :return: + The passthrough projection fragment (``{table_name}.*`` or a star + ``REPLACE``) + """ + if output_table is None or ( + output_table.coordinate_system == "0based" + and output_table.interval_type == "half_open" + ): + return f"{table_name}.*" + pt_start = decanonical_start(f'{table_name}."{target_start}"', output_table) + pt_end = decanonical_end(f'{table_name}."{target_end}"', output_table) + return ( + f"{table_name}.* REPLACE " + f'({pt_start} AS "{target_start}", {pt_end} AS "{target_end}")' + ) + def giqldisjoin_sql(self, expression: GIQLDisjoin) -> str: """Generate SQL for the DISJOIN table function. @@ -475,21 +561,22 @@ def giqldistance_sql(self, expression: GIQLDistance) -> str: strand_a = col_a.strand if stranded else None strand_b = col_b.strand if stranded else None - # Distance math below assumes 0-based half-open. - start_a = canonical_start(col_a.start, col_a.table) - end_a = canonical_end(col_a.end, col_a.table) - start_b = canonical_start(col_b.start, col_b.table) - end_b = canonical_end(col_b.end, col_b.table) + # Distance math below assumes 0-based half-open. Input canonicalization is + # owned by CanonicalizeCoordinates (pass 2, issue #123): each operand's + # start/end fragments are canonicalized in place by the pass, so the + # emitter consumes them verbatim with no in-emitter canonicalization. The + # returned distance is an encoding-invariant base count, so it needs no + # output de-canonicalization. # Generate CASE expression return self._generate_distance_case( col_a.chrom, - start_a, - end_a, + col_a.start, + col_a.end, strand_a, col_b.chrom, - start_b, - end_b, + col_b.start, + col_b.end, strand_b, stranded=stranded, signed=signed, @@ -715,12 +802,14 @@ def _generate_range_predicate( :return: SQL predicate string """ - # Canonicalize the raw physical endpoints to 0-based half-open. The - # alias-qualified column fragments come pre-resolved on the - # ResolvedColumn; canonicalization stays here (epic #114 step #123). + # The alias-qualified column fragments come pre-resolved on the + # ResolvedColumn, already canonicalized to 0-based half-open by + # CanonicalizeCoordinates (pass 2, issue #123). The predicate returns a + # boolean, which is encoding-invariant, so no output de-canonicalization + # is needed. chrom_col = column.chrom - start_col = canonical_start(column.start, column.table) - end_col = canonical_end(column.end, column.table) + start_col = column.start + end_col = column.end chrom = parsed_range.chromosome start = parsed_range.start @@ -774,14 +863,16 @@ def _generate_column_join( :return: SQL predicate string """ - # Canonicalize each side's raw physical endpoints; the alias-qualified - # chrom/start/end fragments come pre-resolved on the ResolvedColumns. + # The alias-qualified chrom/start/end fragments come pre-resolved on the + # ResolvedColumns, already canonicalized to 0-based half-open by + # CanonicalizeCoordinates (pass 2, issue #123). The predicate returns a + # boolean (encoding-invariant), so no output de-canonicalization is needed. l_chrom = left.chrom r_chrom = right.chrom - l_start = canonical_start(left.start, left.table) - l_end = canonical_end(left.end, left.table) - r_start = canonical_start(right.start, right.table) - r_end = canonical_end(right.end, right.table) + l_start = left.start + l_end = left.end + r_start = right.start + r_end = right.end if op_type == "intersects": # Ranges overlap if: chrom1 = chrom2 AND start1 < end2 AND end1 > start2 diff --git a/tests/generators/test_base.py b/tests/generators/test_base.py index 9cd6598..d316157 100644 --- a/tests/generators/test_base.py +++ b/tests/generators/test_base.py @@ -12,12 +12,33 @@ from sqlglot import parse_one from giql import Table +from giql import transpile +from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect from giql.expressions import GIQLNearest from giql.generators import BaseGIQLGenerator +from giql.resolver import resolve_operator_refs from giql.table import Tables +def _generate_through_passes(sql: str, tables: Tables) -> str: + """Parse, run normalization passes 1 and 2, then generate SQL. + + Coordinate canonicalization for operator operands moved out of the emitter and + into the CanonicalizeCoordinates pass (issue #123). Emitter-level tests that + pin canonicalized output must therefore run both passes before generating, + exactly as :func:`giql.transpile.transpile` does, rather than calling + ``generate`` on a bare parsed AST (which would skip canonicalization). This + helper is used where the full ``transpile`` pipeline would otherwise rewrite + the node away (a column-to-column ``INTERSECTS`` is turned into a binned + equi-join before the predicate emitter runs). + """ + ast = parse_one(sql, dialect=GIQLDialect) + ast = resolve_operator_refs(ast, tables) + ast = canonicalize_coordinates(ast) + return BaseGIQLGenerator(tables=tables).generate(ast) + + @pytest.fixture def tables_info(): """Basic Tables with a single table containing genomic columns.""" @@ -815,11 +836,10 @@ def test_giqldistance_should_not_apply_gap_plus_one_for_closed_intervals( "SELECT DISTANCE(a.interval, b.interval) as dist " "FROM bed_features a CROSS JOIN bed_features_b b" ) - ast = parse_one(sql, dialect=GIQLDialect) - generator = BaseGIQLGenerator(tables=tables_with_closed_intervals) - # Act - output = generator.generate(ast) + # Act — canonicalization for DISTANCE operands now lives in the + # CanonicalizeCoordinates pass (#123), so run both passes via the helper. + output = _generate_through_passes(sql, tables_with_closed_intervals) # Assert expected = ( @@ -930,30 +950,29 @@ def test_giqlnearest_should_not_apply_gap_plus_one_for_closed_intervals( Given: A 0-based closed-interval target table and NEAREST. When: - giqlnearest_sql is called. + The query is transpiled. Then: - It should canonicalize the target end as (target."end" + 1) but - omit any "+1" from the gap branches. + It should canonicalize the target end as (end + 1) inside the + wrapper CTE while the distance gap branches read the bare canonical + end with no trailing "+ 1". """ # Arrange - tables = Tables() - tables.register("genes_closed", Table("genes_closed", interval_type="closed")) sql = ( "SELECT * FROM NEAREST(genes_closed, reference := 'chr1:1000-2000', k := 3)" ) - ast = parse_one(sql, dialect=GIQLDialect) - generator = BaseGIQLGenerator(tables=tables) - # Act - output = generator.generate(ast) + # Act — the target's coordinate canonicalization now lives in the + # CanonicalizeCoordinates pass (#123): a non-canonical target is wrapped + # in a __giql_canon_* CTE, so the (end + 1) arithmetic appears there and + # the distance CASE reads the bare canonical columns. + output = transpile(sql, tables=[Table("genes_closed", interval_type="closed")]) - # Assert — canonicalization is applied, but neither gap branch carries - # a trailing "+ 1". - assert '(genes_closed."end" + 1)' in output - assert 'THEN (genes_closed."start" - 2000)' in output - assert 'ELSE (1000 - (genes_closed."end" + 1))' in output - assert '(genes_closed."start" - 2000 + 1)' not in output - assert '(1000 - (genes_closed."end" + 1) + 1)' not in output + # Assert — the wrapper carries (end + 1); the gap branches carry no "+ 1". + assert '("end" + 1) AS "end"' in output + assert 'THEN (__giql_canon_0."start" - 2000)' in output + assert 'ELSE (1000 - __giql_canon_0."end")' in output + assert '__giql_canon_0."start" - 2000 + 1' not in output + assert '1000 - __giql_canon_0."end" + 1' not in output def test_giqldistance_sql_literal_first_arg_error(self, tables_with_two_tables): """ @@ -1272,14 +1291,11 @@ def test_intersects_should_canonicalize_table_columns_for_each_convention( unchanged. """ # Arrange - tables = Tables() - tables.register(table.name, table) sql = "SELECT * FROM variants WHERE interval INTERSECTS 'chr1:100-200'" - ast = parse_one(sql, dialect=GIQLDialect) - generator = BaseGIQLGenerator(tables=tables) - # Act - output = generator.generate(ast) + # Act — predicate operand canonicalization now lives in the + # CanonicalizeCoordinates pass (#123); transpile runs it. + output = transpile(sql, tables=[table]) # Assert expected = ( @@ -1328,14 +1344,11 @@ def test_contains_should_canonicalize_table_columns_for_each_convention( unchanged. """ # Arrange - tables = Tables() - tables.register(table.name, table) sql = "SELECT * FROM variants WHERE interval CONTAINS 'chr1:1500-2000'" - ast = parse_one(sql, dialect=GIQLDialect) - generator = BaseGIQLGenerator(tables=tables) - # Act - output = generator.generate(ast) + # Act — predicate operand canonicalization now lives in the + # CanonicalizeCoordinates pass (#123); transpile runs it. + output = transpile(sql, tables=[table]) # Assert expected = ( @@ -1384,14 +1397,11 @@ def test_within_should_canonicalize_table_columns_for_each_convention( unchanged. """ # Arrange - tables = Tables() - tables.register(table.name, table) sql = "SELECT * FROM variants WHERE interval WITHIN 'chr1:1000-5000'" - ast = parse_one(sql, dialect=GIQLDialect) - generator = BaseGIQLGenerator(tables=tables) - # Act - output = generator.generate(ast) + # Act — predicate operand canonicalization now lives in the + # CanonicalizeCoordinates pass (#123); transpile runs it. + output = transpile(sql, tables=[table]) # Assert expected = ( @@ -1417,11 +1427,13 @@ def test_intersects_should_canonicalize_both_sides_when_conventions_differ( "SELECT * FROM bed_a AS a CROSS JOIN vcf_b AS b " "WHERE a.interval INTERSECTS b.interval" ) - ast = parse_one(sql, dialect=GIQLDialect) - generator = BaseGIQLGenerator(tables=tables_mixed_conventions) - # Act - output = generator.generate(ast) + # Act — column-join operand canonicalization now lives in the + # CanonicalizeCoordinates pass (#123). The full transpile pipeline rewrites + # a column-to-column INTERSECTS into a binned equi-join before the + # predicate emitter runs, so run passes 1 and 2 directly to exercise the + # predicate emitter's column-join branch on canonicalized metadata. + output = _generate_through_passes(sql, tables_mixed_conventions) # Assert expected = ( @@ -1451,11 +1463,11 @@ def test_contains_should_canonicalize_both_sides_when_conventions_differ( "SELECT * FROM bed_a AS a CROSS JOIN vcf_b AS b " "WHERE a.interval CONTAINS b.interval" ) - ast = parse_one(sql, dialect=GIQLDialect) - generator = BaseGIQLGenerator(tables=tables_mixed_conventions) - # Act - output = generator.generate(ast) + # Act — column-join operand canonicalization now lives in the + # CanonicalizeCoordinates pass (#123); transpile runs it. (Column-to-column + # CONTAINS is not rewritten into a binned join, so it reaches the emitter.) + output = transpile(sql, tables=list(tables_mixed_conventions)) # Assert expected = ( @@ -1485,11 +1497,11 @@ def test_within_should_canonicalize_both_sides_when_conventions_differ( "SELECT * FROM bed_a AS a CROSS JOIN vcf_b AS b " "WHERE a.interval WITHIN b.interval" ) - ast = parse_one(sql, dialect=GIQLDialect) - generator = BaseGIQLGenerator(tables=tables_mixed_conventions) - # Act - output = generator.generate(ast) + # Act — column-join operand canonicalization now lives in the + # CanonicalizeCoordinates pass (#123); transpile runs it. (Column-to-column + # WITHIN is not rewritten into a binned join, so it reaches the emitter.) + output = transpile(sql, tables=list(tables_mixed_conventions)) # Assert expected = ( @@ -1515,11 +1527,10 @@ def test_contains_point_should_shift_start_when_table_is_one_based_closed( """ # Arrange sql = "SELECT * FROM vcf_variants WHERE interval CONTAINS 'chr1:1500'" - ast = parse_one(sql, dialect=GIQLDialect) - generator = BaseGIQLGenerator(tables=tables_with_one_based_closed) - # Act - output = generator.generate(ast) + # Act — predicate operand canonicalization now lives in the + # CanonicalizeCoordinates pass (#123); transpile runs it. + output = transpile(sql, tables=list(tables_with_one_based_closed)) # Assert expected = ( @@ -1546,11 +1557,10 @@ def test_intersects_any_should_canonicalize_disjuncts_when_table_is_one_based_cl "SELECT * FROM vcf_variants " "WHERE interval INTERSECTS ANY('chr1:100-200', 'chr1:500-600')" ) - ast = parse_one(sql, dialect=GIQLDialect) - generator = BaseGIQLGenerator(tables=tables_with_one_based_closed) - # Act - output = generator.generate(ast) + # Act — SpatialSetPredicate operand canonicalization now lives in the + # CanonicalizeCoordinates pass (#123); transpile runs it. + output = transpile(sql, tables=list(tables_with_one_based_closed)) # Assert expected = ( @@ -1638,11 +1648,10 @@ def test_giqldistance_should_canonicalize_table_columns_for_each_convention( "SELECT DISTANCE(a.interval, b.interval) as dist " "FROM dist_a a CROSS JOIN dist_b b" ) - ast = parse_one(sql, dialect=GIQLDialect) - generator = BaseGIQLGenerator(tables=tables) - # Act - output = generator.generate(ast) + # Act — DISTANCE operand canonicalization now lives in the + # CanonicalizeCoordinates pass (#123); transpile runs it. + output = transpile(sql, tables=list(tables)) # Assert expected = ( @@ -1673,11 +1682,10 @@ def test_giqldistance_should_canonicalize_each_side_when_conventions_differ( "SELECT DISTANCE(a.interval, b.interval) as dist " "FROM bed_a a CROSS JOIN vcf_b b" ) - ast = parse_one(sql, dialect=GIQLDialect) - generator = BaseGIQLGenerator(tables=tables_mixed_conventions) - # Act - output = generator.generate(ast) + # Act — DISTANCE operand canonicalization now lives in the + # CanonicalizeCoordinates pass (#123); transpile runs it. + output = transpile(sql, tables=list(tables_mixed_conventions)) # Assert expected = ( @@ -1691,75 +1699,128 @@ def test_giqldistance_should_canonicalize_each_side_when_conventions_differ( assert output == expected @pytest.mark.parametrize( - "coordinate_system, interval_type, target_start, target_end", + "coordinate_system, interval_type, wrap_start, wrap_end", [ - pytest.param( - "0based", - "half_open", - 'genes."start"', - 'genes."end"', - id="0based-half_open", - ), pytest.param( "0based", "closed", - 'genes."start"', - '(genes."end" + 1)', + '"start" AS "start"', + '("end" + 1) AS "end"', id="0based-closed", ), pytest.param( "1based", "half_open", - '(genes."start" - 1)', - '(genes."end" - 1)', + '("start" - 1) AS "start"', + '("end" - 1) AS "end"', id="1based-half_open", ), pytest.param( "1based", "closed", - '(genes."start" - 1)', - 'genes."end"', + '("start" - 1) AS "start"', + '"end" AS "end"', id="1based-closed", ), ], ) def test_giqlnearest_should_canonicalize_target_columns_for_each_convention( - self, coordinate_system, interval_type, target_start, target_end + self, coordinate_system, interval_type, wrap_start, wrap_end ): - """Test NEAREST canonicalizes target endpoints per convention. + """Test NEAREST canonicalizes a non-canonical target via a wrapper CTE. Given: - A target table declared with one of the four (coordinate_system, - interval_type) combinations and a literal reference range. + A target table declared with one of the three non-canonical + (coordinate_system, interval_type) combinations and a literal + reference range. When: - giqlnearest_sql is called. + The query is transpiled. Then: - It should wrap the target-side start/end (or not) per the - canonical 0-based half-open conversion in the distance CASE - expression. + It should wrap the target in a __giql_canon_* CTE carrying the + canonical conversion, and the distance CASE should read the bare + canonical target columns with no in-CASE canonicalization. + """ + # Arrange + sql = "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 1)" + + # Act — the target's coordinate canonicalization now lives in the + # CanonicalizeCoordinates pass (#123): a non-canonical target is wrapped + # in a __giql_canon_* CTE before generation, so the distance CASE reads + # already-canonical columns. + output = transpile( + sql, + tables=[ + Table( + "genes", + coordinate_system=coordinate_system, + interval_type=interval_type, + ) + ], + ) + + # Assert — the wrapper carries the canonical conversion; the distance + # CASE reads the bare canonical columns against the literal [1000, 2000). + assert f"REPLACE ({wrap_start}, {wrap_end}) FROM genes" in output + assert ( + 'WHEN 1000 < __giql_canon_0."end" AND 2000 > __giql_canon_0."start" THEN 0' + ) in output + assert ( + 'WHEN 2000 <= __giql_canon_0."start" THEN (__giql_canon_0."start" - 2000)' + ) in output + assert 'ELSE (1000 - __giql_canon_0."end")' in output + + def test_giqlnearest_should_pass_target_columns_through_when_target_is_canonical( + self, + ): + """Test NEAREST leaves a canonical target unwrapped and byte-identical. + + Given: + A canonical 0-based half-open target table and a literal reference. + When: + The query is transpiled. + Then: + It should synthesize no wrapper CTE; the distance CASE reads the raw + target columns and the row passes through as a plain ``genes.*``. """ # Arrange - tables = Tables() - tables.register( - "genes", - Table( - "genes", - coordinate_system=coordinate_system, - interval_type=interval_type, - ), - ) sql = "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 1)" - ast = parse_one(sql, dialect=GIQLDialect) - generator = BaseGIQLGenerator(tables=tables) # Act - output = generator.generate(ast) + output = transpile(sql, tables=[Table("genes")]) + + # Assert + assert "__giql_canon_" not in output + assert "genes.*," in output + assert 'WHEN 1000 < genes."end" AND 2000 > genes."start" THEN 0' in output + + def test_giqlnearest_should_round_trip_passthrough_row_to_target_encoding(self): + """Test NEAREST de-canonicalizes the passed-through row to the target encoding. - # Assert — distance CASE expression uses canonicalized target endpoints - # against the (already-canonical) literal reference [1000, 2000). - assert f"WHEN 1000 < {target_end} AND 2000 > {target_start} THEN 0" in output - assert f"WHEN 2000 <= {target_start} THEN ({target_start} - 2000)" in output - assert f"ELSE (1000 - {target_end})" in output + Given: + A 1-based closed target table wrapped by the canonicalization pass. + When: + The query is transpiled. + Then: + The ``*`` passthrough should de-canonicalize the interval columns + back to the target's declared encoding via a star REPLACE, so the + returned row carries the table's own convention; the synthesized + ``distance`` column is encoding-invariant and stays unwrapped. + """ + # Arrange + sql = "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 1)" + + # Act + output = transpile( + sql, + tables=[Table("genes", coordinate_system="1based", interval_type="closed")], + ) + + # Assert — passthrough de-canonicalizes 1-based-closed start as (start + 1) + # and leaves end identity; the distance column carries no round-trip. + assert ( + '__giql_canon_0.* REPLACE ((__giql_canon_0."start" + 1) AS "start", ' + '__giql_canon_0."end" AS "end")' + ) in output def test_giqlnearest_should_canonicalize_reference_column_when_reference_is_one_based_closed( self, tables_mixed_conventions @@ -1780,13 +1841,14 @@ def test_giqlnearest_should_canonicalize_reference_column_when_reference_is_one_ "SELECT * FROM vcf_b b CROSS JOIN LATERAL " "NEAREST(bed_a, reference := b.interval, k := 1)" ) - ast = parse_one(sql, dialect=GIQLDialect) - generator = BaseGIQLGenerator(tables=tables_mixed_conventions) - # Act - output = generator.generate(ast) + # Act — the reference operand's canonicalization now lives in the + # CanonicalizeCoordinates pass (#123); transpile runs it. The reference is + # a column operand (not a wrappable relation), so the arithmetic stays + # inline; the canonical target bed_a is left unwrapped. + output = transpile(sql, tables=list(tables_mixed_conventions)) - # Assert — reference's start canonicalized via _canonical_start + # Assert — reference's start canonicalized as (start - 1) assert '(b."start" - 1) < bed_a."end"' in output assert 'bed_a."end" <= (b."start" - 1)' not in output # reference is left side # Reference's end stays raw (1-based closed → identity for end) @@ -1810,11 +1872,12 @@ def test_giqlnearest_should_canonicalize_outer_table_columns_when_reference_is_i """ # Arrange sql = "SELECT * FROM vcf_b CROSS JOIN LATERAL NEAREST(bed_a, k := 1)" - ast = parse_one(sql, dialect=GIQLDialect) - generator = BaseGIQLGenerator(tables=tables_mixed_conventions) - # Act - output = generator.generate(ast) + # Act — the implicit-outer reference's canonicalization now lives in the + # CanonicalizeCoordinates pass (#123); transpile runs it. The outer + # reference is a column operand (not a wrappable relation), so the + # arithmetic stays inline; the canonical target bed_a is left unwrapped. + output = transpile(sql, tables=list(tables_mixed_conventions)) # Assert — distance CASE uses canonicalized outer-table columns and # raw target-table columns. diff --git a/tests/test_canonicalizer.py b/tests/test_canonicalizer.py index 6899e9a..afea3e5 100644 --- a/tests/test_canonicalizer.py +++ b/tests/test_canonicalizer.py @@ -22,6 +22,9 @@ from giql.canonicalizer import canonicalize_coordinates from giql.dialect import GIQLDialect from giql.expressions import GIQLDisjoin +from giql.expressions import GIQLDistance +from giql.expressions import GIQLNearest +from giql.expressions import Intersects from giql.generators import BaseGIQLGenerator from giql.resolver import META_KEY from giql.resolver import resolve_operator_refs @@ -631,3 +634,193 @@ def test_all_encoding_pairs_covered(): # Assert assert len(pairs) == 4 assert len(canonical) == 1 + + +def _two_tables(encoding) -> list[Table]: + """Build two registered tables under the same (non-default) encoding.""" + coordinate_system, interval_type = encoding + return [ + Table(name, coordinate_system=coordinate_system, interval_type=interval_type) + for name in ("intervals_a", "intervals_b") + ] + + +class TestColumnOperandCanonicalization: + """Pass 2 canonicalizes column / interval operands in place (issue #123). + + A column operand references an alias bound in the enclosing query's FROM, + shared with the user's own projection, so it cannot be wrapped in a canonical + CTE without changing unrelated columns. Pass 2 therefore rewrites the operand's + resolution metadata to carry the canonical arithmetic inline and the emitter + consumes it verbatim — no in-emitter canonicalization, no wrapper CTE. + """ + + def test_distance_operand_canonicalized_inline_without_wrapper_cte(self): + """Test a non-canonical DISTANCE operand canonicalizes inline, no wrapper. + + Given: + Two 1-based closed tables and a DISTANCE between their columns. + When: + The query is transpiled. + Then: + The CASE arithmetic should wrap each side's start as (start - 1) + inline and synthesize no __giql_canon_* wrapper CTE. + """ + # Arrange + sql = ( + "SELECT DISTANCE(a.interval, b.interval) AS dist " + "FROM intervals_a a, intervals_b b" + ) + + # Act + output = transpile(sql, tables=_two_tables(("1based", "closed"))) + + # Assert + assert CANON_PREFIX not in output + assert '(a."start" - 1)' in output + assert '(b."start" - 1)' in output + + def test_canonical_distance_operand_is_byte_identical(self): + """Test a canonical DISTANCE operand emits byte-identical SQL. + + Given: + Two default (0-based half-open) tables and a DISTANCE between their + columns, transpiled with the operator opted in and with its + GIQL_CANONICALIZE flag toggled off. + When: + The two transpilations are compared. + Then: + They should be byte-identical and carry no wrapper CTE — the pass is + inert for an already-canonical operand. + """ + # Arrange + sql = ( + "SELECT DISTANCE(a.interval, b.interval) AS dist " + "FROM intervals_a a, intervals_b b" + ) + tables = [Table("intervals_a"), Table("intervals_b")] + + # Act + opted_in = transpile(sql, tables=tables) + previous = GIQLDistance.GIQL_CANONICALIZE + GIQLDistance.GIQL_CANONICALIZE = False + try: + flag_off = transpile(sql, tables=tables) + finally: + GIQLDistance.GIQL_CANONICALIZE = previous + + # Assert + assert opted_in == flag_off + assert CANON_PREFIX not in opted_in + + def test_predicate_operand_canonicalized_inline_without_wrapper_cte(self): + """Test a non-canonical INTERSECTS operand canonicalizes inline, no wrapper. + + Given: + A 1-based closed table and a literal-range INTERSECTS predicate. + When: + The query is transpiled. + Then: + The predicate should wrap the table-side start as (start - 1) inline + and synthesize no __giql_canon_* wrapper CTE. + """ + # Arrange + sql = "SELECT * FROM variants WHERE interval INTERSECTS 'chr1:100-200'" + + # Act + output = transpile( + sql, + tables=[ + Table("variants", coordinate_system="1based", interval_type="closed") + ], + ) + + # Assert + assert CANON_PREFIX not in output + assert '("start" - 1) < 200' in output + + def test_metadata_blanked_after_in_place_canonicalization(self): + """Test a canonicalized column operand carries a blanked Table. + + Given: + A 1-based closed INTERSECTS predicate annotated by pass 1. + When: + Pass 2 runs. + Then: + The operand's ResolvedColumn should carry the canonical arithmetic and + its Table should be blanked so the emitter applies no further wrapping. + """ + # Arrange + tables = _tables(("1based", "closed"), names=("variants",)) + query = "SELECT * FROM variants WHERE interval INTERSECTS 'chr1:100-200'" + ast = resolve_operator_refs(parse_one(query, dialect=GIQLDialect), tables) + + # Act + ast = canonicalize_coordinates(ast) + + # Assert + node = next(n for n in ast.walk() if isinstance(n, Intersects)) + column = node.meta[META_KEY].column("this") + assert column.table is None + assert column.start == '("start" - 1)' + + +class TestNearestTargetCanonicalization: + """NEAREST's registered-table target is wrapped and its row round-trips.""" + + def test_non_canonical_target_wrapped_and_row_round_tripped(self): + """Test a non-canonical NEAREST target is wrapped and its row de-canonicalized. + + Given: + A 1-based closed NEAREST target table and a literal reference. + When: + The query is transpiled. + Then: + The target should be wrapped in a __giql_canon_* CTE, the distance + CASE should read the bare canonical columns, and the passed-through + row should de-canonicalize the interval back to the declared encoding. + """ + # Arrange + sql = "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 1)" + + # Act + output = transpile( + sql, + tables=[Table("genes", coordinate_system="1based", interval_type="closed")], + ) + + # Assert + assert f"{CANON_PREFIX}0 AS (SELECT * REPLACE" in output + assert 'WHEN 1000 < __giql_canon_0."end"' in output + assert ( + '__giql_canon_0.* REPLACE ((__giql_canon_0."start" + 1) AS "start"' + ) in output + + def test_canonical_target_not_wrapped(self): + """Test a canonical NEAREST target is left unwrapped. + + Given: + A canonical 0-based half-open NEAREST target and a literal reference, + transpiled with NEAREST opted in and with its flag toggled off. + When: + The two transpilations are compared. + Then: + They should be byte-identical with no wrapper CTE — the identity fast + path. + """ + # Arrange + sql = "SELECT * FROM NEAREST(genes, reference := 'chr1:1000-2000', k := 1)" + tables = [Table("genes")] + + # Act + opted_in = transpile(sql, tables=tables) + previous = GIQLNearest.GIQL_CANONICALIZE + GIQLNearest.GIQL_CANONICALIZE = False + try: + flag_off = transpile(sql, tables=tables) + finally: + GIQLNearest.GIQL_CANONICALIZE = previous + + # Assert + assert opted_in == flag_off + assert CANON_PREFIX not in opted_in