diff --git a/docs/dialect/aggregation-operators.rst b/docs/dialect/aggregation-operators.rst index 9887b87..f34d1c0 100644 --- a/docs/dialect/aggregation-operators.rst +++ b/docs/dialect/aggregation-operators.rst @@ -40,6 +40,9 @@ Syntax -- Strand-specific clustering CLUSTER(interval, stranded := true) AS cluster_id + -- Predicate-gated clustering (run-length encoding on a column) + CLUSTER(interval, predicate := depth = PREV(depth)) AS cluster_id + -- Combined parameters CLUSTER(interval, distance, stranded := true) AS cluster_id @@ -56,6 +59,36 @@ Parameters **stranded** *(optional)* When ``true``, only cluster intervals on the same strand. Default: ``false``. +**predicate** *(optional)* + A boolean expression evaluated between each interval and its sorted + predecessor. When supplied, the cluster-boundary condition becomes + **adjacent AND predicate**: an interval stays in the current cluster only + when it is within ``distance`` of its predecessor *and* the predicate holds + between the two. A change in the predicate forces a new cluster, so an + equality predicate yields a run-length encoding of the input sequence. + Omitting the predicate preserves the default adjacency-only behavior. + + Bare column references resolve to the *current* interval; the predecessor's + value of a column is referenced with ``PREV(column)`` + (e.g. ``depth = PREV(depth)``). The predicate composes with ``distance`` and + ``stranded`` and is evaluated under the operator's existing per-chromosome + (and per-strand) partition and start-position order. + + Two constraints apply: + + - **References existing columns only.** The predicate *gates* merging on + columns already present on the input rows; it does not synthesize a + statistic. Coverage depth, for example, must already be a column on the + rows (typically produced upstream by :ref:`DISJOIN ` and + aggregation). + - **Pairwise only, with single-linkage drift.** The predicate compares each + interval to its immediate sorted predecessor (everything ``LAG`` can + express). Whole-cluster conditions are out of scope. When the predicate is + not an equivalence relation (e.g. ``ABS(score - PREV(score)) < 5``), + consecutive pairs may each satisfy it while the cluster's extremes do not + — the same single-linkage behavior that ``distance``-based clustering + already exhibits. + Return Value ~~~~~~~~~~~~ @@ -101,6 +134,19 @@ Cluster intervals separately by strand: FROM features ORDER BY chrom, strand, start +**Predicate-Gated Clustering:** + +Cut adjacent intervals into clusters wherever a column's value changes +(run-length encoding). ``PREV(column)`` references the predecessor row's value: + +.. code-block:: sql + + SELECT + *, + CLUSTER(interval, predicate := depth = PREV(depth)) AS cluster_id + FROM features + ORDER BY chrom, start + **Analyze Cluster Statistics:** Count features per cluster: @@ -194,6 +240,9 @@ Syntax -- Strand-specific merge SELECT MERGE(interval, stranded := true) FROM features + -- Predicate-gated merge (merge only equal-valued adjacent runs) + SELECT MERGE(interval, predicate := depth = PREV(depth)) FROM features + -- Merge with additional aggregations SELECT MERGE(interval), @@ -214,6 +263,16 @@ Parameters **stranded** *(optional)* When ``true``, merge intervals separately by strand. Default: ``false``. +**predicate** *(optional)* + A boolean expression that further restricts which adjacent intervals are + merged. ``MERGE`` decomposes into :ref:`CLUSTER ` plus a + ``GROUP BY`` over the cluster id, so it inherits predicate-aware boundaries + directly — see the :ref:`CLUSTER predicate ` description + for the full semantics, the ``PREV(column)`` convention, the + references-existing-columns-only constraint, and the pairwise-only / + single-linkage caveat. Omitting the predicate preserves the default + adjacency-only merge. + Return Value ~~~~~~~~~~~~ @@ -256,6 +315,24 @@ Merge intervals separately by strand: SELECT MERGE(interval, stranded := true) FROM features +**Predicate-Gated Merge (coverage depth):** + +Merge only adjacent intervals that share the same coverage depth, reconstructing +a re-clustered, depth-segmented partition from per-breakpoint segments produced +by :ref:`DISJOIN ` and aggregation: + +.. code-block:: sql + + SELECT MERGE(interval, predicate := depth = PREV(depth)) + FROM ( + SELECT disjoin_chrom AS chrom, + disjoin_start AS start, + disjoin_end AS end, + COUNT(*) AS depth + FROM DISJOIN(features) + GROUP BY disjoin_chrom, disjoin_start, disjoin_end + ) AS segments + **Merge with Feature Count:** Count how many features were merged into each region: diff --git a/docs/recipes/clustering.rst b/docs/recipes/clustering.rst index 8703021..f56c85e 100644 --- a/docs/recipes/clustering.rst +++ b/docs/recipes/clustering.rst @@ -347,6 +347,86 @@ Compare raw vs merged coverage: **Use case:** Quantify the redundancy in your feature set. +Predicate-Gated Clustering and Merging +-------------------------------------- + +Both ``CLUSTER`` and ``MERGE`` accept an optional ``predicate :=`` argument that +further restricts which adjacent intervals are coalesced: an interval stays in +the current cluster only when it is adjacent to its predecessor *and* the +predicate holds between the two. Bare columns resolve to the current interval; +the predecessor's value is referenced with ``PREV(column)``. The predicate +references columns already present on the rows — it gates merging, it does not +synthesize a statistic — and it compares each interval only to its immediate +sorted predecessor (so non-equivalence predicates exhibit single-linkage drift, +just like ``distance``-based clustering). + +Run-Length Encoding on a Column +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Merge only adjacent intervals that share the same value, cutting a new region +wherever the value changes: + +.. code-block:: sql + + SELECT MERGE(interval, predicate := depth = PREV(depth)) + FROM segments + +**Use case:** Collapse a per-base or per-segment signal into maximal runs of +constant value (e.g. equal coverage depth, same genotype, same annotation). + +Run-Length Encode with CLUSTER +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Assign a distinct cluster id to each maximal equal-valued run while keeping the +individual rows: + +.. code-block:: sql + + SELECT + *, + CLUSTER(interval, predicate := depth = PREV(depth)) AS run_id + FROM segments + ORDER BY chrom, start + +**Use case:** Label run boundaries for inspection before aggregating. + +Reconstruct disjoin() Coverage Segments +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +GIQL's :ref:`DISJOIN ` primitive splits overlapping intervals +at every breakpoint but deliberately does not re-cluster the resulting +sub-intervals. Pairing it with a predicate-gated ``MERGE`` closes that gap: cut +the input at every breakpoint, aggregate coverage depth per segment, then merge +back the contiguous runs of equal depth — reproducing the re-clustered, +depth-annotated output Bioconductor's ``disjoin()`` produces: + +.. code-block:: sql + + SELECT MERGE(interval, predicate := depth = PREV(depth)) + FROM ( + SELECT disjoin_chrom AS chrom, + disjoin_start AS start, + disjoin_end AS end, + COUNT(*) AS depth + FROM DISJOIN(features) + GROUP BY disjoin_chrom, disjoin_start, disjoin_end + ) AS segments + +**Use case:** Build a re-clustered coverage profile from overlapping intervals, +the expression-based generalization of ``disjoin()`` to any pairwise condition. + +Multi-Column Predicate +~~~~~~~~~~~~~~~~~~~~~~~ + +Gate merging on more than one column by combining comparisons with ``AND``: + +.. code-block:: sql + + SELECT MERGE(interval, predicate := strand = PREV(strand) AND name = PREV(name)) + FROM features + +**Use case:** Keep merged regions homogeneous across several attributes at once. + Advanced Patterns ----------------- diff --git a/src/giql/expressions.py b/src/giql/expressions.py index e8de483..c6894d1 100644 --- a/src/giql/expressions.py +++ b/src/giql/expressions.py @@ -199,17 +199,25 @@ class GIQLCluster(exp.Func): Implicitly partitions by chromosome and orders by start position. + The optional ``predicate`` argument is a boolean expression evaluated + between each interval and its sorted predecessor; intervals are only kept + in the same cluster when they are adjacent *and* the predicate holds. Bare + columns resolve to the current interval; the predecessor's value of a + column is referenced with ``PREV(column)``. + Examples: CLUSTER(interval) CLUSTER(interval, 1000) CLUSTER(interval, stranded := true) CLUSTER(interval, 1000, stranded := true) + CLUSTER(interval, predicate := depth = PREV(depth)) """ arg_types = { "this": True, # genomic column "distance": False, # maximum distance between features "stranded": False, # strand-specific clustering + "predicate": False, # pairwise boolean gate (current row vs PREV(col)) } @classmethod @@ -232,16 +240,25 @@ class GIQLMerge(exp.Func): Merges overlapping or bookended intervals into single intervals. Built on top of CLUSTER operation. + The optional ``predicate`` argument gates merging on a pairwise boolean + expression between each interval and its sorted predecessor (see + :class:`GIQLCluster`); ``PREV(column)`` references the predecessor's value + of a column. When the predicate tests equality of a value this yields a + run-length encoding of the input interval sequence. + Examples: MERGE(interval) MERGE(interval, 1000) MERGE(interval, stranded := true) + MERGE(interval, predicate := depth = PREV(depth)) + MERGE(interval, predicate := strand = PREV(strand) AND name = PREV(name)) """ arg_types = { "this": True, # genomic column "distance": False, # maximum distance between features "stranded": False, # strand-specific merging + "predicate": False, # pairwise boolean gate (current row vs PREV(col)) } @classmethod diff --git a/src/giql/mcp/server.py b/src/giql/mcp/server.py index 194abb7..08fb7a8 100644 --- a/src/giql/mcp/server.py +++ b/src/giql/mcp/server.py @@ -105,6 +105,14 @@ "description": "Max gap to consider same cluster (default: 0)", }, {"name": "stranded", "description": "Cluster by strand (default: false)"}, + { + "name": "predicate", + "description": ( + "Pairwise boolean gate; keep adjacent intervals together " + "only when it holds. Use PREV(col) for the predecessor row's " + "value (e.g. predicate := depth = PREV(depth)). Optional." + ), + }, ], "returns": "Integer cluster ID", "example": "SELECT *, CLUSTER(interval) AS cluster_id FROM features", @@ -118,6 +126,14 @@ {"name": "interval", "description": "Genomic column to merge"}, {"name": "distance", "description": "Max gap to merge (default: 0)"}, {"name": "stranded", "description": "Merge by strand (default: false)"}, + { + "name": "predicate", + "description": ( + "Pairwise boolean gate; merge adjacent intervals only when " + "it holds. Use PREV(col) for the predecessor row's value " + "(e.g. predicate := depth = PREV(depth)). Optional." + ), + }, ], "returns": "Merged interval coordinates (chromosome, start_pos, end_pos)", "example": "SELECT MERGE(interval), COUNT(*) FROM features", diff --git a/src/giql/transformer.py b/src/giql/transformer.py index b0355c6..c757aaf 100644 --- a/src/giql/transformer.py +++ b/src/giql/transformer.py @@ -443,14 +443,34 @@ def _transform_for_cluster( else: lag_with_distance = lag_window + # Build the adjacency condition (predecessor end >= current start). + adjacency = exp.GTE( + this=lag_with_distance, + expression=exp.column(start_col, quoted=True), + ) + + # An optional predicate further restricts which adjacent intervals + # are kept together: a row stays in the current cluster only when it + # is adjacent to its predecessor AND the predicate holds between them. + # ``PREV(col)`` references in the predicate resolve to the predecessor + # row via LAG over the same partition/order as the adjacency window. + predicate_expr = cluster_expr.args.get("predicate") + if predicate_expr is not None: + rewritten_predicate = self._rewrite_predecessor_refs( + predicate_expr, partition_cols, order_by + ) + keep_together = exp.And( + this=adjacency, + expression=exp.Paren(this=rewritten_predicate), + ) + else: + keep_together = adjacency + # Create CASE expression for is_new_cluster case_expr = exp.Case( ifs=[ exp.If( - this=exp.GTE( - this=lag_with_distance, - expression=exp.column(start_col, quoted=True), - ), + this=keep_together, true=exp.Literal.number(0), ) ], @@ -475,6 +495,14 @@ def _transform_for_cluster( if stranded: required_cols.add(strand_col) + # The predicate is evaluated inside the lag_calc CTE, so every column + # it references (current-row columns and PREV() arguments alike) must + # be projected into that CTE. Folding them into required_cols makes the + # scope dependency explicit and keeps the columns available even when a + # later operator wraps this query in a further subquery. + if predicate_expr is not None: + required_cols |= {col.name for col in predicate_expr.find_all(exp.Column)} + # Check if required columns are already in the select list selected_cols = set() for expr in cte_expressions: @@ -550,6 +578,66 @@ def _transform_for_cluster( return new_query + def _rewrite_predecessor_refs( + self, + predicate: exp.Expression, + partition_cols: list[exp.Expression], + order_by: list[exp.Ordered], + ) -> exp.Expression: + """Rewrite ``PREV(col)`` calls in a predicate to LAG windows. + + Bare column references in the predicate denote the current interval. + Each ``PREV(col)`` call denotes the sorted predecessor's value of that + column and is rewritten to ``LAG(col) OVER (...)`` using the same + partition/order as the cluster's adjacency window, so the predicate is + evaluated pairwise against the immediately preceding row. Every column + identifier (current-row columns and LAG arguments alike) is quoted so + that reserved-word genomic columns such as ``start`` / ``end`` are + emitted as valid SQL, matching how the rest of this transformer quotes + genomic columns. + + :param predicate: + Boolean predicate expression to rewrite (not mutated). + :param partition_cols: + Window partition columns (chromosome, optionally strand). + :param order_by: + Window ORDER BY terms (start position). + :return: + A copy of the predicate with every ``PREV(...)`` call replaced by an + equivalent LAG window and all column identifiers quoted. + :raises ValueError: + If a ``PREV()`` call does not take exactly one argument, or if a + ``PREV()`` call is nested inside another (predicates compare only + the immediate predecessor). + """ + + def _is_prev(node: exp.Expression) -> bool: + return isinstance(node, exp.Anonymous) and node.name.upper() == "PREV" + + def _replace(node: exp.Expression) -> exp.Expression: + if _is_prev(node): + args = node.expressions + if len(args) != 1: + raise ValueError( + f"PREV() takes exactly one column argument; got {len(args)}." + ) + if any(_is_prev(inner) for inner in args[0].find_all(exp.Anonymous)): + raise ValueError( + "PREV() cannot be nested; a CLUSTER/MERGE predicate " + "compares only the immediate predecessor." + ) + return exp.Window( + this=exp.Anonymous(this="LAG", expressions=[args[0].copy()]), + partition_by=[col.copy() for col in partition_cols], + order=exp.Order(expressions=[term.copy() for term in order_by]), + ) + return node + + rewritten = predicate.copy().transform(_replace) + for column in rewritten.find_all(exp.Column): + column.this.set("quoted", True) + return rewritten + class MergeTransformer: """Transforms queries containing MERGE into GROUP BY queries. @@ -673,6 +761,7 @@ def _transform_for_merge( # Extract MERGE parameters (same as CLUSTER) distance_expr = merge_expr.args.get("distance") stranded_expr = merge_expr.args.get("stranded") + predicate_expr = merge_expr.args.get("predicate") # Get column names from table config or use defaults ( @@ -688,6 +777,8 @@ def _transform_for_merge( cluster_kwargs["distance"] = distance_expr if stranded_expr: cluster_kwargs["stranded"] = stranded_expr + if predicate_expr is not None: + cluster_kwargs["predicate"] = predicate_expr cluster_expr = GIQLCluster(**cluster_kwargs) diff --git a/tests/integration/cluster_predicate/__init__.py b/tests/integration/cluster_predicate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/cluster_predicate/conftest.py b/tests/integration/cluster_predicate/conftest.py new file mode 100644 index 0000000..6b04518 --- /dev/null +++ b/tests/integration/cluster_predicate/conftest.py @@ -0,0 +1,45 @@ +"""Pytest fixtures for CLUSTER/MERGE predicate integration tests. + +These tests assert the functional behavior of the optional ``predicate :=`` +argument by executing the generated SQL against DuckDB. They do not invoke +``bedtools`` or ``pybedtools`` -- the comparison reference is a partition +computed directly in the test, matching the DISJOIN coordinate-space approach. +""" + +import pytest + +from giql import transpile + +duckdb = pytest.importorskip("duckdb") + +from tests.integration.bedtools.utils.duckdb_loader import load_intervals # noqa: E402 + + +@pytest.fixture(scope="function") +def duckdb_connection(): + """Provide a clean DuckDB connection for each test.""" + conn = duckdb.connect(":memory:") + yield conn + conn.close() + + +@pytest.fixture(scope="function") +def giql_query(duckdb_connection): + """Provide a helper that loads data, transpiles GIQL, and executes. + + Usage:: + + rows = giql_query( + "SELECT MERGE(interval, predicate := score = PREV(score)) FROM t", + tables=["t"], + t=[("chr1", 0, 10, "a", 5, "+"), ...], + ) + """ + + def _run(query: str, *, tables: list[str], **table_data): + for name, intervals in table_data.items(): + load_intervals(duckdb_connection, name, intervals) + sql = transpile(query, tables=tables) + return duckdb_connection.execute(sql).fetchall() + + return _run diff --git a/tests/integration/cluster_predicate/test_cluster_predicate.py b/tests/integration/cluster_predicate/test_cluster_predicate.py new file mode 100644 index 0000000..8cbc528 --- /dev/null +++ b/tests/integration/cluster_predicate/test_cluster_predicate.py @@ -0,0 +1,425 @@ +"""Integration tests for the CLUSTER/MERGE predicate argument. + +These execute the generated SQL against DuckDB and assert the functional +behavior of ``predicate :=``: run-length encoding of equal-valued runs, +reconstruction of Bioconductor ``disjoin()``-style depth-annotated output via a +DISJOIN -> depth-aggregation -> predicate-MERGE pipeline, and the documented +single-linkage drift for non-equivalence predicates. +""" + +import pytest + +pytestmark = pytest.mark.integration + + +def _ival(chrom, start, end, name, score, strand="+"): + """Build a 6-tuple for ``load_intervals``.""" + return (chrom, start, end, name, score, strand) + + +class TestMergePredicate: + """Functional behavior of MERGE(interval, predicate := ...).""" + + def test_merge_without_predicate_coalesces_adjacent_run(self, giql_query): + """Test that MERGE without a predicate collapses an adjacent run. + + Given: + Five abutting intervals spanning [0, 50) with mixed scores. + When: + MERGE(interval) runs with no predicate. + Then: + It should collapse the whole adjacent run into a single interval. + """ + # Arrange & act + rows = giql_query( + "SELECT MERGE(interval) FROM intervals", + tables=["intervals"], + intervals=[ + _ival("chr1", 0, 10, "a", 5), + _ival("chr1", 10, 20, "b", 5), + _ival("chr1", 20, 30, "c", 3), + _ival("chr1", 30, 40, "d", 3), + _ival("chr1", 40, 50, "e", 5), + ], + ) + + # Assert + assert rows == [("chr1", 0, 50)] + + def test_merge_with_equality_predicate_run_length_encodes(self, giql_query): + """Test that an equality predicate run-length-encodes adjacent runs. + + Given: + Five abutting intervals whose scores form the runs 5,5 | 3,3 | 5. + When: + MERGE(interval, predicate := score = PREV(score)) runs. + Then: + It should emit one merged interval per maximal equal-score run. + """ + # Arrange & act + rows = giql_query( + "SELECT MERGE(interval, predicate := score = PREV(score)) FROM intervals", + tables=["intervals"], + intervals=[ + _ival("chr1", 0, 10, "a", 5), + _ival("chr1", 10, 20, "b", 5), + _ival("chr1", 20, 30, "c", 3), + _ival("chr1", 30, 40, "d", 3), + _ival("chr1", 40, 50, "e", 5), + ], + ) + + # Assert + assert rows == [ + ("chr1", 0, 20), + ("chr1", 20, 40), + ("chr1", 40, 50), + ] + + def test_merge_predicate_drifts_under_single_linkage(self, giql_query): + """Test that a non-equivalence predicate exhibits single-linkage drift. + + Given: + Three abutting intervals with scores 10, 13, 16, where each + consecutive pair differs by 3 but the extremes differ by 6. + When: + MERGE(interval, predicate := ABS(score - PREV(score)) < 5) runs. + Then: + It should merge the entire run into one interval even though the + cluster's extremes violate the predicate (documented drift). + """ + # Arrange & act + rows = giql_query( + "SELECT MERGE(interval, predicate := ABS(score - PREV(score)) < 5) " + "FROM intervals", + tables=["intervals"], + intervals=[ + _ival("chr1", 0, 10, "a", 10), + _ival("chr1", 10, 20, "b", 13), + _ival("chr1", 20, 30, "c", 16), + ], + ) + + # Assert + assert rows == [("chr1", 0, 30)] + + def test_merge_compound_predicate_breaks_on_either_column(self, giql_query): + """Test that a multi-column AND predicate breaks on any column change. + + Given: + Four abutting intervals where the (name, strand) pair changes first + on strand and then on name partway through the run. + When: + MERGE(interval, predicate := strand = PREV(strand) AND + name = PREV(name)) runs. + Then: + It should break the merge wherever either column changes, yielding + one merged interval per homogeneous (name, strand) run. + """ + # Arrange & act + rows = giql_query( + "SELECT MERGE(interval, predicate := strand = PREV(strand) " + "AND name = PREV(name)) FROM intervals", + tables=["intervals"], + intervals=[ + _ival("chr1", 0, 10, "g", 5, "+"), + _ival("chr1", 10, 20, "g", 5, "+"), + _ival("chr1", 20, 30, "g", 5, "-"), + _ival("chr1", 30, 40, "h", 5, "-"), + ], + ) + + # Assert + assert rows == [ + ("chr1", 0, 20), + ("chr1", 20, 30), + ("chr1", 30, 40), + ] + + def test_merge_stranded_predicate_evaluates_within_strand(self, giql_query): + """Test that the predicate is evaluated within each strand partition. + + Given: + Equal-name abutting intervals on both the + and - strands. + When: + MERGE(interval, stranded := true, predicate := name = PREV(name)) + runs. + Then: + It should merge each strand's equal-name run independently, emitting + one interval per strand rather than collapsing across strands. + """ + # Arrange & act + rows = giql_query( + "SELECT MERGE(interval, stranded := true, predicate := name = PREV(name)) " + "FROM intervals", + tables=["intervals"], + intervals=[ + _ival("chr1", 0, 10, "a", 5, "+"), + _ival("chr1", 10, 20, "a", 5, "+"), + _ival("chr1", 0, 10, "a", 5, "-"), + _ival("chr1", 10, 20, "a", 5, "-"), + ], + ) + + # Assert + # The emitted MERGE SQL orders only by (chrom, start); these two rows + # tie on (chr1, 0) and differ only by strand, so the row order is not + # deterministic. Compare order-independently. + assert sorted(rows) == sorted( + [ + ("chr1", "+", 0, 20), + ("chr1", "-", 0, 20), + ] + ) + + +class TestClusterPredicate: + """Functional behavior of CLUSTER(interval, predicate := ...).""" + + def test_cluster_with_equality_predicate_assigns_run_ids(self, giql_query): + """Test that an equality predicate assigns one cluster id per run. + + Given: + Five abutting intervals whose scores form the runs 5,5 | 3,3 | 5. + When: + CLUSTER(interval, predicate := score = PREV(score)) runs. + Then: + It should assign a distinct cluster id to each maximal equal-score + run, comparing each row only to its immediate predecessor. + """ + # Arrange & act + rows = giql_query( + "SELECT name, CLUSTER(interval, predicate := score = PREV(score)) AS cid " + "FROM intervals", + tables=["intervals"], + intervals=[ + _ival("chr1", 0, 10, "a", 5), + _ival("chr1", 10, 20, "b", 5), + _ival("chr1", 20, 30, "c", 3), + _ival("chr1", 30, 40, "d", 3), + _ival("chr1", 40, 50, "e", 5), + ], + ) + + # Assert + cid = dict(rows) + assert cid["a"] == cid["b"] + assert cid["c"] == cid["d"] + assert len({cid["a"], cid["c"], cid["e"]}) == 3 + + def test_cluster_distance_and_predicate_compose(self, giql_query): + """Test that the distance gate and predicate gate both apply. + + Given: + Three equal-score intervals separated by a 50bp gap then a 150bp + gap, clustered with CLUSTER(interval, 100, predicate := score = + PREV(score)). + When: + The generated SQL runs in DuckDB. + Then: + It should keep the 50bp-gap pair together but start a new cluster at + the 150bp gap even though the scores match, proving distance and + predicate compose. + """ + # Arrange & act + rows = giql_query( + "SELECT name, CLUSTER(interval, 100, predicate := score = PREV(score)) " + "AS cid FROM intervals", + tables=["intervals"], + intervals=[ + _ival("chr1", 100, 200, "i1", 5), + _ival("chr1", 250, 350, "i2", 5), # 50bp gap, score matches + _ival("chr1", 500, 600, "i3", 5), # 150bp gap, score matches + ], + ) + + # Assert + cid = dict(rows) + assert cid["i1"] == cid["i2"] + assert cid["i3"] != cid["i1"] + + def test_cluster_predicate_resets_at_chromosome_boundary(self, giql_query): + """Test that the predicate does not carry across chromosomes. + + Given: + Equal-score abutting intervals on chr1 and chr2, clustered with + CLUSTER(interval, predicate := score = PREV(score)). + When: + The generated SQL runs in DuckDB. + Then: + It should begin a fresh cluster for the first row of each + chromosome, since the predecessor LAG resets at the per-chromosome + partition boundary. + """ + # Arrange & act + rows = giql_query( + "SELECT chrom, name, CLUSTER(interval, predicate := score = PREV(score)) " + "AS cid FROM intervals", + tables=["intervals"], + intervals=[ + _ival("chr1", 0, 10, "a", 5), + _ival("chr1", 10, 20, "b", 5), + _ival("chr2", 0, 10, "c", 5), + _ival("chr2", 10, 20, "d", 5), + ], + ) + + # Assert + by_name = {name: (chrom, cid) for chrom, name, cid in rows} + # Same chromosome, equal score, adjacent -> same cluster. + assert by_name["a"][1] == by_name["b"][1] + assert by_name["c"][1] == by_name["d"][1] + # chr2's first row starts its own cluster (ids are per-partition). + assert {by_name["a"], by_name["c"]} == {("chr1", 1), ("chr2", 1)} + + def test_cluster_reserved_word_predicate_executes(self, giql_query): + """Test that a predicate over reserved-word columns executes. + + Given: + Three abutting intervals and a predicate over the reserved-word + genomic columns start and end (start = PREV(end)). + When: + The generated SQL runs in DuckDB. + Then: + It should emit valid quoted SQL and merge the abutting run, proving + reserved-word predicate columns are handled. + """ + # Arrange & act + rows = giql_query( + "SELECT MERGE(interval, predicate := start = PREV(end)) FROM intervals", + tables=["intervals"], + intervals=[ + _ival("chr1", 0, 10, "a", 5), + _ival("chr1", 10, 20, "b", 5), + _ival("chr1", 20, 30, "c", 5), + ], + ) + + # Assert + assert rows == [("chr1", 0, 30)] + + def test_cluster_predicate_column_absent_from_projection_executes(self, giql_query): + """Test that a predicate column need not appear in the projection. + + Given: + A CLUSTER query whose explicit projection omits the predicate + column score (selecting only chrom, start, end, and the cluster id). + When: + The generated SQL runs in DuckDB. + Then: + It should still cluster by the score runs, confirming the predicate + column is projected into the intermediate CTE. + """ + # Arrange & act + rows = giql_query( + 'SELECT chrom, start, "end", ' + "CLUSTER(interval, predicate := score = PREV(score)) AS cid " + "FROM intervals", + tables=["intervals"], + intervals=[ + _ival("chr1", 0, 10, "a", 5), + _ival("chr1", 10, 20, "b", 5), + _ival("chr1", 20, 30, "c", 3), + ], + ) + + # Assert + cids = [row[-1] for row in rows] + assert cids[0] == cids[1] + assert cids[2] != cids[0] + + +class TestDisjoinDepthMergePipeline: + """DISJOIN -> depth-aggregation -> predicate-MERGE reconstructs disjoin().""" + + _FEATURES = [ + _ival("chr1", 0, 20, "a", 0), + _ival("chr1", 0, 20, "b", 0), + _ival("chr1", 20, 40, "c", 0), + _ival("chr1", 20, 40, "d", 0), + _ival("chr1", 10, 30, "e", 0), + ] + # Per-breakpoint coverage depth: [0,10)=2, [10,20)=3, [20,30)=3, [30,40)=2. + + _DEPTH_SEGMENTS_QUERY = """ + SELECT disjoin_chrom AS chrom, + disjoin_start AS start, + disjoin_end AS "end", + COUNT(*) AS depth + FROM DISJOIN(features) + GROUP BY disjoin_chrom, disjoin_start, disjoin_end + """ + + def test_disjoin_depth_segments_form_expected_partition(self, giql_query): + """Test that DISJOIN + depth aggregation yields per-breakpoint depths. + + Given: + Two doubled abutting regions overlaid by a spanning interval, + producing per-breakpoint coverage depths 2, 3, 3, 2. + When: + DISJOIN(features) is aggregated to per-segment coverage depth. + Then: + It should yield four disjoint segments carrying those depths. + """ + # Arrange & act + rows = giql_query( + self._DEPTH_SEGMENTS_QUERY + ' ORDER BY "start"', + tables=["features"], + features=self._FEATURES, + ) + + # Assert + assert rows == [ + ("chr1", 0, 10, 2), + ("chr1", 10, 20, 3), + ("chr1", 20, 30, 3), + ("chr1", 30, 40, 2), + ] + + def test_predicate_merge_remerges_equal_depth_segments(self, giql_query): + """Test that predicate-MERGE re-clusters runs of equal coverage depth. + + Given: + The depth-annotated disjoint segments (depths 2, 3, 3, 2). + When: + MERGE(interval, predicate := depth = PREV(depth)) runs over them. + Then: + It should coalesce the adjacent depth-3 run into [10, 30) while + keeping the depth-2 flanks distinct, reproducing the re-clustered + partition Bioconductor disjoin() emits. + """ + # Arrange & act + rows = giql_query( + "SELECT MERGE(interval, predicate := depth = PREV(depth)) " + f"FROM ({self._DEPTH_SEGMENTS_QUERY}) AS segments", + tables=["features"], + features=self._FEATURES, + ) + + # Assert + assert rows == [ + ("chr1", 0, 10), + ("chr1", 10, 30), + ("chr1", 30, 40), + ] + + def test_predicate_merge_differs_from_unconditioned_merge(self, giql_query): + """Test that the predicate is what preserves the depth boundaries. + + Given: + The same depth-annotated disjoint segments. + When: + MERGE(interval) runs over them with no predicate. + Then: + It should collapse every adjacent segment into one interval, + confirming the predicate alone preserves the coverage structure. + """ + # Arrange & act + rows = giql_query( + f"SELECT MERGE(interval) FROM ({self._DEPTH_SEGMENTS_QUERY}) AS segments", + tables=["features"], + features=self._FEATURES, + ) + + # Assert + assert rows == [("chr1", 0, 40)] diff --git a/tests/test_cluster_parsing.py b/tests/test_cluster_parsing.py index adfb116..b66cd12 100644 --- a/tests/test_cluster_parsing.py +++ b/tests/test_cluster_parsing.py @@ -6,6 +6,7 @@ """ import pytest +from sqlglot import exp from sqlglot import parse_one from sqlglot.errors import ParseError @@ -95,3 +96,98 @@ def test_from_arg_list_should_reject_missing_target(self): "SELECT CLUSTER(stranded := true) AS cluster_id FROM peaks", dialect=GIQLDialect, ) + + def test_from_arg_list_with_predicate(self): + """Test that a predicate named argument is captured on the node. + + Given: + A GIQL query with CLUSTER(interval, predicate := depth = PREV(depth)). + When: + Parsing the query. + Then: + It should attach the predicate as an equality expression in args. + """ + # Act + ast = parse_one( + "SELECT *, CLUSTER(interval, predicate := depth = PREV(depth)) AS cid " + "FROM peaks", + dialect=GIQLDialect, + ) + + # Assert + cluster_expr = ast.expressions[1].this + assert isinstance(cluster_expr, GIQLCluster) + assert isinstance(cluster_expr.args.get("predicate"), exp.EQ) + + def test_from_arg_list_with_predicate_prev_call(self): + """Test that a PREV() call parses as a predecessor function reference. + + Given: + A predicate CLUSTER(interval, predicate := depth = PREV(depth)) + referencing the predecessor row with the PREV() function. + When: + Parsing the query. + Then: + It should parse PREV(depth) as an anonymous PREV call over depth. + """ + # Act + ast = parse_one( + "SELECT *, CLUSTER(interval, predicate := depth = PREV(depth)) AS cid " + "FROM peaks", + dialect=GIQLDialect, + ) + + # Assert + predicate = ast.expressions[1].this.args["predicate"] + prev_call = predicate.expression + assert isinstance(prev_call, exp.Anonymous) + assert prev_call.name.upper() == "PREV" + assert [arg.name for arg in prev_call.expressions] == ["depth"] + + def test_from_arg_list_with_predicate_kwarg_syntax(self): + """Test that the => kwarg form also binds the predicate argument. + + Given: + A CLUSTER call using CLUSTER(interval, predicate => score = PREV(score)) + with the => kwarg form rather than :=. + When: + Parsing the query. + Then: + It should attach the predicate as an equality expression in args. + """ + # Act + ast = parse_one( + "SELECT *, CLUSTER(interval, predicate => score = PREV(score)) AS cid " + "FROM peaks", + dialect=GIQLDialect, + ) + + # Assert + cluster_expr = ast.expressions[1].this + assert isinstance(cluster_expr, GIQLCluster) + assert isinstance(cluster_expr.args.get("predicate"), exp.EQ) + + def test_from_arg_list_with_predicate_and_positional_distance(self): + """Test that a predicate composes with positional distance and stranded. + + Given: + A query mixing a positional distance, stranded :=, and predicate := + on a single CLUSTER call. + When: + Parsing the query. + Then: + It should retain distance, stranded, and predicate together in args. + """ + # Act + ast = parse_one( + "SELECT *, CLUSTER(interval, 1000, stranded := true, " + "predicate := name = PREV(name)) AS cid FROM peaks", + dialect=GIQLDialect, + ) + + # Assert + cluster_expr = ast.expressions[1].this + assert isinstance(cluster_expr, GIQLCluster) + assert cluster_expr.args.get("distance") is not None + assert cluster_expr.args.get("stranded") is not None + assert cluster_expr.args.get("predicate") is not None diff --git a/tests/test_cluster_predicate_transpilation.py b/tests/test_cluster_predicate_transpilation.py new file mode 100644 index 0000000..93be367 --- /dev/null +++ b/tests/test_cluster_predicate_transpilation.py @@ -0,0 +1,208 @@ +"""Transpilation tests for the CLUSTER/MERGE predicate argument. + +Tests verify that an optional ``predicate :=`` argument ANDs into the +cluster-boundary CASE, that ``PREV(col)`` calls are rewritten to quoted LAG +windows over the operator's partition/order, and that omitting the predicate +leaves the emitted SQL byte-identical to the pre-predicate behavior. +""" + +import pytest + +from giql import transpile + + +class TestClusterPredicateTranspilation: + """Tests for CLUSTER predicate transpilation to SQL.""" + + def test_transpile_without_predicate_is_unchanged(self): + """Test that a predicate-free CLUSTER transpiles to the legacy shape. + + Given: + A CLUSTER query with no predicate argument. + When: + Transpiling to SQL. + Then: + It should emit the bare adjacency CASE with no AND clause. + """ + # Act + sql = transpile( + "SELECT *, CLUSTER(interval) AS cid FROM peaks", tables=["peaks"] + ) + + # Assert + assert ( + 'CASE WHEN LAG("end") OVER (PARTITION BY "chrom" ORDER BY "start" ' + 'NULLS LAST) >= "start" THEN 0 ELSE 1 END' in sql + ) + + def test_transpile_with_predicate_ands_into_case(self): + """Test that a predicate ANDs into the cluster-boundary CASE. + + Given: + A CLUSTER query with predicate := depth = PREV(depth). + When: + Transpiling to SQL. + Then: + It should AND the predicate onto the adjacency test inside the CASE. + """ + # Act + sql = transpile( + "SELECT *, CLUSTER(interval, predicate := depth = PREV(depth)) AS cid " + "FROM peaks", + tables=["peaks"], + ) + + # Assert + assert ' >= "start" AND (' in sql + + def test_transpile_rewrites_prev_call_to_lag_window(self): + """Test that a PREV() call becomes a quoted LAG over the cluster window. + + Given: + A CLUSTER query with predicate := depth = PREV(depth). + When: + Transpiling to SQL. + Then: + It should rewrite PREV(depth) to a quoted LAG("depth") over the same + partition/order as the adjacency window. + """ + # Act + sql = transpile( + "SELECT *, CLUSTER(interval, predicate := depth = PREV(depth)) AS cid " + "FROM peaks", + tables=["peaks"], + ) + + # Assert + assert ( + '"depth" = LAG("depth") OVER (PARTITION BY "chrom" ORDER BY "start" ' + "NULLS LAST)" in sql + ) + + def test_transpile_predicate_uses_stranded_partition(self): + """Test that PREV() LAG windows honor the stranded partition. + + Given: + A stranded CLUSTER query with a predicate referencing PREV(name). + When: + Transpiling to SQL. + Then: + It should partition the predicate's LAG window by chrom and strand. + """ + # Act + sql = transpile( + "SELECT *, CLUSTER(interval, stranded := true, " + "predicate := name = PREV(name)) AS cid FROM peaks", + tables=["peaks"], + ) + + # Assert + assert ( + 'LAG("name") OVER (PARTITION BY "chrom", "strand" ORDER BY "start" ' + "NULLS LAST)" in sql + ) + + def test_transpile_quotes_reserved_word_predicate_columns(self): + """Test that reserved-word predicate columns are quoted on both sides. + + Given: + A CLUSTER query whose predicate references the reserved-word + genomic columns end and start (end = PREV(start)). + When: + Transpiling to SQL. + Then: + It should quote both the current-row column and the LAG argument so + the emitted SQL is valid. + """ + # Act + sql = transpile( + "SELECT *, CLUSTER(interval, predicate := end = PREV(start)) AS cid " + "FROM peaks", + tables=["peaks"], + ) + + # Assert + assert '"end" = LAG("start") OVER (' in sql + + def test_transpile_prev_with_wrong_arity_raises(self): + """Test that a PREV() call with the wrong arity is rejected. + + Given: + A CLUSTER predicate calling PREV() with two arguments. + When: + Transpiling to SQL. + Then: + It should raise a ValueError naming the one-argument requirement. + """ + # Arrange, act, & assert + with pytest.raises(ValueError, match="exactly one column argument"): + transpile( + "SELECT *, CLUSTER(interval, predicate := depth = PREV(depth, score)) " + "AS cid FROM peaks", + tables=["peaks"], + ) + + def test_transpile_nested_prev_raises(self): + """Test that a nested PREV() call is rejected. + + Given: + A CLUSTER predicate nesting PREV() inside another PREV(). + When: + Transpiling to SQL. + Then: + It should raise a ValueError naming the no-nesting restriction. + """ + # Arrange, act, & assert + with pytest.raises(ValueError, match="cannot be nested"): + transpile( + "SELECT *, CLUSTER(interval, predicate := depth = PREV(PREV(depth))) " + "AS cid FROM peaks", + tables=["peaks"], + ) + + +class TestMergePredicateTranspilation: + """Tests for MERGE predicate transpilation to SQL.""" + + def test_transpile_without_predicate_is_unchanged(self): + """Test that a predicate-free MERGE transpiles to the legacy shape. + + Given: + A MERGE query with no predicate argument. + When: + Transpiling to SQL. + Then: + It should emit the bare adjacency CASE with no AND clause. + """ + # Act + sql = transpile("SELECT MERGE(interval) FROM peaks", tables=["peaks"]) + + # Assert + assert ( + 'CASE WHEN LAG("end") OVER (PARTITION BY "chrom" ORDER BY "start" ' + 'NULLS LAST) >= "start" THEN 0 ELSE 1 END' in sql + ) + + def test_transpile_predicate_inherited_through_cluster(self): + """Test that MERGE inherits predicate-aware cluster boundaries. + + Given: + A MERGE query with predicate := depth = PREV(depth). + When: + Transpiling to SQL. + Then: + It should AND the rewritten predicate into the underlying CLUSTER + CASE that drives the GROUP BY. + """ + # Act + sql = transpile( + "SELECT MERGE(interval, predicate := depth = PREV(depth)) FROM peaks", + tables=["peaks"], + ) + + # Assert + assert ' >= "start" AND (' in sql + assert ( + '"depth" = LAG("depth") OVER (PARTITION BY "chrom" ORDER BY "start" ' + "NULLS LAST)" in sql + ) diff --git a/tests/test_merge_parsing.py b/tests/test_merge_parsing.py index 6ae104c..1385ec1 100644 --- a/tests/test_merge_parsing.py +++ b/tests/test_merge_parsing.py @@ -6,6 +6,7 @@ """ import pytest +from sqlglot import exp from sqlglot import parse_one from sqlglot.errors import ParseError @@ -84,6 +85,71 @@ def test_from_arg_list_should_reject_missing_target(self): """ # Arrange, act, & assert with pytest.raises(ParseError, match="requires a genomic interval"): - parse_one( - "SELECT MERGE(stranded := true) FROM peaks", dialect=GIQLDialect - ) + parse_one("SELECT MERGE(stranded := true) FROM peaks", dialect=GIQLDialect) + + def test_from_arg_list_with_predicate(self): + """Test that a predicate named argument is captured on the node. + + Given: + A GIQL query with MERGE(interval, predicate := depth = PREV(depth)). + When: + Parsing the query. + Then: + It should attach the predicate as an equality expression in args. + """ + # Act + ast = parse_one( + "SELECT MERGE(interval, predicate := depth = PREV(depth)) FROM peaks", + dialect=GIQLDialect, + ) + + # Assert + merge_expr = ast.expressions[0] + assert isinstance(merge_expr, GIQLMerge) + assert isinstance(merge_expr.args.get("predicate"), exp.EQ) + + def test_from_arg_list_with_predicate_kwarg_syntax(self): + """Test that the => kwarg form also binds the predicate argument. + + Given: + A MERGE call using MERGE(interval, predicate => score = PREV(score)) + with the => kwarg form rather than :=. + When: + Parsing the query. + Then: + It should attach the predicate as an equality expression in args. + """ + # Act + ast = parse_one( + "SELECT MERGE(interval, predicate => score = PREV(score)) FROM peaks", + dialect=GIQLDialect, + ) + + # Assert + merge_expr = ast.expressions[0] + assert isinstance(merge_expr, GIQLMerge) + assert isinstance(merge_expr.args.get("predicate"), exp.EQ) + + def test_from_arg_list_with_compound_predicate(self): + """Test that a multi-term predicate parses as a conjunction. + + Given: + A query MERGE(interval, predicate := strand = PREV(strand) AND + name = PREV(name)) joining two pairwise comparisons with AND. + When: + Parsing the query. + Then: + It should attach the predicate as an AND of two PREV() comparisons. + """ + # Act + ast = parse_one( + "SELECT MERGE(interval, predicate := strand = PREV(strand) " + "AND name = PREV(name)) FROM peaks", + dialect=GIQLDialect, + ) + + # Assert + predicate = ast.expressions[0].args["predicate"] + assert isinstance(predicate, exp.And) + prev_calls = {call.name.upper() for call in predicate.find_all(exp.Anonymous)} + assert "PREV" in prev_calls